Marco-Cheung
commited on
Commit
•
60ec677
1
Parent(s):
dd98e2f
Update app.py
Browse files
app.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
import gradio as gr
|
2 |
import numpy as np
|
3 |
import torch
|
4 |
-
from transformers import AutoProcessor, pipeline, BarkModel
|
5 |
|
6 |
ASR_MODEL_NAME = "bofenghuang/whisper-large-v2-cv11-german"
|
7 |
TTS_MODEL_NAME = "suno/bark-small"
|
@@ -16,16 +16,14 @@ device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
|
16 |
# load speech translation checkpoint
|
17 |
asr_pipe = pipeline("automatic-speech-recognition", model=ASR_MODEL_NAME, chunk_length_s=10,device=device)
|
18 |
|
19 |
-
# update the generation config
|
20 |
-
MULTILINGUAL = True # set True for multilingual models, False for English-only
|
21 |
-
generation_config = GenerationConfig.from_pretrained("openai/whisper-large-v2")
|
22 |
-
|
23 |
-
|
24 |
# load text-to-speech checkpoint
|
25 |
processor = AutoProcessor.from_pretrained("suno/bark-small")
|
26 |
model = BarkModel.from_pretrained("suno/bark-small").to(device)
|
27 |
sampling_rate = model.generation_config.sample_rate
|
28 |
|
|
|
|
|
|
|
29 |
def translate(audio):
|
30 |
outputs = asr_pipe(audio, max_new_tokens=256, generate_kwargs={"task": "translate"})
|
31 |
return outputs["text"]
|
|
|
1 |
import gradio as gr
|
2 |
import numpy as np
|
3 |
import torch
|
4 |
+
from transformers import AutoProcessor, pipeline, BarkModel
|
5 |
|
6 |
ASR_MODEL_NAME = "bofenghuang/whisper-large-v2-cv11-german"
|
7 |
TTS_MODEL_NAME = "suno/bark-small"
|
|
|
16 |
# load speech translation checkpoint
|
17 |
asr_pipe = pipeline("automatic-speech-recognition", model=ASR_MODEL_NAME, chunk_length_s=10,device=device)
|
18 |
|
|
|
|
|
|
|
|
|
|
|
19 |
# load text-to-speech checkpoint
|
20 |
processor = AutoProcessor.from_pretrained("suno/bark-small")
|
21 |
model = BarkModel.from_pretrained("suno/bark-small").to(device)
|
22 |
sampling_rate = model.generation_config.sample_rate
|
23 |
|
24 |
+
# set the forced ids
|
25 |
+
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(task="translate")
|
26 |
+
|
27 |
def translate(audio):
|
28 |
outputs = asr_pipe(audio, max_new_tokens=256, generate_kwargs={"task": "translate"})
|
29 |
return outputs["text"]
|