ChancesYuan commited on
Commit
c32018d
1 Parent(s): 9339e05

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -63
app.py CHANGED
@@ -5,7 +5,6 @@ import jsonlines
5
  import torch
6
  from src.modeling_bert import EXBertForMaskedLM
7
  from higher.patch import monkeypatch as make_functional
8
- # from src.models.one_shot_learner import OneShotLearner
9
 
10
  ### load KGE model
11
  edit_origin_model = BertForMaskedLM.from_pretrained(pretrained_model_name_or_path="ChancesYuan/KGEditor_Edit_Test")
@@ -23,7 +22,6 @@ id2ent_name = defaultdict(str)
23
  rel_name2id = defaultdict(str)
24
  id2ent_text = defaultdict(str)
25
  id2rel_text = defaultdict(str)
26
- corrupt_triple = defaultdict(list)
27
 
28
  ### init tokenizer
29
  tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
@@ -34,6 +32,7 @@ def init_triple_input():
34
  global ent2id
35
  global id2ent
36
  global rel2token
 
37
 
38
  with open("./dataset/fb15k237/relations.txt", "r") as f:
39
  lines = f.readlines()
@@ -65,10 +64,10 @@ def init_triple_input():
65
  ent2id = {ent: i for i, ent in enumerate(entities)}
66
  id2ent = {i: ent for i, ent in enumerate(entities)}
67
 
68
- with jsonlines.open("./dataset/fb15k237/edit_test.jsonl") as f:
69
- lines = []
70
- for d in f:
71
- corrupt_triple[" ".join(d["ori"])] = d["cor"]
72
 
73
  def solve(triple, alter_label, edit_task):
74
  print(triple, alter_label)
@@ -77,13 +76,12 @@ def solve(triple, alter_label, edit_task):
77
  text_a = "[MASK]"
78
  text_b = id2rel_text[r] + " " + rel2token[r]
79
  text_c = ent2token[ent_name2id[t]] + " " + id2ent_text[ent_name2id[t]]
80
- origin_label = corrupt_triple[" ".join([ent_name2id[alter_label], r, ent_name2id[t]])][0] if edit_task else ent_name2id[alter_label]
81
  else:
82
  text_a = ent2token[ent_name2id[h]]
83
- # text_b = id2rel_text[r] + "[PAD]"
84
  text_b = id2rel_text[r] + " " + rel2token[r]
85
  text_c = "[MASK]" + " " + id2ent_text[ent_name2id[h]]
86
- origin_label = corrupt_triple[" ".join([ent_name2id[h], r, ent_name2id[alter_label]])][2] if edit_task else ent_name2id[alter_label]
87
 
88
  if text_a == "[MASK]":
89
  input_text_a = tokenizer.sep_token.join(["[MASK]", id2rel_text[r] + "[PAD]"])
@@ -91,12 +89,6 @@ def solve(triple, alter_label, edit_task):
91
  else:
92
  input_text_a = "[PAD] "
93
  input_text_b = tokenizer.sep_token.join([id2rel_text[r] + "[PAD]", "[MASK]" + " " + id2ent_text[ent_name2id[h]]])
94
-
95
- cond_inputs_text = "{} >> {} || {}".format(
96
- add_tokenizer.added_tokens_decoder[ent2id[origin_label] + len(tokenizer)],
97
- add_tokenizer.added_tokens_decoder[ent2id[ent_name2id[alter_label]] + len(tokenizer)],
98
- input_text_a + input_text_b
99
- )
100
 
101
  inputs = tokenizer(
102
  f"{text_a} [SEP] {text_b} [SEP] {text_c}",
@@ -115,14 +107,6 @@ def solve(triple, alter_label, edit_task):
115
  add_special_tokens=True,
116
  )
117
 
118
- cond_inputs = tokenizer(
119
- cond_inputs_text,
120
- truncation=True,
121
- max_length=64,
122
- padding="max_length",
123
- add_special_tokens=True,
124
- )
125
-
126
  inputs = {
127
  "input_ids": torch.tensor(inputs["input_ids"]).unsqueeze(dim=0),
128
  "attention_mask": torch.tensor(inputs["attention_mask"]).unsqueeze(dim=0),
@@ -135,13 +119,46 @@ def solve(triple, alter_label, edit_task):
135
  "token_type_ids": torch.tensor(edit_inputs["token_type_ids"]).unsqueeze(dim=0)
136
  }
137
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  cond_inputs = {
139
  "input_ids": torch.tensor(cond_inputs["input_ids"]).unsqueeze(dim=0),
140
  "attention_mask": torch.tensor(cond_inputs["attention_mask"]).unsqueeze(dim=0),
141
  "token_type_ids": torch.tensor(cond_inputs["token_type_ids"]).unsqueeze(dim=0)
142
  }
143
 
144
- return inputs, cond_inputs, edit_inputs
 
 
 
 
 
 
 
 
145
 
146
  def get_logits_orig_params_dict(inputs, cond_inputs, alter_label, ex_model, learner):
147
  with torch.enable_grad():
@@ -149,12 +166,7 @@ def get_logits_orig_params_dict(inputs, cond_inputs, alter_label, ex_model, lear
149
  input_ids=inputs["input_ids"],
150
  attention_mask=inputs["attention_mask"],
151
  ).logits
152
- # print(logits.shape)
153
- # logits_orig, logit_for_grad, _ = logits.split([
154
- # len(inputs["input_ids"]) - 1,
155
- # 1,
156
- # 0,
157
- # ])
158
  input_ids = inputs['input_ids']
159
  _, mask_idx = (input_ids == tokenizer.mask_token_id).nonzero(as_tuple=True)
160
  mask_logits = logits[:, mask_idx, 30522:45473].squeeze(dim=0)
@@ -174,7 +186,6 @@ def get_logits_orig_params_dict(inputs, cond_inputs, alter_label, ex_model, lear
174
  for (name, _), grad in zip(ex_model.named_parameters(), grads)
175
  }
176
 
177
- # cond_inputs里面有pad
178
  params_dict = learner(
179
  cond_inputs["input_ids"][-1:],
180
  cond_inputs["attention_mask"][-1:],
@@ -184,30 +195,22 @@ def get_logits_orig_params_dict(inputs, cond_inputs, alter_label, ex_model, lear
184
  return params_dict
185
 
186
  def edit_process(edit_input, alter_label):
187
- inputs, cond_inputs, edit_inputs = solve(edit_input, alter_label, edit_task=True)
188
-
189
- _, mask_idx = (inputs["input_ids"] == tokenizer.mask_token_id).nonzero(as_tuple=True)
190
- logits = edit_origin_model(**inputs).logits[:, :, 30522:45473].squeeze()
191
- logits = logits[mask_idx, :]
192
-
193
- ### origin output
194
- _, origin_entity_order = torch.sort(logits, dim=1, descending=True)
195
- origin_entity_order = origin_entity_order.squeeze(dim=0)
196
- origin_top3 = [id2ent_name[id2ent[origin_entity_order[i].item()]] for i in range(3)]
197
 
198
  ### edit output
199
  fmodel = make_functional(edit_ex_model).eval()
200
- params_dict = get_logits_orig_params_dict(inputs, cond_inputs, ent2id[ent_name2id[alter_label]], edit_ex_model, edit_learner)
201
  edit_logits = fmodel(
202
- input_ids=inputs["input_ids"],
203
- attention_mask=inputs["attention_mask"],
204
  # add delta theta
205
  params=[
206
  params_dict.get(n, 0) + p
207
  for n, p in edit_ex_model.named_parameters()
208
  ],
209
  ).logits[:, :, 30522:45473].squeeze()
210
-
 
211
  edit_logits = edit_logits[mask_idx, :]
212
  _, edit_entity_order = torch.sort(edit_logits, dim=1, descending=True)
213
  edit_entity_order = edit_entity_order.squeeze(dim=0)
@@ -216,23 +219,14 @@ def edit_process(edit_input, alter_label):
216
  return "\n".join(origin_top3), "\n".join(edit_top3)
217
 
218
  def add_process(edit_input, alter_label):
219
- inputs, cond_inputs, add_inputs = solve(edit_input, alter_label, edit_task=False)
220
-
221
- _, mask_idx = (inputs["input_ids"] == tokenizer.mask_token_id).nonzero(as_tuple=True)
222
- logits = add_origin_model(**inputs).logits[:, :, 30522:45473].squeeze()
223
- logits = logits[mask_idx, :]
224
-
225
- ### origin output
226
- _, origin_entity_order = torch.sort(logits, dim=1, descending=True)
227
- origin_entity_order = origin_entity_order.squeeze(dim=0)
228
- origin_top3 = [id2ent_name[id2ent[origin_entity_order[i].item()]] for i in range(3)]
229
 
230
  ### add output
231
  fmodel = make_functional(add_ex_model).eval()
232
- params_dict = get_logits_orig_params_dict(inputs, cond_inputs, ent2id[ent_name2id[alter_label]], add_ex_model, add_learner)
233
  add_logits = fmodel(
234
- input_ids=inputs["input_ids"],
235
- attention_mask=inputs["attention_mask"],
236
  # add delta theta
237
  params=[
238
  params_dict.get(n, 0) + p
@@ -240,6 +234,7 @@ def add_process(edit_input, alter_label):
240
  ],
241
  ).logits[:, :, 30522:45473].squeeze()
242
 
 
243
  add_logits = add_logits[mask_idx, :]
244
  _, add_entity_order = torch.sort(add_logits, dim=1, descending=True)
245
  add_entity_order = add_entity_order.squeeze(dim=0)
@@ -250,9 +245,6 @@ def add_process(edit_input, alter_label):
250
 
251
  with gr.Blocks() as demo:
252
  init_triple_input()
253
- ### example
254
- # edit_process("[MASK]|/people/person/profession|Jack Black", "Kellie Martin")
255
- add_process("Red Skelton|/people/person/places_lived./people/place_lived/location|[MASK]", "Palm Springs")
256
  gr.Markdown("# KGE Editing")
257
 
258
  # 多个tab
@@ -270,7 +262,12 @@ with gr.Blocks() as demo:
270
  edit_output = gr.Textbox(label="After Edit", lines=3, placeholder="")
271
 
272
  gr.Examples(
273
- examples=[["[MASK]|/people/person/profession|Jack Black", "Kellie Martin"], ["Jay-Z|/people/person/spouse_s./people/marriage/type_of_union|[MASK]", "Sydney Pollack"]],
 
 
 
 
 
274
  inputs=[edit_input, alter_label],
275
  outputs=[origin_output, edit_output],
276
  fn=edit_process,
@@ -290,7 +287,12 @@ with gr.Blocks() as demo:
290
  add_output = gr.Textbox(label="Add Results", lines=3, placeholder="")
291
 
292
  gr.Examples(
293
- examples=[["Jane Wyman|/people/person/places_lived./people/place_lived/location|[MASK]", "Palm Springs"], ["Red Skelton|/people/person/places_lived./people/place_lived/location|[MASK]", "Palm Springs"]],
 
 
 
 
 
294
  inputs=[add_input, inductive_entity],
295
  outputs=[add_origin_output, add_output],
296
  fn=add_process,
 
5
  import torch
6
  from src.modeling_bert import EXBertForMaskedLM
7
  from higher.patch import monkeypatch as make_functional
 
8
 
9
  ### load KGE model
10
  edit_origin_model = BertForMaskedLM.from_pretrained(pretrained_model_name_or_path="ChancesYuan/KGEditor_Edit_Test")
 
22
  rel_name2id = defaultdict(str)
23
  id2ent_text = defaultdict(str)
24
  id2rel_text = defaultdict(str)
 
25
 
26
  ### init tokenizer
27
  tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
 
32
  global ent2id
33
  global id2ent
34
  global rel2token
35
+ global rel2id
36
 
37
  with open("./dataset/fb15k237/relations.txt", "r") as f:
38
  lines = f.readlines()
 
64
  ent2id = {ent: i for i, ent in enumerate(entities)}
65
  id2ent = {i: ent for i, ent in enumerate(entities)}
66
 
67
+ rel2id = {
68
+ w: i + len(entities)
69
+ for i, w in enumerate(rel2token.keys())
70
+ }
71
 
72
  def solve(triple, alter_label, edit_task):
73
  print(triple, alter_label)
 
76
  text_a = "[MASK]"
77
  text_b = id2rel_text[r] + " " + rel2token[r]
78
  text_c = ent2token[ent_name2id[t]] + " " + id2ent_text[ent_name2id[t]]
79
+ replace_token = [rel2id[r], ent2id[ent_name2id[t]]]
80
  else:
81
  text_a = ent2token[ent_name2id[h]]
 
82
  text_b = id2rel_text[r] + " " + rel2token[r]
83
  text_c = "[MASK]" + " " + id2ent_text[ent_name2id[h]]
84
+ replace_token = [ent2id[ent_name2id[h]], rel2id[r]]
85
 
86
  if text_a == "[MASK]":
87
  input_text_a = tokenizer.sep_token.join(["[MASK]", id2rel_text[r] + "[PAD]"])
 
89
  else:
90
  input_text_a = "[PAD] "
91
  input_text_b = tokenizer.sep_token.join([id2rel_text[r] + "[PAD]", "[MASK]" + " " + id2ent_text[ent_name2id[h]]])
 
 
 
 
 
 
92
 
93
  inputs = tokenizer(
94
  f"{text_a} [SEP] {text_b} [SEP] {text_c}",
 
107
  add_special_tokens=True,
108
  )
109
 
 
 
 
 
 
 
 
 
110
  inputs = {
111
  "input_ids": torch.tensor(inputs["input_ids"]).unsqueeze(dim=0),
112
  "attention_mask": torch.tensor(inputs["attention_mask"]).unsqueeze(dim=0),
 
119
  "token_type_ids": torch.tensor(edit_inputs["token_type_ids"]).unsqueeze(dim=0)
120
  }
121
 
122
+ _, mask_idx = (inputs["input_ids"] == tokenizer.mask_token_id).nonzero(as_tuple=True)
123
+ logits = edit_origin_model(**inputs).logits[:, :, 30522:45473].squeeze() if edit_task else add_origin_model(**inputs).logits[:, :, 30522:45473].squeeze()
124
+ logits = logits[mask_idx, :]
125
+
126
+ ### origin output
127
+ _, origin_entity_order = torch.sort(logits, dim=1, descending=True)
128
+ origin_entity_order = origin_entity_order.squeeze(dim=0)
129
+ origin_top3 = [id2ent_name[id2ent[origin_entity_order[i].item()]] for i in range(3)]
130
+
131
+ origin_label = origin_top3[0] if edit_task else alter_label
132
+
133
+ cond_inputs_text = "{} >> {} || {}".format(
134
+ add_tokenizer.added_tokens_decoder[ent2id[ent_name2id[origin_label]] + len(tokenizer)],
135
+ add_tokenizer.added_tokens_decoder[ent2id[ent_name2id[alter_label]] + len(tokenizer)],
136
+ input_text_a + input_text_b
137
+ )
138
+
139
+ cond_inputs = tokenizer(
140
+ cond_inputs_text,
141
+ truncation=True,
142
+ max_length=64,
143
+ padding="max_length",
144
+ add_special_tokens=True,
145
+ )
146
+
147
  cond_inputs = {
148
  "input_ids": torch.tensor(cond_inputs["input_ids"]).unsqueeze(dim=0),
149
  "attention_mask": torch.tensor(cond_inputs["attention_mask"]).unsqueeze(dim=0),
150
  "token_type_ids": torch.tensor(cond_inputs["token_type_ids"]).unsqueeze(dim=0)
151
  }
152
 
153
+ flag = 0
154
+ for idx, i in enumerate(edit_inputs["input_ids"][0, :].tolist()):
155
+ if i == tokenizer.pad_token_id and flag == 0:
156
+ edit_inputs["input_ids"][0, idx] = replace_token[0] + 30522
157
+ flag = 1
158
+ elif i == tokenizer.pad_token_id and flag != 0:
159
+ edit_inputs["input_ids"][0, idx] = replace_token[1] + 30522
160
+
161
+ return inputs, cond_inputs, edit_inputs, origin_top3
162
 
163
  def get_logits_orig_params_dict(inputs, cond_inputs, alter_label, ex_model, learner):
164
  with torch.enable_grad():
 
166
  input_ids=inputs["input_ids"],
167
  attention_mask=inputs["attention_mask"],
168
  ).logits
169
+
 
 
 
 
 
170
  input_ids = inputs['input_ids']
171
  _, mask_idx = (input_ids == tokenizer.mask_token_id).nonzero(as_tuple=True)
172
  mask_logits = logits[:, mask_idx, 30522:45473].squeeze(dim=0)
 
186
  for (name, _), grad in zip(ex_model.named_parameters(), grads)
187
  }
188
 
 
189
  params_dict = learner(
190
  cond_inputs["input_ids"][-1:],
191
  cond_inputs["attention_mask"][-1:],
 
195
  return params_dict
196
 
197
  def edit_process(edit_input, alter_label):
198
+ _, cond_inputs, edit_inputs, origin_top3 = solve(edit_input, alter_label, edit_task=True)
 
 
 
 
 
 
 
 
 
199
 
200
  ### edit output
201
  fmodel = make_functional(edit_ex_model).eval()
202
+ params_dict = get_logits_orig_params_dict(edit_inputs, cond_inputs, ent2id[ent_name2id[alter_label]], edit_ex_model, edit_learner)
203
  edit_logits = fmodel(
204
+ input_ids=edit_inputs["input_ids"],
205
+ attention_mask=edit_inputs["attention_mask"],
206
  # add delta theta
207
  params=[
208
  params_dict.get(n, 0) + p
209
  for n, p in edit_ex_model.named_parameters()
210
  ],
211
  ).logits[:, :, 30522:45473].squeeze()
212
+
213
+ _, mask_idx = (edit_inputs["input_ids"] == tokenizer.mask_token_id).nonzero(as_tuple=True)
214
  edit_logits = edit_logits[mask_idx, :]
215
  _, edit_entity_order = torch.sort(edit_logits, dim=1, descending=True)
216
  edit_entity_order = edit_entity_order.squeeze(dim=0)
 
219
  return "\n".join(origin_top3), "\n".join(edit_top3)
220
 
221
  def add_process(edit_input, alter_label):
222
+ _, cond_inputs, add_inputs, origin_top3 = solve(edit_input, alter_label, edit_task=False)
 
 
 
 
 
 
 
 
 
223
 
224
  ### add output
225
  fmodel = make_functional(add_ex_model).eval()
226
+ params_dict = get_logits_orig_params_dict(add_inputs, cond_inputs, ent2id[ent_name2id[alter_label]], add_ex_model, add_learner)
227
  add_logits = fmodel(
228
+ input_ids=add_inputs["input_ids"],
229
+ attention_mask=add_inputs["attention_mask"],
230
  # add delta theta
231
  params=[
232
  params_dict.get(n, 0) + p
 
234
  ],
235
  ).logits[:, :, 30522:45473].squeeze()
236
 
237
+ _, mask_idx = (add_inputs["input_ids"] == tokenizer.mask_token_id).nonzero(as_tuple=True)
238
  add_logits = add_logits[mask_idx, :]
239
  _, add_entity_order = torch.sort(add_logits, dim=1, descending=True)
240
  add_entity_order = add_entity_order.squeeze(dim=0)
 
245
 
246
  with gr.Blocks() as demo:
247
  init_triple_input()
 
 
 
248
  gr.Markdown("# KGE Editing")
249
 
250
  # 多个tab
 
262
  edit_output = gr.Textbox(label="After Edit", lines=3, placeholder="")
263
 
264
  gr.Examples(
265
+ examples=[["[MASK]|/people/person/profession|Jack Black", "Kellie Martin"],
266
+ ["[MASK]|/people/person/nationality|United States of America", "Mark Mothersbaugh"],
267
+ ["[MASK]|/people/person/gender|Male", "Iggy Pop"],
268
+ ["Rachel Weisz|/people/person/nationality|[MASK]", "J.J. Abrams"],
269
+ ["Jeff Goldblum|/people/person/spouse_s./people/marriage/type_of_union|[MASK]", "Sydney Pollack"],
270
+ ],
271
  inputs=[edit_input, alter_label],
272
  outputs=[origin_output, edit_output],
273
  fn=edit_process,
 
287
  add_output = gr.Textbox(label="Add Results", lines=3, placeholder="")
288
 
289
  gr.Examples(
290
+ examples=[["Jane Wyman|/people/person/places_lived./people/place_lived/location|[MASK]", "Palm Springs"],
291
+ ["Darryl F. Zanuck|/people/deceased_person/place_of_death|[MASK]", "Palm Springs"],
292
+ ["[MASK]|/location/location/contains|Antigua and Barbuda", "Americas"],
293
+ ["Hard rock|/music/genre/artists|[MASK]", "Social Distortion"],
294
+ ["[MASK]|/people/person/nationality|United States of America", "Serj Tankian"]
295
+ ],
296
  inputs=[add_input, inductive_entity],
297
  outputs=[add_origin_output, add_output],
298
  fn=add_process,