mfidabel's picture
Create app.py
9d6f79c verified
raw
history blame
No virus
1.55 kB
import gradio as gr
import numpy as np
import torch
from peft import PeftModel, PeftConfig
from transformers import WhisperForConditionalGeneration, WhisperTokenizer, WhisperProcessor, AutomaticSpeechRecognitionPipeline
peft_model_id = "mfidabel/Modelo_1_Whisper_Large_V3"
language = "guarani"
task = "transcribe"
peft_config = PeftConfig.from_pretrained(peft_model_id)
model = WhisperForConditionalGeneration.from_pretrained(
peft_config.base_model_name_or_path, load_in_8bit=True, device_map="auto"
)
model = PeftModel.from_pretrained(model, peft_model_id)
tokenizer = WhisperTokenizer.from_pretrained(peft_config.base_model_name_or_path, language=language, task=task)
processor = WhisperProcessor.from_pretrained(peft_config.base_model_name_or_path, language=language, task=task)
feature_extractor = processor.feature_extractor
forced_decoder_ids = processor.get_decoder_prompt_ids(language="english", task=task)
pipeline = AutomaticSpeechRecognitionPipeline(model=model, tokenizer=tokenizer, feature_extractor=feature_extractor)
def transcribe(audio):
if audio is None:
return "Espera a que la grabación termine de subirse al servidor !! Intentelo de nuevo en unos segundos"
sr, y = audio
y = y.astype(np.float32)
y /= np.max(np.abs(y))
with torch.cuda.amp.autocast():
return pipeline({"sampling_rate": sr, "raw": y}, generate_kwargs={"forced_decoder_ids": forced_decoder_ids}, max_new_tokens=255)["text"]
gr.Interface(fn=transcribe, inputs="microphone", outputs="text").launch(share=True)