multimodalart's picture
Upload folder using huggingface_hub
a891a57 verified
raw
history blame
No virus
2.51 kB
import pickle
import numpy as np
import os
import os.path as osp
import sys
import mxnet as mx
class RecBuilder():
def __init__(self, path, image_size=(112, 112)):
self.path = path
self.image_size = image_size
self.widx = 0
self.wlabel = 0
self.max_label = -1
assert not osp.exists(path), '%s exists' % path
os.makedirs(path)
self.writer = mx.recordio.MXIndexedRecordIO(os.path.join(path, 'train.idx'),
os.path.join(path, 'train.rec'),
'w')
self.meta = []
def add(self, imgs):
#!!! img should be BGR!!!!
#assert label >= 0
#assert label > self.last_label
assert len(imgs) > 0
label = self.wlabel
for img in imgs:
idx = self.widx
image_meta = {'image_index': idx, 'image_classes': [label]}
header = mx.recordio.IRHeader(0, label, idx, 0)
if isinstance(img, np.ndarray):
s = mx.recordio.pack_img(header,img,quality=95,img_fmt='.jpg')
else:
s = mx.recordio.pack(header, img)
self.writer.write_idx(idx, s)
self.meta.append(image_meta)
self.widx += 1
self.max_label = label
self.wlabel += 1
def add_image(self, img, label):
#!!! img should be BGR!!!!
#assert label >= 0
#assert label > self.last_label
idx = self.widx
header = mx.recordio.IRHeader(0, label, idx, 0)
if isinstance(label, list):
idlabel = label[0]
else:
idlabel = label
image_meta = {'image_index': idx, 'image_classes': [idlabel]}
if isinstance(img, np.ndarray):
s = mx.recordio.pack_img(header,img,quality=95,img_fmt='.jpg')
else:
s = mx.recordio.pack(header, img)
self.writer.write_idx(idx, s)
self.meta.append(image_meta)
self.widx += 1
self.max_label = max(self.max_label, idlabel)
def close(self):
with open(osp.join(self.path, 'train.meta'), 'wb') as pfile:
pickle.dump(self.meta, pfile, protocol=pickle.HIGHEST_PROTOCOL)
print('stat:', self.widx, self.wlabel)
with open(os.path.join(self.path, 'property'), 'w') as f:
f.write("%d,%d,%d\n" % (self.max_label+1, self.image_size[0], self.image_size[1]))
f.write("%d\n" % (self.widx))