dmar1313 commited on
Commit
9f09c86
1 Parent(s): b757e6b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -2
app.py CHANGED
@@ -1,5 +1,59 @@
1
  import gradio as gr
2
 
3
- from transformers import pipeline
 
4
 
5
- pipe = pipeline("text-generation", model="TheBloke/Wizard-Vicuna-30B-Uncensored-GPTQ")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
 
3
+ from transformers import AutoTokenizer, pipeline, logging
4
+ from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
5
 
6
+ model_name_or_path = "TheBloke/Llama-2-13B-GPTQ"
7
+ model_basename = "gptq_model-4bit-128g"
8
+
9
+ use_triton = False
10
+
11
+ tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True)
12
+
13
+ model = AutoGPTQForCausalLM.from_quantized(model_name_or_path,
14
+ model_basename=model_basename,
15
+ use_safetensors=True,
16
+ trust_remote_code=True,
17
+ device="cuda:0",
18
+ use_triton=use_triton,
19
+ quantize_config=None)
20
+
21
+ """
22
+ To download from a specific branch, use the revision parameter, as in this example:
23
+
24
+ model = AutoGPTQForCausalLM.from_quantized(model_name_or_path,
25
+ revision="gptq-4bit-32g-actorder_True",
26
+ model_basename=model_basename,
27
+ use_safetensors=True,
28
+ trust_remote_code=True,
29
+ device="cuda:0",
30
+ quantize_config=None)
31
+ """
32
+
33
+ prompt = "Tell me about AI"
34
+ prompt_template=f'''{prompt}
35
+ '''
36
+
37
+ print("\n\n*** Generate:")
38
+
39
+ input_ids = tokenizer(prompt_template, return_tensors='pt').input_ids.cuda()
40
+ output = model.generate(inputs=input_ids, temperature=0.7, max_new_tokens=512)
41
+ print(tokenizer.decode(output[0]))
42
+
43
+ # Inference can also be done using transformers' pipeline
44
+
45
+ # Prevent printing spurious transformers error when using pipeline with AutoGPTQ
46
+ logging.set_verbosity(logging.CRITICAL)
47
+
48
+ print("*** Pipeline:")
49
+ pipe = pipeline(
50
+ "text-generation",
51
+ model=model,
52
+ tokenizer=tokenizer,
53
+ max_new_tokens=512,
54
+ temperature=0.7,
55
+ top_p=0.95,
56
+ repetition_penalty=1.15
57
+ )
58
+
59
+ print(pipe(prompt_template)[0]['generated_text'])