wenkai commited on
Commit
0b7981d
1 Parent(s): 0e6b0a0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -3
app.py CHANGED
@@ -28,9 +28,62 @@ def generate_caption(protein, prompt):
28
  # f.write('>{}\n'.format("protein_name"))
29
  # f.write('{}\n'.format(protein.strip()))
30
  # os.system("python esm_scripts/extract.py esm2_t36_3B_UR50D /home/user/app/example.fasta /home/user/app --repr_layers 36 --truncation_seq_length 1024 --include per_tok")
31
- esm_emb = run_demo(protein_name='protein_name', protein_seq=protein,
32
- model=model_esm, alphabet=alphabet,
33
- include='per_tok', repr_layers=[36], truncation_seq_length=1024)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  print("esm embedding generated")
35
  esm_emb = F.pad(esm_emb.t(), (0, 1024 - len(esm_emb))).t().to('cuda')
36
  print("esm embedding processed")
 
28
  # f.write('>{}\n'.format("protein_name"))
29
  # f.write('{}\n'.format(protein.strip()))
30
  # os.system("python esm_scripts/extract.py esm2_t36_3B_UR50D /home/user/app/example.fasta /home/user/app --repr_layers 36 --truncation_seq_length 1024 --include per_tok")
31
+ # esm_emb = run_demo(protein_name='protein_name', protein_seq=protein,
32
+ # model=model_esm, alphabet=alphabet,
33
+ # include='per_tok', repr_layers=[36], truncation_seq_length=1024)
34
+ protein_name='protein_name'
35
+ protein_seq=protein
36
+ include='per_tok'
37
+ repr_layers=[36]
38
+ truncation_seq_length=1024
39
+ toks_per_batch=4096
40
+
41
+ dataset = FastaBatchedDataset([protein_name], [protein_seq])
42
+ batches = dataset.get_batch_indices(toks_per_batch, extra_toks_per_seq=1)
43
+ data_loader = torch.utils.data.DataLoader(
44
+ dataset, collate_fn=alphabet.get_batch_converter(truncation_seq_length), batch_sampler=batches
45
+ )
46
+ print(f"Read sequences")
47
+ return_contacts = "contacts" in include
48
+ assert all(-(model_esm.num_layers + 1) <= i <= model_esm.num_layers for i in repr_layers)
49
+ repr_layers = [(i + model_esm.num_layers + 1) % (model_esm.num_layers + 1) for i in repr_layers]
50
+ with torch.no_grad():
51
+ for batch_idx, (labels, strs, toks) in enumerate(data_loader):
52
+ print(
53
+ f"Processing {batch_idx + 1} of {len(batches)} batches ({toks.size(0)} sequences)"
54
+ )
55
+ if torch.cuda.is_available():
56
+ toks = toks.to(device="cuda", non_blocking=True)
57
+ out = model_esm(toks, repr_layers=repr_layers, return_contacts=return_contacts)
58
+ logits = out["logits"].to(device="cpu")
59
+ representations = {
60
+ layer: t.to(device="cpu") for layer, t in out["representations"].items()
61
+ }
62
+ if return_contacts:
63
+ contacts = out["contacts"].to(device="cpu")
64
+ for i, label in enumerate(labels):
65
+ result = {"label": label}
66
+ truncate_len = min(truncation_seq_length, len(strs[i]))
67
+ # Call clone on tensors to ensure tensors are not views into a larger representation
68
+ # See https://github.com/pytorch/pytorch/issues/1995
69
+ if "per_tok" in include:
70
+ result["representations"] = {
71
+ layer: t[i, 1 : truncate_len + 1].clone()
72
+ for layer, t in representations.items()
73
+ }
74
+ if "mean" in include:
75
+ result["mean_representations"] = {
76
+ layer: t[i, 1 : truncate_len + 1].mean(0).clone()
77
+ for layer, t in representations.items()
78
+ }
79
+ if "bos" in include:
80
+ result["bos_representations"] = {
81
+ layer: t[i, 0].clone() for layer, t in representations.items()
82
+ }
83
+ if return_contacts:
84
+ result["contacts"] = contacts[i, : truncate_len, : truncate_len].clone()
85
+ esm_emb = result['representations'][36]
86
+
87
  print("esm embedding generated")
88
  esm_emb = F.pad(esm_emb.t(), (0, 1024 - len(esm_emb))).t().to('cuda')
89
  print("esm embedding processed")