import torch from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig import gradio as gr import traceback # Update the model name to the quantized version model_name = "lei-HuggingFace/Qwen2-7B-4it-Chat_Level_Measurement_Guide_07222024" # Configure quantization bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16 ) def load_model(): try: model = AutoModelForCausalLM.from_pretrained( model_name, device_map="auto", quantization_config=bnb_config, trust_remote_code=True ) model.config.use_cache = True # Enable KV cache print(f"Model loaded successfully. Device: {model.device}") return model except Exception as e: print(f"Error loading model: {str(e)}") traceback.print_exc() return None try: tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False, trust_remote_code=True) print("Tokenizer loaded successfully") except Exception as e: print(f"Error loading tokenizer: {str(e)}") traceback.print_exc() tokenizer = None model = load_model() def generate_response(message, history): try: if model is None or tokenizer is None: return "Model or tokenizer failed to load. Please check the logs and try again." # Prepare the input messages = [] for h in history: messages.append({"role": "user", "content": h[0]}) messages.append({"role": "assistant", "content": h[1]}) messages.append({"role": "user", "content": message}) # Convert messages to model input format input_ids = tokenizer.apply_chat_template(messages, return_tensors="pt").to(model.device) # Generate response with torch.no_grad(): output = model.generate( input_ids, max_new_tokens=256, do_sample=True, temperature=0.7, top_p=0.95, top_k=40, pad_token_id=tokenizer.eos_token_id ) response = tokenizer.decode(output[0][input_ids.shape[1]:], skip_special_tokens=True) return response.strip() except Exception as e: error_message = f"Error generating response: {str(e)}" print(error_message) traceback.print_exc() return error_message # Create Gradio interface iface = gr.ChatInterface( generate_response, chatbot=gr.Chatbot(height=300), textbox=gr.Textbox(placeholder="Type your message here...", container=False, scale=7), title="Level Measurement Guide Chatbot (Optimized Quantized Model)", description="Chat with the optimized quantized fine-tuned Level Measurement Guide model.", theme="soft", examples=[ "What are the key considerations for level measurement in industrial settings?", "Can you explain the principle behind ultrasonic level sensors?", "What are the advantages of using radar level sensors?", ], cache_examples=False, retry_btn=None, undo_btn="Delete Previous", clear_btn="Clear", ) # Launch the interface iface.launch()