my-alexa / app.py
jiuuee's picture
Update app.py
8eaeb08 verified
raw
history blame contribute delete
No virus
4.87 kB
import gradio as gr
import json
import librosa
import os
import soundfile as sf
import tempfile
import uuid
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from nemo.collections.asr.models import ASRModel
from nemo.collections.asr.parts.utils.streaming_utils import FrameBatchMultiTaskAED
from nemo.collections.asr.parts.utils.transcribe_utils import get_buffered_pred_feat_multitaskAED
from transformers import VitsTokenizer, VitsModel, set_seed
import scipy.io.wavfile as wav
# Constants
SAMPLE_RATE = 16000 # Hz
# Load ASR model
asr_model = ASRModel.from_pretrained("nvidia/canary-1b")
asr_model.eval()
asr_model.change_decoding_strategy(None)
decoding_cfg = asr_model.cfg.decoding
decoding_cfg.beam.beam_size = 1
asr_model.change_decoding_strategy(decoding_cfg)
asr_model.cfg.preprocessor.dither = 0.0
asr_model.cfg.preprocessor.pad_to = 0
feature_stride = asr_model.cfg.preprocessor['window_stride']
model_stride_in_secs = feature_stride * 8
frame_asr = FrameBatchMultiTaskAED(
asr_model=asr_model,
frame_len=40.0,
total_buffer=40.0,
batch_size=16,
)
# Load LLM model
torch.random.manual_seed(0)
llm_model = AutoModelForCausalLM.from_pretrained(
"microsoft/Phi-3-mini-128k-instruct",
device_map="auto",
torch_dtype="auto",
trust_remote_code=True,
)
tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-128k-instruct")
pipe = pipeline("text-generation", model=llm_model, tokenizer=tokenizer)
# Load TTS model
tts_tokenizer = VitsTokenizer.from_pretrained("facebook/mms-tts-eng")
tts_model = VitsModel.from_pretrained("facebook/mms-tts-eng")
# Function to convert audio to text using ASR
def transcribe(audio_filepath):
if audio_filepath is None:
raise gr.Error("Please provide some input audio.")
utt_id = uuid.uuid4()
with tempfile.TemporaryDirectory() as tmpdir:
# Convert to 16 kHz
data, sr = librosa.load(audio_filepath, sr=None, mono=True)
if sr != SAMPLE_RATE:
data = librosa.resample(data, orig_sr=sr, target_sr=SAMPLE_RATE)
converted_audio_filepath = os.path.join(tmpdir, f"{utt_id}.wav")
sf.write(converted_audio_filepath, data, SAMPLE_RATE)
# Transcribe audio
duration = len(data) / SAMPLE_RATE
manifest_data = {
"audio_filepath": converted_audio_filepath,
"source_lang": "en",
"target_lang": "en",
"taskname": "asr",
"pnc": "no",
"answer": "predict",
"duration": str(duration),
}
manifest_filepath = os.path.join(tmpdir, f"{utt_id}.json")
with open(manifest_filepath, 'w') as fout:
fout.write(json.dumps(manifest_data))
if duration < 40:
transcription = asr_model.transcribe(manifest_filepath)[0]
else:
transcription = get_buffered_pred_feat_multitaskAED(
frame_asr,
asr_model.cfg.preprocessor,
model_stride_in_secs,
asr_model.device,
manifest=manifest_filepath,
)[0].text
return transcription
# Function to generate text using LLM
def generate_text(input_text):
messages=input_text
generation_args = {
"max_new_tokens": 200,
"return_full_text": True,
"temperature": 0.0,
"do_sample": False,
}
generated_text = pipe(messages, **generation_args)[0]["generated_text"]
return generated_text
# Function to convert text to speech using TTS
def gen_speech(text):
set_seed(555) # Make it deterministic
input_text = tts_tokenizer(text, return_tensors="pt")
with torch.no_grad():
outputs = tts_model(**input_text)
waveform_np = outputs.waveform[0].cpu().numpy()
output_file = f"{str(uuid.uuid4())}.wav"
wav.write(output_file, rate=tts_model.config.sampling_rate, data=waveform_np)
return output_file
# Combined function for Gradio interface
def process_audio(audio_filepath):
transcription = transcribe(audio_filepath)
print("Done transcribing")
generated_text = generate_text(transcription)
print("Done generating")
audio_output_filepath = gen_speech(generated_text)
print("Done speaking")
return transcription, generated_text, audio_output_filepath
# Create Gradio interface
gr.Interface(
fn=process_audio,
inputs=[gr.Audio(sources=["microphone"], type="filepath", label="Input Audio")],
outputs=[
gr.Textbox(label="Transcription"),
gr.Textbox(label="Generated Text"),
gr.Audio(type="filepath", label="Generated Speech")
],
title="YOUR AWESOME AI ASSISTANT",
description="Gets input audio from user, transcribe it with ASR Canary1b, generate text with Phi3LLM, and convert it back to speech with VITS TTS."
).launch(inbrowser=True)