DrishtiSharma's picture
Update app.py
0ad1db9
raw
history blame
5.37 kB
import torch
import gradio as gr
import librosa
from transformers import AutoFeatureExtractor, AutoModelForSeq2SeqLM, AutoTokenizer, pipeline
from TTS.utils.manage import ModelManager
from TTS.utils.synthesizer import Synthesizer
first_generation = True
device = 'cuda' if torch.cuda.is_available() else 'cpu'
def load_and_fix_data(input_file, model_sampling_rate):
speech, sample_rate = librosa.load(input_file)
if len(speech.shape) > 1:
speech = speech[:, 0] + speech[:, 1]
if sample_rate != model_sampling_rate:
speech = librosa.resample(speech, sample_rate, model_sampling_rate)
return speech
feature_extractor = AutoFeatureExtractor.from_pretrained("jonatasgrosman/wav2vec2-large-xlsr-53-spanish")
sampling_rate = feature_extractor.sampling_rate
asr = pipeline("automatic-speech-recognition", model="jonatasgrosman/wav2vec2-large-xlsr-53-spanish")
prefix = ''
model_checkpoint = "hackathon-pln-es/es_text_neutralizer"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)
manager = ModelManager()
MODEL_NAMES = manager.list_tts_models()
def postproc(input_sentence, preds):
try:
preds = preds.replace('De el', 'Del').replace('de el', 'del').replace(' ', ' ')
if preds[0].islower():
preds = preds.capitalize()
preds = preds.replace(' . ', '. ').replace(' , ', ', ')
# Nombres en mayusculas
prev_letter = ''
for word in input_sentence.split(' '):
if word:
if word[0].isupper():
if word.lower() in preds and word != input_sentence.split(' ')[0]:
if prev_letter == '.':
preds = preds.replace('. ' + word.lower() + ' ', '. ' + word + ' ')
else:
if word[-1] == '.':
preds = preds.replace(word.lower(), word)
else:
preds = preds.replace(word.lower() + ' ', word + ' ')
prev_letter = word[-1]
preds = preds.strip() # quitar ultimo espacio
except:
pass
return preds
model_name = "es/mai/tacotron2-DDC"
MAX_TXT_LEN = 100
def predict_and_ctc_lm_decode(input_file, speaker_idx: str=None):
speech = load_and_fix_data(input_file, sampling_rate)
transcribed_text = asr(speech, chunk_length_s=5, stride_length_s=1)
transcribed_text = transcribed_text["text"]
inputs = tokenizer([prefix + transcribed_text], return_tensors="pt", padding=True)
with torch.no_grad():
if first_generation:
output_sequence = model.generate(
input_ids=inputs["input_ids"].to(device),
attention_mask=inputs["attention_mask"].to(device),
do_sample=False, # disable sampling to test if batching affects output
)
else:
output_sequence = model.generate(
input_ids=inputs["input_ids"].to(device),
attention_mask=inputs["attention_mask"].to(device),
do_sample=False,
num_beams=2,
repetition_penalty=2.5,
# length_penalty=1.0,
early_stopping=True# disable sampling to test if batching affects output
)
preds = postproc(transcribed_text,
preds=tokenizer.decode(output_sequence[0], skip_special_tokens=True, clean_up_tokenization_spaces=True))
if len(preds) > MAX_TXT_LEN:
text = preds[:MAX_TXT_LEN]
print(f"Input text was cutoff since it went over the {MAX_TXT_LEN} character limit.")
print(text, model_name)
# download model
model_path, config_path, model_item = manager.download_model(f"tts_models/{model_name}")
vocoder_name: Optional[str] = model_item["default_vocoder"]
# download vocoder
vocoder_path = None
vocoder_config_path = None
if vocoder_name is not None:
vocoder_path, vocoder_config_path, _ = manager.download_model(vocoder_name)
# init synthesizer
synthesizer = Synthesizer(
model_path, config_path, None, None, vocoder_path, vocoder_config_path,
)
# synthesize
if synthesizer is None:
raise NameError("model not found")
wavs = synthesizer.tts(preds, speaker_idx)
# return output
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp:
synthesizer.save_wav(wavs, fp)
return fp.name
gr.Interface(
predict_and_ctc_lm_decode,
inputs=[
gr.inputs.Audio(source="microphone", type="filepath", label="Record your audio")
],
outputs=gr.outputs.Audio(label="Output"),
examples=[["audio1.wav"], ["travel.wav"], ["example2.wav"], ["example3.wav"]],
title="Generate-Gender-Neutralized-Audios",
description = "This is a Gradio demo for generating gender neutralized audios. To use it, simply provide an audio input (via microphone or audio recording), which will then be transcribed and gender-neutralized using a pre-trained models. Finally, with the help of Coqui's TTS model, gender neutralised audio is generated.",
#article="<p><center><img src='........e'></center></p>",
layout="horizontal",
theme="huggingface",
).launch(enable_queue=True, cache_examples=True)