{"cells":[{"cell_type":"code","execution_count":null,"metadata":{"execution":{"iopub.execute_input":"2024-09-22T14:41:17.988745Z","iopub.status.busy":"2024-09-22T14:41:17.988306Z","iopub.status.idle":"2024-09-22T14:41:17.994222Z","shell.execute_reply":"2024-09-22T14:41:17.993225Z","shell.execute_reply.started":"2024-09-22T14:41:17.988710Z"},"trusted":true},"outputs":[],"source":["import os\n","import requests\n","import zipfile\n","import io\n","\n","def lesgooo(name_it):\n"," response = requests.get(name_it, stream=True)\n","\n"," # Create an in-memory bytes buffer for the zip file\n"," zip_buffer = io.BytesIO()\n","\n"," # Download the file in chunks and write to the in-memory buffer\n"," for chunk in response.iter_content(chunk_size=1024):\n"," if chunk:\n"," zip_buffer.write(chunk)\n"," \n"," # Seek to the beginning of the buffer\n"," zip_buffer.seek(0)\n","\n"," # Open the zip file in memory\n"," with zipfile.ZipFile(zip_buffer, 'r') as zip_ref:\n"," # Extract all files directly to your desired directory\n"," zip_ref.extractall('/kaggle/working/')\n","\n","# Function to download and unzip files\n","def download_and_unzip(index):\n"," images_zip = f\"https://download.visinf.tu-darmstadt.de/data/from_games/data/{str(index).zfill(2)}_images.zip\"\n"," labels_zip = f\"https://download.visinf.tu-darmstadt.de/data/from_games/data/{str(index).zfill(2)}_labels.zip\"\n"," \n"," lesgooo(images_zip)\n"," lesgooo(labels_zip)\n","\n","# Loop through indices 1 to 10\n","for i in range(1, 11):\n"," download_and_unzip(i)\n"," print(f\"Part {i} done\")"]},{"cell_type":"code","execution_count":11,"metadata":{"execution":{"iopub.execute_input":"2024-09-22T16:21:09.902762Z","iopub.status.busy":"2024-09-22T16:21:09.902339Z","iopub.status.idle":"2024-09-22T16:21:09.929962Z","shell.execute_reply":"2024-09-22T16:21:09.929117Z","shell.execute_reply.started":"2024-09-22T16:21:09.902721Z"},"trusted":true},"outputs":[],"source":["import os\n","from PIL import Image\n","from torch.utils.data import Dataset, DataLoader\n","from torchvision import transforms\n","\n","# Custom Dataset class\n","class ImageLabelDataset(Dataset):\n"," def __init__(self, images_dir, labels_dir, transform=None):\n"," self.images_dir = images_dir\n"," self.labels_dir = labels_dir\n"," self.image_filenames = sorted(os.listdir(images_dir))\n"," self.label_filenames = sorted(os.listdir(labels_dir))\n"," self.transform = transform\n","\n"," assert len(self.image_filenames) == len(self.label_filenames), \"Mismatch in number of images and labels\"\n","\n"," def __len__(self):\n"," return len(self.image_filenames)\n","\n"," def __getitem__(self, idx):\n"," image_path = os.path.join(self.images_dir, self.image_filenames[idx])\n"," label_path = os.path.join(self.labels_dir, self.label_filenames[idx])\n"," \n"," # Open image and label using PIL\n"," image = Image.open(image_path).convert('RGB') # Ensure image is RGB\n"," label = Image.open(label_path).convert('RGB') # Ensure label is RGB (or use 'L' for grayscale)\n","\n"," if self.transform:\n"," # Apply the same transformation to both the image and the label\n"," image = self.transform(image)\n"," label = self.transform(label)\n","\n"," return image, label\n","\n","resize_dim = 512 # resize to dimensions\n","batch_size = 12 # Adjust batch size according to your GPU memory capacity\n"," \n","# Define transformations (resize, convert to tensor, normalize)\n","transform = transforms.Compose([\n"," transforms.Resize((resize_dim, resize_dim)), # Resize images\n"," transforms.ToTensor(), # Convert images to PyTorch tensors\n"," transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) # Normalize to [-1, 1] for each channel\n","])\n","\n","# Create dataset instances for training\n","images_dir = '/kaggle/working/images/' # Replace with your image directory path\n","labels_dir = '/kaggle/working/labels/' # Replace with your label directory path\n","dataset = ImageLabelDataset(images_dir, labels_dir, transform=transform)\n","\n","# Create DataLoader with batch size control\n","dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=8)"]},{"cell_type":"code","execution_count":2,"metadata":{"_cell_guid":"b1076dfc-b9ad-4769-8c92-a6c4dae69d19","_uuid":"8f2839f25d086af736a60e9eeb907d3b93b6e0e5","execution":{"iopub.execute_input":"2024-09-22T16:16:46.786170Z","iopub.status.busy":"2024-09-22T16:16:46.785554Z","iopub.status.idle":"2024-09-22T16:16:46.810539Z","shell.execute_reply":"2024-09-22T16:16:46.809524Z","shell.execute_reply.started":"2024-09-22T16:16:46.786122Z"},"trusted":true},"outputs":[],"source":["import torch\n","import torch.nn as nn\n","import torch.nn.functional as F\n","\n","class DepthwiseSeparableConv(nn.Module):\n"," def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):\n"," super(DepthwiseSeparableConv, self).__init__()\n"," self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size, stride, padding, groups=in_channels)\n"," self.pointwise = nn.Conv2d(in_channels, out_channels, 1)\n","\n"," def forward(self, x):\n"," x = self.depthwise(x)\n"," x = self.pointwise(x)\n"," return x\n","\n","class LinearAttention(nn.Module):\n"," def __init__(self, dim, heads=4, dim_head=64):\n"," super(LinearAttention, self).__init__()\n"," self.heads = heads\n"," self.scale = dim_head ** -0.5\n"," inner_dim = dim_head * heads\n"," self.to_qkv = nn.Conv2d(dim, inner_dim * 3, 1, bias=False)\n"," self.to_out = nn.Conv2d(inner_dim, dim, 1)\n","\n"," def forward(self, x):\n"," b, c, h, w = x.shape\n"," qkv = self.to_qkv(x).chunk(3, dim=1)\n"," q, k, v = map(lambda t: t.reshape(b, self.heads, -1, h * w), qkv)\n"," q = q * self.scale\n"," k = k.softmax(dim=-1)\n"," context = torch.einsum('bhdn,bhen->bhde', k, v)\n"," out = torch.einsum('bhde,bhdn->bhen', context, q)\n"," out = out.reshape(b, -1, h, w)\n"," return self.to_out(out)\n","\n","class ResidualBlock(nn.Module):\n"," def __init__(self, channels):\n"," super(ResidualBlock, self).__init__()\n"," self.conv1 = DepthwiseSeparableConv(channels, channels, 3, padding=1)\n"," self.in1 = nn.InstanceNorm2d(channels)\n"," self.conv2 = DepthwiseSeparableConv(channels, channels, 3, padding=1)\n"," self.in2 = nn.InstanceNorm2d(channels)\n","\n"," def forward(self, x):\n"," residual = x\n"," out = F.relu(self.in1(self.conv1(x)))\n"," out = self.in2(self.conv2(out))\n"," out += residual\n"," return F.relu(out)\n","\n","class EfficientRealformer(nn.Module):\n"," def __init__(self, input_channels=3, output_channels=3, base_channels=64, num_residuals=6):\n"," super(EfficientRealformer, self).__init__()\n"," \n"," # Encoder\n"," self.encoder = nn.Sequential(\n"," DepthwiseSeparableConv(input_channels, base_channels, 7, padding=3),\n"," nn.InstanceNorm2d(base_channels),\n"," nn.ReLU(inplace=True),\n"," DepthwiseSeparableConv(base_channels, base_channels * 2, 3, stride=2, padding=1),\n"," nn.InstanceNorm2d(base_channels * 2),\n"," nn.ReLU(inplace=True),\n"," DepthwiseSeparableConv(base_channels * 2, base_channels * 4, 3, stride=2, padding=1),\n"," nn.InstanceNorm2d(base_channels * 4),\n"," nn.ReLU(inplace=True)\n"," )\n"," \n"," # Transformer blocks\n"," self.transformer_blocks = nn.ModuleList([\n"," nn.Sequential(\n"," LinearAttention(base_channels * 4),\n"," ResidualBlock(base_channels * 4)\n"," ) for _ in range(num_residuals)\n"," ])\n"," \n"," # Decoder\n"," self.decoder = nn.Sequential(\n"," nn.ConvTranspose2d(base_channels * 4, base_channels * 2, 3, stride=2, padding=1, output_padding=1),\n"," nn.InstanceNorm2d(base_channels * 2),\n"," nn.ReLU(inplace=True),\n"," nn.ConvTranspose2d(base_channels * 2, base_channels, 3, stride=2, padding=1, output_padding=1),\n"," nn.InstanceNorm2d(base_channels),\n"," nn.ReLU(inplace=True),\n"," DepthwiseSeparableConv(base_channels, output_channels, 7, padding=3)\n"," )\n"," \n"," # Style encoder\n"," self.style_encoder = nn.Sequential(\n"," DepthwiseSeparableConv(input_channels, base_channels, 3, stride=2, padding=1),\n"," nn.InstanceNorm2d(base_channels),\n"," nn.ReLU(inplace=True),\n"," DepthwiseSeparableConv(base_channels, base_channels * 2, 3, stride=2, padding=1),\n"," nn.InstanceNorm2d(base_channels * 2),\n"," nn.ReLU(inplace=True),\n"," DepthwiseSeparableConv(base_channels * 2, base_channels * 4, 3, stride=2, padding=1),\n"," nn.AdaptiveAvgPool2d(1),\n"," nn.Flatten(),\n"," nn.Linear(base_channels * 4, base_channels * 4)\n"," )\n","\n"," def forward(self, content, style):\n"," # Encode content\n"," x = self.encoder(content)\n"," \n"," # Extract style features\n"," style_features = self.style_encoder(style)\n"," \n"," # Apply transformer blocks with style injection\n"," for block in self.transformer_blocks:\n"," x = block(x)\n"," x = x + style_features.view(*style_features.shape, 1, 1)\n"," \n"," # Decode\n"," output = self.decoder(x)\n"," return torch.tanh(output) # Ensure output is in [-1, 1] range"]},{"cell_type":"code","execution_count":3,"metadata":{"execution":{"iopub.execute_input":"2024-09-22T16:16:46.811827Z","iopub.status.busy":"2024-09-22T16:16:46.811543Z","iopub.status.idle":"2024-09-22T16:16:46.822285Z","shell.execute_reply":"2024-09-22T16:16:46.821363Z","shell.execute_reply.started":"2024-09-22T16:16:46.811794Z"},"trusted":true},"outputs":[],"source":["def total_variation_loss(x):\n"," return torch.sum(torch.abs(x[:, :, :-1, :] - x[:, :, 1:, :])) + torch.sum(torch.abs(x[:, :, :, :-1] - x[:, :, :, 1:]))\n","\n","def combined_loss(output, target):\n"," l1_loss = nn.L1Loss()(output, target)\n"," tv_loss = total_variation_loss(output)\n"," return l1_loss + 0.0001 * tv_loss\n","\n","def psnr(img1, img2):\n"," mse = torch.mean((img1 - img2) ** 2)\n"," if mse == 0:\n"," return float('inf')\n"," return 20 * torch.log10(1.0 / torch.sqrt(mse))"]},{"cell_type":"code","execution_count":null,"metadata":{"execution":{"iopub.execute_input":"2024-09-22T16:21:12.853117Z","iopub.status.busy":"2024-09-22T16:21:12.852217Z","iopub.status.idle":"2024-09-22T17:16:46.272031Z","shell.execute_reply":"2024-09-22T17:16:46.268249Z","shell.execute_reply.started":"2024-09-22T16:21:12.853074Z"},"trusted":true},"outputs":[],"source":["import torch\n","import torch.nn as nn\n","from torch.utils.data import DataLoader\n","from tqdm import tqdm\n","import os\n","\n","# Instantiate the model\n","model = EfficientRealformer(input_channels=3, output_channels=3, base_channels=64, num_residuals=6)\n","\n","# Move model to the appropriate device (GPU if available)\n","device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n","\n","# If using 2 x T4 GPU only\n","model = nn.DataParallel(model, device_ids = [0,1])\n","\n","# Move model to device\n","model = model.to(device)\n","\n","# Optimizer\n","optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)\n","\n","# Number of epochs\n","num_epochs = 20\n","\n","best_loss = float('inf')\n","\n","# Path to save the best model\n","save_dir = 'saved_models'\n","os.makedirs(save_dir, exist_ok=True) # Create directory if it doesn't exist\n","best_model_path = os.path.join(save_dir, 'best_model.pth')\n","\n","# Initialize best loss as a large value\n","best_loss = float('inf')\n","\n","# Training loop with tqdm\n","for epoch in range(num_epochs):\n"," model.train()\n"," running_loss = 0.0\n"," running_psnr = 0.0\n"," \n"," # Wrap dataloader with tqdm for progress bar\n"," pbar = tqdm(enumerate(dataloader), total=len(dataloader), desc=f\"Epoch {epoch+1}/{num_epochs}\")\n"," \n"," for batch_idx, (input, target) in pbar:\n"," # Move data to the same device as the model\n"," input, target = input.to(device), target.to(device)\n"," \n"," optimizer.zero_grad() # Clear the gradients from the last step\n"," output = model(input, target) # Forward pass\n","# print(f\"Input shape: {input.shape}, Output shape: {output.shape}, Target shape: {target.shape}\")\n"," loss = combined_loss(output, target) # Compute the loss\n"," \n"," loss.backward() # Backward pass (compute gradients)\n"," optimizer.step() # Update the weights\n"," \n"," running_loss += loss.item() # Accumulate the loss for this batch\n"," \n"," # Calculate metrics\n"," current_psnr = psnr(output, target).item()\n"," running_psnr += current_psnr\n"," \n"," # Update progress bar\n"," pbar.set_postfix({\n"," 'loss': loss.item(),\n"," 'psnr': current_psnr,\n"," })\n"," \n"," # Calculate the average loss and metrics for the epoch\n"," epoch_loss = running_loss / len(dataloader)\n"," avg_psnr = running_psnr / len(dataloader)\n"," \n"," print(f\"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}, PSNR: {avg_psnr:.4f}\")\n"," \n"," # Save the best model based on the lowest loss\n"," if epoch_loss < best_loss:\n"," best_loss = epoch_loss\n"," torch.save(model.state_dict(), best_model_path)\n"," print(f\"New best model saved at epoch {epoch+1} with loss {best_loss:.4f}\")\n","\n","print(\"Training complete\")"]},{"cell_type":"code","execution_count":10,"metadata":{"execution":{"iopub.execute_input":"2024-09-22T16:20:38.886138Z","iopub.status.busy":"2024-09-22T16:20:38.885742Z","iopub.status.idle":"2024-09-22T16:20:39.010224Z","shell.execute_reply":"2024-09-22T16:20:39.009275Z","shell.execute_reply.started":"2024-09-22T16:20:38.886101Z"},"trusted":true},"outputs":[{"data":{"text/plain":["0"]},"execution_count":10,"metadata":{},"output_type":"execute_result"}],"source":["import gc\n","torch.cuda.empty_cache()\n","gc.collect()\n","#RuntimeError: The size of tensor a (768) must match the size of tensor b (512) at non-singleton dimension 3"]},{"cell_type":"code","execution_count":null,"metadata":{"execution":{"iopub.status.busy":"2024-09-22T16:17:14.709235Z","iopub.status.idle":"2024-09-22T16:17:14.709692Z","shell.execute_reply":"2024-09-22T16:17:14.709509Z","shell.execute_reply.started":"2024-09-22T16:17:14.709481Z"},"trusted":true},"outputs":[],"source":["total_params = sum(p.numel() for p in model.parameters())\n","print(total_params)\n","print(model)"]}],"metadata":{"kaggle":{"accelerator":"nvidiaTeslaT4","dataSources":[],"dockerImageVersionId":30762,"isGpuEnabled":true,"isInternetEnabled":true,"language":"python","sourceType":"notebook"},"kernelspec":{"display_name":"Python 3","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.10.14"}},"nbformat":4,"nbformat_minor":4}