FAPM_demo / app.py
wenkai's picture
Update app.py
bb50772 verified
raw
history blame
No virus
2.24 kB
import os
import torch
import torch.nn as nn
import pandas as pd
import torch.nn.functional as F
from lavis.models.protein_models.protein_function_opt import Blip2ProteinMistral
from lavis.models.base_model import FAPMConfig
import spaces
import gradio as gr
from esm_scripts.extract import run_demo
from esm import pretrained, FastaBatchedDataset
# from transformers import EsmTokenizer, EsmModel
# Load the model
model = Blip2ProteinMistral(config=FAPMConfig(), esm_size='3b')
model.load_checkpoint("model/checkpoint_mf2.pth")
model.to('cuda')
@spaces.GPU
def generate_caption(protein, prompt):
esm_emb = torch.load('data/emb_esm2_3b/P18281.pt')['representations'][36]
torch.save(esm_emb, 'data/emb_esm2_3b/example.pt')
'''
inputs = tokenizer([protein], return_tensors="pt", padding=True, truncation=True).to('cuda')
with torch.no_grad():
outputs = model_esm(**inputs)
esm_emb = outputs.last_hidden_state.detach()[0]
'''
print("esm embedding generated")
esm_emb = F.pad(esm_emb.t(), (0, 1024 - len(esm_emb))).t().to('cuda')
print("esm embedding processed")
samples = {'name': ['protein_name'],
'image': torch.unsqueeze(esm_emb, dim=0),
'text_input': ['none'],
'prompt': [prompt]}
# Generate the output
prediction = model.generate(samples, length_penalty=0., num_beams=15, num_captions=10, temperature=1.,
repetition_penalty=1.0)
return prediction
# return "test"
# Define the FAPM interface
description = """Quick demonstration of the FAPM model for protein function prediction. Upload an protein sequence to generate a function description. Modify the Prompt to provide the taxonomy information.
The model used in this app is available at [Hugging Face Model Hub](https://huggingface.co/wenkai/FAPM) and the source code can be found on [GitHub](https://github.com/xiangwenkai/FAPM/tree/main)."""
iface = gr.Interface(
fn=generate_caption,
inputs=[gr.Textbox(type="text", label="Upload sequence"), gr.Textbox(type="text", label="Prompt")],
outputs=gr.Textbox(label="Generated description"),
description=description
)
# Launch the interface
iface.launch()