misc update
authorYangqing Jia <jiayq84@gmail.com>
Tue, 1 Oct 2013 23:27:48 +0000 (16:27 -0700)
committerYangqing Jia <jiayq84@gmail.com>
Tue, 1 Oct 2013 23:27:48 +0000 (16:27 -0700)
.gitignore
src/Makefile
src/caffe/layers/loss_layer.cu
src/caffe/layers/softmax_layer.cpp [deleted file]
src/caffe/layers/softmax_layer.cu [new file with mode: 0644]
src/caffe/test/test_solver_mnist.cpp [deleted file]
src/caffe/util/io.cpp
src/caffe/util/io.hpp
src/caffe/vision_layers.hpp
src/programs/convert_dataset.cpp [new file with mode: 0644]

index 14428f6..bc38afc 100644 (file)
@@ -19,8 +19,9 @@
 *.pb.cc
 *_pb2.py
 
-# test files
+# bin files
 *.testbin
+*.bin
 
 # vim swp files
 *.swp
index 05b7bc0..31d225e 100644 (file)
@@ -79,8 +79,8 @@ $(PROTO_GEN_CC): $(PROTO_SRCS)
        protoc $(PROTO_SRCS) --cpp_out=. --python_out=.
 
 clean:
-       @- $(RM) $(NAME) $(TEST_BINS)
-       @- $(RM) $(OBJS) $(TEST_OBJS) 
+       @- $(RM) $(NAME) $(TEST_BINS) $(PROGRAM_BINS)
+       @- $(RM) $(OBJS) $(TEST_OBJS) $(PROGRAM_OBJS)
        @- $(RM) $(PROTO_GEN_HEADER) $(PROTO_GEN_CC) $(PROTO_GEN_PY)
 
 distclean: clean
index 1ea0626..737f1a2 100644 (file)
@@ -47,6 +47,7 @@ Dtype MultinomialLogisticLossLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>
 
 // TODO: implement the GPU version for multinomial loss
 
+
 template <typename Dtype>
 void EuclideanLossLayer<Dtype>::SetUp(
   const vector<Blob<Dtype>*>& bottom, vector<Blob<Dtype>*>* top) {
diff --git a/src/caffe/layers/softmax_layer.cpp b/src/caffe/layers/softmax_layer.cpp
deleted file mode 100644 (file)
index a263ad3..0000000
+++ /dev/null
@@ -1,91 +0,0 @@
-// Copyright 2013 Yangqing Jia
-
-#include <algorithm>
-#include <vector>
-
-#include "caffe/layer.hpp"
-#include "caffe/vision_layers.hpp"
-#include "caffe/util/math_functions.hpp"
-
-using std::max;
-
-namespace caffe {
-
-template <typename Dtype>
-void SoftmaxLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
-      vector<Blob<Dtype>*>* top) {
-  CHECK_EQ(bottom.size(), 1) << "Softmax Layer takes a single blob as input.";
-  CHECK_EQ(top->size(), 1) << "Softmax Layer takes a single blob as output.";
-  (*top)[0]->Reshape(bottom[0]->num(), bottom[0]->channels(),
-      bottom[0]->height(), bottom[0]->width());
-  sum_multiplier_.Reshape(1, bottom[0]->channels(),
-      bottom[0]->height(), bottom[0]->width());
-  Dtype* multiplier_data = sum_multiplier_.mutable_cpu_data();
-  for (int i = 0; i < sum_multiplier_.count(); ++i) {
-    multiplier_data[i] = 1.;
-  }
-  scale_.Reshape(bottom[0]->num(), 1, 1, 1);
-};
-
-template <typename Dtype>
-void SoftmaxLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
-    vector<Blob<Dtype>*>* top) {
-  const Dtype* bottom_data = bottom[0]->cpu_data();
-  Dtype* top_data = (*top)[0]->mutable_cpu_data();
-  Dtype* scale_data = scale_.mutable_cpu_data();
-  int num = bottom[0]->num();
-  int dim = bottom[0]->count() / bottom[0]->num();
-  memcpy(top_data, bottom_data, sizeof(Dtype) * bottom[0]->count());
-  // we need to subtract the max to avoid numerical issues, compute the exp,
-  // and then normalize.
-  // Compute sum
-  for (int i = 0; i < num; ++i) {
-    scale_data[i] = bottom_data[i*dim];
-    for (int j = 0; j < dim; ++j) {
-      scale_data[i] = max(scale_data[i], bottom_data[i * dim + j]);
-    }
-  }
-  // subtraction
-  caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, num, dim, 1, -1.,
-    scale_data, sum_multiplier_.cpu_data(), 1., top_data);
-  // Perform exponentiation
-  caffe_exp<Dtype>(num * dim, top_data, top_data);
-  // sum after exp
-  caffe_cpu_gemv<Dtype>(CblasNoTrans, num, dim, 1., top_data,
-      sum_multiplier_.cpu_data(), 0., scale_data);
-  // Do division
-  for (int i = 0; i < num; ++i) {
-    caffe_scal<Dtype>(dim, Dtype(1.) / scale_data[i], top_data + i * dim);
-  }
-}
-
-template <typename Dtype>
-Dtype SoftmaxLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
-    const bool propagate_down,
-    vector<Blob<Dtype>*>* bottom) {
-  const Dtype* top_diff = top[0]->cpu_diff();
-  const Dtype* top_data = top[0]->cpu_data();
-  Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff();
-  Dtype* scale_data = scale_.mutable_cpu_data();
-  int num = top[0]->num();
-  int dim = top[0]->count() / top[0]->num();
-  memcpy(bottom_diff, top_diff, sizeof(Dtype) * top[0]->count());
-  // Compute inner1d(top_diff, top_data) and subtract them from the bottom diff
-  for (int i = 0; i < num; ++i) {
-    scale_data[i] = caffe_cpu_dot<Dtype>(dim, top_diff + i * dim,
-        top_data + i * dim);
-  }
-  // subtraction
-  caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, num, dim, 1, -1.,
-      scale_data, sum_multiplier_.cpu_data(), 1., bottom_diff);
-  // elementwise multiplication
-  caffe_mul<Dtype>(top[0]->count(), bottom_diff, top_data, bottom_diff);
-  return Dtype(0);
-}
-
-// TODO(Yangqing): implement the GPU version of softmax.
-
-INSTANTIATE_CLASS(SoftmaxLayer);
-
-
-}  // namespace caffe
diff --git a/src/caffe/layers/softmax_layer.cu b/src/caffe/layers/softmax_layer.cu
new file mode 100644 (file)
index 0000000..a765969
--- /dev/null
@@ -0,0 +1,181 @@
+// Copyright 2013 Yangqing Jia
+
+#include <algorithm>
+#include <cfloat>
+#include <vector>
+#include <thrust/device_vector.h>
+
+#include "caffe/layer.hpp"
+#include "caffe/vision_layers.hpp"
+#include "caffe/util/math_functions.hpp"
+
+using std::max;
+
+namespace caffe {
+
+template <typename Dtype>
+void SoftmaxLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top) {
+  CHECK_EQ(bottom.size(), 1) << "Softmax Layer takes a single blob as input.";
+  CHECK_EQ(top->size(), 1) << "Softmax Layer takes a single blob as output.";
+  (*top)[0]->Reshape(bottom[0]->num(), bottom[0]->channels(),
+      bottom[0]->height(), bottom[0]->width());
+  sum_multiplier_.Reshape(1, bottom[0]->channels(),
+      bottom[0]->height(), bottom[0]->width());
+  Dtype* multiplier_data = sum_multiplier_.mutable_cpu_data();
+  for (int i = 0; i < sum_multiplier_.count(); ++i) {
+    multiplier_data[i] = 1.;
+  }
+  scale_.Reshape(bottom[0]->num(), 1, 1, 1);
+};
+
+template <typename Dtype>
+void SoftmaxLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+    vector<Blob<Dtype>*>* top) {
+  const Dtype* bottom_data = bottom[0]->cpu_data();
+  Dtype* top_data = (*top)[0]->mutable_cpu_data();
+  Dtype* scale_data = scale_.mutable_cpu_data();
+  int num = bottom[0]->num();
+  int dim = bottom[0]->count() / bottom[0]->num();
+  memcpy(top_data, bottom_data, sizeof(Dtype) * bottom[0]->count());
+  // we need to subtract the max to avoid numerical issues, compute the exp,
+  // and then normalize.
+  for (int i = 0; i < num; ++i) {
+    scale_data[i] = bottom_data[i*dim];
+    for (int j = 0; j < dim; ++j) {
+      scale_data[i] = max(scale_data[i], bottom_data[i * dim + j]);
+    }
+  }
+  // subtraction
+  caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, num, dim, 1, -1.,
+    scale_data, sum_multiplier_.cpu_data(), 1., top_data);
+  // Perform exponentiation
+  caffe_exp<Dtype>(num * dim, top_data, top_data);
+  // sum after exp
+  caffe_cpu_gemv<Dtype>(CblasNoTrans, num, dim, 1., top_data,
+      sum_multiplier_.cpu_data(), 0., scale_data);
+  // Do division
+  for (int i = 0; i < num; ++i) {
+    caffe_scal<Dtype>(dim, Dtype(1.) / scale_data[i], top_data + i * dim);
+  }
+}
+
+template <typename Dtype>
+__global__ void kernel_get_max(const int num, const int dim,
+    const Dtype* data, Dtype* out) {
+  int index = threadIdx.x + blockIdx.x * blockDim.x;
+  if (index < num) {
+    Dtype maxval = -FLT_MAX;
+    for (int i = 0; i < dim; ++i) {
+      maxval = max(data[index * dim + i], maxval);
+    }
+    out[index] = maxval;
+  }
+}
+
+template <typename Dtype>
+__global__ void kernel_softmax_div(const int num, const int dim,
+    const Dtype* scale, Dtype* data) {
+  int index = threadIdx.x + blockIdx.x * blockDim.x;
+  if (index < num * dim) {
+    int n = index / dim;
+    data[index] /= scale[n];
+  }
+}
+
+template <typename Dtype>
+__global__ void kernel_exp(const int num, const Dtype* data, Dtype* out) {
+  int index = threadIdx.x + blockIdx.x * blockDim.x;
+  if (index < num) {
+    out[index] = exp(data[index]);
+  }
+}
+
+template <typename Dtype>
+void SoftmaxLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+    vector<Blob<Dtype>*>* top) {
+  const Dtype* bottom_data = bottom[0]->gpu_data();
+  Dtype* top_data = (*top)[0]->mutable_gpu_data();
+  Dtype* scale_data = scale_.mutable_gpu_data();
+  int num = bottom[0]->num();
+  int dim = bottom[0]->count() / bottom[0]->num();
+  CUDA_CHECK(cudaMemcpy(top_data, bottom_data,
+      sizeof(Dtype) * bottom[0]->count(), cudaMemcpyDeviceToDevice));
+  // we need to subtract the max to avoid numerical issues, compute the exp,
+  // and then normalize.
+  // Compute max
+  kernel_get_max<Dtype><<<CAFFE_GET_BLOCKS(num), CAFFE_CUDA_NUM_THREADS>>>(
+      num, dim, bottom_data, scale_data);
+  // subtraction
+  caffe_gpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, num, dim, 1, -1.,
+      scale_data, sum_multiplier_.gpu_data(), 1., top_data);
+  // Perform exponentiation
+  kernel_exp<Dtype><<<CAFFE_GET_BLOCKS(num * dim), CAFFE_CUDA_NUM_THREADS>>>(
+      num * dim, top_data, top_data);
+  // sum after exp
+  caffe_gpu_gemv<Dtype>(CblasNoTrans, num, dim, 1., top_data,
+      sum_multiplier_.gpu_data(), 0., scale_data);
+  // Do division
+  kernel_softmax_div<Dtype><<<CAFFE_GET_BLOCKS(num * dim), CAFFE_CUDA_NUM_THREADS>>>(
+      num, dim, scale_data, top_data);
+}
+
+template <typename Dtype>
+Dtype SoftmaxLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
+    const bool propagate_down,
+    vector<Blob<Dtype>*>* bottom) {
+  const Dtype* top_diff = top[0]->cpu_diff();
+  const Dtype* top_data = top[0]->cpu_data();
+  Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff();
+  Dtype* scale_data = scale_.mutable_cpu_data();
+  int num = top[0]->num();
+  int dim = top[0]->count() / top[0]->num();
+  memcpy(bottom_diff, top_diff, sizeof(Dtype) * top[0]->count());
+  // Compute inner1d(top_diff, top_data) and subtract them from the bottom diff
+  for (int i = 0; i < num; ++i) {
+    scale_data[i] = caffe_cpu_dot<Dtype>(dim, top_diff + i * dim,
+        top_data + i * dim);
+  }
+  // subtraction
+  caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, num, dim, 1, -1.,
+      scale_data, sum_multiplier_.cpu_data(), 1., bottom_diff);
+  // elementwise multiplication
+  caffe_mul<Dtype>(top[0]->count(), bottom_diff, top_data, bottom_diff);
+  return Dtype(0);
+}
+
+// TODO(Yangqing): implement the GPU version of softmax.
+template <typename Dtype>
+Dtype SoftmaxLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
+    const bool propagate_down, vector<Blob<Dtype>*>* bottom) {
+  const Dtype* top_diff = top[0]->gpu_diff();
+  const Dtype* top_data = top[0]->gpu_data();
+  Dtype* bottom_diff = (*bottom)[0]->mutable_gpu_diff();
+  int num = top[0]->num();
+  int dim = top[0]->count() / top[0]->num();
+  CUDA_CHECK(cudaMemcpy(bottom_diff, top_diff,
+      sizeof(Dtype) * top[0]->count(), cudaMemcpyDeviceToDevice));
+  // Compute inner1d(top_diff, top_data) and subtract them from the bottom diff
+  // cuda dot returns the result to cpu, so we temporarily change the pointer
+  // mode
+  CUBLAS_CHECK(cublasSetPointerMode(Caffe::cublas_handle(),
+      CUBLAS_POINTER_MODE_DEVICE));
+  Dtype* scale_data = scale_.mutable_gpu_data();
+  for (int i = 0; i < num; ++i) {
+    caffe_gpu_dot<Dtype>(dim, top_diff + i * dim,
+        top_data + i * dim, scale_data + i);
+  }
+  CUBLAS_CHECK(cublasSetPointerMode(Caffe::cublas_handle(),
+      CUBLAS_POINTER_MODE_HOST));
+  // subtraction
+  caffe_gpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, num, dim, 1, -1.,
+      scale_.gpu_data(), sum_multiplier_.gpu_data(), 1., bottom_diff);
+  // elementwise multiplication
+  caffe_gpu_mul<Dtype>(top[0]->count(), bottom_diff, top_data, bottom_diff);
+  return Dtype(0);
+}
+
+INSTANTIATE_CLASS(SoftmaxLayer);
+
+
+}  // namespace caffe
diff --git a/src/caffe/test/test_solver_mnist.cpp b/src/caffe/test/test_solver_mnist.cpp
deleted file mode 100644 (file)
index 4c8d3fd..0000000
+++ /dev/null
@@ -1,108 +0,0 @@
-// Copyright 2013 Yangqing Jia
-
-#include <cuda_runtime.h>
-#include <fcntl.h>
-#include <google/protobuf/text_format.h>
-#include <google/protobuf/io/zero_copy_stream_impl.h>
-#include <gtest/gtest.h>
-
-#include <cstring>
-
-#include "caffe/blob.hpp"
-#include "caffe/common.hpp"
-#include "caffe/net.hpp"
-#include "caffe/filler.hpp"
-#include "caffe/proto/caffe.pb.h"
-#include "caffe/util/io.hpp"
-#include "caffe/optimization/solver.hpp"
-
-#include "caffe/test/test_caffe_main.hpp"
-
-namespace caffe {
-
-template <typename Dtype>
-class MNISTSolverTest : public ::testing::Test {};
-
-typedef ::testing::Types<float> Dtypes;
-TYPED_TEST_CASE(MNISTSolverTest, Dtypes);
-
-TYPED_TEST(MNISTSolverTest, TestSolve) {
-  Caffe::set_mode(Caffe::GPU);
-
-  NetParameter net_param;
-  ReadProtoFromTextFile("caffe/test/data/lenet.prototxt",
-      &net_param);
-  vector<Blob<TypeParam>*> bottom_vec;
-  Net<TypeParam> caffe_net(net_param, bottom_vec);
-
-  // Run the network without training.
-  LOG(ERROR) << "Performing Forward";
-  caffe_net.Forward(bottom_vec);
-  LOG(ERROR) << "Performing Backward";
-  LOG(ERROR) << "Initial loss: " << caffe_net.Backward();
-
-  SolverParameter solver_param;
-  solver_param.set_base_lr(0.01);
-  solver_param.set_display(0);
-  solver_param.set_max_iter(6000);
-  solver_param.set_lr_policy("inv");
-  solver_param.set_gamma(0.0001);
-  solver_param.set_power(0.75);
-  solver_param.set_momentum(0.9);
-
-  LOG(ERROR) << "Starting Optimization";
-  SGDSolver<TypeParam> solver(solver_param);
-  solver.Solve(&caffe_net);
-  LOG(ERROR) << "Optimization Done.";
-
-  // Run the network after training.
-  LOG(ERROR) << "Performing Forward";
-  caffe_net.Forward(bottom_vec);
-  LOG(ERROR) << "Performing Backward";
-  TypeParam loss = caffe_net.Backward();
-  LOG(ERROR) << "Final loss: " << loss;
-  EXPECT_LE(loss, 0.5);
-
-  NetParameter trained_net_param;
-  caffe_net.ToProto(&trained_net_param);
-  // LOG(ERROR) << "Writing to disk.";
-  // WriteProtoToBinaryFile(trained_net_param,
-  //     "caffe/test/data/lenet_trained.prototxt");
-
-  NetParameter traintest_net_param;
-  ReadProtoFromTextFile("caffe/test/data/lenet_traintest.prototxt",
-      &traintest_net_param);
-  Net<TypeParam> caffe_traintest_net(traintest_net_param, bottom_vec);
-  caffe_traintest_net.CopyTrainedLayersFrom(trained_net_param);
-
-  // Test run
-  double train_accuracy = 0;
-  int batch_size = traintest_net_param.layers(0).layer().batchsize();
-  for (int i = 0; i < 60000 / batch_size; ++i) {
-    const vector<Blob<TypeParam>*>& result =
-        caffe_traintest_net.Forward(bottom_vec);
-    train_accuracy += result[0]->cpu_data()[0];
-  }
-  train_accuracy /= 60000 / batch_size;
-  LOG(ERROR) << "Train accuracy:" << train_accuracy;
-  EXPECT_GE(train_accuracy, 0.98);
-
-  NetParameter test_net_param;
-  ReadProtoFromTextFile("caffe/test/data/lenet_test.prototxt", &test_net_param);
-  Net<TypeParam> caffe_test_net(test_net_param, bottom_vec);
-  caffe_test_net.CopyTrainedLayersFrom(trained_net_param);
-
-  // Test run
-  double test_accuracy = 0;
-  batch_size = test_net_param.layers(0).layer().batchsize();
-  for (int i = 0; i < 10000 / batch_size; ++i) {
-    const vector<Blob<TypeParam>*>& result =
-        caffe_test_net.Forward(bottom_vec);
-    test_accuracy += result[0]->cpu_data()[0];
-  }
-  test_accuracy /= 10000 / batch_size;
-  LOG(ERROR) << "Test accuracy:" << test_accuracy;
-  EXPECT_GE(test_accuracy, 0.98);
-}
-
-}  // namespace caffe
index 0d4f9bb..b7a830b 100644 (file)
@@ -47,6 +47,28 @@ void ReadImageToProto(const string& filename, BlobProto* proto) {
   }
 }
 
+void ReadImageToDatum(const string& filename, const int label, Datum* datum) {
+  Mat cv_img;
+  cv_img = cv::imread(filename, CV_LOAD_IMAGE_COLOR);
+  CHECK(cv_img.data) << "Could not open or find the image.";
+  DCHECK_EQ(cv_img.channels(), 3);
+  datum->set_channels(3);
+  datum->set_height(cv_img.rows);
+  datum->set_width(cv_img.cols);
+  datum->set_label(label);
+  datum->clear_data();
+  datum->clear_float_data();
+  string* datum_string = datum->mutable_data();
+  for (int c = 0; c < 3; ++c) {
+    for (int h = 0; h < cv_img.rows; ++h) {
+      for (int w = 0; w < cv_img.cols; ++w) {
+        datum_string->push_back(static_cast<char>(cv_img.at<Vec3b>(h, w)[c]));
+      }
+    }
+  }
+}
+
+
 void WriteProtoToImage(const string& filename, const BlobProto& proto) {
   CHECK_EQ(proto.num(), 1);
   CHECK(proto.channels() == 3 || proto.channels() == 1);
index 57beef1..ab45936 100644 (file)
@@ -33,6 +33,8 @@ inline void WriteBlobToImage(const string& filename, const Blob<Dtype>& blob) {
   WriteProtoToImage(filename, proto);
 }
 
+void ReadImageToDatum(const string& filename, const int label, Datum* datum);
+
 void ReadProtoFromTextFile(const char* filename,
     Message* proto);
 inline void ReadProtoFromTextFile(const string& filename,
index 3aa43b2..74c5978 100644 (file)
@@ -267,12 +267,12 @@ class SoftmaxLayer : public Layer<Dtype> {
  protected:
   virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
       vector<Blob<Dtype>*>* top);
-  // virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
-  //     vector<Blob<Dtype>*>* top);
+  virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top);
   virtual Dtype Backward_cpu(const vector<Blob<Dtype>*>& top,
       const bool propagate_down, vector<Blob<Dtype>*>* bottom);
-  // virtual Dtype Backward_gpu(const vector<Blob<Dtype>*>& top,
-  //     const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+  virtual Dtype Backward_gpu(const vector<Blob<Dtype>*>& top,
+     const bool propagate_down, vector<Blob<Dtype>*>* bottom);
 
   // sum_multiplier is just used to carry out sum using blas
   Blob<Dtype> sum_multiplier_;
diff --git a/src/programs/convert_dataset.cpp b/src/programs/convert_dataset.cpp
new file mode 100644 (file)
index 0000000..7fb6a04
--- /dev/null
@@ -0,0 +1,66 @@
+// Copyright 2013 Yangqing Jia
+// This program converts a set of images to a leveldb by storing them as Datum
+// proto buffers.
+// Usage:
+//    convert_dataset ROOTFOLDER LISTFILE DB_NAME
+// where ROOTFOLDER is the root folder that holds all the images, and LISTFILE
+// should be a list of files as well as their labels, in the format as
+// subfolder1/file1.JPEG 0
+// ....
+
+#include <glog/logging.h>
+#include <leveldb/db.h>
+
+#include <string>
+#include <iostream>
+#include <fstream>
+
+#include "caffe/proto/caffe.pb.h"
+#include "caffe/util/io.hpp"
+
+using namespace caffe;
+using std::string;
+
+// A utility function to generate random strings
+void GenerateRandomPrefix(const int n, string* key) {
+  const char* kCHARS = "abcdefghijklmnopqrstuvwxyz";
+  key->clear();
+  for (int i = 0; i < n; ++i) {
+    key->push_back(kCHARS[rand() % 26]);
+  }
+  key->push_back('_');
+}
+
+int main(int argc, char** argv) {
+  ::google::InitGoogleLogging(argv[0]);
+  std::ifstream infile(argv[2]);
+  leveldb::DB* db;
+  leveldb::Options options;
+  options.error_if_exists = true;
+  options.create_if_missing = true;
+  LOG(INFO) << "Opening leveldb " << argv[3];
+  leveldb::Status status = leveldb::DB::Open(
+      options, argv[3], &db);
+  CHECK(status.ok()) << "Failed to open leveldb " << argv[3];
+
+  string root_folder(argv[1]);
+  string filename;
+  int label;
+  Datum datum;
+  string key;
+  string value;
+  while (infile >> filename >> label) {
+    ReadImageToDatum(root_folder + filename, label, &datum);
+    // get the key, and add a random string so the leveldb will have permuted
+    // data
+    GenerateRandomPrefix(8, &key);
+    key += filename;
+    // get the value
+    datum.SerializeToString(&value);
+    db->Put(leveldb::WriteOptions(), key, value);
+    LOG(ERROR) << "Writing " << key;
+  }
+
+  delete db;
+  return 0;
+}