Vineel Pratap commited on
Commit
a4107b1
1 Parent(s): 9f2bd1d

update_model

Browse files
app.py CHANGED
@@ -1,24 +1,65 @@
1
  import gradio as gr
2
  from zeroshot import process, ZS_EXAMPLES
3
 
4
- with gr.Blocks() as demo:
5
- gr.Markdown("")
6
  gr.Markdown(
7
  "<p align='center' style='font-size: 20px;'>MMS Zero-shot ASR Demo. See our arXiV <a href='https://arxiv.org/'>paper</a> for model details.</p>"
8
  )
9
  gr.HTML(
10
- """<center>The demo works on input audio in any language, as long as you provide a list of words for that language and an optional n-gram language model (even a simple 1-gram model will work!) to help with accuracy.</center>"""
11
  )
12
  with gr.Row():
13
  with gr.Column():
14
  audio = gr.Audio(label="Audio Input\n(use microphone or upload a file)")
 
15
  with gr.Row():
16
- words_file = gr.File(label="Words File\n(one word per line)")
17
  lm_file = gr.File(label="Language Model\n(optional)")
18
- btn = gr.Button("Submit")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  with gr.Column():
20
  text = gr.Textbox(label="Transcript")
21
- btn.click(process, inputs=[audio, words_file, lm_file], outputs=text)
 
 
 
 
 
 
 
 
 
 
 
 
22
  examples = gr.Examples(examples=ZS_EXAMPLES, inputs=[audio, words_file])
23
 
24
- demo.launch(share=True)
 
1
  import gradio as gr
2
  from zeroshot import process, ZS_EXAMPLES
3
 
4
+ with gr.Blocks(css="style.css") as demo:
 
5
  gr.Markdown(
6
  "<p align='center' style='font-size: 20px;'>MMS Zero-shot ASR Demo. See our arXiV <a href='https://arxiv.org/'>paper</a> for model details.</p>"
7
  )
8
  gr.HTML(
9
+ """<center>The demo works on input audio in any language, as long as you provide a list of words or sentences for that language and an optional n-gram language model (even a simple 1-gram model will work!) to help with accuracy.<br>We recommend having a minimum of 5000 distinct words in the textfile to acheive a good performance.</center>"""
10
  )
11
  with gr.Row():
12
  with gr.Column():
13
  audio = gr.Audio(label="Audio Input\n(use microphone or upload a file)")
14
+
15
  with gr.Row():
16
+ words_file = gr.File(label="Text Data")
17
  lm_file = gr.File(label="Language Model\n(optional)")
18
+
19
+ with gr.Accordion("Advanced Settings", open=False):
20
+ gr.Markdown(
21
+ "The following parameters are used for beam-search decoding. Use the default values if you are not sure."
22
+ )
23
+ with gr.Row():
24
+ wscore = gr.Slider(
25
+ minimum=-10.0,
26
+ maximum=10.0,
27
+ value=0,
28
+ step=0.1,
29
+ interactive=True,
30
+ label="Word Insertion Score",
31
+ )
32
+ lmscore = gr.Slider(
33
+ minimum=-10.0,
34
+ maximum=10.0,
35
+ value=0,
36
+ step=0.1,
37
+ interactive=True,
38
+ label="Language Model Score",
39
+ )
40
+ with gr.Row():
41
+ wscore_usedefault = gr.Checkbox(
42
+ label="Use Default Word Insertion Score", value=True
43
+ )
44
+ lmscore_usedefault = gr.Checkbox(
45
+ label="Use Default Language Model Score", value=True
46
+ )
47
+ btn = gr.Button("Submit", elem_id="submit")
48
  with gr.Column():
49
  text = gr.Textbox(label="Transcript")
50
+ btn.click(
51
+ process,
52
+ inputs=[
53
+ audio,
54
+ words_file,
55
+ lm_file,
56
+ wscore,
57
+ lmscore,
58
+ wscore_usedefault,
59
+ lmscore_usedefault,
60
+ ],
61
+ outputs=text,
62
+ )
63
  examples = gr.Examples(examples=ZS_EXAMPLES, inputs=[audio, words_file])
64
 
65
+ demo.launch()
style.css ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #submit {
2
+ margin: auto;
3
+ color: #fff;
4
+ background: #1565c0;
5
+ border-radius: 100vh;
6
+ }
upload/mms_zs/config.json ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "activation_dropout": 0.0,
3
+ "adapter_attn_dim": null,
4
+ "adapter_kernel_size": 3,
5
+ "adapter_stride": 2,
6
+ "add_adapter": false,
7
+ "apply_spec_augment": true,
8
+ "architectures": [
9
+ "Wav2Vec2ForCTC"
10
+ ],
11
+ "attention_dropout": 0.1,
12
+ "bos_token_id": 1,
13
+ "classifier_proj_size": 256,
14
+ "codevector_dim": 768,
15
+ "contrastive_logits_temperature": 0.1,
16
+ "conv_bias": true,
17
+ "conv_dim": [
18
+ 512,
19
+ 512,
20
+ 512,
21
+ 512,
22
+ 512,
23
+ 512,
24
+ 512
25
+ ],
26
+ "conv_kernel": [
27
+ 10,
28
+ 3,
29
+ 3,
30
+ 3,
31
+ 3,
32
+ 2,
33
+ 2
34
+ ],
35
+ "conv_stride": [
36
+ 5,
37
+ 2,
38
+ 2,
39
+ 2,
40
+ 2,
41
+ 2,
42
+ 2
43
+ ],
44
+ "ctc_loss_reduction": "sum",
45
+ "ctc_zero_infinity": false,
46
+ "diversity_loss_weight": 0.1,
47
+ "do_stable_layer_norm": true,
48
+ "eos_token_id": 2,
49
+ "feat_extract_activation": "gelu",
50
+ "feat_extract_dropout": 0.0,
51
+ "feat_extract_norm": "layer",
52
+ "feat_proj_dropout": 0.1,
53
+ "feat_quantizer_dropout": 0.0,
54
+ "final_dropout": 0.0,
55
+ "gradient_checkpointing": false,
56
+ "hidden_act": "gelu",
57
+ "hidden_dropout": 0.1,
58
+ "hidden_size": 1024,
59
+ "initializer_range": 0.02,
60
+ "intermediate_size": 4096,
61
+ "layer_norm_eps": 1e-05,
62
+ "layerdrop": 0.1,
63
+ "mask_feature_length": 10,
64
+ "mask_feature_min_masks": 0,
65
+ "mask_feature_prob": 0.0,
66
+ "mask_time_length": 10,
67
+ "mask_time_min_masks": 2,
68
+ "mask_time_prob": 0.075,
69
+ "model_type": "wav2vec2",
70
+ "num_adapter_layers": 3,
71
+ "num_attention_heads": 16,
72
+ "num_codevector_groups": 2,
73
+ "num_codevectors_per_group": 320,
74
+ "num_conv_pos_embedding_groups": 16,
75
+ "num_conv_pos_embeddings": 128,
76
+ "num_feat_extract_layers": 7,
77
+ "num_hidden_layers": 24,
78
+ "num_negatives": 100,
79
+ "output_hidden_size": 1024,
80
+ "pad_token_id": 0,
81
+ "proj_codevector_dim": 768,
82
+ "tdnn_dilation": [
83
+ 1,
84
+ 2,
85
+ 3,
86
+ 1,
87
+ 1
88
+ ],
89
+ "tdnn_dim": [
90
+ 512,
91
+ 512,
92
+ 512,
93
+ 512,
94
+ 1500
95
+ ],
96
+ "tdnn_kernel": [
97
+ 5,
98
+ 3,
99
+ 3,
100
+ 1,
101
+ 1
102
+ ],
103
+ "torch_dtype": "float32",
104
+ "transformers_version": "4.42.1",
105
+ "use_weighted_layer_sum": false,
106
+ "vocab_size": 32,
107
+ "xvector_output_dim": 512
108
+ }
upload/mms_zs/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:39baa2c87b9abd9910c1982bf82aabda3dbe3ba615e20d5ee0be1026975dcb8c
3
+ size 1261938632
upload/mms_zs/preprocessor_config.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "do_normalize": true,
3
+ "feature_extractor_type": "Wav2Vec2FeatureExtractor",
4
+ "feature_size": 1,
5
+ "padding_side": "right",
6
+ "padding_value": 0,
7
+ "processor_class": "Wav2Vec2Processor",
8
+ "return_attention_mask": true,
9
+ "sampling_rate": 16000
10
+ }
upload/mms_zs/special_tokens_map.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": "<s>",
3
+ "eos_token": "</s>",
4
+ "pad_token": "<pad>",
5
+ "unk_token": "<unk>"
6
+ }
upload/mms_zs/tokenizer_config.json ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "<pad>",
5
+ "lstrip": true,
6
+ "normalized": false,
7
+ "rstrip": true,
8
+ "single_word": false,
9
+ "special": false
10
+ },
11
+ "1": {
12
+ "content": "<s>",
13
+ "lstrip": true,
14
+ "normalized": false,
15
+ "rstrip": true,
16
+ "single_word": false,
17
+ "special": false
18
+ },
19
+ "2": {
20
+ "content": "</s>",
21
+ "lstrip": true,
22
+ "normalized": false,
23
+ "rstrip": true,
24
+ "single_word": false,
25
+ "special": false
26
+ },
27
+ "3": {
28
+ "content": "<unk>",
29
+ "lstrip": true,
30
+ "normalized": false,
31
+ "rstrip": true,
32
+ "single_word": false,
33
+ "special": false
34
+ }
35
+ },
36
+ "bos_token": "<s>",
37
+ "clean_up_tokenization_spaces": true,
38
+ "do_lower_case": false,
39
+ "eos_token": "</s>",
40
+ "model_max_length": 1000000000000000019884624838656,
41
+ "pad_token": "<pad>",
42
+ "processor_class": "Wav2Vec2Processor",
43
+ "replace_word_delimiter_char": " ",
44
+ "target_lang": null,
45
+ "tokenizer_class": "Wav2Vec2CTCTokenizer",
46
+ "unk_token": "<unk>",
47
+ "word_delimiter_token": "|"
48
+ }
upload/mms_zs/tokens.txt ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <s>
2
+ <pad>
3
+ </s>
4
+ <unk>
5
+ |
6
+ a
7
+ i
8
+ e
9
+ n
10
+ o
11
+ u
12
+ t
13
+ k
14
+ m
15
+ s
16
+ r
17
+ l
18
+ h
19
+ g
20
+ d
21
+ y
22
+ b
23
+ p
24
+ c
25
+ w
26
+ j
27
+ '
28
+ v
29
+ z
30
+ f
31
+ q
32
+ x
upload/mms_zs/vocab.json ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "'": 26,
3
+ "</s>": 2,
4
+ "<pad>": 0,
5
+ "<s>": 1,
6
+ "<unk>": 3,
7
+ "a": 5,
8
+ "b": 21,
9
+ "c": 23,
10
+ "d": 19,
11
+ "e": 7,
12
+ "f": 29,
13
+ "g": 18,
14
+ "h": 17,
15
+ "i": 6,
16
+ "j": 25,
17
+ "k": 12,
18
+ "l": 16,
19
+ "m": 13,
20
+ "n": 8,
21
+ "o": 9,
22
+ "p": 22,
23
+ "q": 30,
24
+ "r": 15,
25
+ "s": 14,
26
+ "t": 11,
27
+ "u": 10,
28
+ "v": 27,
29
+ "w": 24,
30
+ "x": 31,
31
+ "y": 20,
32
+ "z": 28,
33
+ "|": 4
34
+ }
zeroshot.py CHANGED
@@ -16,34 +16,17 @@ UROMAN_PL = os.path.join(uroman_dir, "bin", "uroman.pl")
16
 
17
  ASR_SAMPLING_RATE = 16_000
18
 
19
- MODEL_ID = "facebook/mms-1b-all"
 
 
 
 
20
 
21
  processor = AutoProcessor.from_pretrained(MODEL_ID)
22
  model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID)
23
 
24
- lm_decoding_config = {}
25
- lm_decoding_configfile = hf_hub_download(
26
- repo_id="facebook/mms-cclms",
27
- filename="decoding_config.json",
28
- subfolder="mms-1b-all",
29
- )
30
-
31
- with open(lm_decoding_configfile) as f:
32
- lm_decoding_config = json.loads(f.read())
33
-
34
- decoding_config = lm_decoding_config["eng"]
35
-
36
- lm_file = hf_hub_download(
37
- repo_id="facebook/mms-cclms",
38
- filename=decoding_config["lmfile"].rsplit("/", 1)[1],
39
- subfolder=decoding_config["lmfile"].rsplit("/", 1)[0],
40
- )
41
 
42
- token_file = hf_hub_download(
43
- repo_id="facebook/mms-cclms",
44
- filename=decoding_config["tokensfile"].rsplit("/", 1)[1],
45
- subfolder=decoding_config["tokensfile"].rsplit("/", 1)[0],
46
- )
47
 
48
  def error_check_file(filepath):
49
  if not isinstance(filepath, str):
@@ -53,13 +36,15 @@ def error_check_file(filepath):
53
  if not os.path.exists(filepath):
54
  return "Input file '{}' doesn't exists".format(type(filepath))
55
 
 
56
  def norm_uroman(text):
57
  text = text.lower()
58
  text = text.replace("’", "'")
59
  text = re.sub("([^a-z' ])", " ", text)
60
- text = re.sub(' +', ' ', text)
61
  return text.strip()
62
 
 
63
  def uromanize(words):
64
  iso = "xxx"
65
  with tempfile.NamedTemporaryFile() as tf, tempfile.NamedTemporaryFile() as tf2:
@@ -72,24 +57,35 @@ def uromanize(words):
72
  lexicon = {}
73
  with open(tf2.name) as f:
74
  for idx, line in enumerate(f):
 
 
75
  line = re.sub(r"\s+", " ", norm_uroman(line)).strip()
76
  lexicon[words[idx]] = " ".join(line) + " |"
77
  return lexicon
78
 
79
 
80
  def load_lexicon(filepath):
81
- words = []
82
  with open(filepath) as f:
83
  for line in f:
84
  line = line.strip()
85
  # ignore invalid words.
86
  if not line or " " in line or len(line) > 50:
87
  continue
88
- words.append(line)
89
- return uromanize(words)
90
-
91
-
92
- def process(audio_data, words_file, lm_path=None):
 
 
 
 
 
 
 
 
 
93
  if isinstance(audio_data, tuple):
94
  # microphone
95
  sr, audio_samples = audio_data
@@ -101,17 +97,18 @@ def process(audio_data, words_file, lm_path=None):
101
  audio_samples = librosa.load(audio_data, sr=ASR_SAMPLING_RATE, mono=True)[0]
102
  # print(audio_samples[:10])
103
  # print("I'm here 102")
104
- # print("len audio_samples", len(audio_samples))
105
  lang_code = "eng"
106
- processor.tokenizer.set_target_lang(lang_code)
107
  # print("I'm here 107")
108
- model.load_adapter(lang_code)
109
  # print("I'm here 109")
110
  inputs = processor(
111
  audio_samples, sampling_rate=ASR_SAMPLING_RATE, return_tensors="pt"
112
  )
113
  # print("I'm here 106")
114
-
 
115
  # set device
116
  if torch.cuda.is_available():
117
  device = torch.device("cuda")
@@ -123,27 +120,37 @@ def process(audio_data, words_file, lm_path=None):
123
  device = torch.device("mps")
124
  else:
125
  device = torch.device("cpu")
126
-
127
  model.to(device)
128
  inputs = inputs.to(device)
129
  # print("I'm here 122")
130
  with torch.no_grad():
131
  outputs = model(**inputs).logits
132
 
133
- # Setup lexicon and decoder
134
  # print("before uroman")
135
  lexicon = load_lexicon(words_file)
136
  # print("after uroman")
137
  # print("len lexicon", len(lexicon))
138
  with tempfile.NamedTemporaryFile() as lexicon_file:
139
-
140
  with open(lexicon_file.name, "w") as f:
141
  idx = 10
142
  for word, spelling in lexicon.items():
143
  f.write(word + " " + spelling + "\n")
144
- if idx%100 == 0:
145
  print(word, spelling, flush=True)
146
- idx+=1
 
 
 
 
 
 
 
 
 
 
147
  beam_search_decoder = ctc_decoder(
148
  lexicon=lexicon_file.name,
149
  tokens=token_file,
@@ -151,9 +158,9 @@ def process(audio_data, words_file, lm_path=None):
151
  nbest=1,
152
  beam_size=500,
153
  beam_size_token=50,
154
- lm_weight=float(decoding_config["lmweight"]),
155
- word_score=float(decoding_config["wordscore"]),
156
- sil_score=float(decoding_config["silweight"]),
157
  blank_token="<s>",
158
  )
159
 
@@ -163,8 +170,6 @@ def process(audio_data, words_file, lm_path=None):
163
  return transcription
164
 
165
 
166
- ZS_EXAMPLES = [
167
- ["upload/english.mp3", "upload/words_top10k.txt"]
168
- ]
169
 
170
- # print(process("upload/english.mp3", "upload/words_top10k.txt"))
 
16
 
17
  ASR_SAMPLING_RATE = 16_000
18
 
19
+ WORD_SCORE_DEAULT_IF_LM = -0.18
20
+ WORD_SCORE_DEAULT_IF_NOLM = -3.5
21
+ LM_SCORE_DEAULT = 1.48
22
+
23
+ MODEL_ID = "upload/mms_zs"
24
 
25
  processor = AutoProcessor.from_pretrained(MODEL_ID)
26
  model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID)
27
 
28
+ token_file = "upload/mms_zs/tokens.txt"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
 
 
 
 
 
30
 
31
  def error_check_file(filepath):
32
  if not isinstance(filepath, str):
 
36
  if not os.path.exists(filepath):
37
  return "Input file '{}' doesn't exists".format(type(filepath))
38
 
39
+
40
  def norm_uroman(text):
41
  text = text.lower()
42
  text = text.replace("’", "'")
43
  text = re.sub("([^a-z' ])", " ", text)
44
+ text = re.sub(" +", " ", text)
45
  return text.strip()
46
 
47
+
48
  def uromanize(words):
49
  iso = "xxx"
50
  with tempfile.NamedTemporaryFile() as tf, tempfile.NamedTemporaryFile() as tf2:
 
57
  lexicon = {}
58
  with open(tf2.name) as f:
59
  for idx, line in enumerate(f):
60
+ if not line.strip():
61
+ continue
62
  line = re.sub(r"\s+", " ", norm_uroman(line)).strip()
63
  lexicon[words[idx]] = " ".join(line) + " |"
64
  return lexicon
65
 
66
 
67
  def load_lexicon(filepath):
68
+ words = {}
69
  with open(filepath) as f:
70
  for line in f:
71
  line = line.strip()
72
  # ignore invalid words.
73
  if not line or " " in line or len(line) > 50:
74
  continue
75
+ for w in line.split():
76
+ words[w.lower()] = True
77
+ return uromanize(list(words.keys()))
78
+
79
+
80
+ def process(
81
+ audio_data,
82
+ words_file,
83
+ lm_path=None,
84
+ wscore=None,
85
+ lmscore=None,
86
+ wscore_usedefault=True,
87
+ lmscore_usedefault=True,
88
+ ):
89
  if isinstance(audio_data, tuple):
90
  # microphone
91
  sr, audio_samples = audio_data
 
97
  audio_samples = librosa.load(audio_data, sr=ASR_SAMPLING_RATE, mono=True)[0]
98
  # print(audio_samples[:10])
99
  # print("I'm here 102")
100
+ print("len audio_samples", len(audio_samples))
101
  lang_code = "eng"
102
+ # processor.tokenizer.set_target_lang(lang_code)
103
  # print("I'm here 107")
104
+ # model.load_adapter(lang_code)
105
  # print("I'm here 109")
106
  inputs = processor(
107
  audio_samples, sampling_rate=ASR_SAMPLING_RATE, return_tensors="pt"
108
  )
109
  # print("I'm here 106")
110
+ print("inputs type", type(inputs))
111
+ # print("inputs size", inputs.size)
112
  # set device
113
  if torch.cuda.is_available():
114
  device = torch.device("cuda")
 
120
  device = torch.device("mps")
121
  else:
122
  device = torch.device("cpu")
123
+ device = torch.device("cpu")
124
  model.to(device)
125
  inputs = inputs.to(device)
126
  # print("I'm here 122")
127
  with torch.no_grad():
128
  outputs = model(**inputs).logits
129
 
130
+ # Setup lexicon and decoder
131
  # print("before uroman")
132
  lexicon = load_lexicon(words_file)
133
  # print("after uroman")
134
  # print("len lexicon", len(lexicon))
135
  with tempfile.NamedTemporaryFile() as lexicon_file:
136
+
137
  with open(lexicon_file.name, "w") as f:
138
  idx = 10
139
  for word, spelling in lexicon.items():
140
  f.write(word + " " + spelling + "\n")
141
+ if idx % 100 == 0:
142
  print(word, spelling, flush=True)
143
+ idx += 1
144
+
145
+ if wscore_usedefault:
146
+ wscore = (
147
+ WORD_SCORE_DEAULT_IF_LM
148
+ if lm_path is not None
149
+ else WORD_SCORE_DEAULT_IF_NOLM
150
+ )
151
+ if lmscore_usedefault:
152
+ lmscore = LM_SCORE_DEAULT if lm_path is not None else 0
153
+
154
  beam_search_decoder = ctc_decoder(
155
  lexicon=lexicon_file.name,
156
  tokens=token_file,
 
158
  nbest=1,
159
  beam_size=500,
160
  beam_size_token=50,
161
+ lm_weight=lmscore,
162
+ word_score=wscore,
163
+ sil_score=0,
164
  blank_token="<s>",
165
  )
166
 
 
170
  return transcription
171
 
172
 
173
+ ZS_EXAMPLES = [["upload/english.mp3", "upload/words_top10k.txt"]]
 
 
174
 
175
+ print(process("upload/english.mp3", "upload/words_top10k.txt"))