ktangri commited on
Commit
eb6ba59
1 Parent(s): 1e629a8

Add punctuation correction

Browse files
Files changed (2) hide show
  1. app.py +4 -1
  2. requirements.txt +1 -0
app.py CHANGED
@@ -1,17 +1,20 @@
1
  import gradio as gr
2
  from transformers import pipeline, Wav2Vec2ProcessorWithLM
3
  from librosa import load, resample
4
-
5
 
6
  asr_model = 'patrickvonplaten/wav2vec2-base-100h-with-lm'
7
  processor = Wav2Vec2ProcessorWithLM.from_pretrained(asr_model)
8
  asr = pipeline('automatic-speech-recognition', model=asr_model, tokenizer=processor.tokenizer, feature_extractor=processor.feature_extractor, decoder=processor.decoder)
9
 
 
 
10
  def transcribe(filepath):
11
  speech, sampling_rate = load(filepath)
12
  if sampling_rate != 16000:
13
  speech = resample(speech, sampling_rate, 16000)
14
  text = asr(speech)['text']
 
15
  return text
16
 
17
  mic = gr.inputs.Audio(source='microphone', type='filepath', label='Speech input', optional=False)
 
1
  import gradio as gr
2
  from transformers import pipeline, Wav2Vec2ProcessorWithLM
3
  from librosa import load, resample
4
+ from rpunct import RestorePuncts
5
 
6
  asr_model = 'patrickvonplaten/wav2vec2-base-100h-with-lm'
7
  processor = Wav2Vec2ProcessorWithLM.from_pretrained(asr_model)
8
  asr = pipeline('automatic-speech-recognition', model=asr_model, tokenizer=processor.tokenizer, feature_extractor=processor.feature_extractor, decoder=processor.decoder)
9
 
10
+ rpunct = RestorePuncts()
11
+
12
  def transcribe(filepath):
13
  speech, sampling_rate = load(filepath)
14
  if sampling_rate != 16000:
15
  speech = resample(speech, sampling_rate, 16000)
16
  text = asr(speech)['text']
17
+ text = rpunct.punctuate(text.lower())
18
  return text
19
 
20
  mic = gr.inputs.Audio(source='microphone', type='filepath', label='Speech input', optional=False)
requirements.txt CHANGED
@@ -3,3 +3,4 @@ transformers
3
  librosa
4
  pyctcdecode
5
  pypi-kenlm
 
 
3
  librosa
4
  pyctcdecode
5
  pypi-kenlm
6
+ git+https://github.com/anuragshas/rpunct.git