File size: 1,993 Bytes
c71fd53
b3fa900
c71fd53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b3fa900
c71fd53
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
#Importing all the necessary packages
import gradio as gr
import torch, librosa, torchaudio
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
from pyctcdecode import build_ctcdecoder

# Define ASR MODEL
class Speech2Text:
    def __init__(self):
        self.vocab = list(processor.tokenizer.get_vocab().keys())
        self.decoder = build_ctcdecoder(self.vocab, kenlm_model_path=None)

    def wav2feature(self, path):
        speech_array, sampling_rate = torchaudio.load(path)
        speech_array = librosa.resample(speech_array.squeeze().numpy(),
                                        sampling_rate, processor.feature_extractor.sampling_rate)
        return processor(speech_array, return_tensors="pt",
                         sampling_rate=processor.feature_extractor.sampling_rate)

    def feature2logits(self, features):
        with torch.no_grad():
            return model(features.input_values.to(device),
                         attention_mask=features.attention_mask.to(device)).logits.numpy()[0]

    def __call__(self, path):
        logits = self.feature2logits(self.wav2feature(path))
        return self.decoder.decode(logits)
        
#Loading the model and the tokenizer
model_name = 'masoudmzb/wav2vec2-xlsr-multilingual-53-fa'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 
wav2vec_model = Wav2Vec2ForCTC.from_pretrained(model_name).to(device).eval()
processor = Wav2Vec2Processor.from_pretrained(model_name)
s2t = Speech2Text()

  
# themes="default", "huggingface", "seafoam", "grass", "peach"
gr.Interface(s2t,
             inputs = gr.inputs.Audio(source="microphone", type="filepath", optional=True, label="Record Your Beautiful Persian Voice"),
             outputs = gr.outputs.Textbox(label="Output Text"),
             title="Persian ASR using Wav2Vec 2.0",
             description = "This application displays transcribed text for given audio input",
             examples = [["Test_File1.wav"]], theme="grass").launch()