File size: 4,868 Bytes
a8286dc
604efba
0deb309
604efba
 
a8286dc
604efba
 
6160888
604efba
 
 
6160888
 
604efba
6160888
 
046c2b1
6160888
 
 
 
 
604efba
6160888
927a24e
 
6160888
 
604efba
6160888
 
 
 
a8286dc
0deb309
6160888
 
 
 
 
 
 
 
 
 
604efba
6160888
 
 
604efba
6160888
c5a564e
6160888
8eaeb08
6160888
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b35863e
6160888
a031984
6160888
 
 
 
a031984
6160888
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dd54457
6160888
dd54457
6160888
dd54457
6160888
 
 
 
 
 
 
 
 
 
 
927a24e
 
6160888
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import gradio as gr
import json
import librosa
import os
import soundfile as sf
import tempfile
import uuid
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from nemo.collections.asr.models import ASRModel
from nemo.collections.asr.parts.utils.streaming_utils import FrameBatchMultiTaskAED
from nemo.collections.asr.parts.utils.transcribe_utils import get_buffered_pred_feat_multitaskAED
from transformers import VitsTokenizer, VitsModel, set_seed
import scipy.io.wavfile as wav

# Constants
SAMPLE_RATE = 16000  # Hz

# Load ASR model
asr_model = ASRModel.from_pretrained("nvidia/canary-1b")
asr_model.eval()
asr_model.change_decoding_strategy(None)
decoding_cfg = asr_model.cfg.decoding
decoding_cfg.beam.beam_size = 1
asr_model.change_decoding_strategy(decoding_cfg)
asr_model.cfg.preprocessor.dither = 0.0
asr_model.cfg.preprocessor.pad_to = 0
feature_stride = asr_model.cfg.preprocessor['window_stride']
model_stride_in_secs = feature_stride * 8
frame_asr = FrameBatchMultiTaskAED(
    asr_model=asr_model,
    frame_len=40.0,
    total_buffer=40.0,
    batch_size=16,
)

# Load LLM model
torch.random.manual_seed(0)
llm_model = AutoModelForCausalLM.from_pretrained(
    "microsoft/Phi-3-mini-128k-instruct", 
    device_map="auto", 
    torch_dtype="auto", 
    trust_remote_code=True,
)
tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-128k-instruct")
pipe = pipeline("text-generation", model=llm_model, tokenizer=tokenizer)

# Load TTS model
tts_tokenizer = VitsTokenizer.from_pretrained("facebook/mms-tts-eng")
tts_model = VitsModel.from_pretrained("facebook/mms-tts-eng")

# Function to convert audio to text using ASR
def transcribe(audio_filepath):
    if audio_filepath is None:
        raise gr.Error("Please provide some input audio.")
    
    utt_id = uuid.uuid4()
    with tempfile.TemporaryDirectory() as tmpdir:
        # Convert to 16 kHz
        data, sr = librosa.load(audio_filepath, sr=None, mono=True)
        if sr != SAMPLE_RATE:
            data = librosa.resample(data, orig_sr=sr, target_sr=SAMPLE_RATE)
        converted_audio_filepath = os.path.join(tmpdir, f"{utt_id}.wav")
        sf.write(converted_audio_filepath, data, SAMPLE_RATE)

        # Transcribe audio
        duration = len(data) / SAMPLE_RATE
        manifest_data = {
            "audio_filepath": converted_audio_filepath,
            "source_lang": "en",
            "target_lang": "en",
            "taskname": "asr",
            "pnc": "no",
            "answer": "predict",
            "duration": str(duration),
        }
        manifest_filepath = os.path.join(tmpdir, f"{utt_id}.json")
        with open(manifest_filepath, 'w') as fout:
            fout.write(json.dumps(manifest_data))
        
        if duration < 40:
            transcription = asr_model.transcribe(manifest_filepath)[0]
        else:
            transcription = get_buffered_pred_feat_multitaskAED(
                frame_asr,
                asr_model.cfg.preprocessor,
                model_stride_in_secs,
                asr_model.device,
                manifest=manifest_filepath,
            )[0].text
    
    return transcription

# Function to generate text using LLM
def generate_text(input_text):
    messages=input_text
    generation_args = {
        "max_new_tokens": 200,
        "return_full_text": True,
        "temperature": 0.0,
        "do_sample": False,
    }
    generated_text = pipe(messages, **generation_args)[0]["generated_text"]
    return generated_text

# Function to convert text to speech using TTS
def gen_speech(text):
    set_seed(555)  # Make it deterministic
    input_text = tts_tokenizer(text, return_tensors="pt")
    with torch.no_grad():
        outputs = tts_model(**input_text)
    waveform_np = outputs.waveform[0].cpu().numpy()
    output_file = f"{str(uuid.uuid4())}.wav"
    wav.write(output_file, rate=tts_model.config.sampling_rate, data=waveform_np)
    return output_file

# Combined function for Gradio interface
def process_audio(audio_filepath):
    transcription = transcribe(audio_filepath)
    print("Done transcribing")
    generated_text = generate_text(transcription)
    print("Done generating")
    audio_output_filepath = gen_speech(generated_text)
    print("Done speaking")
    return transcription, generated_text, audio_output_filepath

# Create Gradio interface
gr.Interface(
    fn=process_audio,
    inputs=[gr.Audio(sources=["microphone"], type="filepath", label="Input Audio")],
    outputs=[
        gr.Textbox(label="Transcription"),
        gr.Textbox(label="Generated Text"),
        gr.Audio(type="filepath", label="Generated Speech")
    ],
    title="YOUR AWESOME AI ASSISTANT",
    description="Gets input audio from user, transcribe it with ASR Canary1b, generate text with Phi3LLM, and convert it back to speech with VITS TTS."
).launch(inbrowser=True)