Spaces:
Sleeping
Sleeping
import torch | |
from transformers import T5TokenizerFast, T5ForConditionalGeneration | |
from torch.quantization import quantize_dynamic | |
import warnings | |
import gradio as gr | |
# Suprimir específicamente los avisos de TypedStorage | |
warnings.filterwarnings("ignore", category=UserWarning, message=".*TypedStorage is deprecated.*") | |
# Cargar el tokenizer y el modelo base | |
tokenizer = T5TokenizerFast.from_pretrained('google/flan-t5-small') | |
model = T5ForConditionalGeneration.from_pretrained('google/flan-t5-small') | |
model.eval() # Preparar el modelo para evaluación | |
# Aplicar cuantificación dinámica al modelo ANTES de cargar el estado del modelo cuantificado | |
quantized_model = quantize_dynamic( | |
model, {torch.nn.Linear}, dtype=torch.qint8 | |
) | |
# Cargar el estado del modelo cuantificado | |
quantized_model.load_state_dict(torch.load('flan-t5-small-quantized.pth')) | |
# Función para resumir texto usando Flan-T5 cuantificado | |
def summarize(text): | |
inputs = tokenizer([text], max_length=1024, return_tensors="pt", truncation=True) | |
summary_ids = quantized_model.generate(inputs["input_ids"], num_beams=4, max_length=200, early_stopping=True) | |
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True) | |
return summary | |
# Crear interfaz de Gradio | |
interface = gr.Interface( | |
fn=summarize, | |
inputs=gr.inputs.Textbox(lines=2, placeholder="Enter text to summarize..."), | |
outputs="text", | |
title="Text Summarizer", | |
description="A simple text summarizer based on the quantized Flan-T5 model." | |
) | |
# Lanzar la aplicación con Gradio, preparada para ser alojada en Hugging Face Spaces | |
if __name__ == "__main__": | |
interface.launch() | |