AmelieSchreiber's picture
Update app.py
00d2cb3
raw
history blame contribute delete
No virus
3.44 kB
import gradio as gr
from transformers import AutoTokenizer, EsmForMaskedLM
import torch
import matplotlib.pyplot as plt
import numpy as np
import os
def generate_heatmap(protein_sequence, start_pos=1, end_pos=None):
# Load the model and tokenizer
model_name = "facebook/esm2_t6_8M_UR50D"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = EsmForMaskedLM.from_pretrained(model_name)
# Tokenize the input sequence
input_ids = tokenizer.encode(protein_sequence, return_tensors="pt")
sequence_length = input_ids.shape[1] - 2 # Excluding the special tokens
# Adjust end position if not specified
if end_pos is None:
end_pos = sequence_length
# List of amino acids
amino_acids = list("ACDEFGHIKLMNPQRSTVWY")
# Initialize heatmap
heatmap = np.zeros((20, end_pos - start_pos + 1))
# Calculate LLRs for each position and amino acid
for position in range(start_pos, end_pos + 1):
# Mask the target position
masked_input_ids = input_ids.clone()
masked_input_ids[0, position] = tokenizer.mask_token_id
# Get logits for the masked token
with torch.no_grad():
logits = model(masked_input_ids).logits
# Calculate log probabilities
probabilities = torch.nn.functional.softmax(logits[0, position], dim=0)
log_probabilities = torch.log(probabilities)
# Get the log probability of the wild-type residue
wt_residue = input_ids[0, position].item()
log_prob_wt = log_probabilities[wt_residue].item()
# Calculate LLR for each variant
for i, amino_acid in enumerate(amino_acids):
log_prob_mt = log_probabilities[tokenizer.convert_tokens_to_ids(amino_acid)].item()
heatmap[i, position - start_pos] = log_prob_mt - log_prob_wt
# Visualize the heatmap
plt.figure(figsize=(15, 5))
plt.imshow(heatmap, cmap="viridis_r", aspect="auto")
plt.xticks(range(end_pos - start_pos + 1), list(protein_sequence[start_pos-1:end_pos]))
plt.yticks(range(20), amino_acids)
plt.xlabel("Position in Protein Sequence")
plt.ylabel("Amino Acid Mutations")
plt.title("Predicted Effects of Mutations on Protein Sequence (LLR)")
plt.colorbar(label="Log Likelihood Ratio (LLR)")
plt.show()
# Save the plot to a temporary file and return the file path
temp_file = "temp_heatmap.png"
plt.savefig(temp_file)
plt.close()
return temp_file
def heatmap_interface(sequence, start, end=None):
# Convert start and end to integers
start = int(start)
if end is not None:
end = int(end)
# If end is None or greater than sequence length, set it to sequence length
if end is None or end > len(sequence) or end <= 0:
end = len(sequence)
# Ensure start is within bounds
if start < 1 or start > len(sequence):
return "Start position is out of bounds."
# Generate heatmap
heatmap_path = generate_heatmap(sequence, start, end)
return heatmap_path
# Define the Gradio interface
iface = gr.Interface(
fn=heatmap_interface,
inputs=[
gr.Textbox(lines=2, placeholder="Enter Protein Sequence Here..."),
gr.Number(label="Start Position", value=1),
gr.Number(label="End Position") # No default value needed
],
outputs="image",
live=True
)
# Run the Gradio app
iface.launch()