Update app.py
Browse files
app.py
CHANGED
@@ -9,17 +9,20 @@ import spaces
|
|
9 |
import gradio as gr
|
10 |
from esm_scripts.extract import run_demo
|
11 |
from esm import pretrained, FastaBatchedDataset
|
|
|
12 |
# from transformers import EsmTokenizer, EsmModel
|
13 |
|
14 |
|
15 |
# Load the model
|
16 |
-
model = Blip2ProteinMistral(config=FAPMConfig(), esm_size='3b')
|
17 |
-
model.load_checkpoint("model/checkpoint_mf2.pth")
|
18 |
-
model.to('cuda')
|
|
|
|
|
|
|
|
|
|
|
19 |
|
20 |
-
model_esm, alphabet = pretrained.load_model_and_alphabet('esm2_t36_3B_UR50D')
|
21 |
-
model_esm.to('cuda')
|
22 |
-
model_esm.eval()
|
23 |
# tokenizer = EsmTokenizer.from_pretrained("facebook/esm2_t36_3B_UR50D")
|
24 |
# model_esm = EsmModel.from_pretrained("facebook/esm2_t36_3B_UR50D")
|
25 |
# model_esm.to('cuda')
|
@@ -32,22 +35,26 @@ def generate_caption(protein, prompt):
|
|
32 |
# f.write('>{}\n'.format("protein_name"))
|
33 |
# f.write('{}\n'.format(protein.strip()))
|
34 |
# os.system("python esm_scripts/extract.py esm2_t36_3B_UR50D /home/user/app/example.fasta /home/user/app --repr_layers 36 --truncation_seq_length 1024 --include per_tok")
|
35 |
-
# esm_emb = run_demo(protein_name='protein_name', protein_seq=protein,
|
36 |
-
# model=model_esm, alphabet=alphabet,
|
37 |
# include='per_tok', repr_layers=[36], truncation_seq_length=1024)
|
38 |
-
|
39 |
-
protein_name='protein_name'
|
40 |
-
protein_seq=protein
|
41 |
-
include='per_tok'
|
42 |
-
repr_layers=[36]
|
43 |
-
truncation_seq_length=1024
|
44 |
-
toks_per_batch=4096
|
45 |
print("start")
|
46 |
dataset = FastaBatchedDataset([protein_name], [protein_seq])
|
47 |
print("dataset prepared")
|
48 |
batches = dataset.get_batch_indices(toks_per_batch, extra_toks_per_seq=1)
|
49 |
print("batches prepared")
|
50 |
-
|
|
|
|
|
|
|
|
|
51 |
data_loader = torch.utils.data.DataLoader(
|
52 |
dataset, collate_fn=alphabet.get_batch_converter(truncation_seq_length), batch_sampler=batches
|
53 |
)
|
@@ -78,12 +85,12 @@ def generate_caption(protein, prompt):
|
|
78 |
# See https://github.com/pytorch/pytorch/issues/1995
|
79 |
if "per_tok" in include:
|
80 |
result["representations"] = {
|
81 |
-
layer: t[i, 1
|
82 |
for layer, t in representations.items()
|
83 |
}
|
84 |
if "mean" in include:
|
85 |
result["mean_representations"] = {
|
86 |
-
layer: t[i, 1
|
87 |
for layer, t in representations.items()
|
88 |
}
|
89 |
if "bos" in include:
|
@@ -106,18 +113,25 @@ def generate_caption(protein, prompt):
|
|
106 |
'image': torch.unsqueeze(esm_emb, dim=0),
|
107 |
'text_input': ['none'],
|
108 |
'prompt': [prompt]}
|
|
|
|
|
|
|
|
|
|
|
|
|
109 |
# Generate the output
|
110 |
-
prediction = model.generate(samples, length_penalty=0., num_beams=15, num_captions=10, temperature=1.,
|
|
|
111 |
|
112 |
return prediction
|
113 |
# return "test"
|
114 |
|
|
|
115 |
# Define the FAPM interface
|
116 |
description = """Quick demonstration of the FAPM model for protein function prediction. Upload an protein sequence to generate a function description. Modify the Prompt to provide the taxonomy information.
|
117 |
|
118 |
The model used in this app is available at [Hugging Face Model Hub](https://huggingface.co/wenkai/FAPM) and the source code can be found on [GitHub](https://github.com/xiangwenkai/FAPM/tree/main)."""
|
119 |
|
120 |
-
|
121 |
iface = gr.Interface(
|
122 |
fn=generate_caption,
|
123 |
inputs=[gr.Textbox(type="text", label="Upload sequence"), gr.Textbox(type="text", label="Prompt")],
|
|
|
9 |
import gradio as gr
|
10 |
from esm_scripts.extract import run_demo
|
11 |
from esm import pretrained, FastaBatchedDataset
|
12 |
+
|
13 |
# from transformers import EsmTokenizer, EsmModel
|
14 |
|
15 |
|
16 |
# Load the model
|
17 |
+
# model = Blip2ProteinMistral(config=FAPMConfig(), esm_size='3b')
|
18 |
+
# model.load_checkpoint("model/checkpoint_mf2.pth")
|
19 |
+
# model.to('cuda')
|
20 |
+
|
21 |
+
# model_esm, alphabet = pretrained.load_model_and_alphabet('esm2_t36_3B_UR50D')
|
22 |
+
# model_esm.to('cuda')
|
23 |
+
# model_esm.eval()
|
24 |
+
|
25 |
|
|
|
|
|
|
|
26 |
# tokenizer = EsmTokenizer.from_pretrained("facebook/esm2_t36_3B_UR50D")
|
27 |
# model_esm = EsmModel.from_pretrained("facebook/esm2_t36_3B_UR50D")
|
28 |
# model_esm.to('cuda')
|
|
|
35 |
# f.write('>{}\n'.format("protein_name"))
|
36 |
# f.write('{}\n'.format(protein.strip()))
|
37 |
# os.system("python esm_scripts/extract.py esm2_t36_3B_UR50D /home/user/app/example.fasta /home/user/app --repr_layers 36 --truncation_seq_length 1024 --include per_tok")
|
38 |
+
# esm_emb = run_demo(protein_name='protein_name', protein_seq=protein,
|
39 |
+
# model=model_esm, alphabet=alphabet,
|
40 |
# include='per_tok', repr_layers=[36], truncation_seq_length=1024)
|
41 |
+
|
42 |
+
protein_name = 'protein_name'
|
43 |
+
protein_seq = protein
|
44 |
+
include = 'per_tok'
|
45 |
+
repr_layers = [36]
|
46 |
+
truncation_seq_length = 1024
|
47 |
+
toks_per_batch = 4096
|
48 |
print("start")
|
49 |
dataset = FastaBatchedDataset([protein_name], [protein_seq])
|
50 |
print("dataset prepared")
|
51 |
batches = dataset.get_batch_indices(toks_per_batch, extra_toks_per_seq=1)
|
52 |
print("batches prepared")
|
53 |
+
|
54 |
+
model_esm, alphabet = pretrained.load_model_and_alphabet('esm2_t36_3B_UR50D')
|
55 |
+
model_esm.to('cuda')
|
56 |
+
model_esm.eval()
|
57 |
+
|
58 |
data_loader = torch.utils.data.DataLoader(
|
59 |
dataset, collate_fn=alphabet.get_batch_converter(truncation_seq_length), batch_sampler=batches
|
60 |
)
|
|
|
85 |
# See https://github.com/pytorch/pytorch/issues/1995
|
86 |
if "per_tok" in include:
|
87 |
result["representations"] = {
|
88 |
+
layer: t[i, 1: truncate_len + 1].clone()
|
89 |
for layer, t in representations.items()
|
90 |
}
|
91 |
if "mean" in include:
|
92 |
result["mean_representations"] = {
|
93 |
+
layer: t[i, 1: truncate_len + 1].mean(0).clone()
|
94 |
for layer, t in representations.items()
|
95 |
}
|
96 |
if "bos" in include:
|
|
|
113 |
'image': torch.unsqueeze(esm_emb, dim=0),
|
114 |
'text_input': ['none'],
|
115 |
'prompt': [prompt]}
|
116 |
+
|
117 |
+
del model_esm
|
118 |
+
|
119 |
+
model = Blip2ProteinMistral(config=FAPMConfig(), esm_size='3b')
|
120 |
+
model.load_checkpoint("model/checkpoint_mf2.pth")
|
121 |
+
model.to('cuda')
|
122 |
# Generate the output
|
123 |
+
prediction = model.generate(samples, length_penalty=0., num_beams=15, num_captions=10, temperature=1.,
|
124 |
+
repetition_penalty=1.0)
|
125 |
|
126 |
return prediction
|
127 |
# return "test"
|
128 |
|
129 |
+
|
130 |
# Define the FAPM interface
|
131 |
description = """Quick demonstration of the FAPM model for protein function prediction. Upload an protein sequence to generate a function description. Modify the Prompt to provide the taxonomy information.
|
132 |
|
133 |
The model used in this app is available at [Hugging Face Model Hub](https://huggingface.co/wenkai/FAPM) and the source code can be found on [GitHub](https://github.com/xiangwenkai/FAPM/tree/main)."""
|
134 |
|
|
|
135 |
iface = gr.Interface(
|
136 |
fn=generate_caption,
|
137 |
inputs=[gr.Textbox(type="text", label="Upload sequence"), gr.Textbox(type="text", label="Prompt")],
|