import gradio as gr from transformers import AutoTokenizer, EsmForMaskedLM import torch import matplotlib.pyplot as plt import numpy as np import os def generate_heatmap(protein_sequence, start_pos=1, end_pos=None): # Load the model and tokenizer model_name = "facebook/esm2_t6_8M_UR50D" tokenizer = AutoTokenizer.from_pretrained(model_name) model = EsmForMaskedLM.from_pretrained(model_name) # Tokenize the input sequence input_ids = tokenizer.encode(protein_sequence, return_tensors="pt") sequence_length = input_ids.shape[1] - 2 # Excluding the special tokens # Adjust end position if not specified if end_pos is None: end_pos = sequence_length # List of amino acids amino_acids = list("ACDEFGHIKLMNPQRSTVWY") # Initialize heatmap heatmap = np.zeros((20, end_pos - start_pos + 1)) # Calculate LLRs for each position and amino acid for position in range(start_pos, end_pos + 1): # Mask the target position masked_input_ids = input_ids.clone() masked_input_ids[0, position] = tokenizer.mask_token_id # Get logits for the masked token with torch.no_grad(): logits = model(masked_input_ids).logits # Calculate log probabilities probabilities = torch.nn.functional.softmax(logits[0, position], dim=0) log_probabilities = torch.log(probabilities) # Get the log probability of the wild-type residue wt_residue = input_ids[0, position].item() log_prob_wt = log_probabilities[wt_residue].item() # Calculate LLR for each variant for i, amino_acid in enumerate(amino_acids): log_prob_mt = log_probabilities[tokenizer.convert_tokens_to_ids(amino_acid)].item() heatmap[i, position - start_pos] = log_prob_mt - log_prob_wt # Visualize the heatmap plt.figure(figsize=(15, 5)) plt.imshow(heatmap, cmap="viridis_r", aspect="auto") plt.xticks(range(end_pos - start_pos + 1), list(protein_sequence[start_pos-1:end_pos])) plt.yticks(range(20), amino_acids) plt.xlabel("Position in Protein Sequence") plt.ylabel("Amino Acid Mutations") plt.title("Predicted Effects of Mutations on Protein Sequence (LLR)") plt.colorbar(label="Log Likelihood Ratio (LLR)") plt.show() # Save the plot to a temporary file and return the file path temp_file = "temp_heatmap.png" plt.savefig(temp_file) plt.close() return temp_file def heatmap_interface(sequence, start, end=None): # Convert start and end to integers start = int(start) if end is not None: end = int(end) # If end is None or greater than sequence length, set it to sequence length if end is None or end > len(sequence) or end <= 0: end = len(sequence) # Ensure start is within bounds if start < 1 or start > len(sequence): return "Start position is out of bounds." # Generate heatmap heatmap_path = generate_heatmap(sequence, start, end) return heatmap_path # Define the Gradio interface iface = gr.Interface( fn=heatmap_interface, inputs=[ gr.Textbox(lines=2, placeholder="Enter Protein Sequence Here..."), gr.Number(label="Start Position", value=1), gr.Number(label="End Position") # No default value needed ], outputs="image", live=True ) # Run the Gradio app iface.launch()