ttagu99 commited on
Commit
294ad84
β€’
1 Parent(s): edadc60
Files changed (1) hide show
  1. chat.py +121 -0
chat.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %%
2
+ import os, json, itertools, bisect, gc
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
4
+ import transformers
5
+ import torch
6
+ from accelerate import Accelerator
7
+ import accelerate
8
+ import time
9
+ import os
10
+ import gradio as gr
11
+ import requests
12
+ import random
13
+ from dotenv import load_dotenv
14
+ import googletrans
15
+ translator = googletrans.Translator()
16
+
17
+ load_dotenv()
18
+ model = None
19
+ tokenizer = None
20
+ generator = None
21
+
22
+ os.environ["CUDA_VISIBLE_DEVICES"]="1"
23
+
24
+ def load_model(model_name, eight_bit=0, device_map="auto"):
25
+ global model, tokenizer, generator
26
+ print("Loading "+model_name+"...")
27
+
28
+ if device_map == "zero":
29
+ device_map = "balanced_low_0"
30
+
31
+ # config
32
+ gpu_count = torch.cuda.device_count()
33
+ print('gpu_count', gpu_count)
34
+
35
+ print(model_name)
36
+ tokenizer = transformers.LLaMATokenizer.from_pretrained(model_name)
37
+ model = transformers.LLaMAForCausalLM.from_pretrained(
38
+ model_name,
39
+ #device_map=device_map,
40
+ #device_map="auto",
41
+ torch_dtype=torch.float16,
42
+ #max_memory = {0: "14GB", 1: "14GB", 2: "14GB", 3: "14GB",4: "14GB",5: "14GB",6: "14GB",7: "14GB"},
43
+ #load_in_8bit=eight_bit,
44
+ #from_tf=True,
45
+ low_cpu_mem_usage=True,
46
+ load_in_8bit=False,
47
+ cache_dir="cache"
48
+ ).cuda()
49
+ generator = model.generate
50
+
51
+ # chat doctor
52
+ def chatdoctor(input, state):
53
+ # print('input',input)
54
+ # history = history or []
55
+ print('state',state)
56
+
57
+ invitation = "ChatDoctor: "
58
+ human_invitation = "Patient: "
59
+ fulltext = "If you are a doctor, please answer the medical questions based on the patient's description. \n\n"
60
+
61
+ for i in range(len(state)):
62
+ if i % 2:
63
+ fulltext += human_invitation + state[i] + "\n\n"
64
+ else:
65
+ fulltext += invitation + state[i] + "\n\n"
66
+ fulltext += human_invitation + input + "\n\n"
67
+ fulltext += invitation
68
+ print('fulltext: ',fulltext)
69
+
70
+ generated_text = ""
71
+ gen_in = tokenizer(fulltext, return_tensors="pt").input_ids.cuda()
72
+ in_tokens = len(gen_in)
73
+ print('len token',in_tokens)
74
+ with torch.no_grad():
75
+ generated_ids = generator(
76
+ gen_in,
77
+ max_new_tokens=200,
78
+ use_cache=True,
79
+ pad_token_id=tokenizer.eos_token_id,
80
+ num_return_sequences=1,
81
+ do_sample=True,
82
+ repetition_penalty=1.1, # 1.0 means 'off'. unfortunately if we penalize it it will not output Sphynx:
83
+ temperature=0.5, # default: 1.0
84
+ top_k = 50, # default: 50
85
+ top_p = 1.0, # default: 1.0
86
+ early_stopping=True,
87
+ )
88
+ generated_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] # for some reason, batch_decode returns an array of one element?
89
+ text_without_prompt = generated_text[len(fulltext):]
90
+ response = text_without_prompt
91
+ response = response.split(human_invitation)[0]
92
+ response.strip()
93
+ print(invitation + response)
94
+ print("")
95
+ return response
96
+
97
+
98
+ def predict(input, chatbot, state):
99
+ print('predict state: ', state)
100
+ en_input = translator.translate(input, src='ko', dest='en').text
101
+ response = chatdoctor(en_input, state)
102
+ ko_response = translator.translate(response, src='en', dest='ko').text
103
+ state.append(response)
104
+ chatbot.append((input, ko_response))
105
+ return chatbot, state
106
+
107
+ load_model("./ChatDoctor/pretrained/")
108
+
109
+ with gr.Blocks() as demo:
110
+ gr.Markdown("""<h1><center>μ±— λ‹₯ν„°μž…λ‹ˆλ‹€. μ–΄λ””κ°€ λΆˆνŽΈν•˜μ‹ κ°€μš”?</center></h1>
111
+ """)
112
+ chatbot = gr.Chatbot()
113
+ state = gr.State([])
114
+ with gr.Row():
115
+ txt = gr.Textbox(show_label=False, placeholder="여기에 μ§ˆλ¬Έμ„ μ“°κ³  μ—”ν„°").style(container=False)
116
+ clear = gr.Button("상담 μƒˆλ‘œ μ‹œμž‘")
117
+ txt.submit(predict, inputs=[txt, chatbot, state], outputs=[chatbot, state]
118
+ )
119
+ clear.click(lambda: None, None, chatbot, queue=False)
120
+ demo.launch(share=True)
121
+