Spaces:
Sleeping
Sleeping
import torch | |
from transformers import T5TokenizerFast, T5ForConditionalGeneration | |
from torch.quantization import quantize_dynamic | |
import warnings | |
import gradio as gr | |
# Constantes de configuración | |
MODEL_NAME = 'google/flan-t5-base' | |
MODEL_PATH = 'models/flan-t5-base-quantized.pth' | |
QUANTIZE_DTYPE = torch.qint8 | |
QUANTIZE_MODULES = {torch.nn.Linear} | |
MAX_INPUT_LENGTH = 512 | |
SUMMARY_LENGTH = 250 | |
# Suprimir avisos específicos de TypedStorage | |
warnings.filterwarnings("ignore", category=UserWarning, message=".*TypedStorage is deprecated.*") | |
# Cargar el tokenizer y el modelo | |
tokenizer = T5TokenizerFast.from_pretrained(MODEL_NAME) | |
model = T5ForConditionalGeneration.from_pretrained(MODEL_NAME) | |
model.eval() # Preparar el modelo para evaluación | |
# Aplicar cuantificación dinámica al modelo | |
quantized_model = quantize_dynamic(model, QUANTIZE_MODULES, dtype=QUANTIZE_DTYPE) | |
# Cargar el estado del modelo cuantificado | |
try: | |
quantized_model.load_state_dict(torch.load(MODEL_PATH)) | |
print("Modelo cuantificado cargado correctamente.") | |
except Exception as e: | |
print("Error al cargar el modelo cuantificado:", e) | |
# Función para resumir texto usando el modelo cuantificado | |
def summarize(text, max_length): | |
inputs = tokenizer([text], max_length=MAX_INPUT_LENGTH, return_tensors="pt", truncation=True) | |
summary_ids = quantized_model.generate(inputs["input_ids"], num_beams=4, max_length=max_length, early_stopping=True) | |
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True) | |
return summary | |
# Prueba directa de la función summarize | |
try: | |
test_text = "The quick brown fox jumps over the lazy dog multiple times. This is a simple sentence used for illustration." | |
summary_length = 80 # Definir la longitud del resumen para la prueba | |
print("Texto:", test_text) | |
print("Resumen:", summarize(test_text, summary_length)) # Añadir summary_length al llamado | |
except Exception as e: | |
print("Error al generar el resumen:", e) | |
# Crear interfaz de Gradio | |
interface = gr.Interface( | |
fn=summarize, | |
inputs=[gr.Textbox(lines=2, placeholder="Enter text to summarize..."), gr.Slider(10, 512, step=10, value=256, label="Max summary length")], | |
outputs="text", | |
title="Text Summarizer", | |
description=f"A simple text summarizer based on the quantized {MODEL_NAME}. Adjust the slider to set the summary length." | |
) | |
# Lanzar la aplicación con Gradio | |
if __name__ == "__main__": | |
interface.launch(debug=True) | |