VatsalPatel18's picture
Update main.py
2ef5a8a verified
raw
history blame contribute delete
No virus
1.16 kB
import torch
from PIL import Image
from classes.genomic_plip_model import GenomicPLIPModel
from classes.binary_neural_classifier import SimpleNN
from transformers import CLIPImageProcessor
def load_and_preprocess_image(image_path, clip_processor_path):
clip_processor = CLIPImageProcessor.from_pretrained(clip_processor_path)
image = Image.open(image_path).convert("RGB")
inputs = clip_processor(images=[image], return_tensors="pt")
image_tensor = inputs['pixel_values']
return image_tensor
def genomic_plip_predictions(image_tensor, model_path):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
gmodel = GenomicPLIPModel.from_pretrained(model_path).to(device)
gmodel.eval()
with torch.no_grad():
pred_data = gmodel(image_tensor.to(device))
return pred_data
def classify_tiles(pred_data, model_path):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SimpleNN().to(device)
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()
with torch.no_grad():
output = model(pred_data).mean()
return output.item()