GH29BERT / app.py
KeXing
Upload 6 files
7fdc7b5
raw
history blame
No virus
925 Bytes
import gradio as gr
from tape import ProteinBertModel, ProteinBertConfig, TAPETokenizer # type: ignore
from tape.models import modeling_bert
import numpy as np
import torch
tokenizer = TAPETokenizer(vocab='iupac')
config=modeling_bert.ProteinBertConfig(num_hidden_layers=5,num_attention_heads=8,hidden_size=400)
bert_model = torch.load('models/bert.pt')
class_model=torch.load('models/class.pt')
def greet(name):
token_ids = torch.tensor([tokenizer.encode(name)])
token_ids = token_ids
bert_output = bert_model(token_ids)
class_output=class_model(bert_output[1])
cluster = torch.argmax(class_output, dim=1) + 1
cluster=cluster.item()
return "cluster "+str(cluster)
demo = gr.Interface(
fn=greet,
# 自定义输入框
# 具体设置方法查看官方文档
inputs=gr.Textbox(lines=3, placeholder="Name Here...",label="my input"),
outputs="text",
)
demo.launch(share=True)