File size: 2,853 Bytes
63e26f5 02d3919 63e26f5 02d3919 63e26f5 02d3919 63e26f5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 |
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import gradio as gr
import traceback
# Load model and tokenizer
model_name = "lei-HuggingFace/SFT_Level_Measurement_Guide_07182024"
def load_model():
try:
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="auto",
torch_dtype=torch.float16,
trust_remote_code=True
)
print(f"Model loaded successfully. Device: {model.device}")
return model
except Exception as e:
print(f"Error loading model: {str(e)}")
traceback.print_exc()
return None
try:
tokenizer = AutoTokenizer.from_pretrained(model_name)
print("Tokenizer loaded successfully")
except Exception as e:
print(f"Error loading tokenizer: {str(e)}")
traceback.print_exc()
tokenizer = None
model = load_model()
def generate_response(message, history):
try:
# Prepare the input
messages = []
for h in history:
messages.append({"role": "user", "content": h[0]})
messages.append({"role": "assistant", "content": h[1]})
messages.append({"role": "user", "content": message})
print(f"Prepared messages: {messages}")
# Convert messages to model input format
input_ids = tokenizer.apply_chat_template(messages, return_tensors="pt").to(model.device)
print(f"Input shape: {input_ids.shape}")
# Generate response
with torch.no_grad():
output = model.generate(
input_ids,
max_new_tokens=500,
do_sample=True,
temperature=0.7,
pad_token_id=tokenizer.eos_token_id
)
response = tokenizer.decode(output[0][input_ids.shape[1]:], skip_special_tokens=True)
print(f"Generated response: {response}")
return response.strip()
except Exception as e:
error_message = f"Error generating response: {str(e)}"
print(error_message)
traceback.print_exc()
return error_message
# Create Gradio interface
iface = gr.ChatInterface(
generate_response,
chatbot=gr.Chatbot(height=300),
textbox=gr.Textbox(placeholder="Type your message here...", container=False, scale=7),
title="Level Measurement Guide Chatbot",
description="Chat with the fine-tuned Level Measurement Guide model.",
theme="soft",
examples=[
"What are the key considerations for level measurement in industrial settings?",
"Can you explain the principle behind ultrasonic level sensors?",
"What are the advantages of using radar level sensors?",
],
cache_examples=False,
retry_btn=None,
undo_btn="Delete Previous",
clear_btn="Clear",
)
# Launch the interface
iface.launch(share=True, debug=True) |