PereLluis13 commited on
Commit
e02be2a
1 Parent(s): 33a7ca5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -13
app.py CHANGED
@@ -4,26 +4,24 @@ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
4
  from time import time
5
  import torch
6
 
7
- @st.cache_resource(
8
- allow_output_mutation=True,
9
- hash_funcs={
10
- AutoTokenizer: lambda x: None,
11
- AutoModelForSeq2SeqLM: lambda x: None,
12
- },
13
- suppress_st_warning=True
14
- )
15
- def load_models(lan):
16
  st_time = time()
17
  tokenizer = AutoTokenizer.from_pretrained("Babelscape/mrebel-large", src_lang=_Tokens[lan], tgt_lang="tp_XX")
 
 
 
 
 
 
 
18
  print("+++++ loading Model", time() - st_time)
19
  model = AutoModelForSeq2SeqLM.from_pretrained("Babelscape/mrebel-large")
20
  if torch.cuda.is_available():
21
  _ = model.to("cuda:0") # comment if no GPU available
22
  _ = model.eval()
23
  print("+++++ loaded model", time() - st_time)
24
- dataset = load_dataset('Babelscape/SREDFM', lan, split="validation", streaming=True)
25
- dataset = [example for example in dataset.take(1001)]
26
- return (tokenizer, model, dataset)
27
 
28
  def extract_triplets_typed(text):
29
  triplets = []
@@ -63,13 +61,15 @@ def extract_triplets_typed(text):
63
 
64
  st.markdown("""This is a demo for the Findings of EMNLP 2021 paper [REBEL: Relation Extraction By End-to-end Language generation](https://github.com/Babelscape/rebel/blob/main/docs/EMNLP_2021_REBEL__Camera_Ready_.pdf). The pre-trained model is able to extract triplets for up to 200 relation types from Wikidata or be used in downstream Relation Extraction task by fine-tuning. Find the model card [here](https://huggingface.co/Babelscape/rebel-large). Read more about it in the [paper](https://aclanthology.org/2021.findings-emnlp.204) and in the original [repository](https://github.com/Babelscape/rebel).""")
65
 
 
 
66
  lan = st.selectbox(
67
  'Select a Language',
68
  ('ar', 'ca', 'de', 'el', 'en', 'es', 'fr', 'hi', 'it', 'ja', 'ko', 'nl', 'pl', 'pt', 'ru', 'sv', 'vi', 'zh'), index=1)
69
 
70
  _Tokens = {'en': 'en_XX', 'de': 'de_DE', 'ca': 'ca_XX', 'ar': 'ar_AR', 'el': 'el_EL', 'it': 'it_IT', 'ja': 'ja_XX', 'ko': 'ko_KR', 'hi': 'hi_IN', 'pt': 'pt_XX', 'ru': 'ru_RU', 'pl': 'pl_PL', 'zh': 'zh_CN', 'fr': 'fr_XX', 'vi': 'vi_VN', 'sv':'sv_SE'}
71
 
72
- tokenizer, model, dataset = load_models(lan)
73
 
74
  agree = st.checkbox('Free input', False)
75
  if agree:
 
4
  from time import time
5
  import torch
6
 
7
+
8
+ def load_tok_and_data(lan):
 
 
 
 
 
 
 
9
  st_time = time()
10
  tokenizer = AutoTokenizer.from_pretrained("Babelscape/mrebel-large", src_lang=_Tokens[lan], tgt_lang="tp_XX")
11
+ dataset = load_dataset('Babelscape/SREDFM', lan, split="validation", streaming=True)
12
+ dataset = [example for example in dataset.take(1001)]
13
+ return (tokenizer, dataset)
14
+
15
+ @st.cache_resource
16
+ def load_model():
17
+ st_time = time()
18
  print("+++++ loading Model", time() - st_time)
19
  model = AutoModelForSeq2SeqLM.from_pretrained("Babelscape/mrebel-large")
20
  if torch.cuda.is_available():
21
  _ = model.to("cuda:0") # comment if no GPU available
22
  _ = model.eval()
23
  print("+++++ loaded model", time() - st_time)
24
+ return model
 
 
25
 
26
  def extract_triplets_typed(text):
27
  triplets = []
 
61
 
62
  st.markdown("""This is a demo for the Findings of EMNLP 2021 paper [REBEL: Relation Extraction By End-to-end Language generation](https://github.com/Babelscape/rebel/blob/main/docs/EMNLP_2021_REBEL__Camera_Ready_.pdf). The pre-trained model is able to extract triplets for up to 200 relation types from Wikidata or be used in downstream Relation Extraction task by fine-tuning. Find the model card [here](https://huggingface.co/Babelscape/rebel-large). Read more about it in the [paper](https://aclanthology.org/2021.findings-emnlp.204) and in the original [repository](https://github.com/Babelscape/rebel).""")
63
 
64
+ model = load_model()
65
+
66
  lan = st.selectbox(
67
  'Select a Language',
68
  ('ar', 'ca', 'de', 'el', 'en', 'es', 'fr', 'hi', 'it', 'ja', 'ko', 'nl', 'pl', 'pt', 'ru', 'sv', 'vi', 'zh'), index=1)
69
 
70
  _Tokens = {'en': 'en_XX', 'de': 'de_DE', 'ca': 'ca_XX', 'ar': 'ar_AR', 'el': 'el_EL', 'it': 'it_IT', 'ja': 'ja_XX', 'ko': 'ko_KR', 'hi': 'hi_IN', 'pt': 'pt_XX', 'ru': 'ru_RU', 'pl': 'pl_PL', 'zh': 'zh_CN', 'fr': 'fr_XX', 'vi': 'vi_VN', 'sv':'sv_SE'}
71
 
72
+ tokenizer, dataset = load_tok_and_data(lan)
73
 
74
  agree = st.checkbox('Free input', False)
75
  if agree: