From 9711bb999df787edc04fbe488368a9dbaf241311 Mon Sep 17 00:00:00 2001 From: Yangqing Jia Date: Sun, 29 Sep 2013 17:58:03 -0700 Subject: [PATCH] convert scripts --- src/caffe/pyutil/convert.py | 14 +++++++-- src/caffe/pyutil/convert_ilsvrc.py | 63 -------------------------------------- 2 files changed, 12 insertions(+), 65 deletions(-) delete mode 100644 src/caffe/pyutil/convert_ilsvrc.py diff --git a/src/caffe/pyutil/convert.py b/src/caffe/pyutil/convert.py index 56650c6..efcea42 100644 --- a/src/caffe/pyutil/convert.py +++ b/src/caffe/pyutil/convert.py @@ -12,7 +12,17 @@ def blobproto_to_array(blob): def array_to_blobproto(arr): if arr.ndim != 4: raise ValueError('Incorrect array shape.') - blob = caffe_pb2.Blob() + blob = caffe_pb2.BlobProto() blob.num, blob.channels, blob.height, blob.width = arr.shape; blob.data.extend(arr.flat) - return blob \ No newline at end of file + return blob + +def array_to_datum(arr): + if arr.ndim != 3: + raise ValueError('Incorrect array shape.') + if arr.dtype != np.uint8: + raise TypeError('Input array has to be of type uint8.') + datum = caffe_pb2.Datum() + datum.channels, datum.height, datum.width = arr.shape + datum.data = arr.tostring() + return datum diff --git a/src/caffe/pyutil/convert_ilsvrc.py b/src/caffe/pyutil/convert_ilsvrc.py deleted file mode 100644 index c494c01..0000000 --- a/src/caffe/pyutil/convert_ilsvrc.py +++ /dev/null @@ -1,63 +0,0 @@ -"""This script converts images stored in the ILSVRC format to a leveldb, -converting every image to a 256*256 image as well as converting them to channel -first storage. The output will be shuffled - so that a sequential read will -result in pseudo-random minibatches. -""" - -from decaf import util -from decaf.util import transform -import glob -import leveldb -import numpy as np -import random -import os -from skimage import io -import sys - -from caffe.proto import caffe_pb2 - -def main(argv): - root = argv[0] - db_name = argv[1] - db = leveldb.LevelDB(db_name) - synsets = glob.glob(os.path.join(root, "n????????")) - synsets.sort() - print 'A total of %d synsets' % len(synsets) - all_files = [glob.glob(os.path.join(root, synset, "*.JPEG")) - for synset in synsets] - all_labels = [[i] * len(files) for i, files in enumerate(all_files)] - all_files = sum(all_files, []) - all_labels = sum(all_labels, []) - print 'A total of %d files' % len(all_files) - random_indices = list(range(len(all_files))) - random.shuffle(random_indices) - datum = caffe_pb2.Datum() - datum.blob.num = 1 - datum.blob.channels = 3 - datum.blob.height = 256 - datum.blob.width = 256 - my_timer = util.Timer() - batch = leveldb.WriteBatch() - for i in range(len(all_files))[:1281]: - filename = all_files[random_indices[i]] - basename = os.path.basename(filename) - label = all_labels[random_indices[i]] - image = io.imread(filename) - image = transform.scale_and_extract(transform.as_rgb(image), 256) - image = np.ascontiguousarray(image.swapaxes(1,2).swapaxes(0,1)) - del datum.blob.data[:] - datum.blob.data.extend(list(image.flatten())) - datum.label = label - batch.Put('%d_%d_%s' % (i, label, basename), - datum.SerializeToString()) - print '(%d %s) Wrote file %s' % (i, my_timer.total(), basename) - if (i % 256 and i > 0): - # write and start a new batch - db.Write(batch) - batch = leveldb.WriteBatch() - -if __name__ == '__main__': - if len(sys.argv) != 3: - print 'Usage: convert_ilsvrc.py DATA_ROOT OUTPUT_DB' - else: - main(sys.argv[1:]) \ No newline at end of file -- 2.7.4