File size: 3,664 Bytes
8dc4355
 
 
 
 
 
 
 
 
 
6546065
 
8dc4355
 
6546065
 
8dc4355
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6546065
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import os
from typing import List, Optional, Union
import gradio as gr
import spacy
from spacy.tokens import Doc, Span
from relik import Relik
from relik.inference.data.objects import TaskType, RelikOutput
from pyvis.network import Network

# RELIK Models Setup
def setup_relik_model(model_name: str, device: str):
    return Relik.from_pretrained(model_name, device=device)

relik_models = {
    "sapienzanlp/relik-entity-linking-large": setup_relik_model("sapienzanlp/relik-entity-linking-large", "cuda"),
    "relik-ie/relik-relation-extraction-small": setup_relik_model("relik-ie/relik-relation-extraction-small", "cuda")
}

def get_span_annotations(response, doc):
    spans = []
    for span in response.spans:
        spans.append(Span(doc, span.start, span.end, span.label))
    colors = {span.label_: '#ff5733' for span in spans}  # Simple fixed color for demonstration
    return spans, colors

def generate_graph(spans, response, colors):
    g = Network(width="720px", height="600px", directed=True)
    for ent in spans:
        g.add_node(ent.text, label=ent.text, color=colors[ent.label_], size=15)
    seen_rels = set()
    for rel in response.triplets:
        if (rel.subject.text, rel.object.text, rel.label) in seen_rels:
            continue
        g.add_edge(rel.subject.text, rel.object.text, label=rel.label)
        seen_rels.add((rel.subject.text, rel.object.text, rel.label))
    html = g.generate_html()
    return f"""<iframe style="width: 100%; height: 600px;margin:0 auto" srcdoc='{html.replace("'", '"')}'></iframe>"""

def text_analysis(Text, Model, Relation_Threshold, Window_Size, Window_Stride):
    if Model not in relik_models:
        raise ValueError(f"Model {Model} not found.")
    relik = relik_models[Model]
    nlp = spacy.blank("xx")
    annotated_text = relik(Text, annotation_type="word", relation_threshold=Relation_Threshold, window_size=Window_Size, window_stride=Window_Stride)
    doc = Doc(nlp.vocab, words=[token.text for token in annotated_text.tokens])
    spans, colors = get_span_annotations(annotated_text, doc)
    doc.spans["sc"] = spans
    display_el = spacy.displacy.render(doc, style="span", options={"colors": colors}).replace("\n", " ")
    display_el = display_el.replace("border-radius: 0.35em;", "border-radius: 0.35em; white-space: nowrap;").replace("span style", "span id='el' style")
    display_re = generate_graph(spans, annotated_text, colors) if annotated_text.triplets else ""
    return display_el, display_re

theme = gr.themes.Base(primary_hue="rose", secondary_hue="rose", text_size="lg")
css = """
h1 { text-align: center; display: block; }
mark { color: black; }
#el { white-space: nowrap; }
"""

with gr.Blocks(fill_height=True, css=css, theme=theme) as demo:
    gr.Markdown("# ReLiK with P-FAF Integration")
    gr.Interface(
        text_analysis,
        [
            gr.Textbox(label="Input Text", placeholder="Enter sentence here..."),
            gr.Dropdown(list(relik_models.keys()), value="sapienzanlp/relik-entity-linking-large", label="Relik Model"),
            gr.Slider(minimum=0, maximum=1, step=0.05, value=0.5, label="Relation Threshold"),
            gr.Slider(minimum=16, maximum=128, step=16, value=32, label="Window Size"),
            gr.Slider(minimum=8, maximum=64, step=8, value=16, label="Window Stride")
        ],
        [gr.HTML(label="Entities"), gr.HTML(label="Relations")],
        examples=[
            ["Michael Jordan was one of the best players in the NBA."],
            ["Noam Chomsky is a renowned linguist and cognitive scientist."]
        ],
        allow_flagging="never"
    )
    if __name__ == "__main__":
        demo.launch()