codewithRiz's picture
Update app.py
d04a377 verified
raw
history blame contribute delete
No virus
4.28 kB
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import re
from gtts import gTTS
import os
import logging
# Set up logging
logging.basicConfig(filename='app.log', level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
# Function to set up the model and tokenizer
def setup_model(model_name):
logging.info('Setting up model and tokenizer.')
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="auto",
trust_remote_code=False,
revision="main"
)
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
model.eval()
logging.info('Model and tokenizer setup completed.')
return model, tokenizer
# Function to generate a response from the model
def generate_response(model, tokenizer, prompt, max_new_tokens=140):
logging.info('Generating response for the prompt.')
inputs = tokenizer(prompt, return_tensors="pt")
outputs = model.generate(input_ids=inputs["input_ids"].to("cuda"), max_new_tokens=max_new_tokens)
response = tokenizer.batch_decode(outputs)[0]
# Extract only the response part (assuming everything after the last newline belongs to the response)
response_parts = response.split("\n")
logging.info('Response generated.')
return response_parts[-1] # Return the last element (response)
# Function to remove various tags using regular expressions
def remove_tags(text):
logging.info('Removing tags from the text.')
# Combine multiple tag removal patterns for broader coverage
tag_regex = r"<[^>]*>" # Standard HTML tags
custom_tag_regex = r"<.*?>|\[.*?\]|{\s*?\(.*?\)\s*}" # Custom, non-standard tags (may need adjustments)
all_tags_regex = f"{tag_regex}|{custom_tag_regex}" # Combine patterns
cleaned_text = re.sub(all_tags_regex, "", text)
logging.info('Tags removed.')
return cleaned_text
# Function to generate the audio file
def text_to_speech(text, filename="response.mp3"):
logging.info('Generating speech audio file.')
tts = gTTS(text)
tts.save(filename)
logging.info('Speech audio file saved.')
return filename
# Main function for the Gradio app
def main(comment):
logging.info('Main function triggered.')
instructions_string = (
"virtual marketer assistant, communicates in business, focused on services, "
"escalating to technical depth upon request. It reacts to feedback aptly and ends responses "
"with its signature Mr.jon will tailor the length of its responses to match the individual's comment, "
"providing concise acknowledgments to brief expressions of gratitude or feedback, thus keeping the interaction natural and supportive.\n"
)
model_name = "TheBloke/Mistral-7B-Instruct-v0.2-GPTQ"
try:
model, tokenizer = setup_model(model_name)
if comment:
prompt_template = lambda comment: f"[INST] {instructions_string} \n{comment} \n[/INST]"
prompt = prompt_template(comment)
response = generate_response(model, tokenizer, prompt)
# Apply tag removal before displaying the response
response_without_tags = remove_tags(response)
# Remove the "[/INST]" string at the end (assuming it's always present)
response_without_inst = response_without_tags.rstrip("[/INST]")
# Generate and return the response and the audio file
audio_file = text_to_speech(response_without_inst)
logging.info('Response and audio file generated.')
return response_without_inst, audio_file
else:
logging.warning('No comment entered.')
return "Please enter a comment to generate a response.", None
except Exception as e:
logging.error(f'Error occurred: {str(e)}')
return "An error occurred. Please try again later.", None
iface = gr.Interface(
fn=main,
inputs=gr.Textbox(lines=2, placeholder="Enter a comment..."),
outputs=["text", "file"],
title="Virtual Marketer Assistant",
description="Enter a comment and get a response from the virtual marketer assistant. Download the response as an MP3 file."
)
if __name__ == "__main__":
iface.launch(share=True)