File size: 2,725 Bytes
fd52b7f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import os
from PIL import Image
import pandas as pd

import argparse

parser = argparse.ArgumentParser("Create a dataset for training with OML", 
                                 formatter_class=argparse.ArgumentDefaultsHelpFormatter)

parser.add_argument("--root-data-path", help="Path to images for dataset", default="data/train_val/")
parser.add_argument("--image-data-path", help="Image folder in root data path", default="images/")
parser.add_argument("--train-val-split", 
                    help="In which ratio to split data in format train:val (For example 80:20)", default="80:20")
parser.add_argument("--separator", 
                    help="What separator is used in image name to separate class name and instance (E.g. circle1_5, separator=_)",
                    default="_")

args = parser.parse_args()
config = vars(args)

root_path = config["root_data_path"]
image_path = config["image_data_path"]
separator = config["separator"]

train_prc, val_prc = tuple(int(num)/100 for num in config["train_val_split"].split(":"))

class_names = set()
for image in os.listdir(root_path+image_path):
    if image.endswith(("png", "jpg", "bmp", "webp")):
        img_name = image.split(".")[0]
        Image.open(root_path+image_path+image).resize((224,224)).save(root_path+image_path+img_name+".png", "PNG")
        if not image.endswith("png"):
            os.remove(root_path+image_path+image)
        img_name = img_name.split(separator)
        class_name = img_name[0]+img_name[1]
        class_names.add(class_name)
    else:
        print("Not all of the images are in supported format")
    

#For each class in set assign its index in a set as a class label.
class_label_dict = {}
for ind, name in enumerate(class_names):
    class_label_dict[name] = ind

class_count = len(class_names)
train_class_count = int(class_count*train_prc)
print(train_class_count)

df_dict = {"label": [],
           "path": [],
           "split": [],
           "is_query": [],
           "is_gallery": []}
for image in os.listdir(root_path+image_path):
    if image.endswith((".png", ".jpg", ".bmp", ".webp")):
        img_name = image.split(".")[0].split(separator)
        class_name = img_name[0]+img_name[1]
        label = class_label_dict[class_name]
        path = image_path+image
        split = "train" if label <= train_class_count else "validation"
        is_query, is_gallery = (1, 1) if split=="validation" else (None, None)
        df_dict["label"].append(label)
        df_dict["path"].append(path)
        df_dict["split"].append(split)
        df_dict["is_query"].append(is_query)
        df_dict["is_gallery"].append(is_gallery)

df = pd.DataFrame(df_dict)

df.to_csv(root_path+"df_stamps.csv", index=False)