AmelieSchreiber commited on
Commit
c45c45f
1 Parent(s): 225109f

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +48 -0
README.md CHANGED
@@ -20,4 +20,52 @@ tags:
20
  "eval_mcc": 0.25511446421928063,
21
  "eval_precision": 0.08547382057474782,
22
  "eval_recall": 0.7899691877651231,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  ```
 
20
  "eval_mcc": 0.25511446421928063,
21
  "eval_precision": 0.08547382057474782,
22
  "eval_recall": 0.7899691877651231,
23
+ ```
24
+
25
+ ## Using the Model
26
+
27
+ ```python
28
+ from transformers import AutoModelForTokenClassification, AutoTokenizer
29
+ from peft import PeftModel
30
+ import torch
31
+
32
+ # Path to the saved LoRA model
33
+ model_path = "AmelieSchreiber/esm2_t12_35M_ptm_qlora_2100K"
34
+ # ESM2 base model
35
+ base_model_path = "facebook/esm2_t12_35M_UR50D"
36
+
37
+ # Load the model
38
+ base_model = AutoModelForTokenClassification.from_pretrained(base_model_path)
39
+ loaded_model = PeftModel.from_pretrained(base_model, model_path)
40
+
41
+ # Ensure the model is in evaluation mode
42
+ loaded_model.eval()
43
+
44
+ # Load the tokenizer
45
+ loaded_tokenizer = AutoTokenizer.from_pretrained(base_model_path)
46
+
47
+ # Protein sequence for inference
48
+ protein_sequence = "MAVPETRPNHTIYINNLNEKIKKDELKKSLHAIFSRFGQILDILVSRSLKMRGQAFVIFKEVSSATNALRSMQGFPFYDKPMRIQYAKTDSDIIAKMKGT" # Replace with your actual sequence
49
+
50
+ # Tokenize the sequence
51
+ inputs = loaded_tokenizer(protein_sequence, return_tensors="pt", truncation=True, max_length=1024, padding='max_length')
52
+
53
+ # Run the model
54
+ with torch.no_grad():
55
+ logits = loaded_model(**inputs).logits
56
+
57
+ # Get predictions
58
+ tokens = loaded_tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) # Convert input ids back to tokens
59
+ predictions = torch.argmax(logits, dim=2)
60
+
61
+ # Define labels
62
+ id2label = {
63
+ 0: "No ptm site",
64
+ 1: "ptm site"
65
+ }
66
+
67
+ # Print the predicted labels for each token
68
+ for token, prediction in zip(tokens, predictions[0].numpy()):
69
+ if token not in ['<pad>', '<cls>', '<eos>']:
70
+ print((token, id2label[prediction]))
71
  ```