File size: 5,394 Bytes
544843e
3ebab96
f1cf01a
72fab60
 
 
f1cf01a
544843e
0ad1db9
 
3ebab96
1879824
29efdf2
11b23f3
 
3ebab96
 
 
 
 
 
 
 
 
f1cf01a
3ebab96
 
f1cf01a
3ebab96
11b23f3
 
 
 
3ebab96
544843e
 
 
 
 
11b23f3
 
 
 
 
 
f1cf01a
11b23f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d11ace4
f1cf01a
11b23f3
3ebab96
 
f1cf01a
11b23f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aa69b98
11b23f3
aa69b98
 
11b23f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aa69b98
11b23f3
 
 
 
 
 
3ebab96
 
 
 
 
 
 
11b23f3
3329a43
dc433d9
 
f1cf01a
3ebab96
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import torch
import gradio as gr
import librosa 
import tempfile
from typing import Optional
from TTS.config import load_config
from transformers import AutoFeatureExtractor, AutoModelForSeq2SeqLM, AutoTokenizer, pipeline
from TTS.utils.manage import ModelManager
from TTS.utils.synthesizer import Synthesizer


first_generation = True
device = 'cuda' if torch.cuda.is_available() else 'cpu'


def load_and_fix_data(input_file, model_sampling_rate):
    speech, sample_rate = librosa.load(input_file)
    if len(speech.shape) > 1:
        speech = speech[:, 0] + speech[:, 1]
    if sample_rate != model_sampling_rate:
        speech = librosa.resample(speech, sample_rate, model_sampling_rate)
    return speech


feature_extractor = AutoFeatureExtractor.from_pretrained("jonatasgrosman/wav2vec2-large-xlsr-53-spanish")
sampling_rate = feature_extractor.sampling_rate

asr = pipeline("automatic-speech-recognition", model="jonatasgrosman/wav2vec2-large-xlsr-53-spanish")

prefix = ''
model_checkpoint = "hackathon-pln-es/es_text_neutralizer"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)


manager = ModelManager()
MODEL_NAMES = manager.list_tts_models()


def postproc(input_sentence, preds):
    try:
        preds = preds.replace('De el', 'Del').replace('de el', 'del').replace('  ', ' ')
        if preds[0].islower():
            preds = preds.capitalize()
        preds = preds.replace(' . ', '. ').replace(' , ', ', ')

        # Nombres en mayusculas
        prev_letter = ''
        for word in input_sentence.split(' '):
            if word:
                if word[0].isupper():
                    if word.lower() in preds and word != input_sentence.split(' ')[0]:
                        if prev_letter == '.':
                            preds = preds.replace('. ' + word.lower() + ' ', '. ' + word + ' ')
                        else:
                            if word[-1] == '.':
                                preds = preds.replace(word.lower(), word)
                            else:
                                preds = preds.replace(word.lower() + ' ', word + ' ')
                prev_letter = word[-1]
        preds = preds.strip()  # quitar ultimo espacio
    except:
        pass
    return preds
    
model_name = "es/mai/tacotron2-DDC"    
MAX_TXT_LEN = 100

def predict_and_ctc_lm_decode(input_file, speaker_idx: str=None):
    speech = load_and_fix_data(input_file, sampling_rate)
    transcribed_text = asr(speech, chunk_length_s=5, stride_length_s=1)
    transcribed_text = transcribed_text["text"]
    inputs = tokenizer([prefix + transcribed_text], return_tensors="pt", padding=True)
    with torch.no_grad():
        if first_generation:
            output_sequence = model.generate(
                input_ids=inputs["input_ids"].to(device),
                attention_mask=inputs["attention_mask"].to(device),
                do_sample=False,  # disable sampling to test if batching affects output
            )
        else:

            output_sequence = model.generate(
                input_ids=inputs["input_ids"].to(device),
                attention_mask=inputs["attention_mask"].to(device),
                do_sample=False,  
                num_beams=2,
                repetition_penalty=2.5, 
                # length_penalty=1.0, 
                early_stopping=True# disable sampling to test if batching affects output
            )
    text = postproc(transcribed_text,
                     preds=tokenizer.decode(output_sequence[0], skip_special_tokens=True, clean_up_tokenization_spaces=True))
    if len(text) > MAX_TXT_LEN:
        text = text[:MAX_TXT_LEN]
        print(f"Input text was cutoff since it went over the {MAX_TXT_LEN} character limit.")
    print(text, model_name)
    # download model
    model_path, config_path, model_item = manager.download_model(f"tts_models/{model_name}")
    vocoder_name: Optional[str] = model_item["default_vocoder"]
    # download vocoder
    vocoder_path = None
    vocoder_config_path = None
    if vocoder_name is not None:
        vocoder_path, vocoder_config_path, _ = manager.download_model(vocoder_name)
    # init synthesizer
    synthesizer = Synthesizer(
        model_path, config_path, None, None, vocoder_path, vocoder_config_path,
    )
    # synthesize
    if synthesizer is None:
        raise NameError("model not found")
    wavs = synthesizer.tts(text, speaker_idx)
    # return output
    with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp:
        synthesizer.save_wav(wavs, fp)
        return fp.name




gr.Interface(
    predict_and_ctc_lm_decode,
    inputs=[
        gr.inputs.Audio(source="microphone", type="filepath", label="Record your audio")
    ],
    outputs=gr.outputs.Audio(label="Output"),
    examples=[["Example3.wav"]],
    title="Generate-Gender-Neutralized-Audios",
    description = "This is a Gradio demo for generating gender neutralized audios. To use it, simply provide an audio input (via microphone or audio recording), which will then be transcribed and gender-neutralized using a pre-trained models. Finally, with the help of Coqui's TTS model, gender neutralised audio is generated.",
    #article="<p><center><img src='........e'></center></p>",
    layout="horizontal",
    theme="huggingface",
).launch(enable_queue=True, cache_examples=True)