import spaces import gradio as gr from phi3_instruct_graph import MODEL_LIST, Phi3InstructGraph from textwrap import dedent import rapidjson import spaces from pyvis.network import Network import networkx as nx import spacy from spacy import displacy from spacy.tokens import Span import random @spaces.GPU def extract(text, model): model = Phi3InstructGraph(model=model) result = model.extract(text) return rapidjson.loads(result) def handle_text(text): return " ".join(text.split()) def get_random_color(): return f"#{random.randint(0, 0xFFFFFF):06x}" def get_random_light_color(): # Generate higher RGB values to ensure a lighter color r = random.randint(128, 255) g = random.randint(128, 255) b = random.randint(128, 255) return f"#{r:02x}{g:02x}{b:02x}" def get_random_color(): return f"#{random.randint(0, 0xFFFFFF):06x}" def find_token_indices(doc, substring, text): result = [] start_index = text.find(substring) while start_index != -1: end_index = start_index + len(substring) start_token = None end_token = None for token in doc: if token.idx == start_index: start_token = token.i if token.idx + len(token) == end_index: end_token = token.i + 1 if start_token is None or end_token is None: print(f"Token boundaries not found for '{substring}' at index {start_index}") else: result.append({ "start": start_token, "end": end_token }) # Search for next occurrence start_index = text.find(substring, end_index) if not result: print(f"Token boundaries not found for '{substring}'") return result def create_custom_entity_viz(data, full_text): nlp = spacy.blank("xx") doc = nlp(full_text) spans = [] colors = {} for node in data["nodes"]: entity_spans = find_token_indices(doc, node["id"], full_text) for dataentity in entity_spans: start = dataentity["start"] end = dataentity["end"] if start < len(doc) and end <= len(doc): # Check for overlapping spans overlapping = any(s.start < end and start < s.end for s in spans) if not overlapping: span = Span(doc, start, end, label=node["type"]) # print(span) spans.append(span) if node["type"] not in colors: colors[node["type"]] = get_random_light_color() doc.set_ents(spans, default="unmodified") doc.spans["sc"] = spans options = { "colors": colors, "ents": list(colors.keys()), "style": "ent", "manual": True } html = displacy.render(doc, style="span", options=options) return html def create_graph(json_data): G = nx.Graph() for node in json_data['nodes']: G.add_node(node['id'], title=f"{node['type']}: {node['detailed_type']}") for edge in json_data['edges']: G.add_edge(edge['from'], edge['to'], title=edge['label'], label=edge['label']) nt = Network( width="720px", height="600px", directed=True, notebook=False, bgcolor="#111827", font_color="white" # bgcolor="#FFFFFF", # font_color="#111827" ) nt.from_nx(G) nt.barnes_hut( gravity=-3000, central_gravity=0.3, spring_length=50, spring_strength=0.001, damping=0.09, overlap=0, ) # Customize edge appearance # for edge in nt.edges: # edge['font'] = {'size': 12, 'color': '#FFD700', 'face': 'Arial'} # Removed strokeWidth # edge['color'] = {'color': '#FF4500', 'highlight': '#FF4500'} # edge['width'] = 1 # edge['arrows'] = {'to': {'enabled': True, 'type': 'arrow'}} # edge['smooth'] = {'type': 'curvedCW', 'roundness': 0.2} html = nt.generate_html() html = html.replace("'", '"') return f"""""" def process_and_visualize(text, model): if not text or not model: raise gr.Error("Text and model must be provided.") json_data = extract(text, model) entities_viz = create_custom_entity_viz(json_data, text) graph_html = create_graph(json_data) return graph_html, entities_viz, json_data with gr.Blocks(title="Phi-3 Instruct Graph (by Emergent Methods") as demo: gr.Markdown("# Phi-3 Instruct Graph (by Emergent Methods)") gr.Markdown("Extract a JSON graph from a text input and visualize it.") with gr.Row(): with gr.Column(scale=1): input_model = gr.Dropdown( MODEL_LIST, label="Model", ) input_text = gr.TextArea(label="Text", info="The text to be extracted") examples = gr.Examples( examples=[ handle_text("""Legendary rock band Aerosmith has officially announced their retirement from touring after 54 years, citing lead singer Steven Tyler's unrecoverable vocal cord injury. The decision comes after months of unsuccessful treatment for Tyler's fractured larynx, which he suffered in September 2023."""), handle_text("""Pop star Justin Timberlake, 43, had his driver's license suspended by a New York judge during a virtual court hearing on August 2, 2024. The suspension follows Timberlake's arrest for driving while intoxicated (DWI) in Sag Harbor on June 18. Timberlake, who is currently on tour in Europe, pleaded not guilty to the charges."""), ], inputs=input_text ) submit_button = gr.Button("Extract and Visualize") with gr.Column(scale=1): output_entity_viz = gr.HTML(label="Entities Visualization", show_label=True) output_graph = gr.HTML(label="Graph Visualization", show_label=True) submit_button.click( fn=process_and_visualize, inputs=[input_text, input_model], outputs=[output_graph, output_entity_viz] ) demo.launch(share=False)