File size: 5,623 Bytes
e6f2745
 
 
 
 
 
e02be2a
 
e6f2745
da80bd2
 
 
 
3e08cc6
 
e02be2a
 
 
 
 
 
e6f2745
 
 
 
 
 
e02be2a
e6f2745
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5fdb117
e6f2745
e02be2a
 
e6f2745
 
a419564
e6f2745
4e6bd61
e6f2745
e02be2a
e6f2745
 
 
a419564
e6f2745
 
 
a419564
5bffab4
e6f2745
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
da80bd2
 
 
 
 
e6f2745
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import streamlit as st
from datasets import load_dataset
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from time import time
import torch


def load_tok_and_data(lan):
    st_time = time()
    tokenizer = AutoTokenizer.from_pretrained("Babelscape/mrebel-large", tgt_lang="tp_XX")
    tokenizer._src_lang = _Tokens[lan]
    tokenizer.cur_lang_code_id = tokenizer.convert_tokens_to_ids(_Tokens[lan])
    tokenizer.set_src_lang_special_tokens(_Tokens[lan])
    dataset = load_dataset('Babelscape/SREDFM', lan, split="test", streaming=True, trust_remote_code=True)
    
    dataset = [example for example in dataset.take(1001)]
    return (tokenizer, dataset)
    
@st.cache_resource
def load_model():
    st_time = time()
    print("+++++ loading Model", time() - st_time)
    model = AutoModelForSeq2SeqLM.from_pretrained("Babelscape/mrebel-large")
    if torch.cuda.is_available():
        _ = model.to("cuda:0") # comment if no GPU available
    _ = model.eval()
    print("+++++ loaded model", time() - st_time)
    return model

def extract_triplets_typed(text):
    triplets = []
    relation = ''
    text = text.strip()
    current = 'x'
    subject, relation, object_, object_type, subject_type = '','','','',''

    for token in text.replace("<s>", "").replace("<pad>", "").replace("</s>", "").replace("tp_XX", "").replace("__en__", "").split():
        if token == "<triplet>" or token == "<relation>":
            current = 't'
            if relation != '':
                triplets.append({'head': subject.strip(), 'head_type': subject_type, 'type': relation.strip(),'tail': object_.strip(), 'tail_type': object_type})
                relation = ''
            subject = ''
        elif token.startswith("<") and token.endswith(">"):
            if current == 't' or current == 'o':
                current = 's'
                if relation != '':
                    triplets.append({'head': subject.strip(), 'head_type': subject_type, 'type': relation.strip(),'tail': object_.strip(), 'tail_type': object_type})
                object_ = ''
                subject_type = token[1:-1]
            else:
                current = 'o'
                object_type = token[1:-1]
                relation = ''
        else:
            if current == 't':
                subject += ' ' + token
            elif current == 's':
                object_ += ' ' + token
            elif current == 'o':
                relation += ' ' + token
    if subject != '' and relation != '' and object_ != '' and object_type != '' and subject_type != '':
        triplets.append({'head': subject.strip(), 'head_type': subject_type, 'type': relation.strip(),'tail': object_.strip(), 'tail_type': object_type})
    return triplets

st.markdown("""This is a demo for the ACL 2023 paper [RED$^{FM}$: a Filtered and Multilingual Relation Extraction Dataset](https://arxiv.org/abs/2306.09802). The pre-trained model is able to extract triplets for up to 400 relation types from Wikidata or be used in downstream Relation Extraction task by fine-tuning. Find the model card [here](https://huggingface.co/Babelscape/mrebel-large). Read more about it in the [paper](https://arxiv.org/abs/2306.09802) and in the original [repository](https://github.com/Babelscape/rebel#REDFM).""")

model = load_model()

lan = st.selectbox(
    'Select a Language',
    ('ar', 'ca', 'de', 'el', 'en', 'es', 'fr', 'hi', 'it', 'ja', 'ko', 'nl', 'pl', 'pt', 'ru', 'sv', 'vi', 'zh'), index=1)

_Tokens = {'en': 'en_XX', 'de': 'de_DE', 'ca': 'ca_XX', 'ar': 'ar_AR', 'el': 'el_EL', 'es': 'es_XX', 'it': 'it_IT', 'ja': 'ja_XX', 'ko': 'ko_KR', 'hi': 'hi_IN', 'pt': 'pt_XX', 'ru': 'ru_RU', 'pl': 'pl_PL', 'zh': 'zh_CN', 'fr': 'fr_XX', 'vi': 'vi_VN', 'sv':'sv_SE'}

tokenizer, dataset = load_tok_and_data(lan)

agree = st.checkbox('Free input', False)
if agree:
    text = st.text_input('Input text (current example in catalan)', 'Els Red Hot Chili Peppers es van formar a Los Angeles per Kiedis, Flea, el guitarrista Hillel Slovak i el bateria Jack Irons.')
    print(text)
else:
    dataset_example = st.slider('dataset id', 0, 1000, 0)
    text = dataset[dataset_example]['text']
length_penalty = st.slider('length_penalty', 0, 10, 1)
num_beams = st.slider('num_beams', 1, 20, 3)
num_return_sequences = st.slider('num_return_sequences', 1, num_beams, 2)

gen_kwargs = {
    "max_length": 256,
    "length_penalty": length_penalty,
    "num_beams": num_beams,
    "num_return_sequences": num_return_sequences,
    "forced_bos_token_id": None,
}

model_inputs = tokenizer(text, max_length=256, padding=True, truncation=True, return_tensors = 'pt')
generated_tokens = model.generate(
    model_inputs["input_ids"].to(model.device),
    attention_mask=model_inputs["attention_mask"].to(model.device),
    decoder_start_token_id = tokenizer.convert_tokens_to_ids("tp_XX"),
    **gen_kwargs,
)

decoded_preds = tokenizer.batch_decode(generated_tokens, skip_special_tokens=False)
st.title('Input text')

st.write(text)

if not agree:
    st.title('Silver output')
    entities = dataset[dataset_example]['entities']
    relations =[]
    for trip in dataset[dataset_example]['relations']:
      relations.append({'subject': entities[trip['subject']], 'predicate': trip['predicate'], 'object': entities[trip['object']]})
    st.write(relations)

st.title('Prediction text')
decoded_preds = [text.replace('<s>', '').replace('</s>', '').replace('<pad>', '') for text in decoded_preds]
st.write(decoded_preds)

for idx, sentence in enumerate(decoded_preds):
    st.title(f'Prediction triplets sentence {idx}')
    st.write(extract_triplets_typed(sentence))