P-FAF-RELIK / app.py
TuringsSolutions's picture
Create app.py
8dc4355 verified
raw
history blame
No virus
4.39 kB
import os
from typing import List, Optional, Union
import gradio as gr
import spacy
from spacy.tokens import Doc, Span
from relik import Relik
from relik.inference.data.objects import TaskType, RelikOutput
from relik.retriever.pytorch_modules import GoldenRetriever
from relik.retriever.indexers.inmemory import InMemoryDocumentIndex
from pyvis.network import Network
# RELIK Models Setup
wikipedia_retriever = GoldenRetriever("relik-ie/encoder-e5-base-v2-wikipedia", device="cuda")
wikipedia_index = InMemoryDocumentIndex.from_pretrained("relik-ie/encoder-e5-base-v2-wikipedia-index", index_precision="bf16", device="cuda")
wikidata_retriever = GoldenRetriever("relik-ie/encoder-e5-small-v2-wikipedia-relations", device="cuda")
wikidata_index = InMemoryDocumentIndex.from_pretrained("relik-ie/encoder-e5-small-v2-wikipedia-relations-index", index_precision="bf16", device="cuda")
relik_models = {
"sapienzanlp/relik-entity-linking-large": Relik.from_pretrained(
"sapienzanlp/relik-entity-linking-large", device="cuda", index=wikipedia_index, retriever=wikipedia_retriever,
reader_kwargs={"dataset_kwargs": {"use_nme": True}}
),
"relik-ie/relik-relation-extraction-small": Relik.from_pretrained(
"relik-ie/relik-relation-extraction-small", index=wikidata_index, device="cuda", retriever=wikidata_retriever
)
}
def get_span_annotations(response, doc):
spans = []
for span in response.spans:
spans.append(Span(doc, span.start, span.end, span.label))
colors = {span.label_: '#ff5733' for span in spans} # Simple fixed color for demonstration
return spans, colors
def generate_graph(spans, response, colors):
g = Network(width="720px", height="600px", directed=True)
for ent in spans:
g.add_node(ent.text, label=ent.text, color=colors[ent.label_], size=15)
seen_rels = set()
for rel in response.triplets:
if (rel.subject.text, rel.object.text, rel.label) in seen_rels:
continue
g.add_edge(rel.subject.text, rel.object.text, label=rel.label)
seen_rels.add((rel.subject.text, rel.object.text, rel.label))
html = g.generate_html()
return f"""<iframe style="width: 100%; height: 600px;margin:0 auto" srcdoc='{html.replace("'", '"')}'></iframe>"""
def text_analysis(Text, Model, Relation_Threshold, Window_Size, Window_Stride):
if Model not in relik_models:
raise ValueError(f"Model {Model} not found.")
relik = relik_models[Model]
nlp = spacy.blank("xx")
annotated_text = relik(Text, annotation_type="word", relation_threshold=Relation_Threshold, window_size=Window_Size, window_stride=Window_Stride)
doc = Doc(nlp.vocab, words=[token.text for token in annotated_text.tokens])
spans, colors = get_span_annotations(annotated_text, doc)
doc.spans["sc"] = spans
display_el = spacy.displacy.render(doc, style="span", options={"colors": colors}).replace("\n", " ")
display_el = display_el.replace("border-radius: 0.35em;", "border-radius: 0.35em; white-space: nowrap;").replace("span style", "span id='el' style")
display_re = generate_graph(spans, annotated_text, colors) if annotated_text.triplets else ""
return display_el, display_re
theme = gr.themes.Base(primary_hue="rose", secondary_hue="rose", text_size="lg")
css = """
h1 { text-align: center; display: block; }
mark { color: black; }
#el { white-space: nowrap; }
"""
with gr.Blocks(fill_height=True, css=css, theme=theme) as demo:
gr.Markdown("# ReLiK with P-FAF Integration")
gr.Interface(
text_analysis,
[
gr.Textbox(label="Input Text", placeholder="Enter sentence here..."),
gr.Dropdown(list(relik_models.keys()), value="sapienzanlp/relik-entity-linking-large", label="Relik Model"),
gr.Slider(minimum=0, maximum=1, step=0.05, value=0.5, label="Relation Threshold"),
gr.Slider(minimum=16, maximum=128, step=16, value=32, label="Window Size"),
gr.Slider(minimum=8, maximum=64, step=8, value=16, label="Window Stride")
],
[gr.HTML(label="Entities"), gr.HTML(label="Relations")],
examples=[
["Michael Jordan was one of the best players in the NBA."],
["Noam Chomsky is a renowned linguist and cognitive scientist."]
],
allow_flagging="never"
)
if __name__ == "__main__":
demo.launch()