Update app.py
Browse files
app.py
CHANGED
@@ -10,16 +10,14 @@ import gradio as gr
|
|
10 |
from esm_scripts.extract import run_demo
|
11 |
from esm import pretrained
|
12 |
|
13 |
-
print(torch.cuda.get_device_name(0))
|
14 |
|
15 |
model_esm, alphabet = pretrained.load_model_and_alphabet('esm2_t36_3B_UR50D')
|
16 |
-
model_esm.
|
17 |
-
model_esm = model_esm.cuda()
|
18 |
|
19 |
# Load the model
|
20 |
model = Blip2ProteinMistral(config=FAPMConfig(), esm_size='3b')
|
21 |
model.load_checkpoint("model/checkpoint_mf2.pth")
|
22 |
-
model
|
23 |
|
24 |
|
25 |
@spaces.GPU
|
@@ -50,6 +48,7 @@ def generate_caption(protein, prompt):
|
|
50 |
return_contacts = "contacts" in include
|
51 |
assert all(-(model_esm.num_layers + 1) <= i <= model_esm.num_layers for i in repr_layers)
|
52 |
repr_layers = [(i + model_esm.num_layers + 1) % (model_esm.num_layers + 1) for i in repr_layers]
|
|
|
53 |
with torch.no_grad():
|
54 |
for batch_idx, (labels, strs, toks) in enumerate(data_loader):
|
55 |
print(
|
|
|
10 |
from esm_scripts.extract import run_demo
|
11 |
from esm import pretrained
|
12 |
|
|
|
13 |
|
14 |
model_esm, alphabet = pretrained.load_model_and_alphabet('esm2_t36_3B_UR50D')
|
15 |
+
model_esm.to('cuda')
|
|
|
16 |
|
17 |
# Load the model
|
18 |
model = Blip2ProteinMistral(config=FAPMConfig(), esm_size='3b')
|
19 |
model.load_checkpoint("model/checkpoint_mf2.pth")
|
20 |
+
model.to('cuda')
|
21 |
|
22 |
|
23 |
@spaces.GPU
|
|
|
48 |
return_contacts = "contacts" in include
|
49 |
assert all(-(model_esm.num_layers + 1) <= i <= model_esm.num_layers for i in repr_layers)
|
50 |
repr_layers = [(i + model_esm.num_layers + 1) % (model_esm.num_layers + 1) for i in repr_layers]
|
51 |
+
model_esm.eval()
|
52 |
with torch.no_grad():
|
53 |
for batch_idx, (labels, strs, toks) in enumerate(data_loader):
|
54 |
print(
|