from torch.utils.data import DataLoader, Dataset import torch from transformers import ViTForImageClassification, AdamW import os import numpy as np import torch import streamlit as st from transformers import ViTForImageClassification, ViTImageProcessor # Custom dataset class for loading images class MRIDataset(Dataset): def __init__(self, image_paths, labels): self.image_paths = image_paths self.labels = labels def __len__(self): return len(self.image_paths) def __getitem__(self, idx): image = preprocess_image(self.image_paths[idx]) label = torch.tensor(self.labels[idx]) return image, label # Load your ViT model and processor model = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224-in21k", num_labels=3) processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224") # Move the model to the device (GPU if available) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model.to(device) # Define optimizer and loss function optimizer = AdamW(model.parameters(), lr=1e-4) criterion = torch.nn.CrossEntropyLoss() # Load your dataset image_paths = ["path_to_image1.npy", "path_to_image2.npy"] # Update with actual image paths labels = [0, 1] # Corresponding labels dataset = MRIDataset(image_paths, labels) data_loader = DataLoader(dataset, batch_size=16, shuffle=True) # Fine-tuning loop num_epochs = 10 for epoch in range(num_epochs): model.train() total_loss = 0 for images, labels in data_loader: images, labels = images.to(device), labels.to(device) optimizer.zero_grad() outputs = model(pixel_values=images).logits loss = criterion(outputs, labels) loss.backward() optimizer.step() total_loss += loss.item() print(f'Epoch {epoch+1}/{num_epochs}, Loss: {total_loss/len(data_loader)}') # Save the fine-tuned model torch.save(model.state_dict(), 'vit_finetuned.pth') def fine_tune_model(): # Your fine-tuning logic goes here (using the ViT model) num_epochs = 10 running_loss = 0.0 for epoch in range(num_epochs): # Fine-tuning loop (train the model) # ... running_loss += 0.5 # Just a placeholder for demo purposes return running_loss # Return the final loss after training # Streamlit UI to trigger fine-tuning and display results st.title("MRI Image Fine-Tuning with ViT") if st.button("Start Training"): # Run the fine-tuning loop when the button is clicked final_loss = fine_tune_model() # Call the function where your fine-tuning loop is st.write(f"Training complete with final loss: {final_loss}")