Text2Pandas / app.py
zeyadusf's picture
Update app.py
45e6476 verified
raw
history blame
No virus
2.13 kB
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
# Load model and tokenizer
def load_model():
model_name = "zeyadusf/text2pandas-T5"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
return model, tokenizer
model, tokenizer = load_model()
# Define the function to generate text
def generate_text(question, context, max_length=512, num_beams=4, early_stopping=True):
input_text = f"<question> {question} <context> {context}"
inputs = tokenizer.encode(input_text, return_tensors="pt", truncation=True, max_length=max_length).to(model.device)
with torch.no_grad():
outputs = model.generate(inputs, max_length=max_length, num_beams=num_beams, early_stopping=early_stopping)
predicted_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
return predicted_text
# Gradio interface
def gradio_interface(question, context, max_length, num_beams, early_stopping):
return generate_text(question, context, max_length, num_beams, early_stopping)
# Gradio UI Components
question_input = gr.Textbox(label="Enter the Question", value="what is the total amount of players for the rockets in 1998 only?")
context_input = gr.Textbox(label="Enter the Context", value="df = pd.DataFrame(columns=['player', 'years_for_rockets'])")
max_length_input = gr.Slider(minimum=50, maximum=1024, value=512, label="Max Length", step=1) # Step=1 makes it integer only
num_beams_input = gr.Slider(minimum=1, maximum=10, value=4, label="Number of Beams", step=1) # Step=1 makes it integer only
early_stopping_input = gr.Checkbox(value=True, label="Early Stopping")
# Create Gradio Interface
gr.Interface(
fn=gradio_interface,
inputs=[question_input, context_input, max_length_input, num_beams_input, early_stopping_input],
outputs="text",
title="Text to Pandas Code Generator",
description="Generate Pandas code by providing a question and a context."
).launch()