File size: 4,263 Bytes
053ee5a
5321388
9b3cda0
5321388
9b3cda0
5321388
7fdc7b5
 
 
 
 
9b3cda0
7fdc7b5
 
 
9b3cda0
 
 
 
5321388
9b3cda0
7fdc7b5
 
5321388
9b3cda0
 
 
7fdc7b5
9b3cda0
7fdc7b5
 
9b3cda0
7fdc7b5
 
 
 
9b3cda0
 
5321388
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9b3cda0
5321388
 
 
 
9b3cda0
 
7fdc7b5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import gradio as gr
import pandas as pd

import io

from Bio import SeqIO
from tape import ProteinBertModel, ProteinBertConfig, TAPETokenizer  # type: ignore
from tape.models import modeling_bert
import numpy as np
import torch


tokenizer = TAPETokenizer(vocab='iupac')
config=modeling_bert.ProteinBertConfig(num_hidden_layers=5,num_attention_heads=8,hidden_size=400)

bert_model = torch.load('models/transformer1500_95p_500.pt')
class_model=torch.load('models/down_model_500_kfold1.pt')

bert_model=bert_model.module
bert_model=bert_model.to("cpu")
bert_model=bert_model.eval()


def func(name):

    translation_table = str.maketrans("", "", " \t\n\r\f\v")
    name = name.translate(translation_table)
    token_ids = torch.tensor([tokenizer.encode(name)])
    token_ids = token_ids
    bert_output = bert_model(token_ids)
    class_output=class_model(bert_output[1])
    class_output = torch.softmax(class_output, dim=1)
    cluster = torch.argmax(class_output, dim=1) + 1
    cluster=cluster.item()

    return "cluster "+str(cluster)


def func_mult(name):
    sequence_list = name.split("\n")
    sequence_list = [s.strip() for s in sequence_list]
    sequence_list = [x for x in sequence_list if x]  # 列表推导式
    output=[]
    for i in range(0, len(sequence_list), 1):
        output.append(func(sequence_list[i]))
    result = "\n".join(output)
    return result


def read_fasta_file(file_path):
    sequences = []
    for seq_record in SeqIO.parse(file_path, "fasta"):
        sequences.append(str(seq_record.seq))
    return sequences


def func_file(file_path):
    sequence_list = read_fasta_file(file_path)
    output=[]
    for i in range(0, len(sequence_list), 1):
        output.append(func(sequence_list[i]))
    result = "\n".join(output)
    return result


def upload_file(files):
    file_paths = [file.name for file in files]
    return file_paths[0]


def save_to_txt(data):
    # 写入数据到 TXT 文件
    file_name="output.txt"
    with open(file_name, mode='w') as file:
        file.write(data)

    # 返回文件路径
    return file_name


css = """
.gradio-container {background-color: #EDEFF7}
.button {background-color: #515D90; color:#FFFFFF}
.feedback {font-size: 36px}
"""

with gr.Blocks(css=css, title="GH29 Prediction", theme=gr.themes.Soft()) as demo:
    gr.Markdown("GH29 Prediction", elem_classes="feedback")

    # 创建一个包含Markdown说明的示例块

    with gr.Tab("Single sequence input"):
        with gr.Row():
            single_input = gr.Textbox(lines=10, placeholder="Please input sequence data (note: do not input fasta data)", label="Input")
            single_output = gr.Textbox(lines=10, label="Output", show_copy_button=True)
        single_button = gr.Button("Predict", elem_classes="button")

    with gr.Tab("Multiple sequence input"):
        multiple_input = gr.Textbox(lines=10, placeholder="Please enter multiple sequence data separated by line breaks (do not enter fasta data)", label="Input")
        multiple_button = gr.Button("Predict", elem_classes="button")
        multiple_output = gr.Textbox(lines=10, label="Output", show_copy_button=True)

    with gr.Tab("FASTA input"):
        with gr.Row():
            file_upload = gr.File(label="Fasta File", interactive=False, scale=2)
            file_output_textbox = gr.Textbox(lines=15, label="Output", scale=3, container=True, autoscroll=True, show_copy_button=True)
            file_output_file = gr.File(label="Output File", scale=2)
        with gr.Row():
            upload_button = gr.UploadButton("Click to Upload a File", file_types=["fasta"], scale=2, size="sm", file_count="multiple")
            upload_button.upload(upload_file, upload_button, file_upload)
            file_button = gr.Button("Predict", scale=3, size="lg", elem_classes="button")
            file_button_GenerateFile = gr.Button("Save to file", scale=2, size="sm")

    single_button.click(func, inputs=single_input, outputs=single_output)
    multiple_button.click(func_mult, inputs=multiple_input, outputs=multiple_output)
    file_button.click(func_file, inputs=file_upload, outputs=file_output_textbox)
    file_button_GenerateFile.click(save_to_txt, inputs=file_output_textbox, outputs=file_output_file)


demo.launch(share=True)