moraxgiga commited on
Commit
a62cd08
1 Parent(s): 5f1e15d

Upload 2 files

Browse files
Files changed (2) hide show
  1. app (1).py +47 -0
  2. requirements (2).txt +4 -0
app (1).py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch, os
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer
4
+ from transformers import StoppingCriteria, TextIteratorStreamer
5
+ from threading import Thread
6
+
7
+ torch.set_num_threads(2)
8
+ HF_TOKEN = os.environ.get("HF_TOKEN")
9
+
10
+ # Loading the tokenizer and model from Hugging Face's model hub.
11
+ tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it", use_auth_token=HF_TOKEN)
12
+ model = AutoModelForCausalLM.from_pretrained("google/gemma-2b-it", use_auth_token=HF_TOKEN)
13
+
14
+ def count_tokens(text):
15
+ return len(tokenizer.tokenize(text))
16
+
17
+ # Function to generate model predictions.
18
+ def predict(message, history):
19
+
20
+ formatted_prompt = f"<start_of_turn>user\n{message}<end_of_turn>\n<start_of_turn>model\n"
21
+ model_inputs = tokenizer(formatted_prompt, return_tensors="pt")
22
+
23
+ streamer = TextIteratorStreamer(tokenizer, timeout=120., skip_prompt=True, skip_special_tokens=True)
24
+
25
+ generate_kwargs = dict(
26
+ model_inputs,
27
+ streamer=streamer,
28
+ max_new_tokens=2048 - count_tokens(formatted_prompt),
29
+ top_p=0.2,
30
+ top_k=20,
31
+ temperature=0.1,
32
+ repetition_penalty=2.0,
33
+ length_penalty=-0.5,
34
+ num_beams=1
35
+ )
36
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
37
+ t.start() # Starting the generation in a separate thread.
38
+ partial_message = ""
39
+ for new_token in streamer:
40
+ partial_message += new_token
41
+ yield partial_message
42
+
43
+ # Setting up the Gradio chat interface.
44
+ gr.ChatInterface(predict,
45
+ title="Gemma 2b Instruct Chat",
46
+ description=None
47
+ ).launch() # Launching the web interface.
requirements (2).txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch>=2.0
2
+ transformers>=4.36.2
3
+ gradio>=4.13.0
4
+ sentencepiece