Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature Request: allow same GrammarConstrainedLogitsProcessor to be reused across multiple generations #49

Open
Saibo-creator opened this issue Jun 4, 2024 · 1 comment

Comments

@Saibo-creator
Copy link
Collaborator

Reproduce

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers_cfg.grammar_utils import IncrementalGrammarConstraint
from transformers_cfg.generation.logits_process import GrammarConstrainedLogitsProcessor


if __name__ == "__main__":

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    model_id = "mistralai/Mistral-7B-v0.1"

    # Load model and tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    tokenizer.pad_token = tokenizer.eos_token

    model = AutoModelForCausalLM.from_pretrained(model_id, load_in_4bit=True, device_map="auto")
    model.generation_config.pad_token_id = model.generation_config.eos_token_id

    grammar_str = """
    # Grammar for subset of JSON
    # String doesn't support unicode and escape yet
    # If you don't need to generate unicode and escape, you can use this grammar
    # We are working to support unicode and escape

    root   ::= object

    object ::= "{" ws ( string ":" ws value ("," ws string ":" ws value)* )? "}"

    value  ::= object | array | string | number | ("true" | "false" | "null") ws

    array  ::= "[" ws ( value ("," ws value)* )? "]" ws

    string ::= "\"" [ \t!#-\[\]-~]* "\"" ws

    number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws


    ws ::= ([ \t\n] ws)?
    """
    grammar = IncrementalGrammarConstraint(grammar_str, "root", tokenizer)
    grammar_processor = GrammarConstrainedLogitsProcessor(grammar)

    # Generate
    prefix1 = "This is a valid json string for http request:"
    prefix2 = "This is a valid json string for shopping cart:"

    for prefix in [prefix1, prefix2]:
        input_ids = tokenizer(
            [prefix], add_special_tokens=False, return_tensors="pt", padding=True
        )["input_ids"]

        output = model.generate(
            input_ids,
            do_sample=False,
            max_new_tokens=60,
            logits_processor=[grammar_processor],
            repetition_penalty=1.1,
            num_return_sequences=1,
        )
        # decode output
        generations = tokenizer.batch_decode(output, skip_special_tokens=True)
        print(generations)

        """
        'This is a valid json string for http request:{ "request": { "method": "GET", "headers": [], "content": "Content","type": "application" }}
        'This is a valid json string for shopping cart:This is a valid json string for shopping cart:{ "name": "MyCart", "price": 0, "value": 1 }
        """

Error message

Traceback (most recent call last):
  File "/home/saibo/Dev/SGCD-new/scripts/reproduce_tcfg_bug1.py", line 54, in <module>
    output = model.generate(
  File "/home/saibo/.virtualenvs/python3.10/SGCD-new/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/saibo/.virtualenvs/python3.10/SGCD-new/lib/python3.10/site-packages/transformers/generation/utils.py", line 1736, in generate
    result = self._sample(
  File "/home/saibo/.virtualenvs/python3.10/SGCD-new/lib/python3.10/site-packages/transformers/generation/utils.py", line 2388, in _sample
    next_token_scores = logits_processor(input_ids, next_token_logits)
  File "/home/saibo/.virtualenvs/python3.10/SGCD-new/lib/python3.10/site-packages/transformers/generation/logits_process.py", line 98, in __call__
    scores = processor(input_ids, scores)
  File "/home/saibo/.virtualenvs/python3.10/SGCD-new/lib/python3.10/site-packages/transformers_cfg/generation/logits_process.py", line 106, in __call__
    return self.process_logits(input_ids, scores)
  File "/home/saibo/.virtualenvs/python3.10/SGCD-new/lib/python3.10/site-packages/transformers_cfg/generation/logits_process.py", line 93, in process_logits
    self.batch_accept_states = self.grammar_constraint.consume_token_ids(
  File "/home/saibo/.virtualenvs/python3.10/SGCD-new/lib/python3.10/site-packages/transformers_cfg/token_grammar_recognizer.py", line 211, in consume_token_ids
    raise RuntimeError(
RuntimeError: Input ID's length is inconsistent with the current state of the GrammarConstrainedLogitsProcessor. If you want to process another input sequence, please instantiate a new GrammarConstrainedLogitsProcessor.
@nathanrchn
Copy link
Contributor

Very strange. I ran this and got a very strange but still valid JSON.

import torch
import numpy as np
import mlx.core as mx
from transformers import AutoTokenizer
from mlx_lm import load, stream_generate
from transformers_cfg.grammar_utils import IncrementalGrammarConstraint
from transformers_cfg.generation.logits_process import GrammarConstrainedLogitsProcessor

model, _ = load("mlx-community/Phi-3.5-mini-instruct-4bit")

tokenizer = AutoTokenizer.from_pretrained("mlx-community/Phi-3.5-mini-instruct-4bit")
tokenizer.pad_token = tokenizer.eos_token   

with open("examples/grammars/json.ebnf", "r") as f:
    grammar_str = f.read()

grammar = IncrementalGrammarConstraint(grammar_str, "root", tokenizer)
grammar_processor = GrammarConstrainedLogitsProcessor(grammar)

def logits_processor(input_ids: mx.array, logits: mx.array) -> mx.array:
    torch_input_ids = torch.tensor(np.array(input_ids[None, :]), device="mps")
    torch_logits = torch.tensor(np.array(logits), device="mps")

    torch_processed_logits = grammar_processor(torch_input_ids, torch_logits)
    return mx.array(torch_processed_logits.cpu().numpy())

prefix1 = "This is a valid json string for http request:"
prefix2 = "This is a valid json string for shopping cart:"

for prefix in [prefix2, prefix1]:
    generation_stream = stream_generate(
        model,
        tokenizer,
        prompt=prefix,
        max_tokens=500,
        repetition_penalty=1.1,
        logits_processor=logits_processor
    )

    print("\033[92m" + "Prompt:" + prefix + "\033[0m")

    for token in generation_stream:
        print(token, end="", flush=True)

    print()
    grammar_processor.reset()

Console output:

transformers-CFG % python3 debug.py
Fetching 11 files: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 45680.54it/s]
Prompt:This is a valid json string for shopping cart:
{"items":[
   {"item":"milk","quantity":2,"price":{"$type":"NumberInt", "value":3},
   "item":"bread","quantity":1,"price":{"$type":"NumberInt", "value":2},
   "item":"eggs","quantity":1,"price":{"$type":"NumberInt", "value":5}},
   {"item":"flour","quantity":2,"price":{"$type":"NumberInt", "value":1}}]
}
Prompt:This is a valid json string for http request:
{"name":"John","age":30,"city":"New York"}

When using a beautify website to validate the JSON, I got:

{
  "items": [
    {
      "item": "eggs",
      "quantity": 1,
      "price": {
        "$type": "NumberInt",
        "value": 5
      }
    },
    {
      "item": "flour",
      "quantity": 2,
      "price": {
        "$type": "NumberInt",
        "value": 1
      }
    }
  ]
}

I think is very strange that the model didn't generate the following JSON:

{"items":[
   {"item":"milk","quantity":2,"price":{"$type":"NumberInt", "value":3}},
   {"item":"bread","quantity":1,"price":{"$type":"NumberInt", "value":2}},
   {"item":"eggs","quantity":1,"price":{"$type":"NumberInt", "value":5}},
   {"item":"flour","quantity":2,"price":{"$type":"NumberInt", "value":1}}]
}

Maybe there is a bug here...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants