{ "cells": [ { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "import numpy as np\n", "\n", "from pathlib import Path\n", "import os\n", "from PIL import Image\n", "\n", "from model import VAE\n", "from losses import *" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "from torch.utils.data import DataLoader, Dataset\n", "from torchvision import transforms\n", "import pandas as pd\n", "import re\n", "from sklearn.model_selection import train_test_split" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "IMAGE_FOLDER = './data/images/'" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "image_names = os.listdir(IMAGE_FOLDER)\n", "data = pd.DataFrame({'image_name': image_names})\n", "data['label'] = data['image_name'].apply(lambda x: int(re.match('^\\d+', x)[0]))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class StampDataset(Dataset):\n", " def __init__(self, data, image_folder=Path(IMAGE_FOLDER), transform=None):\n", " super().__init__()\n", " self.image_folder = image_folder\n", " self.data = data\n", " self.transform = transform\n", "\n", " def __getitem__(self, idx):\n", " image = Image.open(self.image_folder / self.data.iloc[idx]['image_name'])\n", " label = self.data.iloc[idx]['label']\n", " if self.transform:\n", " image = self.transform(image)\n", "\n", " return image, label\n", "\n", " \n", " def __len__(self):\n", " return len(self.data)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "train_data, val_data = train_test_split(data, test_size=0.3, shuffle=True, stratify=data['label'])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "train_transform = transforms.Compose([\n", " transforms.Resize((118, 118)),\n", " transforms.RandomHorizontalFlip(0.5),\n", " transforms.RandomVerticalFlip(0.5),\n", " transforms.ToTensor(),\n", " # transforms.Normalize((0.76302232, 0.77820438, 0.81879729), (0.16563211, 0.14949341, 0.1055889)),\n", "])\n", "\n", "val_transform = transforms.Compose([\n", " transforms.Resize((118, 118)),\n", " transforms.ToTensor(),\n", " # transforms.Normalize((0.76302232, 0.77820438, 0.81879729), (0.16563211, 0.14949341, 0.1055889)),\n", "])\n", "train_dataset = StampDataset(train_data, transform=train_transform)\n", "val_dataset = StampDataset(val_data, transform=val_transform)\n", "\n", "train_loader = DataLoader(train_dataset, shuffle=True, batch_size=256)\n", "val_loader = DataLoader(val_dataset, shuffle=True, batch_size=256)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "import pytorch_lightning as pl\n", "from torch import optim\n", "from pytorch_lightning.loggers import TensorBoardLogger\n", "\n", "from torchvision.utils import make_grid" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "MEAN = torch.tensor((0.76302232, 0.77820438, 0.81879729)).view(3, 1, 1)\n", "STD = torch.tensor((0.16563211, 0.14949341, 0.1055889)).view(3, 1, 1)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "class LitModel(pl.LightningModule):\n", " def __init__(self, alpha=1e-3):\n", " super().__init__()\n", " self.vae = VAE()\n", " self.vae_loss = VAELoss()\n", " self.triplet_loss = BatchHardTripletLoss(margin=1.)\n", " self.alpha = alpha\n", " \n", " def forward(self, x):\n", " return self.vae(x)\n", " \n", " def configure_optimizers(self):\n", " optimizer = optim.AdamW(self.parameters(), lr=3e-4)\n", " return optimizer\n", " # scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, 1000)\n", " # return {\"optimizer\": optimizer, \"lr_scheduler\": scheduler}\n", "\n", " def training_step(self, batch, batch_idx):\n", " images, labels = batch\n", " labels = labels.unsqueeze(1)\n", " recon_images, encoding = self.vae(images)\n", " vae_loss = self.vae_loss(recon_images, images, encoding)\n", " self.log(\"train_vae_loss\", vae_loss, on_step=False, on_epoch=True, prog_bar=False, logger=True)\n", " triplet_loss = self.triplet_loss(encoding.mean, labels)\n", " self.log(\"train_triplet_loss\", triplet_loss, on_step=False, on_epoch=True, prog_bar=False, logger=True)\n", " loss = self.alpha * triplet_loss + vae_loss\n", " self.log(\"train_loss\", loss, on_epoch=True, prog_bar=True, logger=True)\n", " return loss\n", "\n", " def validation_step(self, batch, batch_idx):\n", " images, labels = batch\n", " labels = labels.unsqueeze(1)\n", " recon_images, encoding = self.vae(images)\n", " vae_loss = self.vae_loss(recon_images, images, encoding)\n", " self.log(\"val_vae_loss\", vae_loss, on_step=False, on_epoch=True, prog_bar=False, logger=True)\n", " triplet_loss = self.triplet_loss(encoding.mean, labels)\n", " self.log(\"val_triplet_loss\", triplet_loss, on_step=False, on_epoch=True, prog_bar=False, logger=True)\n", " loss = self.alpha * triplet_loss + vae_loss\n", " self.log(\"val_loss\", loss, on_epoch=True, prog_bar=True, logger=True)\n", " return loss\n", "\n", " def on_validation_epoch_end(self):\n", " images, _ = iter(val_loader).next()\n", " image_unflat = images.detach().cpu()\n", " image_grid = make_grid(image_unflat[:16], nrow=4)\n", " self.logger.experiment.add_image('original images', image_grid, self.current_epoch)\n", "\n", " recon_images, _ = self.vae(images.to(self.device))\n", " image_unflat = recon_images.detach().cpu()\n", " image_grid = make_grid(image_unflat[:16], nrow=4)\n", " self.logger.experiment.add_image('reconstructed images', image_grid, self.current_epoch)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "litmodel = LitModel()" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "logger = TensorBoardLogger(\"reconstruction_logs\")" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "epochs = 100" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "trainer = pl.Trainer(accelerator=\"auto\", max_epochs=epochs, logger=logger)\n", "trainer.fit(model=litmodel, train_dataloaders=train_loader, val_dataloaders=val_loader)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%tensorboard" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "device = 'cuda' if torch.cuda.is_available() else 'cpu'" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "from huggingface_hub import hf_hub_download" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "emb_model = torch.jit.load(hf_hub_download(repo_id=\"stamps-labs/vits8-stamp\", filename=\"vits8stamp-torchscript.pth\")).to(device)" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [], "source": [ "val_transform = transforms.Compose([\n", " transforms.Resize((224, 224)),\n", " transforms.ToTensor(),\n", " # transforms.Normalize((0.76302232, 0.77820438, 0.81879729), (0.16563211, 0.14949341, 0.1055889)),\n", "])" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [], "source": [ "train_data['embed'] = train_data['image_name'].apply(lambda x: emb_model(val_transform(Image.open(Path(IMAGE_FOLDER) / x)).unsqueeze(0).to(device))[0].tolist())" ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "C:\\Users\\javid\\AppData\\Local\\Temp\\ipykernel_23064\\1572292890.py:1: FutureWarning: The frame.append method is deprecated and will be removed from pandas in a future version. Use pandas.concat instead.\n", " embeds = pd.DataFrame(train_data['embed'].tolist()).append(pd.DataFrame(val_data['embed'].tolist()), ignore_index=True)\n", "C:\\Users\\javid\\AppData\\Local\\Temp\\ipykernel_23064\\1572292890.py:2: FutureWarning: The frame.append method is deprecated and will be removed from pandas in a future version. Use pandas.concat instead.\n", " labels = pd.DataFrame(train_data['label']).append(pd.DataFrame(val_data['label']), ignore_index=True)\n" ] } ], "source": [ "embeds = pd.DataFrame(train_data['embed'].tolist()).append(pd.DataFrame(val_data['embed'].tolist()), ignore_index=True)\n", "labels = pd.DataFrame(train_data['label']).append(pd.DataFrame(val_data['label']), ignore_index=True)" ] }, { "cell_type": "code", "execution_count": 35, "metadata": {}, "outputs": [], "source": [ "embeds.to_csv('./all_embeds.tsv', sep='\\t', index=False, header=False)" ] }, { "cell_type": "code", "execution_count": 36, "metadata": {}, "outputs": [], "source": [ "labels.to_csv('./all_labels.tsv', sep='\\t', index=False, header=False)" ] }, { "cell_type": "code", "execution_count": 126, "metadata": {}, "outputs": [], "source": [ "torch.save(litmodel.vae.encode.state_dict(), './models/encoder.pth')" ] }, { "cell_type": "code", "execution_count": 129, "metadata": {}, "outputs": [], "source": [ "im = train_dataset[0]" ] }, { "cell_type": "code", "execution_count": 132, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 132, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model = Encoder()\n", "model.load_state_dict(torch.load('./models/encoder.pth'))" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.0" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }