File size: 10,000 Bytes
2b26389
72b0e49
2b26389
 
 
 
 
 
 
892748f
9b993cf
f3ed046
 
 
1a0324b
 
2b26389
892748f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
08b9eb6
3daa625
 
 
 
f3ed046
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1a0324b
72b0e49
ca55ed1
3daa625
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c8e59d5
3daa625
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3705c34
77b966b
3705c34
77b966b
3705c34
cdf31f1
61cedea
2b26389
2bc812b
 
 
 
2b26389
 
 
 
 
892748f
2b26389
 
 
 
f3ed046
 
 
1167137
f3ed046
 
 
 
 
1167137
 
43b95d1
1167137
 
 
 
2b26389
 
 
 
 
 
 
 
c3846ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
892748f
c3846ee
892748f
c3846ee
 
 
892748f
c3846ee
 
ca55ed1
 
 
 
 
 
c3846ee
ca55ed1
c3846ee
 
 
 
 
ca55ed1
c3846ee
 
2b26389
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
import os
import torch
import torch.nn as nn
import pandas as pd
import torch.nn.functional as F
from lavis.models.protein_models.protein_function_opt import Blip2ProteinMistral
from lavis.models.base_model import FAPMConfig
import spaces
import gradio as gr
# from esm_scripts.extract import run_demo
from esm import pretrained, FastaBatchedDataset
from data.evaluate_data.utils import Ontology
import difflib
import re


# Load the model
# model = Blip2ProteinMistral(config=FAPMConfig(), esm_size='3b')
# model.load_checkpoint("model/checkpoint_mf2.pth")
# model.to('cuda')

def get_model(type='Molecule Function'):
    model = Blip2ProteinMistral(config=FAPMConfig(), esm_size='3b')
    if type == 'Molecule Function':
        model.load_checkpoint("model/checkpoint_mf2.pth")
        model.to('cuda')
    elif type == 'Biological Process':
        model.load_checkpoint("model/checkpoint_bp1.pth")
        model.to('cuda')
    elif type == 'Cellar Component':
        model.load_checkpoint("model/checkpoint_cc2.pth")
        model.to('cuda')


models = {
    'Molecule Function': get_model('Molecule Function'),
    'Biological Process': get_model('Biological Process'),
    'Cellar Component': get_model('Cellar Component'),
    }


model_esm, alphabet = pretrained.load_model_and_alphabet('esm2_t36_3B_UR50D')
model_esm.to('cuda')
model_esm.eval()

godb = Ontology(f'data/go1.4-basic.obo', with_rels=True)
go_des = pd.read_csv('data/go_descriptions1.4.txt', sep='|', header=None)
go_des.columns = ['id', 'text']
go_des = go_des.dropna()
go_des['id'] = go_des['id'].apply(lambda x: re.sub('_', ':', x))
go_obo_set = set(go_des['id'].tolist())
go_des['text'] = go_des['text'].apply(lambda x: x.lower())
GO_dict = dict(zip(go_des['text'], go_des['id']))
Func_dict = dict(zip(go_des['id'], go_des['text']))

# terms_mf = pd.read_pickle('/cluster/home/wenkai/deepgo2/data/mf/terms.pkl')
terms_mf = pd.read_pickle('data/terms/mf_terms.pkl')
choices_mf = [Func_dict[i] for i in list(set(terms_mf['gos']))]
choices = {x.lower(): x for x in choices_mf}


@spaces.GPU
def generate_caption(model_id, protein, prompt):
    # Process the image and the prompt
    # with open('/home/user/app/example.fasta', 'w') as f:
    #     f.write('>{}\n'.format("protein_name"))
    #     f.write('{}\n'.format(protein.strip()))
    # os.system("python esm_scripts/extract.py esm2_t36_3B_UR50D /home/user/app/example.fasta /home/user/app --repr_layers 36 --truncation_seq_length 1024 --include per_tok")
    # esm_emb = run_demo(protein_name='protein_name', protein_seq=protein,
    #                    model=model_esm, alphabet=alphabet,
    #                    include='per_tok', repr_layers=[36], truncation_seq_length=1024)

    protein_name = 'protein_name'
    protein_seq = protein
    include = 'per_tok'
    repr_layers = [36]
    truncation_seq_length = 1024
    toks_per_batch = 4096
    print("start")
    dataset = FastaBatchedDataset([protein_name], [protein_seq])
    print("dataset prepared")
    batches = dataset.get_batch_indices(toks_per_batch, extra_toks_per_seq=1)
    print("batches prepared")

    data_loader = torch.utils.data.DataLoader(
        dataset, collate_fn=alphabet.get_batch_converter(truncation_seq_length), batch_sampler=batches
    )
    print(f"Read sequences")
    return_contacts = "contacts" in include

    assert all(-(model_esm.num_layers + 1) <= i <= model_esm.num_layers for i in repr_layers)
    repr_layers = [(i + model_esm.num_layers + 1) % (model_esm.num_layers + 1) for i in repr_layers]

    with torch.no_grad():
        for batch_idx, (labels, strs, toks) in enumerate(data_loader):
            print(
                f"Processing {batch_idx + 1} of {len(batches)} batches ({toks.size(0)} sequences)"
            )
            if torch.cuda.is_available():
                toks = toks.to(device="cuda", non_blocking=True)
            out = model_esm(toks, repr_layers=repr_layers, return_contacts=return_contacts)
            representations = {
                layer: t.to(device="cpu") for layer, t in out["representations"].items()
            }
            if return_contacts:
                contacts = out["contacts"].to(device="cpu")
            for i, label in enumerate(labels):
                result = {"label": label}
                truncate_len = min(truncation_seq_length, len(strs[i]))
                # Call clone on tensors to ensure tensors are not views into a larger representation
                # See https://github.com/pytorch/pytorch/issues/1995
                if "per_tok" in include:
                    result["representations"] = {
                        layer: t[i, 1: truncate_len + 1].clone()
                        for layer, t in representations.items()
                    }
                if "mean" in include:
                    result["mean_representations"] = {
                        layer: t[i, 1: truncate_len + 1].mean(0).clone()
                        for layer, t in representations.items()
                    }
                if "bos" in include:
                    result["bos_representations"] = {
                        layer: t[i, 0].clone() for layer, t in representations.items()
                    }
                if return_contacts:
                    result["contacts"] = contacts[i, : truncate_len, : truncate_len].clone()
            esm_emb = result['representations'][36]
    '''
    inputs = tokenizer([protein], return_tensors="pt", padding=True, truncation=True).to('cuda')
    with torch.no_grad():
        outputs = model_esm(**inputs)
    esm_emb = outputs.last_hidden_state.detach()[0]
    '''
    print("esm embedding generated")
    esm_emb = F.pad(esm_emb.t(), (0, 1024 - len(esm_emb))).t().to('cuda')
    if prompt is None:
        prompt = 'none'
    else:
        prompt = prompt.lower()
    samples = {'name': ['protein_name'],
               'image': torch.unsqueeze(esm_emb, dim=0),
               'text_input': ['none'],
               'prompt': [prompt]}

    model = models[model_id]
    # Generate the output
    prediction = model.generate(samples, length_penalty=0., num_beams=15, num_captions=10, temperature=1.,
                                repetition_penalty=1.0)

    x = prediction[0]
    x = [eval(i) for i in x.split('; ')]
    pred_terms = []
    temp = []
    for i in x:
        txt = i[0]
        prob = i[1]
        sim_list = difflib.get_close_matches(txt.lower(), choices, n=1, cutoff=0.9)
        if len(sim_list) > 0:
            t_standard = sim_list[0]
            if t_standard not in temp:
                pred_terms.append(t_standard+f'({prob})')
                temp.append(t_standard)

    res_str = f"Based on the given amino acid sequence, the proteinappears to have a primary function of {', '.join(pred_terms)}"
    return res_str
    # return "test"


# Define the FAPM interface
description = """Quick demonstration of the FAPM model for protein function prediction. Upload an protein sequence to generate a function description. Modify the Prompt to provide the taxonomy information.

The model used in this app is available at [Hugging Face Model Hub](https://huggingface.co/wenkai/FAPM) and the source code can be found on [GitHub](https://github.com/xiangwenkai/FAPM/tree/main)."""

# iface = gr.Interface(
#     fn=generate_caption,
#     inputs=[gr.Textbox(type="text", label="Upload sequence"), gr.Textbox(type="text", label="Prompt")],
#     outputs=gr.Textbox(label="Generated description"),
#     description=description
# )
# # Launch the interface
# iface.launch()

css = """
  #output {
    height: 500px; 
    overflow: auto; 
    border: 1px solid #ccc; 
  }
"""

with gr.Blocks(css=css) as demo:
    gr.Markdown(description)
    with gr.Tab(label="Protein caption"):
        with gr.Row():
            with gr.Column():
                model_selector = gr.Dropdown(choices=list(models.keys()), label="Model", value='Molecule Function')
                input_protein = gr.Textbox(type="text", label="Upload sequence")
                prompt = gr.Textbox(type="text", label="Taxonomy Prompt (Optional)")
                submit_btn = gr.Button(value="Submit")
            with gr.Column():
                output_text = gr.Textbox(label="Output Text")
        # O14813 train index 127, 266, 738, 1060 test index 4
        gr.Examples(
            examples=[
                ["Molecule Function", "MDYSYLNSYDSCVAAMEASAYGDFGACSQPGGFQYSPLRPAFPAAGPPCPALGSSNCALGALRDHQPAPYSAVPYKFFPEPSGLHEKRKQRRIRTTFTSAQLKELERVFAETHYPDIYTREELALKIDLTEARVQVWFQNRRAKFRKQERAASAKGAAGAAGAKKGEARCSSEDDDSKESTCSPTPDSTASLPPPPAPGLASPRLSPSPLPVALGSGPGPGPGPQPLKGALWAGVAGGGGGGPGAGAAELLKAWQPAESGPGPFSGVLSSFHRKPGPALKTNLF", ''],
                ["Molecule Function", "MKTLALFLVLVCVLGLVQSWEWPWNRKPTKFPIPSPNPRDKWCRLNLGPAWGGRC", ''],
                ["Molecule Function", "MAAAGGARLLRAASAVLGGPAGRWLHHAGSRAGSSGLLRNRGPGGSAEASRSLSVSARARSSSEDKITVHFINRDGETLTTKGKVGDSLLDVVVENNLDIDGFGACEGTLACSTCHLIFEDHIYEKLDAITDEENDMLDLAYGLTDRSRLGCQICLTKSMDNMTVRVPETVADARQSIDVGKTS", 'Homo'],
                ["Molecule Function", 'MASAELSREENVYMAKLAEQAERYEEMVEFMEKVAKTVDSEELTVEERNLLSVAYKNVIGARRASWRIISSIEQKEEGRGNEDRVTLIKDYRGKIETELTKICDGILKLLETHLVPSSTAPESKVFYLKMKGDYYRYLAEFKTGAERKDAAENTMVAYKAAQDIALAELAPTHPIRLGLALNFSVFYYEILNSPDRACSLAKQAFDEAISELDTLSEESYKDSTLIMQLLRDNLTLWTSDISEDPAEEIREAPKRDSSEGQ', 'Zea'],
                ["Molecule Function", 'MIKAAVTKESLYRMNTLMEAFQGFLGLDLGEFTFKVKPGVFLLTDVKSYLIGDKYDDAFNALIDFVLRNDRDAVEGTETDVSIRLGLSPSDMVVKRQDKTFTFTHGDLEFEVHWINL', 'Bacteriophage'],
                ["Molecule Function", 'MNDLMIQLLDQFEMGLRERAIKVMATINDEKHRFPMELNKKQCSLMLLGTTDTTTFDMRFNSKKDFPRIKGAREKYPRDAVIEWYHQNWMRTEVKQ', 'Bacteriophage'],
            ],
            inputs=[model_selector, input_protein, prompt],
            outputs=[output_text],
            fn=generate_caption,
            cache_examples=True,
            label='Try examples'
        )
        submit_btn.click(generate_caption, [model_selector, input_protein, prompt], [output_text])

demo.launch(debug=True)