chat-doctor-kr / app.py
ttagu99's picture
add requirement
679d911
raw
history blame
No virus
4 kB
# %%
import os, json, itertools, bisect, gc
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
import transformers
import torch
from accelerate import Accelerator
import accelerate
import time
import os
import gradio as gr
import requests
import random
# from dotenv import load_dotenv
import googletrans
translator = googletrans.Translator()
# load_dotenv()
model = None
tokenizer = None
generator = None
os.environ["CUDA_VISIBLE_DEVICES"]="1"
def load_model(model_name, eight_bit=0, device_map="auto"):
global model, tokenizer, generator
print("Loading "+model_name+"...")
if device_map == "zero":
device_map = "balanced_low_0"
# config
gpu_count = torch.cuda.device_count()
print('gpu_count', gpu_count)
print(model_name)
tokenizer = transformers.LLaMATokenizer.from_pretrained(model_name)
model = transformers.LLaMAForCausalLM.from_pretrained(
model_name,
#device_map=device_map,
#device_map="auto",
torch_dtype=torch.float16,
#max_memory = {0: "14GB", 1: "14GB", 2: "14GB", 3: "14GB",4: "14GB",5: "14GB",6: "14GB",7: "14GB"},
#load_in_8bit=eight_bit,
#from_tf=True,
low_cpu_mem_usage=True,
load_in_8bit=False,
cache_dir="cache"
).cuda()
generator = model.generate
# chat doctor
def chatdoctor(input, state):
# print('input',input)
# history = history or []
print('state',state)
invitation = "ChatDoctor: "
human_invitation = "Patient: "
fulltext = "If you are a doctor, please answer the medical questions based on the patient's description. \n\n"
for i in range(len(state)):
if i % 2:
fulltext += human_invitation + state[i] + "\n\n"
else:
fulltext += invitation + state[i] + "\n\n"
fulltext += human_invitation + input + "\n\n"
fulltext += invitation
print('fulltext: ',fulltext)
generated_text = ""
gen_in = tokenizer(fulltext, return_tensors="pt").input_ids.cuda()
in_tokens = len(gen_in)
print('len token',in_tokens)
with torch.no_grad():
generated_ids = generator(
gen_in,
max_new_tokens=200,
use_cache=True,
pad_token_id=tokenizer.eos_token_id,
num_return_sequences=1,
do_sample=True,
repetition_penalty=1.1, # 1.0 means 'off'. unfortunately if we penalize it it will not output Sphynx:
temperature=0.5, # default: 1.0
top_k = 50, # default: 50
top_p = 1.0, # default: 1.0
early_stopping=True,
)
generated_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] # for some reason, batch_decode returns an array of one element?
text_without_prompt = generated_text[len(fulltext):]
response = text_without_prompt
response = response.split(human_invitation)[0]
response.strip()
print(invitation + response)
print("")
return response
def predict(input, chatbot, state):
print('predict state: ', state)
en_input = translator.translate(input, src='ko', dest='en').text
response = chatdoctor(en_input, state)
ko_response = translator.translate(response, src='en', dest='ko').text
state.append(response)
chatbot.append((input, ko_response))
return chatbot, state
load_model("zl111/ChatDoctor")
with gr.Blocks() as demo:
gr.Markdown("""<h1><center>μ±— λ‹₯ν„°μž…λ‹ˆλ‹€. μ–΄λ””κ°€ λΆˆνŽΈν•˜μ‹ κ°€μš”?</center></h1>
""")
chatbot = gr.Chatbot()
state = gr.State([])
with gr.Row():
txt = gr.Textbox(show_label=False, placeholder="여기에 μ§ˆλ¬Έμ„ μ“°κ³  μ—”ν„°").style(container=False)
clear = gr.Button("상담 μƒˆλ‘œ μ‹œμž‘")
txt.submit(predict, inputs=[txt, chatbot, state], outputs=[chatbot, state]
)
clear.click(lambda: None, None, chatbot, queue=False)
demo.launch(share=True)