t5 / app.py
jordigonzm's picture
Update app.py
25dc2bf verified
raw
history blame
No virus
2.47 kB
import torch
from transformers import T5TokenizerFast, T5ForConditionalGeneration
from torch.quantization import quantize_dynamic
import warnings
import gradio as gr
# Constantes de configuración
MODEL_NAME = 'google/flan-t5-base'
MODEL_PATH = 'models/flan-t5-base-quantized.pth'
QUANTIZE_DTYPE = torch.qint8
QUANTIZE_MODULES = {torch.nn.Linear}
MAX_INPUT_LENGTH = 512
SUMMARY_LENGTH = 250
# Suprimir avisos específicos de TypedStorage
warnings.filterwarnings("ignore", category=UserWarning, message=".*TypedStorage is deprecated.*")
# Cargar el tokenizer y el modelo
tokenizer = T5TokenizerFast.from_pretrained(MODEL_NAME)
model = T5ForConditionalGeneration.from_pretrained(MODEL_NAME)
model.eval() # Preparar el modelo para evaluación
# Aplicar cuantificación dinámica al modelo
quantized_model = quantize_dynamic(model, QUANTIZE_MODULES, dtype=QUANTIZE_DTYPE)
# Cargar el estado del modelo cuantificado
try:
quantized_model.load_state_dict(torch.load(MODEL_PATH))
print("Modelo cuantificado cargado correctamente.")
except Exception as e:
print("Error al cargar el modelo cuantificado:", e)
# Función para resumir texto usando el modelo cuantificado
def summarize(text, max_length):
inputs = tokenizer([text], max_length=MAX_INPUT_LENGTH, return_tensors="pt", truncation=True)
summary_ids = quantized_model.generate(inputs["input_ids"], num_beams=4, max_length=max_length, early_stopping=True)
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
return summary
# Prueba directa de la función summarize
try:
test_text = "The quick brown fox jumps over the lazy dog multiple times. This is a simple sentence used for illustration."
summary_length = 80 # Definir la longitud del resumen para la prueba
print("Texto:", test_text)
print("Resumen:", summarize(test_text, summary_length)) # Añadir summary_length al llamado
except Exception as e:
print("Error al generar el resumen:", e)
# Crear interfaz de Gradio
interface = gr.Interface(
fn=summarize,
inputs=[gr.Textbox(lines=2, placeholder="Enter text to summarize..."), gr.Slider(10, 512, step=10, value=256, label="Max summary length")],
outputs="text",
title="Text Summarizer",
description=f"A simple text summarizer based on the quantized {MODEL_NAME}. Adjust the slider to set the summary length."
)
# Lanzar la aplicación con Gradio
if __name__ == "__main__":
interface.launch(debug=True)