import torch from torch.utils.data import Dataset, DataLoader import numpy as np from sklearn.model_selection import train_test_split import albumentations as A from albumentations.pytorch.transforms import ToTensorV2 from PIL import Image from pathlib import Path from random import randint from utils import * """ Dataset class for storing stamps data. Arguments: data -- list of dictionaries containing file_path (path to the image), box_nb (number of boxes on the image), and boxes of shape (4,) image_folder -- path to folder containing images transforms -- transforms from albumentations package """ class StampDataset(Dataset): def __init__( self, data=read_data(), image_folder=Path(IMAGE_FOLDER), transforms=None): self.data = data self.image_folder = image_folder self.transforms = transforms def __getitem__(self, idx): item = self.data[idx] image_fn = self.image_folder / item['file_path'] boxes = item['boxes'] box_nb = item['box_nb'] labels = torch.zeros((box_nb, 2), dtype=torch.int64) labels[:, 0] = 1 img = np.array(Image.open(image_fn)) try: if self.transforms: sample = self.transforms(**{ "image":img, "bboxes": boxes, "labels": labels, }) img = sample['image'] boxes = torch.stack(tuple(map(torch.tensor, zip(*sample['bboxes'])))).permute(1, 0) except: return self.__getitem__(randint(0, len(self.data)-1)) target_tensor = boxes_to_tensor(boxes.type(torch.float32)) return img, target_tensor def __len__(self): return len(self.data) def collate_fn(batch): return tuple(zip(*batch)) def get_datasets(data_path=ANNOTATIONS_PATH, train_transforms=None, val_transforms=None): """ Creates StampDataset objects. Arguments: data_path -- string or Path, specifying path to annotations file train_transforms -- transforms to be applied during training val_transforms -- transforms to be applied during validation Returns: (train_dataset, val_dataset) -- tuple of StampDataset for training and validation """ data = read_data(data_path) if train_transforms is None: train_transforms = A.Compose([ A.RandomCropNearBBox(max_part_shift=0.6, p=0.4), A.Resize(height=448, width=448), A.HorizontalFlip(p=0.5), A.VerticalFlip(p=0.5), # A.Affine(scale=(0.9, 1.1), translate_percent=(0.05, 0.1), rotate=(-45, 45), shear=(-30, 30), p=0.3), # A.Blur(blur_limit=4, p=0.3), A.Normalize(), ToTensorV2(p=1.0), ], bbox_params={ "format":"coco", 'label_fields': ['labels'] }) if val_transforms is None: val_transforms = A.Compose([ A.Resize(height=448, width=448), A.Normalize(), ToTensorV2(p=1.0), ], bbox_params={ "format":"coco", 'label_fields': ['labels'] }) train, test_data = train_test_split(data, test_size=0.1, shuffle=True) train_data, val_data = train_test_split(train, test_size=0.2, shuffle=True) train_dataset = StampDataset(train_data, transforms=train_transforms) val_dataset = StampDataset(val_data, transforms=val_transforms) test_dataset = StampDataset(test_data, transforms=val_transforms) return train_dataset, val_dataset, test_dataset def get_loaders(batch_size=8, data_path=ANNOTATIONS_PATH, num_workers=0, train_transforms=None, val_transforms=None): """ Creates StampDataset objects. Arguments: batch_size -- integer specifying the number of images in the batch data_path -- string or Path, specifying path to annotations file train_transforms -- transforms to be applied during training val_transforms -- transforms to be applied during validation Returns: (train_loader, val_loader) -- tuple of DataLoader for training and validation """ train_dataset, val_dataset, _ = get_datasets(data_path) train_loader = DataLoader( train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, collate_fn=collate_fn, drop_last=True) val_loader = DataLoader( val_dataset, batch_size=batch_size, collate_fn=collate_fn) return train_loader, val_loader