In [4]:
import torch
import torch.nn as nn
import numpy as np

from pathlib import Path
import os
from PIL import Image

from model import VAE
from losses import *

In [2]:
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
import pandas as pd
import re
from sklearn.model_selection import train_test_split

In [1]:
IMAGE_FOLDER = './data/images/'

In [5]:
image_names = os.listdir(IMAGE_FOLDER)
data = pd.DataFrame({'image_name': image_names})
data['label'] = data['image_name'].apply(lambda x: int(re.match('^\d+', x)[0]))

In [None]:
class StampDataset(Dataset):
 def __init__(self, data, image_folder=Path(IMAGE_FOLDER), transform=None):
 super().__init__()
 self.image_folder = image_folder
 self.data = data
 self.transform = transform

 def __getitem__(self, idx):
 image = Image.open(self.image_folder / self.data.iloc[idx]['image_name'])
 label = self.data.iloc[idx]['label']
 if self.transform:
 image = self.transform(image)

 return image, label

 
 def __len__(self):
 return len(self.data)

In [6]:
train_data, val_data = train_test_split(data, test_size=0.3, shuffle=True, stratify=data['label'])

In [None]:
train_transform = transforms.Compose([
 transforms.Resize((118, 118)),
 transforms.RandomHorizontalFlip(0.5),
 transforms.RandomVerticalFlip(0.5),
 transforms.ToTensor(),
 # transforms.Normalize((0.76302232, 0.77820438, 0.81879729), (0.16563211, 0.14949341, 0.1055889)),
])

val_transform = transforms.Compose([
 transforms.Resize((118, 118)),
 transforms.ToTensor(),
 # transforms.Normalize((0.76302232, 0.77820438, 0.81879729), (0.16563211, 0.14949341, 0.1055889)),
])
train_dataset = StampDataset(train_data, transform=train_transform)
val_dataset = StampDataset(val_data, transform=val_transform)

train_loader = DataLoader(train_dataset, shuffle=True, batch_size=256)
val_loader = DataLoader(val_dataset, shuffle=True, batch_size=256)

In [8]:
import pytorch_lightning as pl
from torch import optim
from pytorch_lightning.loggers import TensorBoardLogger

from torchvision.utils import make_grid

In [9]:
MEAN = torch.tensor((0.76302232, 0.77820438, 0.81879729)).view(3, 1, 1)
STD = torch.tensor((0.16563211, 0.14949341, 0.1055889)).view(3, 1, 1)

In [9]:
class LitModel(pl.LightningModule):
 def __init__(self, alpha=1e-3):
 super().__init__()
 self.vae = VAE()
 self.vae_loss = VAELoss()
 self.triplet_loss = BatchHardTripletLoss(margin=1.)
 self.alpha = alpha
 
 def forward(self, x):
 return self.vae(x)
 
 def configure_optimizers(self):
 optimizer = optim.AdamW(self.parameters(), lr=3e-4)
 return optimizer
 # scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, 1000)
 # return {"optimizer": optimizer, "lr_scheduler": scheduler}

 def training_step(self, batch, batch_idx):
 images, labels = batch
 labels = labels.unsqueeze(1)
 recon_images, encoding = self.vae(images)
 vae_loss = self.vae_loss(recon_images, images, encoding)
 self.log("train_vae_loss", vae_loss, on_step=False, on_epoch=True, prog_bar=False, logger=True)
 triplet_loss = self.triplet_loss(encoding.mean, labels)
 self.log("train_triplet_loss", triplet_loss, on_step=False, on_epoch=True, prog_bar=False, logger=True)
 loss = self.alpha * triplet_loss + vae_loss
 self.log("train_loss", loss, on_epoch=True, prog_bar=True, logger=True)
 return loss

 def validation_step(self, batch, batch_idx):
 images, labels = batch
 labels = labels.unsqueeze(1)
 recon_images, encoding = self.vae(images)
 vae_loss = self.vae_loss(recon_images, images, encoding)
 self.log("val_vae_loss", vae_loss, on_step=False, on_epoch=True, prog_bar=False, logger=True)
 triplet_loss = self.triplet_loss(encoding.mean, labels)
 self.log("val_triplet_loss", triplet_loss, on_step=False, on_epoch=True, prog_bar=False, logger=True)
 loss = self.alpha * triplet_loss + vae_loss
 self.log("val_loss", loss, on_epoch=True, prog_bar=True, logger=True)
 return loss

 def on_validation_epoch_end(self):
 images, _ = iter(val_loader).next()
 image_unflat = images.detach().cpu()
 image_grid = make_grid(image_unflat[:16], nrow=4)
 self.logger.experiment.add_image('original images', image_grid, self.current_epoch)

 recon_images, _ = self.vae(images.to(self.device))
 image_unflat = recon_images.detach().cpu()
 image_grid = make_grid(image_unflat[:16], nrow=4)
 self.logger.experiment.add_image('reconstructed images', image_grid, self.current_epoch)

In [10]:
litmodel = LitModel()

In [11]:
logger = TensorBoardLogger("reconstruction_logs")

In [12]:
epochs = 100

In [None]:
trainer = pl.Trainer(accelerator="auto", max_epochs=epochs, logger=logger)
trainer.fit(model=litmodel, train_dataloaders=train_loader, val_dataloaders=val_loader)

In [None]:
%tensorboard

In [8]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [11]:
from huggingface_hub import hf_hub_download

In [12]:
emb_model = torch.jit.load(hf_hub_download(repo_id="stamps-labs/vits8-stamp", filename="vits8stamp-torchscript.pth")).to(device)

In [21]:
val_transform = transforms.Compose([
 transforms.Resize((224, 224)),
 transforms.ToTensor(),
 # transforms.Normalize((0.76302232, 0.77820438, 0.81879729), (0.16563211, 0.14949341, 0.1055889)),
])

In [28]:
train_data['embed'] = train_data['image_name'].apply(lambda x: emb_model(val_transform(Image.open(Path(IMAGE_FOLDER) / x)).unsqueeze(0).to(device))[0].tolist())

In [34]:
embeds = pd.DataFrame(train_data['embed'].tolist()).append(pd.DataFrame(val_data['embed'].tolist()), ignore_index=True)
labels = pd.DataFrame(train_data['label']).append(pd.DataFrame(val_data['label']), ignore_index=True)

 embeds = pd.DataFrame(train_data['embed'].tolist()).append(pd.DataFrame(val_data['embed'].tolist()), ignore_index=True)
 labels = pd.DataFrame(train_data['label']).append(pd.DataFrame(val_data['label']), ignore_index=True)


In [35]:
embeds.to_csv('./all_embeds.tsv', sep='\t', index=False, header=False)

In [36]:
labels.to_csv('./all_labels.tsv', sep='\t', index=False, header=False)

In [126]:
torch.save(litmodel.vae.encode.state_dict(), './models/encoder.pth')

In [129]:
im = train_dataset[0]

In [132]:
model = Encoder()
model.load_state_dict(torch.load('./models/encoder.pth'))

