How to introduce stop_strings in llama3?

#132
by Srinjoy - opened

Hey all, I want to stop my text generation when I encounter certain strings. The code I am using is as follows:

def get_tokenizer_model(config):
    
    global generator_model
    global model_tokenizer

    if model_tokenizer is None:
        load_dotenv(find_dotenv())
        HF_TOKEN = os.getenv('HF_TOKEN')
        model_tokenizer = AutoTokenizer.from_pretrained(
            config['model_name'],
            token=HF_TOKEN
        )

        print(f'======================>The model name is :{config["model_name"]}')

        bnb_config = None

        if config['quantisation']:
            print('======================>The model is quantised?', config["quantisation"])
            bnb_config = BitsAndBytesConfig(
                load_in_4bit = True,
                bnb_4bit_use_double_quant = True,
                bnb_4bit_quant_type = 'nf4',
                bnb_4bit_compute_dtype = torch.bfloat16
            )


        generator_model = AutoModelForCausalLM.from_pretrained(
            config['model_name'],
            device_map='auto',
            quantization_config = bnb_config            
        )

        model_tokenizer.pad_token = model_tokenizer.eos_token
        # text_generator = pipeline('text-generation', model=model, tokenizer=tokenizer, device = 'auto')
        print(f'======================>The device used by pipeline is:{generator_model.device}')

    return model_tokenizer, generator_model

and the model is used here:

def get_hf_chat(prompt:str, model: Model = 'llama3-8b-8192', temperature: float = 0.0, max_tokens: int = 256, stop_strs: Optional[List[str]] = None, is_batched: bool = False, args = None) -> str:

    config = {
        'quantisation':args.quantised,
        'model_name': args.model,
        'stop_strings': stop_strs
    }
    
    global generator_model
    global model_tokenizer
    model_tokenizer, generator_model = get_tokenizer_model(config=config)

    inputs = model_tokenizer(prompt, return_tensors='pt')
    gen_out = generator_model.generate(**inputs,
                              temperature=temperature,
                              max_new_tokens = max_tokens,
                              stop_strings = stop_strs,
                              tokenizer=model_tokenizer)
    
    output_text = model_tokenizer.batch_decode(gen_out, skip_special_tokens=True)[0][len(prompt):]
    return output_text

When I call the above function I am encountering the following error:

 File "/data/home/srinjoym/reflexion/alfworld_runs/demo.py", line 207, in <module>
    print(get_hf_chat(prompt=prompt1, model='meta-llama/Meta-Llama-3-8B-Instruct', temperature=0.2, max_tokens=512, stop_strs=['\n'], args = args))
  File "/data/home/srinjoym/reflexion/alfworld_runs/utils.py", line 146, in get_hf_chat
    gen_out = generator_model.generate(**inputs,
  File "/data/home/srinjoym/miniconda3/envs/alfworld/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/data/home/srinjoym/miniconda3/envs/alfworld/lib/python3.9/site-packages/transformers/generation/utils.py", line 1384, in generate
    self._validate_model_kwargs(model_kwargs.copy())
  File "/data/home/srinjoym/miniconda3/envs/alfworld/lib/python3.9/site-packages/transformers/generation/utils.py", line 1130, in _validate_model_kwargs
    raise ValueError(
ValueError: The following `model_kwargs` are not used by the model: ['stop_strings', 'tokenizer'] (note: typos in the generate arguments will also show up in this list)

I used the code from this place

My question is how can i use/introduce the stop strings in this model? Also, do we need to checked whether stop_strings argument can be used for each model or is there some function standard all pretrained huggingface models follow?

Sign up or log in to comment