Mikunono commited on
Commit
2096aa8
1 Parent(s): 928fb21

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -18
app.py CHANGED
@@ -47,26 +47,19 @@ import librosa
47
 
48
  ########################ASR model###############################
49
 
50
- from transformers import WhisperProcessor, WhisperForConditionalGeneration
51
 
52
- # load model and processor
53
- processor = WhisperProcessor.from_pretrained("openai/whisper-base")
54
- model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-base")
55
- model.config.forced_decoder_ids = None
56
 
57
- sample_rate = 16000
 
 
 
58
 
59
- def ASR_model(audio, sr=16000):
60
- DB_audio = audio
61
- input_features = processor(audio, sampling_rate=sr, return_tensors="pt").input_features
62
- # generate token ids
63
- predicted_ids = model.generate(input_features)
64
- # decode token ids to text
65
- transcription = processor.batch_decode(predicted_ids, skip_special_tokens=False)
66
-
67
- transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
68
-
69
- return transcription
70
 
71
 
72
 
@@ -82,7 +75,7 @@ def print_like_dislike(x: gr.LikeData):
82
  def upfile(files):
83
  x = librosa.load(files, sr=16000)
84
  print(x[0])
85
- text = ASR_model(x[0])
86
  return [text[0], text[0]]
87
 
88
  def transcribe(audio):
 
47
 
48
  ########################ASR model###############################
49
 
50
+ from transformers import Speech2TextForConditionalGeneration, Speech2TextProcessor
51
 
52
+ model = Speech2TextForConditionalGeneration.from_pretrained("facebook/s2t-small-librispeech-asr").to("cuda")
53
+ processor = Speech2TextProcessor.from_pretrained("facebook/s2t-small-librispeech-asr", do_upper_case=True)
 
 
54
 
55
+ def RallyListen(audio):
56
+ features = processor(audio, sampling_rate=16000, padding=True, return_tensors="pt")
57
+ input_features = features.input_features.to("cuda")
58
+ attention_mask = features.attention_mask.to("cuda")
59
 
60
+ gen_tokens = model.generate(input_features=input_features, attention_mask=attention_mask)
61
+ ret = processor.batch_decode(gen_tokens, skip_special_tokens=True)
62
+ return ret
 
 
 
 
 
 
 
 
63
 
64
 
65
 
 
75
  def upfile(files):
76
  x = librosa.load(files, sr=16000)
77
  print(x[0])
78
+ text = RallyListen(x[0])
79
  return [text[0], text[0]]
80
 
81
  def transcribe(audio):