ImpactInsights commited on
Commit
8bc425b
1 Parent(s): 7c0fe22

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +72 -0
app.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import pipeline,GemmaForCausalLM,AutoTokenizer,BitsAndBytesConfig
2
+ import gradio as gr
3
+ import spaces
4
+ import torch
5
+ # ignore_mismatched_sizes=True
6
+ quantization_config = BitsAndBytesConfig(load_in_4bit=True)
7
+ tokenizer = AutoTokenizer.from_pretrained('google/gemma-2-9b')
8
+ model = GemmaForCausalLM.from_pretrained('google/gemma-2-9b',
9
+ quantization_config=quantization_config
10
+ )
11
+ # pipe = pipeline('text-generation', model=model,tokenizer = tokenizer)
12
+
13
+ MAX_MAX_NEW_TOKENS = 2048
14
+ DEFAULT_MAX_NEW_TOKENS = 1024
15
+
16
+ @spaces.GPU(duration=120)
17
+ def generate(
18
+ message: str,
19
+ max_new_tokens: int = 1024,
20
+ temperature: float = 0.6,
21
+ top_p: float = 0.9,
22
+ top_k: int = 50,
23
+ repetition_penalty: float = 1.2,
24
+ ):
25
+ input_ids = tokenizer(message, return_tensors="pt").to("cuda")
26
+ outputs = model.generate(**input_ids,top_p=top_p,max_new_tokens=max_new_tokens,top_k=top_k,repetition_penalty=repetition_penalty,temperature=temperature)
27
+ return tokenizer.decode(outputs[0], skip_special_tokens=True);
28
+ # return pipe(prompt)[0]['generated_text']
29
+
30
+ gr.Interface(
31
+ fn=generate,
32
+ inputs=[
33
+ gr.Text(),
34
+ gr.Slider(
35
+ label="Max new tokens",
36
+ minimum=1,
37
+ maximum=MAX_MAX_NEW_TOKENS,
38
+ step=1,
39
+ value=DEFAULT_MAX_NEW_TOKENS,
40
+ ),
41
+ gr.Slider(
42
+ label="Temperature",
43
+ minimum=0.1,
44
+ maximum=4.0,
45
+ step=0.1,
46
+ value=0.6,
47
+ ),
48
+ gr.Slider(
49
+ label="Top-p (nucleus sampling)",
50
+ minimum=0.05,
51
+ maximum=1.0,
52
+ step=0.05,
53
+ value=0.9,
54
+ ),
55
+ gr.Slider(
56
+ label="Top-k",
57
+ minimum=1,
58
+ maximum=1000,
59
+ step=1,
60
+ value=50,
61
+ ),
62
+ gr.Slider(
63
+ label="Repetition penalty",
64
+ minimum=1.0,
65
+ maximum=2.0,
66
+ step=0.05,
67
+ value=1.2,
68
+ ),],
69
+ outputs="text",
70
+ examples=[['Write me a poem about Machine Learning.']],
71
+
72
+ ).launch()