ViT-MRI-FineTuning / traininginVIT
Tanusree88's picture
Update traininginVIT
03aba46 verified
raw
history blame
No virus
2.68 kB
from torch.utils.data import DataLoader, Dataset
import torch
from transformers import ViTForImageClassification, AdamW
import os
import numpy as np
import torch
import streamlit as st
from transformers import ViTForImageClassification, ViTImageProcessor
# Custom dataset class for loading images
class MRIDataset(Dataset):
def __init__(self, image_paths, labels):
self.image_paths = image_paths
self.labels = labels
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
image = preprocess_image(self.image_paths[idx])
label = torch.tensor(self.labels[idx])
return image, label
# Load your ViT model and processor
model = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224-in21k", num_labels=3)
processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")
# Move the model to the device (GPU if available)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
# Define optimizer and loss function
optimizer = AdamW(model.parameters(), lr=1e-4)
criterion = torch.nn.CrossEntropyLoss()
# Load your dataset
image_paths = ["path_to_image1.npy", "path_to_image2.npy"] # Update with actual image paths
labels = [0, 1] # Corresponding labels
dataset = MRIDataset(image_paths, labels)
data_loader = DataLoader(dataset, batch_size=16, shuffle=True)
# Fine-tuning loop
num_epochs = 10
for epoch in range(num_epochs):
model.train()
total_loss = 0
for images, labels in data_loader:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(pixel_values=images).logits
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f'Epoch {epoch+1}/{num_epochs}, Loss: {total_loss/len(data_loader)}')
# Save the fine-tuned model
torch.save(model.state_dict(), 'vit_finetuned.pth')
def fine_tune_model():
# Your fine-tuning logic goes here (using the ViT model)
num_epochs = 10
running_loss = 0.0
for epoch in range(num_epochs):
# Fine-tuning loop (train the model)
# ...
running_loss += 0.5 # Just a placeholder for demo purposes
return running_loss # Return the final loss after training
# Streamlit UI to trigger fine-tuning and display results
st.title("MRI Image Fine-Tuning with ViT")
if st.button("Start Training"):
# Run the fine-tuning loop when the button is clicked
final_loss = fine_tune_model() # Call the function where your fine-tuning loop is
st.write(f"Training complete with final loss: {final_loss}")