wenkai commited on
Commit
cfac7fd
1 Parent(s): 3660015

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -1
app.py CHANGED
@@ -10,6 +10,7 @@ import gradio as gr
10
  from esm_scripts.extract import run_demo
11
  from esm import pretrained
12
 
 
13
 
14
  model_esm, alphabet = pretrained.load_model_and_alphabet('esm2_t36_3B_UR50D')
15
  model_esm.eval()
@@ -18,7 +19,7 @@ model_esm = model_esm.cuda()
18
  # Load the model
19
  model = Blip2ProteinMistral(config=FAPMConfig(), esm_size='3b')
20
  model.load_checkpoint("model/checkpoint_mf2.pth")
21
- model.to('cuda')
22
 
23
 
24
  @spaces.GPU
 
10
  from esm_scripts.extract import run_demo
11
  from esm import pretrained
12
 
13
+ print(torch.cuda.get_device_name(0))
14
 
15
  model_esm, alphabet = pretrained.load_model_and_alphabet('esm2_t36_3B_UR50D')
16
  model_esm.eval()
 
19
  # Load the model
20
  model = Blip2ProteinMistral(config=FAPMConfig(), esm_size='3b')
21
  model.load_checkpoint("model/checkpoint_mf2.pth")
22
+ model = model.cuda()
23
 
24
 
25
  @spaces.GPU