mms-zeroshot / zeroshot.py
Vineel Pratap
fix ex
d15da79
raw
history blame
No virus
6.6 kB
import os
import tempfile
import re
import librosa
import torch
import json
import numpy as np
from transformers import Wav2Vec2ForCTC, AutoProcessor
from huggingface_hub import hf_hub_download
from torchaudio.models.decoder import ctc_decoder
uroman_dir = "uroman"
assert os.path.exists(uroman_dir)
UROMAN_PL = os.path.join(uroman_dir, "bin", "uroman.pl")
ASR_SAMPLING_RATE = 16_000
WORD_SCORE_DEFAULT_IF_LM = -0.18
WORD_SCORE_DEFAULT_IF_NOLM = -3.5
LM_SCORE_DEFAULT = 1.48
MODEL_ID = "upload/mms_zs"
processor = AutoProcessor.from_pretrained(MODEL_ID)
model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID)
token_file = "upload/mms_zs/tokens.txt"
class MY_LOG:
def __init__(self):
self.text = "[START]"
def add(self, new_log):
self.text = self.text + "\n" + new_log
self.text = self.text.strip()
return self.text
def error_check_file(filepath):
if not isinstance(filepath, str):
return "Expected file to be of type 'str'. Instead got {}".format(
type(filepath)
)
if not os.path.exists(filepath):
return "Input file '{}' doesn't exists".format(type(filepath))
def norm_uroman(text):
text = text.lower()
text = text.replace("’", "'")
text = re.sub("([^a-z' ])", " ", text)
text = re.sub(" +", " ", text)
return text.strip()
def uromanize(words):
iso = "xxx"
with tempfile.NamedTemporaryFile() as tf, tempfile.NamedTemporaryFile() as tf2:
with open(tf.name, "w") as f:
f.write("\n".join(words))
cmd = f"perl " + UROMAN_PL
cmd += f" -l {iso} "
cmd += f" < {tf.name} > {tf2.name}"
os.system(cmd)
lexicon = {}
with open(tf2.name) as f:
for idx, line in enumerate(f):
if not line.strip():
continue
line = re.sub(r"\s+", "", norm_uroman(line)).strip()
lexicon[words[idx]] = " ".join(line) + " |"
return lexicon
def filter_lexicon(lexicon, word_counts):
spelling_to_words = {}
for w, s in lexicon.items():
spelling_to_words.setdefault(s, [])
spelling_to_words[s].append(w)
lexicon = {}
for s, ws in spelling_to_words.items():
if len(ws) > 1:
# use the word which has higest counts, fewed additional characters
ws.sort(key=lambda w: (-word_counts[w], len(w)))
lexicon[ws[0]] = s
return lexicon
def load_words(filepath):
words = {}
with open(filepath) as f:
for line in f:
line = line.strip().lower()
# ignore invalid words.
for w in line.split():
words.setdefault(w, 0)
words[w] += 1
return words
def process(
audio_data,
words_file,
lm_path=None,
wscore=None,
lmscore=None,
wscore_usedefault=True,
lmscore_usedefault=True,
reference=None
):
transcription, logs = "", MY_LOG()
if not audio_data or not words_file:
yield "ERROR: Empty audio data or words file", logs.text
return
if isinstance(audio_data, tuple):
# microphone
sr, audio_samples = audio_data
audio_samples = (audio_samples / 32768.0).astype(float)
assert sr == ASR_SAMPLING_RATE, "Invalid sampling rate"
else:
# file upload
assert isinstance(audio_data, str)
audio_samples = librosa.load(audio_data, sr=ASR_SAMPLING_RATE, mono=True)[0]
yield transcription, logs.add(f"Number of audio samples: {len(audio_samples)}")
inputs = processor(
audio_samples, sampling_rate=ASR_SAMPLING_RATE, return_tensors="pt"
)
# set device
if torch.cuda.is_available():
device = torch.device("cuda")
elif (
hasattr(torch.backends, "mps")
and torch.backends.mps.is_available()
and torch.backends.mps.is_built()
):
device = torch.device("mps")
else:
device = torch.device("cpu")
device = torch.device("cpu")
model.to(device)
inputs = inputs.to(device)
yield transcription, logs.add(f"Using device: {device}")
with torch.no_grad():
outputs = model(**inputs).logits
# Setup lexicon and decoder
yield transcription, logs.add(f"Loading words....")
try:
word_counts = load_words(words_file)
except Exception as e:
yield f"ERROR: Loading words failed '{str(e)}'", logs.text
return
yield transcription, logs.add(
f"Loaded {len(word_counts)} words.\nPreparing lexicon...."
)
try:
lexicon = uromanize(list(word_counts.keys()))
except Exception as e:
yield f"ERROR: Creating lexicon failed '{str(e)}'", logs.text
return
yield transcription, logs.add(f"Leixcon size: {len(lexicon)}")
if lm_path is None:
yield transcription, logs.add(f"Filtering lexicon....")
lexicon = filter_lexicon(lexicon, word_counts)
yield transcription, logs.add(
f"Ok. Leixcon size after filtering: {len(lexicon)}"
)
# print(lexicon["the"], lexicon["\"(t)he"])
with tempfile.NamedTemporaryFile() as lexicon_file:
if lm_path is not None and not lm_path.strip():
lm_path = None
with open(lexicon_file.name, "w") as f:
idx = 10
for word, spelling in lexicon.items():
f.write(word + " " + spelling + "\n")
idx += 1
if wscore_usedefault:
wscore = (
WORD_SCORE_DEFAULT_IF_LM
if lm_path is not None
else WORD_SCORE_DEFAULT_IF_NOLM
)
if lmscore_usedefault:
lmscore = LM_SCORE_DEFAULT if lm_path is not None else 0
yield transcription, logs.add(
f"Using word score: {wscore}\nUsing lm score: {lmscore}"
)
beam_search_decoder = ctc_decoder(
lexicon=lexicon_file.name,
tokens=token_file,
lm=lm_path,
nbest=1,
beam_size=500,
beam_size_token=50,
lm_weight=lmscore,
word_score=wscore,
sil_score=0,
blank_token="<s>",
)
beam_search_result = beam_search_decoder(outputs.to("cpu"))
transcription = " ".join(beam_search_result[0][0].words).strip()
yield transcription, logs.add(f"[DONE]")
# for i in process("upload/english/english.mp3", "upload/english/c4_5k_sentences.txt"):
# print(i)
# for i in process("upload/ligurian/ligurian_1.mp3", "upload/ligurian/zenamt_5k_sentences.txt"):
# print(i)