alakxender's picture
gpu
0bbc91c
import spaces
import gradio as gr
import librosa
import torch
from transformers import Wav2Vec2ForCTC, AutoProcessor
from huggingface_hub import hf_hub_download
from torchaudio.models.decoder import ctc_decoder
# https://github.com/facebookresearch/fairseq/tree/main/examples/mms/zero_shot
ASR_SAMPLING_RATE = 16_000
WORD_SCORE_DEFAULT_IF_LM = -0.18
WORD_SCORE_DEFAULT_IF_NOLM = -3.5
LM_SCORE_DEFAULT = 1.48
MODEL_ID = "mms-meta/mms-zeroshot-300m"
processor = AutoProcessor.from_pretrained(MODEL_ID)
model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID)
token_file = hf_hub_download(
repo_id=MODEL_ID,
filename="tokens.txt",
)
lm5gram = hf_hub_download(
repo_id="alakxender/w2v-bert-2.0-dhivehi-syn",
filename="language_model/5gram.bin",
)
lex_files = [
"dv.domain.news.small.v1.lexicon",
"dv.domain.news.small.v2.lexicon",
"dv.domain.news.large.v1.lexicon",
"dv.domain.stories.small.v1.lexicon",
]
lexicon_file = hf_hub_download(
repo_type="dataset",
repo_id="alakxender/dv-domain-lexicons",
filename=lex_files[0],
)
@spaces.GPU
def transcribe(
audio_data,
wscore=None,
lmscore=None,
wscore_usedefault=True,
lmscore_usedefault=True,
uselm=True,
reference=None,
):
if not audio_data:
yield "ERROR: Empty audio data"
return
# audio
if isinstance(audio_data, tuple):
# microphone
sr, audio_samples = audio_data
audio_samples = (audio_samples / 32768.0).astype(float)
if sr != ASR_SAMPLING_RATE:
audio_samples = librosa.resample(
audio_samples, orig_sr=sr, target_sr=ASR_SAMPLING_RATE
)
else:
# file upload
assert isinstance(audio_data, str)
audio_samples = librosa.load(audio_data, sr=ASR_SAMPLING_RATE, mono=True)[0]
inputs = processor(
audio_samples, sampling_rate=ASR_SAMPLING_RATE, return_tensors="pt"
)
# set device
if torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
model.to(device)
inputs = inputs.to(device)
with torch.no_grad():
outputs = model(**inputs).logits
# params
if uselm == True:
lm_path=lm5gram
else:
lm_path=None
if lm_path is not None and not lm_path.strip():
lm_path = None
if wscore_usedefault:
wscore = (
WORD_SCORE_DEFAULT_IF_LM
if lm_path is not None
else WORD_SCORE_DEFAULT_IF_NOLM
)
if lmscore_usedefault:
lmscore = LM_SCORE_DEFAULT if lm_path is not None else 0
beam_search_decoder = ctc_decoder(
lexicon=lexicon_file,
tokens=token_file,
lm=lm_path,
nbest=1,
beam_size=500,
beam_size_token=50,
lm_weight=lmscore,
word_score=wscore,
sil_score=0,
blank_token="<s>",
)
beam_search_result = beam_search_decoder(outputs.to("cpu"))
transcription = " ".join(beam_search_result[0][0].words).strip()
yield transcription
styles = """
.thaana textarea {
font-size: 18px !important;
font-family: 'MV_Faseyha', 'Faruma', 'A_Faruma' !important;
line-height: 1.8 !important;
}
.textbox2 textarea {
display: none;
}
"""
with gr.Blocks(css=styles) as demo:
gr.Markdown("# <center> Transcribe Dhivehi Audio with MMS-ZEROSHOT</center>")
with gr.Row():
with gr.Column():
audio = gr.Audio(label="Audio Input\n(use microphone or upload a file)",min_length=1,max_length=60)
with gr.Accordion("Advanced Settings", open=False):
gr.Markdown(
"The following parameters are used for beam-search decoding. Use the default values if you are not sure."
)
with gr.Row():
with gr.Column():
wscore_usedefault = gr.Checkbox(
label="Use Default Word Insertion Score", value=True
)
wscore = gr.Slider(
minimum=-10.0,
maximum=10.0,
value=WORD_SCORE_DEFAULT_IF_LM,
step=0.1,
interactive=False,
label="Word Insertion Score",
)
with gr.Column():
lmscore_usedefault = gr.Checkbox(
label="Use Default Language Model Score", value=True
)
lmscore = gr.Slider(
minimum=-10.0,
maximum=10.0,
value=LM_SCORE_DEFAULT,
step=0.1,
interactive=False,
label="Language Model Score",
)
with gr.Column():
uselm = gr.Checkbox(
label="Use LM",
value=True,
)
btn = gr.Button("Submit", elem_id="submit")
@gr.on(
inputs=[wscore_usedefault, lmscore_usedefault, uselm],
outputs=[wscore, lmscore],
)
def update_slider(ws, ls, lm, alm):
ws_slider = gr.Slider(
minimum=-10.0,
maximum=10.0,
value=LM_SCORE_DEFAULT if (lm is not None or alm) else 0,
step=0.1,
interactive=not ws,
label="Word Insertion Score",
)
ls_slider = gr.Slider(
minimum=-10.0,
maximum=10.0,
value=WORD_SCORE_DEFAULT_IF_NOLM
if (lm is None and not alm)
else WORD_SCORE_DEFAULT_IF_LM,
step=0.1,
interactive=not ls,
label="Language Model Score",
)
return ws_slider, ls_slider
with gr.Column():
text = gr.Textbox(label="Transcript",rtl=True,elem_classes="thaana")
reference = gr.Textbox(label="Reference Transcript", visible=False)
btn.click(
transcribe,
inputs=[
audio,
wscore,
lmscore,
wscore_usedefault,
lmscore_usedefault,
uselm,
reference,
],
outputs=[text],
)
# Examples
gr.Examples(
examples=[
[
"samples/audio1.mp3",
"އަޅުގަނޑުވެސް ދާކަށް ބޭނުމެއްނުވި"
],
[
"samples/audio2.wav",
"ރަނގަޅަށްވިއްޔާ އެވާނީ މުސްކުޅި ކުރެހުމަކަށް"
],
[
"samples/audio3.wav",
"އެއީ ޞަހްޔޫނީންގެ ޒަމާންވީ ރޭވުމެއްގެ ދަށުން މެދުނުކެނޑި ކުރިއަށްވާ ޕްރޮގްރާމެއް"
],
],
inputs=[audio, reference],
label="Dhivehi Audio Samples",
)
demo.launch(show_api=False)