# -*- coding: utf-8 -*- """ Created on Fri Jun 12 16:41:54 2020 @author: luol2 """ import io import time import numpy as np from keras import backend as K def ml_intext(infile): fin=open(infile,'r',encoding='utf-8') alltexts=fin.read().strip().split('\n\n') fin.close() data_list=[] label_list=[] for sents in alltexts: lines=sents.split('\n') temp_sentece=[] label=lines[0].split('\t')[0] label_list.append(label) for i in range(1,len(lines)): seg=lines[i].split('\t') temp_sentece.append(seg) data_list.append(temp_sentece) return data_list,label_list def ml_intext_fn(ml_input): fin=io.StringIO(ml_input) alltexts=fin.read().strip().split('\n\n') fin.close() data_list=[] label_list=[] for sents in alltexts: lines=sents.split('\n') temp_sentece=[] label=lines[0].split('\t')[0] label_list.append(label) for i in range(1,len(lines)): seg=lines[i].split('\t') temp_sentece.append(seg) data_list.append(temp_sentece) return data_list,label_list def pun_filter(temp_entity): pun_list=[',','.','!',';',':','?','(',')','[',']','{','}'] filter_flag=0 for ele in temp_entity: if ele in pun_list: filter_flag=1 break return filter_flag def pos_filter(temp_pos,temp_entity): pos_list_l=['PRP'] pos_list=['IN','DT','CC','O','MD','EX','POS','WDT','WP','WP$','WRB','TO','PRP$'] verb_word=['is','are','was','were','had','have','has','be','been','also'] filter_flag=0 if (temp_entity[0] in verb_word) or (temp_entity[-1] in verb_word): filter_flag=1 if (temp_pos[0] in pos_list) or (temp_pos[-1] in pos_list) or (temp_pos[0] in pos_list_l): filter_flag=1 return filter_flag def build_ngram_testset_filted(conll_input,Ngram=8): fin_genia=io.StringIO(conll_input) fout_context=io.StringIO() fout_txt=io.StringIO() index_dict={} allentity=[] alltext=fin_genia.read().strip().split('\n\n') fin_genia.close() num_total=0 for i in range(0,len(alltext)): lines=alltext[i].split('\n') ori_txt=[] for ele in lines: seg=ele.split('\t') ori_txt.append(seg[0]) fout_txt.write(' '.join(ori_txt)+'\n') if Ngram>len(lines): Ngram=len(lines) fout_context_list=[] temp_entity=[] temp_pos=[] for ngram in range(2,Ngram+1): if ngram==1: for j in range(0, len(lines)): sid=0 eid=0 for m in range(0,len(lines)): if m==j: sid=m eid=m fout_context_list.append(lines[m]+'\tO\tB') temp_seg=lines[m].split('\t') temp_entity.append(temp_seg[0]) temp_pos.append(temp_seg[3]) else: pass # print(sentence[m]) # fout_context_list.append(lines[m]+'\tO\tO') if pun_filter(temp_entity)==0 and pos_filter(temp_pos,temp_entity)==0: num_total+=1 if ' '.join(temp_entity) not in allentity: allentity.append(' '.join(temp_entity)) fout_context.write('HP:None\t'+' '.join(temp_entity)+'\n') fout_context.write('\n'.join(fout_context_list)+'\n\n') index_dict[str(num_total)]=[i,sid,eid] temp_entity=[] temp_pos=[] fout_context_list=[] elif ngram==2: for j in range(0, len(lines)-1): sid=0 eid=0 for m in range(0,len(lines)): if m==j: fout_context_list.append(lines[m]+'\tO\tB') sid=m temp_seg=lines[m].split('\t') temp_entity.append(temp_seg[0]) temp_pos.append(temp_seg[3]) elif m==j+1: fout_context_list.append(lines[m]+'\tO\tB') eid=m temp_seg=lines[m].split('\t') temp_entity.append(temp_seg[0]) temp_pos.append(temp_seg[3]) else: pass # fout_context_list.append(lines[m]+'\tO\tO') if pun_filter(temp_entity)==0 and pos_filter(temp_pos,temp_entity)==0: num_total+=1 if ' '.join(temp_entity) not in allentity: allentity.append(' '.join(temp_entity)) fout_context.write('HP:None\t'+' '.join(temp_entity)+'\n') fout_context.write('\n'.join(fout_context_list)+'\n\n') index_dict[str(num_total)]=[i,sid,eid] temp_entity=[] temp_pos=[] fout_context_list=[] else : for j in range(0, len(lines)-ngram+1): sid=0 eid=0 for m in range(0,len(lines)): if m==j: fout_context_list.append(lines[m]+'\tO\tB') sid=m temp_seg=lines[m].split('\t') temp_entity.append(temp_seg[0]) temp_pos.append(temp_seg[3]) elif m>j and mlen(lines): Ngram=len(lines) fout_context_list=[] temp_entity=[] temp_pos=[] for ngram in range(1,Ngram+1): if ngram==1: for j in range(0, len(lines)): sid=0 eid=0 for m in range(0,len(lines)): if m==j: sid=m eid=m fout_context_list.append(lines[m]+'\tO\tB') temp_seg=lines[m].split('\t') temp_entity.append(temp_seg[0]) temp_pos.append(temp_seg[3]) else: pass # print(sentence[m]) # fout_context_list.append(lines[m]+'\tO\tO') if pun_filter(temp_entity)==0 and pos_filter(temp_pos,temp_entity)==0: num_total+=1 if ' '.join(temp_entity) not in allentity: allentity.append(' '.join(temp_entity)) fout_context.write('HP:None\t'+' '.join(temp_entity)+'\n') fout_context.write('\n'.join(fout_context_list)+'\n\n') index_dict[str(num_total)]=[i,sid,eid] temp_entity=[] temp_pos=[] fout_context_list=[] elif ngram==2: for j in range(0, len(lines)-1): sid=0 eid=0 for m in range(0,len(lines)): if m==j: fout_context_list.append(lines[m]+'\tO\tB') sid=m temp_seg=lines[m].split('\t') temp_entity.append(temp_seg[0]) temp_pos.append(temp_seg[3]) elif m==j+1: fout_context_list.append(lines[m]+'\tO\tB') eid=m temp_seg=lines[m].split('\t') temp_entity.append(temp_seg[0]) temp_pos.append(temp_seg[3]) else: pass # fout_context_list.append(lines[m]+'\tO\tO') if pun_filter(temp_entity)==0 and pos_filter(temp_pos,temp_entity)==0: num_total+=1 if ' '.join(temp_entity) not in allentity: allentity.append(' '.join(temp_entity)) fout_context.write('HP:None\t'+' '.join(temp_entity)+'\n') fout_context.write('\n'.join(fout_context_list)+'\n\n') index_dict[str(num_total)]=[i,sid,eid] temp_entity=[] temp_pos=[] fout_context_list=[] else : for j in range(0, len(lines)-ngram+1): sid=0 eid=0 for m in range(0,len(lines)): if m==j: fout_context_list.append(lines[m]+'\tO\tB') sid=m temp_seg=lines[m].split('\t') temp_entity.append(temp_seg[0]) temp_pos.append(temp_seg[3]) elif m>j and mT and seg_pre[0]!='HP:None': term_id=str(i+1) pre_result=[test_index[term_id][1],test_index[term_id][2],seg_pre[0],seg_pre[1]] sent_id=str(test_index[term_id][0]) if sent_id not in sent_result.keys(): sent_result[sent_id]=[pre_result] else: sent_result[sent_id].append(pre_result) for i in range(0,len(test_txt)): fout.write(test_txt[i]+'\n') if str(i) in sent_result.keys(): temp_result={} for ele in sent_result[str(i)]: temp_line=str(ele[0])+'\t'+str(ele[1])+'\t'+' '.join(test_txt[i].split()[ele[0]:ele[1]+1])+'\t'+ele[2]+'\t'+ele[3] temp_result[temp_line]=[ele[0],ele[1]] if len(temp_result)>=1: temp_result=sorted(temp_result.items(), key=lambda d: (d[1][0],d[1][1]), reverse=False) for ent in temp_result: fout.write(ent[0]+'\n') fout.write('\n') return fout.getvalue() def score_filter(temp_entity, T=0.1): result_list=[] for i in range(0,len(temp_entity)): if float (temp_entity[i][-1])>T: result_list.append(temp_entity[i]) return(result_list) def find_max_entity_nest(nest_list): temp_result_list={} for i in range(0, len(nest_list)): hpoid=nest_list[i][-2] score=float(nest_list[i][-1]) if hpoid not in temp_result_list.keys(): temp_result_list[hpoid]=nest_list[i] else: if score>float(temp_result_list[hpoid][-1]): temp_result_list[hpoid]=nest_list[i] new_list=[] for hpoid in temp_result_list.keys(): new_list.append(temp_result_list[hpoid]) return new_list def duplicate_filter(temp_entity): result_list=[] if len(temp_entity)>1: first_entity=temp_entity[0] nest_list=[first_entity] max_eid=int(first_entity[1]) for i in range(1,len(temp_entity)): segs=temp_entity[i] if int(segs[0])> max_eid: if len(nest_list)==1: result_list.append(nest_list[0]) nest_list=[segs] if int(segs[1])>max_eid: max_eid=int(segs[1]) else: result_list.extend(find_max_entity_nest(nest_list)) nest_list=[segs] if int(segs[1])>max_eid: max_eid=int(segs[1]) else: nest_list.append(segs) if int(segs[1])>max_eid: max_eid=int(segs[1]) if nest_list!=[]: if len(nest_list)==1: result_list.append(nest_list[0]) else: result_list.extend(find_max_entity_nest(nest_list)) else: result_list=temp_entity return result_list def combine_strategy(test_decode_temp, T=0.8): fin=io.StringIO(test_decode_temp) fout=io.StringIO() documents=fin.read().strip().split('\n\n') fin.close() for doc in documents: lines=doc.split('\n') context=lines[0] final_entity_list=[] if len(lines)>1: # all entity candidates temp_entity=[] for i in range(1,len(lines)): temp_entity.append(lines[i].split('\t')) #print('all entity condidates: ',len(temp_entity)) # 将阈值低于T的候选过滤 filter1=score_filter(temp_entity,T) # print('filter1:', len(filter1)) filter2=duplicate_filter(filter1) #print('filter2:', filter2) final_entity_list=filter2 fout.write(context+'\n') for ele in final_entity_list: fout.write('\t'.join(ele)+'\n') fout.write('\n') return fout.getvalue() def model_predict(session,ml_input,nn_model,ml_input_txt,ml_input_index,Threshold): if nn_model.model_type=='cnn': #startTime=time.time() test_set,test_label = ml_intext_fn(ml_input) test_x, test_y = nn_model.rep.represent_instances_all_feas(test_set,test_label,word_max_len=nn_model.hyper['sen_max'],char_max_len=nn_model.hyper['word_max']) input_test = [] if nn_model.fea_dict['word'] == 1: input_test.append(test_x[0]) if nn_model.fea_dict['char'] == 1: input_test.append(test_x[1]) if nn_model.fea_dict['lemma'] == 1: input_test.append(test_x[2]) if nn_model.fea_dict['pos'] == 1: input_test.append(test_x[3]) # print('ml-model-represent:',time.time()-startTime) # startTime=time.time() K.set_session(session) test_pre = nn_model.model.predict(input_test) # print('ml-model-predict:',time.time()-startTime) elif nn_model.model_type=='bert' or nn_model.model_type=='bioformer': #startTime=time.time() test_set,test_label = ml_intext_fn(ml_input) test_x,test_y=nn_model.rep.load_data(test_set,test_label,word_max_len=nn_model.maxlen) #print('ml-model-represent:',time.time()-startTime) #startTime=time.time() K.set_session(session) with session.as_default(): with session.graph.as_default(): print('......session') test_pre = nn_model.model.predict(test_x) #print('ml-model-modedpred:',time.time()-startTime) # startTime=time.time() test_score=output_result(test_pre, nn_model.rep.label_2_index,Top_N=3) # print('ml-model-output:',time.time()-startTime) #print('test_score:',test_score) # startTime=time.time() test_decode_temp=decode_tsv(test_score, ml_input_index, ml_input_txt, T=Threshold) # print('ml-model-decode:',time.time()-startTime) #print('decode_temp:\n',test_decode_temp) # test_pre_tsv=combine_strategy(test_decode_temp,T=Threshold) return test_decode_temp def model_predict_old(ml_input,nn_model,ml_input_txt,ml_input_index,Threshold): if nn_model.model_type=='cnn': test_set,test_label = ml_intext_fn(ml_input) test_x, test_y = nn_model.rep.represent_instances_all_feas(test_set,test_label,word_max_len=nn_model.hyper['sen_max'],char_max_len=nn_model.hyper['word_max']) input_test = [] if nn_model.fea_dict['word'] == 1: input_test.append(test_x[0]) if nn_model.fea_dict['char'] == 1: input_test.append(test_x[1]) if nn_model.fea_dict['lemma'] == 1: input_test.append(test_x[2]) if nn_model.fea_dict['pos'] == 1: input_test.append(test_x[3]) K.set_session(nn_model.session) with nn_model.session.as_default(): with nn_model.session.graph.as_default(): test_pre = nn_model.model.predict(input_test,batch_size=256) elif nn_model.model_type=='bert' or nn_model.model_type=='bioformer': test_set,test_label = ml_intext_fn(ml_input) test_x,test_y=nn_model.rep.load_data(test_set,test_label,word_max_len=nn_model.maxlen) K.set_session(nn_model.session) with nn_model.session.as_default(): with nn_model.session.graph.as_default(): test_pre = nn_model.model.predict(test_x,batch_size=128) test_score=output_result(test_pre, nn_model.rep.label_2_index,Top_N=3) #print('test_score:',test_score) test_decode_temp=decode_tsv(test_score, ml_input_index, ml_input_txt, T=0.0) #print('decode_temp:\n',test_decode_temp) test_pre_tsv=combine_strategy(test_decode_temp,T=Threshold) return test_pre_tsv def output_txt(ml_input_txt): fin_text=io.StringIO(ml_input_txt) fout=io.StringIO() test_txt=fin_text.read().strip().split('\n') fin_text.close() for i in range(0,len(test_txt)): fout.write(test_txt[i]+'\n') fout.write('\n') return fout.getvalue() def ml_tagging(session,ssplit_token,ml_model,Threshold): # startTime=time.time() ml_input, ml_input_txt,ml_input_index=build_ngram_testset_filted(ssplit_token) # print('ml-ngrambuild:',time.time()-startTime) #print('ml_input:') #print(ml_input) # startTime=time.time() if len(ml_input_index)>0: ml_pre_tsv=model_predict(session,ml_input,ml_model,ml_input_txt,ml_input_index,Threshold) else: ml_pre_tsv=output_txt(ml_input_txt) # print('ml-modelpred:',time.time()-startTime) return ml_pre_tsv def ml_tagging_allngram(ssplit_token,ml_model,Threshold): ml_input, ml_input_txt,ml_input_index=build_all_ngram_testset_filted(ssplit_token) #print('ml_input:') #print(ml_input) if len(ml_input_index)>0: ml_pre_tsv=model_predict_old(ml_input,ml_model,ml_input_txt,ml_input_index,Threshold) else: ml_pre_tsv=output_txt(ml_input_txt) return ml_pre_tsv