wenkai commited on
Commit
d3260a6
1 Parent(s): cfac7fd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -4
app.py CHANGED
@@ -10,16 +10,14 @@ import gradio as gr
10
  from esm_scripts.extract import run_demo
11
  from esm import pretrained
12
 
13
- print(torch.cuda.get_device_name(0))
14
 
15
  model_esm, alphabet = pretrained.load_model_and_alphabet('esm2_t36_3B_UR50D')
16
- model_esm.eval()
17
- model_esm = model_esm.cuda()
18
 
19
  # Load the model
20
  model = Blip2ProteinMistral(config=FAPMConfig(), esm_size='3b')
21
  model.load_checkpoint("model/checkpoint_mf2.pth")
22
- model = model.cuda()
23
 
24
 
25
  @spaces.GPU
@@ -50,6 +48,7 @@ def generate_caption(protein, prompt):
50
  return_contacts = "contacts" in include
51
  assert all(-(model_esm.num_layers + 1) <= i <= model_esm.num_layers for i in repr_layers)
52
  repr_layers = [(i + model_esm.num_layers + 1) % (model_esm.num_layers + 1) for i in repr_layers]
 
53
  with torch.no_grad():
54
  for batch_idx, (labels, strs, toks) in enumerate(data_loader):
55
  print(
 
10
  from esm_scripts.extract import run_demo
11
  from esm import pretrained
12
 
 
13
 
14
  model_esm, alphabet = pretrained.load_model_and_alphabet('esm2_t36_3B_UR50D')
15
+ model_esm.to('cuda')
 
16
 
17
  # Load the model
18
  model = Blip2ProteinMistral(config=FAPMConfig(), esm_size='3b')
19
  model.load_checkpoint("model/checkpoint_mf2.pth")
20
+ model.to('cuda')
21
 
22
 
23
  @spaces.GPU
 
48
  return_contacts = "contacts" in include
49
  assert all(-(model_esm.num_layers + 1) <= i <= model_esm.num_layers for i in repr_layers)
50
  repr_layers = [(i + model_esm.num_layers + 1) % (model_esm.num_layers + 1) for i in repr_layers]
51
+ model_esm.eval()
52
  with torch.no_grad():
53
  for batch_idx, (labels, strs, toks) in enumerate(data_loader):
54
  print(