GH29BERT / app.py
KeXing
Upload app.py
2b3e5e1 verified
raw
history blame contribute delete
No virus
4.48 kB
import gradio as gr
import pandas as pd
import io
from Bio import SeqIO
from tape import ProteinBertModel, ProteinBertConfig, TAPETokenizer # type: ignore
from tape.models import modeling_bert
import numpy as np
import torch
tokenizer = TAPETokenizer(vocab='iupac')
config=modeling_bert.ProteinBertConfig(num_hidden_layers=5,num_attention_heads=8,hidden_size=400)
bert_model = torch.load('models/transformer1500_95p_500.pt')
class_model=torch.load('models/down_model_500_kfold1.pt')
bert_model=bert_model.module
bert_model=bert_model.to("cpu")
bert_model=bert_model.eval()
def func(name):
translation_table = str.maketrans("", "", " \t\n\r\f\v")
name = name.translate(translation_table)
token_ids = torch.tensor([tokenizer.encode(name)])
token_ids = token_ids
bert_output = bert_model(token_ids)
class_output=class_model(bert_output[1])
class_output = torch.softmax(class_output, dim=1)
cluster = torch.argmax(class_output, dim=1) + 1
cluster=cluster.item()
return "cluster "+str(cluster)
def func_mult(name):
sequence_list = process_fasta(name)
#sequence_list = [s.strip() for s in sequence_list]
#sequence_list = [x for x in sequence_list if x] # 列表推导式
output=[]
for i in range(0, len(sequence_list), 1):
output.append(func(sequence_list[i]))
result = "\n".join(output)
return result
def process_fasta(fasta_content):
sequences = []
fasta_file = io.StringIO(fasta_content)
for record in SeqIO.parse(fasta_file, "fasta"):
sequences.append(str(record.seq))
return sequences
def read_fasta_file(file_path):
sequences = []
for seq_record in SeqIO.parse(file_path, "fasta"):
sequences.append(str(seq_record.seq))
return sequences
def func_file(file_path):
sequence_list = read_fasta_file(file_path)
output=[]
for i in range(0, len(sequence_list), 1):
output.append(func(sequence_list[i]))
result = "\n".join(output)
return result
def upload_file(files):
file_paths = [file.name for file in files]
return file_paths[0]
def save_to_txt(data):
# 写入数据到 TXT 文件
file_name="output.txt"
with open(file_name, mode='w') as file:
file.write(data)
# 返回文件路径
return file_name
css = """
.gradio-container {background-color: #EDEFF7}
.button {background-color: #515D90; color:#FFFFFF}
.feedback {font-size: 36px}
"""
with gr.Blocks(css=css, title="GH29 Prediction", theme=gr.themes.Soft()) as demo:
gr.Markdown("GH29 Prediction", elem_classes="feedback")
# 创建一个包含Markdown说明的示例块
with gr.Tab("Single sequence input"):
with gr.Row():
single_input = gr.Textbox(lines=10, placeholder="Please input sequence data (note: do not input fasta data)", label="Input")
single_output = gr.Textbox(lines=10, label="Output", show_copy_button=True)
single_button = gr.Button("Predict", elem_classes="button")
with gr.Tab("Multiple sequence input"):
multiple_input = gr.Textbox(lines=10, placeholder="Please enter multiple sequence data separated by line breaks (do not enter fasta data)", label="Input")
multiple_button = gr.Button("Predict", elem_classes="button")
multiple_output = gr.Textbox(lines=10, label="Output", show_copy_button=True)
with gr.Tab("FASTA input"):
with gr.Row():
file_upload = gr.File(label="Fasta File", interactive=False, scale=2)
file_output_textbox = gr.Textbox(lines=15, label="Output", scale=3, container=True, autoscroll=True, show_copy_button=True)
file_output_file = gr.File(label="Output File", scale=2)
with gr.Row():
upload_button = gr.UploadButton("Click to Upload a File", file_types=["fasta"], scale=2, size="sm", file_count="multiple")
upload_button.upload(upload_file, upload_button, file_upload)
file_button = gr.Button("Predict", scale=3, size="lg", elem_classes="button")
file_button_GenerateFile = gr.Button("Save to File", scale=2, size="sm")
single_button.click(func, inputs=single_input, outputs=single_output)
multiple_button.click(func_mult, inputs=multiple_input, outputs=multiple_output)
file_button.click(func_file, inputs=file_upload, outputs=file_output_textbox)
file_button_GenerateFile.click(save_to_txt, inputs=file_output_textbox, outputs=file_output_file)
demo.launch(share=True)