File size: 2,677 Bytes
9594f1f
 
 
03aba46
 
 
 
 
9594f1f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
03aba46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
from torch.utils.data import DataLoader, Dataset
import torch
from transformers import ViTForImageClassification, AdamW
import os
import numpy as np
import torch
import streamlit as st
from transformers import ViTForImageClassification, ViTImageProcessor

# Custom dataset class for loading images
class MRIDataset(Dataset):
    def __init__(self, image_paths, labels):
        self.image_paths = image_paths
        self.labels = labels

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = preprocess_image(self.image_paths[idx])
        label = torch.tensor(self.labels[idx])
        return image, label

# Load your ViT model and processor
model = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224-in21k", num_labels=3)
processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")

# Move the model to the device (GPU if available)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

# Define optimizer and loss function
optimizer = AdamW(model.parameters(), lr=1e-4)
criterion = torch.nn.CrossEntropyLoss()

# Load your dataset
image_paths = ["path_to_image1.npy", "path_to_image2.npy"]  # Update with actual image paths
labels = [0, 1]  # Corresponding labels
dataset = MRIDataset(image_paths, labels)
data_loader = DataLoader(dataset, batch_size=16, shuffle=True)

# Fine-tuning loop
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    for images, labels in data_loader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(pixel_values=images).logits
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f'Epoch {epoch+1}/{num_epochs}, Loss: {total_loss/len(data_loader)}')

# Save the fine-tuned model
torch.save(model.state_dict(), 'vit_finetuned.pth')

def fine_tune_model():
    # Your fine-tuning logic goes here (using the ViT model)
    num_epochs = 10
    running_loss = 0.0
    for epoch in range(num_epochs):
        # Fine-tuning loop (train the model)
        # ...
        running_loss += 0.5  # Just a placeholder for demo purposes
    return running_loss  # Return the final loss after training

# Streamlit UI to trigger fine-tuning and display results
st.title("MRI Image Fine-Tuning with ViT")

if st.button("Start Training"):
    # Run the fine-tuning loop when the button is clicked
    final_loss = fine_tune_model()  # Call the function where your fine-tuning loop is
    st.write(f"Training complete with final loss: {final_loss}")