gigant commited on
Commit
0cf0d6f
1 Parent(s): 4f1f3ca

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +84 -2
app.py CHANGED
@@ -7,11 +7,87 @@ import spacy
7
  import gradio as gr
8
  import en_core_web_trf
9
  import numpy as np
 
 
 
10
 
11
  dataset = load_dataset("gigant/tib_transcripts")
12
 
13
  nlp = en_core_web_trf.load()
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  def half_circle_layout(n_nodes, sentence_node=True):
16
  pos = {}
17
  for i_node in range(n_nodes - 1):
@@ -127,19 +203,23 @@ def convert_jraph_to_networkx_graph(jraph_graph: jraph.GraphsTuple) -> nx.Graph:
127
  int(senders[e]), int(receivers[e]), edge_feature=edges[e])
128
  return nx_graph
129
 
130
- def plot_graph_sentence(sentence, graph_type="both"):
131
  # sentences = dataset["train"][0]["abstract"].split(".")
132
  docs = dependency_parser([sentence])
133
  if graph_type == "dependency":
134
  graphs = construct_dependency_graph(docs)
135
  elif graph_type == "structural":
136
  graphs = construct_structural_graph(docs)
137
- elif graph_type == "both":
138
  graphs = construct_both_graph(docs)
 
 
139
  g = to_jraph(graphs[0])
140
  adj_mat = get_adjacency_matrix(g)
141
  nx_graph = convert_jraph_to_networkx_graph(g)
142
  pos = half_circle_layout(len(graphs[0]["nodes"]))
 
 
143
  plot = plt.figure(figsize=(12, 6))
144
  nx.draw(nx_graph, pos=pos,
145
  labels={i: e for i,e in enumerate(graphs[0]["nodes"])},
@@ -160,6 +240,8 @@ def get_list_sentences(id):
160
  return gr.update(choices = dataset["train"][id]["transcript"].split("."))
161
 
162
  with gr.Blocks() as demo:
 
 
163
  with gr.Tab("From transcript"):
164
  with gr.Row():
165
  with gr.Column():
 
7
  import gradio as gr
8
  import en_core_web_trf
9
  import numpy as np
10
+ import benepar
11
+ import re
12
+
13
 
14
  dataset = load_dataset("gigant/tib_transcripts")
15
 
16
  nlp = en_core_web_trf.load()
17
 
18
+ benepar.download('benepar_en3')
19
+ nlp.add_pipe('benepar', config={'model': 'benepar_en3'})
20
+
21
+ def parse_tree(sentence):
22
+ stack = [] # or a `collections.deque()` object, which is a little faster
23
+ top = items = []
24
+ for token in filter(None, re.compile(r'(?:([()])|\s+)').split(sentence)):
25
+ if token == '(':
26
+ stack.append(items)
27
+ items.append([])
28
+ items = items[-1]
29
+ elif token == ')':
30
+ if not stack:
31
+ raise ValueError("Unbalanced parentheses")
32
+ items = stack.pop()
33
+ else:
34
+ items.append(token)
35
+ if stack:
36
+ raise ValueError("Unbalanced parentheses")
37
+ return top
38
+
39
+ class Tree():
40
+ def __init__(self, name, children):
41
+ self.children = children
42
+ self.name = name
43
+ self.id = None
44
+ def set_id_rec(self, id=0):
45
+ self.id = id
46
+ last_id=id
47
+ for child in self.children:
48
+ last_id = child.set_id_rec(id=last_id+1)
49
+ return last_id
50
+ def set_all_ids(self):
51
+ self.set_id_rec(0)
52
+ def print_tree(self, level=0):
53
+ to_print = f'|{"-" * level} {self.name} ({self.id})'
54
+ for child in self.children:
55
+ to_print += f"\n{child.print_tree(level + 1)}"
56
+ return to_print
57
+ def __str__(self):
58
+ return self.print_tree(0)
59
+ def get_list_nodes(self):
60
+ return [self.name] + [_ for child in self.children for _ in child.get_list_nodes()]
61
+
62
+ def rec_const_parsing(list_nodes):
63
+ if isinstance(list_nodes, list):
64
+ name, children = list_nodes[0], list_nodes[1:]
65
+ else:
66
+ name, children = list_nodes, []
67
+ return Tree(name, [rec_const_parsing(child) for i, child in enumerate(children)])
68
+
69
+ def tree_to_graph(t):
70
+ senders = []
71
+ receivers = []
72
+ for child in t.children:
73
+ senders.append(t.id)
74
+ receivers.append(child.id)
75
+ s_rec, r_rec = tree_to_graph(child)
76
+ senders.extend(s_rec)
77
+ receivers.extend(r_rec)
78
+ return senders, receivers
79
+
80
+ def construct_constituency_graph(docs):
81
+ doc = docs[0]
82
+ sent = list(doc.sents)[0]
83
+ print(sent._.parse_string)
84
+ t = rec_const_parsing(parse_tree(sent._.parse_string)[0])
85
+ t.set_all_ids()
86
+ senders, receivers = tree_to_graph(t)
87
+ nodes = t.get_list_nodes()
88
+ graphs = [{"nodes": nodes, "senders": senders, "receivers": receivers, "edge_labels": {}}]
89
+ return graphs
90
+
91
  def half_circle_layout(n_nodes, sentence_node=True):
92
  pos = {}
93
  for i_node in range(n_nodes - 1):
 
203
  int(senders[e]), int(receivers[e]), edge_feature=edges[e])
204
  return nx_graph
205
 
206
+ def plot_graph_sentence(sentence, graph_type="constituency"):
207
  # sentences = dataset["train"][0]["abstract"].split(".")
208
  docs = dependency_parser([sentence])
209
  if graph_type == "dependency":
210
  graphs = construct_dependency_graph(docs)
211
  elif graph_type == "structural":
212
  graphs = construct_structural_graph(docs)
213
+ elif graph_type == "structural+dependency":
214
  graphs = construct_both_graph(docs)
215
+ elif graph_type == "constituency":
216
+ graphs = construct_constituency_graph(docs)
217
  g = to_jraph(graphs[0])
218
  adj_mat = get_adjacency_matrix(g)
219
  nx_graph = convert_jraph_to_networkx_graph(g)
220
  pos = half_circle_layout(len(graphs[0]["nodes"]))
221
+ if graph_type == "constituency":
222
+ pos = nx.planar_layout(nx_graph)
223
  plot = plt.figure(figsize=(12, 6))
224
  nx.draw(nx_graph, pos=pos,
225
  labels={i: e for i,e in enumerate(graphs[0]["nodes"])},
 
240
  return gr.update(choices = dataset["train"][id]["transcript"].split("."))
241
 
242
  with gr.Blocks() as demo:
243
+ with gr.Row():
244
+ graph_type = gr.Dropdown(label="Graph type", choices=["structural", "dependency", "structural+dependency", "constituency"], value="structural+dependency")
245
  with gr.Tab("From transcript"):
246
  with gr.Row():
247
  with gr.Column():