KeXing commited on
Commit
9b3cda0
1 Parent(s): c2f2143

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +21 -6
  2. down_model_500_kfold1.pt +3 -0
  3. transformer1500_95p_500.pt +3 -0
app.py CHANGED
@@ -1,29 +1,44 @@
1
- import gradio as gr
 
 
2
  from tape import ProteinBertModel, ProteinBertConfig, TAPETokenizer # type: ignore
3
  from tape.models import modeling_bert
4
  import numpy as np
5
  import torch
6
 
 
7
  tokenizer = TAPETokenizer(vocab='iupac')
8
  config=modeling_bert.ProteinBertConfig(num_hidden_layers=5,num_attention_heads=8,hidden_size=400)
9
 
10
- bert_model = torch.load('models/bert.pt')
11
- class_model=torch.load('models/class.pt')
 
 
 
 
12
 
13
 
14
 
15
  def greet(name):
16
 
17
- #translation_table = str.maketrans("", "", " \t\n\r\f\v")
18
- name = name.replace(" ", "").replace("\n", "").replace("\t", "")
 
 
19
  token_ids = torch.tensor([tokenizer.encode(name)])
20
-
21
  bert_output = bert_model(token_ids)
22
  class_output=class_model(bert_output[1])
 
23
  cluster = torch.argmax(class_output, dim=1) + 1
24
  cluster=cluster.item()
25
 
26
  return "cluster "+str(cluster)
 
 
 
 
 
27
  demo = gr.Interface(
28
  fn=greet,
29
  # 自定义输入框
 
1
+ import gradio as gr
2
+
3
+
4
  from tape import ProteinBertModel, ProteinBertConfig, TAPETokenizer # type: ignore
5
  from tape.models import modeling_bert
6
  import numpy as np
7
  import torch
8
 
9
+
10
  tokenizer = TAPETokenizer(vocab='iupac')
11
  config=modeling_bert.ProteinBertConfig(num_hidden_layers=5,num_attention_heads=8,hidden_size=400)
12
 
13
+ bert_model = torch.load('models/transformer1500_95p_500.pt')
14
+ class_model=torch.load('models/down_model_500_kfold1.pt')
15
+
16
+ bert_model=bert_model.module
17
+ bert_model=bert_model.to('cpu')
18
+ bert_model=bert_model.eval()
19
 
20
 
21
 
22
  def greet(name):
23
 
24
+
25
+
26
+ translation_table = str.maketrans("", "", " \t\n\r\f\v")
27
+ name = name.translate(translation_table)
28
  token_ids = torch.tensor([tokenizer.encode(name)])
29
+ token_ids = token_ids
30
  bert_output = bert_model(token_ids)
31
  class_output=class_model(bert_output[1])
32
+ class_output = torch.softmax(class_output, dim=1)
33
  cluster = torch.argmax(class_output, dim=1) + 1
34
  cluster=cluster.item()
35
 
36
  return "cluster "+str(cluster)
37
+
38
+
39
+
40
+
41
+
42
  demo = gr.Interface(
43
  fn=greet,
44
  # 自定义输入框
down_model_500_kfold1.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:54ffee6f21ff8d0c743fbf8cad638ed609030e79a6359ec115005c98f0cef9c7
3
+ size 2491739
transformer1500_95p_500.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:77877d4c8f835aacae45e311d5ecfdc180270a22bf7206a33b4a358fbdcb5d6d
3
+ size 80270463