How do I get streaming token generation from mistral_common? Example needed

#146
by narai - opened

I tried to do the following, but it appears that no tokens were generated. (I'm only using bitsandbytes at the moment because I can't get EETQ to work outside of huggingface TGI)

from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM, TextStreamer, TextIteratorStreamer, GenerationConfig, BitsAndBytesConfig #, EetqConfig
from threading import Thread
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
from mistral_common.protocol.instruct.messages import UserMessage
from mistral_common.protocol.instruct.request import ChatCompletionRequest

tokenizer = MistralTokenizer.v1()

streamer = TextIteratorStreamer(tokenizer)

quantization_config = BitsAndBytesConfig(load_in_8bit=True)

model = AutoModelForCausalLM.from_pretrained(
                checkpoint,
                device_map="auto",
                quantization_config=quantization_config,
                token=HF_TOKEN
)

completion_request = ChatCompletionRequest(messages=[UserMessage(content=prompt)])
input_ids = tokenizer.encode_chat_completion(completion_request).tokens

generation_kwargs = {'input_ids':input_ids,
                            'streamer':streamer,
                            'max_new_tokens':max_new_tokens,
                            'do_sample':True,
                            'temperature':TEMPERATURE,
                            'top_k':TOP_K,
                            'top_p':TOP_P,
                            }
        
thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
thread.start()

return streamer

Sign up or log in to comment