Marco-Cheung commited on
Commit
60ec677
1 Parent(s): dd98e2f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -6
app.py CHANGED
@@ -1,7 +1,7 @@
1
  import gradio as gr
2
  import numpy as np
3
  import torch
4
- from transformers import AutoProcessor, pipeline, BarkModel, GenerationConfig
5
 
6
  ASR_MODEL_NAME = "bofenghuang/whisper-large-v2-cv11-german"
7
  TTS_MODEL_NAME = "suno/bark-small"
@@ -16,16 +16,14 @@ device = "cuda:0" if torch.cuda.is_available() else "cpu"
16
  # load speech translation checkpoint
17
  asr_pipe = pipeline("automatic-speech-recognition", model=ASR_MODEL_NAME, chunk_length_s=10,device=device)
18
 
19
- # update the generation config
20
- MULTILINGUAL = True # set True for multilingual models, False for English-only
21
- generation_config = GenerationConfig.from_pretrained("openai/whisper-large-v2")
22
-
23
-
24
  # load text-to-speech checkpoint
25
  processor = AutoProcessor.from_pretrained("suno/bark-small")
26
  model = BarkModel.from_pretrained("suno/bark-small").to(device)
27
  sampling_rate = model.generation_config.sample_rate
28
 
 
 
 
29
  def translate(audio):
30
  outputs = asr_pipe(audio, max_new_tokens=256, generate_kwargs={"task": "translate"})
31
  return outputs["text"]
 
1
  import gradio as gr
2
  import numpy as np
3
  import torch
4
+ from transformers import AutoProcessor, pipeline, BarkModel
5
 
6
  ASR_MODEL_NAME = "bofenghuang/whisper-large-v2-cv11-german"
7
  TTS_MODEL_NAME = "suno/bark-small"
 
16
  # load speech translation checkpoint
17
  asr_pipe = pipeline("automatic-speech-recognition", model=ASR_MODEL_NAME, chunk_length_s=10,device=device)
18
 
 
 
 
 
 
19
  # load text-to-speech checkpoint
20
  processor = AutoProcessor.from_pretrained("suno/bark-small")
21
  model = BarkModel.from_pretrained("suno/bark-small").to(device)
22
  sampling_rate = model.generation_config.sample_rate
23
 
24
+ # set the forced ids
25
+ model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(task="translate")
26
+
27
  def translate(audio):
28
  outputs = asr_pipe(audio, max_new_tokens=256, generate_kwargs={"task": "translate"})
29
  return outputs["text"]