Nuo Chen commited on
Commit
34dbc36
1 Parent(s): b672eb7

Update app.py

Browse files
Files changed (1) hide show
  1. gradio_samples/bertviz/app.py +24 -24
gradio_samples/bertviz/app.py CHANGED
@@ -15,7 +15,7 @@ from tqdm.notebook import tqdm
15
  from torch.utils.data import DataLoader
16
  from functools import partial
17
 
18
- from transformers import AutoTokenizer, AutoModel, AutoModelForSeq2SeqLM
19
 
20
  from bertviz import model_view, head_view
21
  from bertviz_gradio import head_view_mod
@@ -23,44 +23,44 @@ from bertviz_gradio import head_view_mod
23
 
24
 
25
  model_es = "Helsinki-NLP/opus-mt-en-es"
26
- # model_fr = "Helsinki-NLP/opus-mt-en-fr"
27
- # model_zh = "Helsinki-NLP/opus-mt-en-zh"
28
- # model_sw = "Helsinki-NLP/opus-mt-en-sw"
29
 
30
  tokenizer_es = AutoTokenizer.from_pretrained(model_es)
31
- # tokenizer_fr = AutoTokenizer.from_pretrained(model_fr)
32
- # tokenizer_zh = AutoTokenizer.from_pretrained(model_zh)
33
- # tokenizer_sw = AutoTokenizer.from_pretrained(model_sw)
34
 
35
- model_tr_es = AutoModel.from_pretrained(model_es)
36
- # model_tr_fr = MarianMTModel.from_pretrained(model_fr)
37
- # model_tr_zh = MarianMTModel.from_pretrained(model_zh)
38
- # model_tr_sw = MarianMTModel.from_pretrained(model_sw)
39
 
40
  model_es = inseq.load_model("Helsinki-NLP/opus-mt-en-es", "input_x_gradient")
41
- # model_fr = inseq.load_model("Helsinki-NLP/opus-mt-en-fr", "input_x_gradient")
42
- # model_zh = inseq.load_model("Helsinki-NLP/opus-mt-en-zh", "input_x_gradient")
43
- # model_sw = inseq.load_model("Helsinki-NLP/opus-mt-en-sw", "input_x_gradient")
44
 
45
  dict_models = {
46
  'en-es': model_es,
47
- # 'en-fr': model_fr,
48
- # 'en-zh': model_zh,
49
- # 'en-sw': model_sw,
50
  }
51
 
52
  dict_models_tr = {
53
  'en-es': model_tr_es,
54
- # 'en-fr': model_tr_fr,
55
- # 'en-zh': model_tr_zh,
56
- # 'en-sw': model_tr_sw,
57
  }
58
 
59
  dict_tokenizer_tr = {
60
  'en-es': tokenizer_es,
61
- # 'en-fr': tokenizer_fr,
62
- # 'en-zh': tokenizer_zh,
63
- # 'en-sw': tokenizer_sw,
64
  }
65
 
66
  saliency_examples = [
@@ -196,4 +196,4 @@ with gr.Blocks(js="plotsjs_bertviz.js") as demo:
196
  # demo.load(None,None,None,js="plotsjs.js")
197
 
198
  if __name__ == "__main__":
199
- demo.launch()
 
15
  from torch.utils.data import DataLoader
16
  from functools import partial
17
 
18
+ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM
19
 
20
  from bertviz import model_view, head_view
21
  from bertviz_gradio import head_view_mod
 
23
 
24
 
25
  model_es = "Helsinki-NLP/opus-mt-en-es"
26
+ model_fr = "Helsinki-NLP/opus-mt-en-fr"
27
+ model_zh = "Helsinki-NLP/opus-mt-en-zh"
28
+ model_sw = "Helsinki-NLP/opus-mt-en-sw"
29
 
30
  tokenizer_es = AutoTokenizer.from_pretrained(model_es)
31
+ tokenizer_fr = AutoTokenizer.from_pretrained(model_fr)
32
+ tokenizer_zh = AutoTokenizer.from_pretrained(model_zh)
33
+ tokenizer_sw = AutoTokenizer.from_pretrained(model_sw)
34
 
35
+ model_tr_es =AutoModelForSeq2SeqLM.from_pretrained(model_es)
36
+ model_tr_fr = AutoModelForSeq2SeqLM.from_pretrained(model_fr)
37
+ model_tr_zh =AutoModelForSeq2SeqLM.from_pretrained(model_zh)
38
+ model_tr_sw = AutoModelForSeq2SeqLM.from_pretrained(model_sw)
39
 
40
  model_es = inseq.load_model("Helsinki-NLP/opus-mt-en-es", "input_x_gradient")
41
+ model_fr = inseq.load_model("Helsinki-NLP/opus-mt-en-fr", "input_x_gradient")
42
+ model_zh = inseq.load_model("Helsinki-NLP/opus-mt-en-zh", "input_x_gradient")
43
+ model_sw = inseq.load_model("Helsinki-NLP/opus-mt-en-sw", "input_x_gradient")
44
 
45
  dict_models = {
46
  'en-es': model_es,
47
+ 'en-fr': model_fr,
48
+ 'en-zh': model_zh,
49
+ 'en-sw': model_sw,
50
  }
51
 
52
  dict_models_tr = {
53
  'en-es': model_tr_es,
54
+ 'en-fr': model_tr_fr,
55
+ 'en-zh': model_tr_zh,
56
+ 'en-sw': model_tr_sw,
57
  }
58
 
59
  dict_tokenizer_tr = {
60
  'en-es': tokenizer_es,
61
+ 'en-fr': tokenizer_fr,
62
+ 'en-zh': tokenizer_zh,
63
+ 'en-sw': tokenizer_sw,
64
  }
65
 
66
  saliency_examples = [
 
196
  # demo.load(None,None,None,js="plotsjs.js")
197
 
198
  if __name__ == "__main__":
199
+ demo.launch(share=True)