Spaces:
Running
Running
File size: 2,677 Bytes
9594f1f 03aba46 9594f1f 03aba46 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
from torch.utils.data import DataLoader, Dataset
import torch
from transformers import ViTForImageClassification, AdamW
import os
import numpy as np
import torch
import streamlit as st
from transformers import ViTForImageClassification, ViTImageProcessor
# Custom dataset class for loading images
class MRIDataset(Dataset):
def __init__(self, image_paths, labels):
self.image_paths = image_paths
self.labels = labels
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
image = preprocess_image(self.image_paths[idx])
label = torch.tensor(self.labels[idx])
return image, label
# Load your ViT model and processor
model = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224-in21k", num_labels=3)
processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")
# Move the model to the device (GPU if available)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
# Define optimizer and loss function
optimizer = AdamW(model.parameters(), lr=1e-4)
criterion = torch.nn.CrossEntropyLoss()
# Load your dataset
image_paths = ["path_to_image1.npy", "path_to_image2.npy"] # Update with actual image paths
labels = [0, 1] # Corresponding labels
dataset = MRIDataset(image_paths, labels)
data_loader = DataLoader(dataset, batch_size=16, shuffle=True)
# Fine-tuning loop
num_epochs = 10
for epoch in range(num_epochs):
model.train()
total_loss = 0
for images, labels in data_loader:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(pixel_values=images).logits
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f'Epoch {epoch+1}/{num_epochs}, Loss: {total_loss/len(data_loader)}')
# Save the fine-tuned model
torch.save(model.state_dict(), 'vit_finetuned.pth')
def fine_tune_model():
# Your fine-tuning logic goes here (using the ViT model)
num_epochs = 10
running_loss = 0.0
for epoch in range(num_epochs):
# Fine-tuning loop (train the model)
# ...
running_loss += 0.5 # Just a placeholder for demo purposes
return running_loss # Return the final loss after training
# Streamlit UI to trigger fine-tuning and display results
st.title("MRI Image Fine-Tuning with ViT")
if st.button("Start Training"):
# Run the fine-tuning loop when the button is clicked
final_loss = fine_tune_model() # Call the function where your fine-tuning loop is
st.write(f"Training complete with final loss: {final_loss}")
|