File size: 1,318 Bytes
053ee5a
9b3cda0
 
7fdc7b5
 
 
 
 
9b3cda0
7fdc7b5
 
 
9b3cda0
 
 
 
 
 
7fdc7b5
 
 
 
4616fcc
9b3cda0
 
 
 
7fdc7b5
9b3cda0
7fdc7b5
 
9b3cda0
7fdc7b5
 
 
 
9b3cda0
 
 
 
 
7fdc7b5
 
 
 
4a09b16
6da398d
7fdc7b5
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import gradio as gr


from tape import ProteinBertModel, ProteinBertConfig, TAPETokenizer  # type: ignore
from tape.models import modeling_bert
import numpy as np
import torch


tokenizer = TAPETokenizer(vocab='iupac')
config=modeling_bert.ProteinBertConfig(num_hidden_layers=5,num_attention_heads=8,hidden_size=400)

bert_model = torch.load('models/transformer1500_95p_500.pt')
class_model=torch.load('models/down_model_500_kfold1.pt')

bert_model=bert_model.module
bert_model=bert_model.to('cpu')
bert_model=bert_model.eval()



def greet(name):



    translation_table = str.maketrans("", "", " \t\n\r\f\v")
    name = name.translate(translation_table)
    token_ids = torch.tensor([tokenizer.encode(name)])
    token_ids = token_ids
    bert_output = bert_model(token_ids)
    class_output=class_model(bert_output[1])
    class_output = torch.softmax(class_output, dim=1)
    cluster = torch.argmax(class_output, dim=1) + 1
    cluster=cluster.item()

    return "cluster "+str(cluster)





demo = gr.Interface(
    fn=greet,
    # 自定义输入框
    # 具体设置方法查看官方文档
    inputs=gr.Textbox(lines=3, placeholder="",label="Paste a protein sequence in plain text (not in FASTA format)"),
    outputs=gr.Textbox(lines=3, placeholder="",label="Cluster prediction"),
)
demo.launch(share=True)