ydshieh commited on
Commit
9628bd8
1 Parent(s): 4c37b45

Add dataset script

Browse files
Files changed (1) hide show
  1. image_caption_dataset.py +187 -0
image_caption_dataset.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ import json
3
+ import os
4
+
5
+ import datasets
6
+ import pandas as pd
7
+ import numpy as np
8
+
9
+
10
+ class ImageCaptionBuilderConfig(datasets.BuilderConfig):
11
+
12
+ def __init__(self, name, splits, langs, prefix_before_image_fn=False, zfill=1, **kwargs):
13
+
14
+ super().__init__(name, **kwargs)
15
+
16
+ self.splits = splits
17
+ self.langs = langs
18
+ self.prefix_before_image_fn = prefix_before_image_fn
19
+ self.zfill = zfill
20
+
21
+
22
+ # TODO: Add BibTeX citation
23
+ # Find for instance the citation on arxiv or on the dataset repo/website
24
+ _CITATION = """\
25
+ @InProceedings{None,
26
+ title = {Generic images to captions dataset},
27
+ author={Yih-Dar SHIEH},
28
+ year={2020}
29
+ }
30
+ """
31
+
32
+ # TODO: Add description of the dataset here
33
+ # You can copy an official description
34
+ _DESCRIPTION = """\
35
+
36
+ """
37
+
38
+ # TODO: Add a link to an official homepage for the dataset here
39
+ _HOMEPAGE = ""
40
+
41
+ # TODO: Add the licence for the dataset here if you can find it
42
+ _LICENSE = ""
43
+
44
+ # TODO: Add link to the official dataset URLs here
45
+ # The HuggingFace dataset library don't host the datasets but only point to the original files
46
+ # This can be an arbitrary nested dict/list of URLs (see below in `_split_generators` method)
47
+ _URLs = {}
48
+
49
+
50
+ # TODO: Name of the dataset usually match the script name with CamelCase instead of snake_case
51
+ class ImageCaptionDataset(datasets.GeneratorBasedBuilder):
52
+ """TODO: Short description of my dataset."""
53
+
54
+ VERSION = datasets.Version("0.0.0")
55
+
56
+ BUILDER_CONFIG_CLASS = ImageCaptionBuilderConfig
57
+ BUILDER_CONFIGS = [
58
+ ImageCaptionBuilderConfig(name='coco_2017', splits=['train', 'valid'], prefix_before_image_fn=False, zfill=12, langs=['en', 'fr']),
59
+ ImageCaptionBuilderConfig(name='cc3m', splits=['train', 'valid'], prefix_before_image_fn=True, zfill=8, langs=['en', 'fr']),
60
+ ImageCaptionBuilderConfig(name='cc12m', splits=['train', 'valid'], prefix_before_image_fn=True, zfill=8, langs=['en', 'fr'])
61
+ ]
62
+ DEFAULT_CONFIG_NAME = "coco_2017"
63
+
64
+ def _info(self):
65
+ # TODO: This method specifies the datasets.DatasetInfo object which contains informations and typings for the dataset
66
+
67
+ feature_dict = {
68
+ "image_id": datasets.Value("int64"),
69
+ "id": datasets.Value("int64"),
70
+ "caption": datasets.Value("string"),
71
+ }
72
+ for lang in self.config.langs:
73
+ feature_dict[lang] = datasets.Value("string")
74
+ feature_dict["image_url"] = datasets.Value("string")
75
+ feature_dict["image_file"] = datasets.Value("string")
76
+
77
+ features = datasets.Features(feature_dict)
78
+
79
+ return datasets.DatasetInfo(
80
+ # This is the description that will appear on the datasets page.
81
+ description=_DESCRIPTION,
82
+ # This defines the different columns of the dataset and their types
83
+ features=features, # Here we define them above because they are different between the two configurations
84
+ # If there's a common (input, target) tuple from the features,
85
+ # specify them here. They'll be used if as_supervised=True in
86
+ # builder.as_dataset.
87
+ supervised_keys=None,
88
+ # Homepage of the dataset for documentation
89
+ homepage=_HOMEPAGE,
90
+ # License for the dataset if available
91
+ license=_LICENSE,
92
+ # Citation for the dataset
93
+ citation=_CITATION,
94
+ )
95
+
96
+ def _split_generators(self, dl_manager):
97
+ """Returns SplitGenerators."""
98
+ # TODO: This method is tasked with downloading/extracting the data and defining the splits depending on the configuration
99
+ # If several configurations are possible (listed in BUILDER_CONFIGS), the configuration selected by the user is in self.config.name
100
+
101
+ data_dir = self.config.data_dir
102
+
103
+ splits = []
104
+ for split in self.config.splits:
105
+ if split == 'train':
106
+ dataset = datasets.SplitGenerator(
107
+ name=datasets.Split.TRAIN,
108
+ # These kwargs will be passed to _generate_examples
109
+ gen_kwargs={
110
+ "jsonl_dir": os.path.join(data_dir, f'{self.config.name}_jsonls', 'train'),
111
+ "image_dir": os.path.join(data_dir, f'{self.config.name}_images', 'train'),
112
+ "split": "train",
113
+ }
114
+ )
115
+ elif split in ['val', 'valid', 'validation', 'dev']:
116
+ dataset = datasets.SplitGenerator(
117
+ name=datasets.Split.VALIDATION,
118
+ # These kwargs will be passed to _generate_examples
119
+ gen_kwargs={
120
+ "jsonl_dir": os.path.join(data_dir, f'{self.config.name}_jsonls', 'valid'),
121
+ "image_dir": os.path.join(data_dir, f'{self.config.name}_images', 'valid'),
122
+ "split": "valid",
123
+ },
124
+ )
125
+ elif split == 'test':
126
+ dataset = datasets.SplitGenerator(
127
+ name=datasets.Split.TEST,
128
+ # These kwargs will be passed to _generate_examples
129
+ gen_kwargs={
130
+ "jsonl_dir": os.path.join(data_dir, f'{self.config.name}_jsonls', 'test'),
131
+ "image_dir": os.path.join(data_dir, f'{self.config.name}_images', 'test'),
132
+ "split": "test",
133
+ },
134
+ )
135
+ else:
136
+ continue
137
+
138
+ splits.append(dataset)
139
+
140
+ return splits
141
+
142
+ def _generate_examples(
143
+ # method parameters are unpacked from `gen_kwargs` as given in `_split_generators`
144
+ self, jsonl_dir, image_dir, split
145
+ ):
146
+ """ Yields examples as (key, example) tuples. """
147
+ # This method handles input defined in _split_generators to yield (key, example) tuples from the dataset.
148
+ # The `key` is here for legacy reason (tfds) and is not important in itself.
149
+
150
+ if split == 'dev':
151
+ split = 'valid'
152
+
153
+ fns = [os.path.join(jsonl_dir, fn) for fn in os.listdir(jsonl_dir) if os.path.isfile(os.path.join(jsonl_dir, fn)) and fn.endswith("jsonl")]
154
+
155
+ for jsonl_file in fns:
156
+
157
+ with open(jsonl_file, 'r', encoding='UTF-8') as fp:
158
+
159
+ for id_, line in enumerate(fp):
160
+
161
+ ex = json.loads(line)
162
+
163
+ example = {
164
+ "image_id": ex['image_id'],
165
+ "id": ex["id"],
166
+ "caption": ex["caption"],
167
+ }
168
+
169
+ for lang in self.config.langs:
170
+ example[lang] = ex[lang]
171
+
172
+ if 'image_url' in ex:
173
+ example['image_url'] = ex['image_url']
174
+ else:
175
+ example['image_url'] = ''
176
+
177
+ fn = f'{str(ex["image_id"]).zfill(self.config.zfill)}.jpg'
178
+ if self.config.prefix_before_image_fn:
179
+ fn = f'{self.config.name}_{split}_' + fn
180
+
181
+ image_file = os.path.join(image_dir, fn)
182
+ example['image_file'] = image_file
183
+
184
+ if not os.path.isfile(image_file):
185
+ continue
186
+
187
+ yield id_, example