misc update
authorYangqing Jia <jiayq84@gmail.com>
Thu, 10 Oct 2013 21:45:10 +0000 (14:45 -0700)
committerYangqing Jia <jiayq84@gmail.com>
Thu, 10 Oct 2013 21:46:02 +0000 (14:46 -0700)
src/Makefile
src/caffe/layers/data_layer.cpp
src/caffe/optimization/solver.cpp
src/caffe/proto/caffe.proto
src/caffe/util/im2col.cu
src/programs/imagenet.prototxt [new file with mode: 0644]
src/programs/imagenet_solver.prototxt [new file with mode: 0644]
src/programs/net_speed_benchmark.cpp [new file with mode: 0644]
src/programs/train_alexnet.cpp [deleted file]
src/programs/train_net.cpp [new file with mode: 0644]

index d78e99b..27fd760 100644 (file)
@@ -42,12 +42,13 @@ LIBRARIES := cuda cudart cublas protobuf glog mkl_rt mkl_intel_thread curand \
                leveldb snappy opencv_core opencv_highgui pthread tcmalloc
 WARNINGS := -Wall
 
-CXXFLAGS += -fPIC $(foreach includedir,$(INCLUDE_DIRS),-I$(includedir))
+CXXFLAGS += -pthread -fPIC -O2 $(foreach includedir,$(INCLUDE_DIRS),-I$(includedir))
+NVCCFLAGS := -Xcompiler -fPIC -O2 $(foreach includedir,$(INCLUDE_DIRS),-I$(includedir))
 LDFLAGS += $(foreach librarydir,$(LIBRARY_DIRS),-L$(librarydir))
 LDFLAGS += $(foreach library,$(LIBRARIES),-l$(library))
 
 LINK = $(CXX) $(CXXFLAGS) $(CPPFLAGS) $(LDFLAGS) $(WARNINGS)
-NVCC = nvcc ${CXXFLAGS:-fPIC=-Xcompiler -fPIC} $(CPPFLAGS) $(CUDA_ARCH)
+NVCC = nvcc $(NVCCFLAGS) $(CPPFLAGS) $(CUDA_ARCH)
 
 .PHONY: all test clean distclean linecount program
 
index 5b95770..7993a43 100644 (file)
@@ -24,39 +24,64 @@ void* DataLayerPrefetch(void* layer_pointer) {
   const Dtype subtraction = layer->layer_param_.subtraction();
   const int batchsize = layer->layer_param_.batchsize();
   const int cropsize = layer->layer_param_.cropsize();
+  const bool mirror = layer->layer_param_.mirror();
+  if (mirror && cropsize == 0) {
+    LOG(FATAL) << "Current implementation requires mirror and cropsize to be "
+        << "set at the same time.";
+  }
+  // datum scales
+  const int channels = layer->datum_channels_;
+  const int height = layer->datum_height_;
+  const int width = layer->datum_width_;
+  const int size = layer->datum_size_;
   for (int itemid = 0; itemid < batchsize; ++itemid) {
     // get a blob
     datum.ParseFromString(layer->iter_->value().ToString());
     const string& data = datum.data();
     if (cropsize) {
       CHECK(data.size()) << "Image cropping only support uint8 data";
-      int h_offset = rand() % (layer->datum_height_ - cropsize);
-      int w_offset = rand() % (layer->datum_width_ - cropsize);
-      for (int c = 0; c < layer->datum_channels_; ++c) {
-        for (int h = 0; h < cropsize; ++h) {
-          for (int w = 0; w < cropsize; ++w) {
-            top_data[((itemid * layer->datum_channels_ + c) * cropsize + h) * cropsize + w] =
-                static_cast<Dtype>((uint8_t)data[
-                    (c * layer->datum_height_ + h + h_offset) * layer->datum_width_
-                    + w + w_offset]
-                ) * scale - subtraction;
+      int h_offset = rand() % (height - cropsize);
+      int w_offset = rand() % (width - cropsize);
+      if (mirror && rand() % 2) {
+        // Copy mirrored version
+        for (int c = 0; c < channels; ++c) {
+          for (int h = 0; h < cropsize; ++h) {
+            for (int w = 0; w < cropsize; ++w) {
+              top_data[((itemid * channels + c) * cropsize + h) * cropsize + cropsize - 1 - w] =
+                  static_cast<Dtype>((uint8_t)data[
+                      (c * height + h + h_offset) * width + w + w_offset]
+                  ) * scale - subtraction;
+            }
+          }
+        }
+      } else {
+        // Normal copy
+        for (int c = 0; c < channels; ++c) {
+          for (int h = 0; h < cropsize; ++h) {
+            for (int w = 0; w < cropsize; ++w) {
+              top_data[((itemid * channels + c) * cropsize + h) * cropsize + w] =
+                  static_cast<Dtype>((uint8_t)data[
+                      (c * height + h + h_offset) * width + w + w_offset]
+                  ) * scale - subtraction;
+            }
           }
         }
       }
     } else {
       // we will prefer to use data() first, and then try float_data()
       if (data.size()) {
-        for (int j = 0; j < layer->datum_size_; ++j) {
-          top_data[itemid * layer->datum_size_ + j] =
+        for (int j = 0; j < size; ++j) {
+          top_data[itemid * size + j] =
               (static_cast<Dtype>((uint8_t)data[j]) * scale) - subtraction;
         }
       } else {
-        for (int j = 0; j < layer->datum_size_; ++j) {
-          top_data[itemid * layer->datum_size_ + j] =
+        for (int j = 0; j < size; ++j) {
+          top_data[itemid * size + j] =
               (datum.float_data(j) * scale) - subtraction;
         }
       }
     }
+
     top_label[itemid] = datum.label();
     // go to the next iter
     layer->iter_->Next();
index 1afe293..d9ab2c1 100644 (file)
@@ -1,5 +1,7 @@
 // Copyright Yangqing Jia 2013
 
+#include <cstdio>
+
 #include <algorithm>
 #include <string>
 #include <vector>
@@ -34,7 +36,7 @@ void Solver<Dtype>::Solve(Net<Dtype>* net) {
     if (param_.snapshot() > 0 && iter_ % param_.snapshot() == 0) {
       Snapshot(false);
     }
-    if (param_.display() && iter_ % param_.display()) {
+    if (param_.display() && iter_ % param_.display() == 0) {
       LOG(ERROR) << "Iteration " << iter_ << ", loss = " << loss;
     }
   }
@@ -47,14 +49,14 @@ void Solver<Dtype>::Snapshot(bool is_final) {
   NetParameter net_param;
   // For intermediate results, we will also dump the gradient values.
   net_->ToProto(&net_param, !is_final);
-  stringstream ss;
-  ss << param_.snapshot_prefix();
+  string filename(param_.snapshot_prefix());
   if (is_final) {
-    ss << "_final";
+    filename += "_final";
   } else {
-    ss << "_iter_" << iter_;
+    char iter_str_buffer[20];
+    sprintf(iter_str_buffer, "_iter_%d", iter_);
+    filename += iter_str_buffer;
   }
-  string filename = ss.str();
   LOG(ERROR) << "Snapshotting to " << filename;
   WriteProtoToBinaryFile(net_param, filename.c_str());
 }
index 048144c..afefcca 100644 (file)
@@ -66,6 +66,8 @@ message LayerParameter {
   optional uint32 batchsize = 19;
   // For data layers, specify if we would like to randomly crop an image.
   optional uint32 cropsize = 20 [default = 0];
+  // For data layers, specify if we want to randomly mirror data.
+  optional bool mirror = 21 [default = false];
 
   // The blobs containing the numeric parameters of the layer
   repeated BlobProto blobs = 50;
index 81ac3c2..0b0c8b8 100644 (file)
@@ -9,6 +9,7 @@
 
 namespace caffe {
 
+
 template <typename Dtype>
 __global__ void im2col_gpu_kernel(const int n, const Dtype* data_im,
   const int height, const int width, const int ksize,
@@ -48,6 +49,7 @@ void im2col_gpu(const Dtype* data_im, const int channels,
   CUDA_POST_KERNEL_CHECK;
 }
 
+
 // Explicit instantiation
 template void im2col_gpu<float>(const float* data_im, const int channels,
   const int height, const int width, const int ksize, const int stride,
@@ -71,13 +73,24 @@ __global__ void col2im_gpu_kernel(const int n, const Dtype* data_col,
     int w_col_end = min(w / stride + 1, width_col);
     int h_col_start = (h < ksize) ? 0 : (h - ksize) / stride + 1;
     int h_col_end = min(h / stride + 1, height_col);
+    /*
     for (int h_col = h_col_start; h_col < h_col_end; ++h_col) {
       for (int w_col = w_col_start; w_col < w_col_end; ++w_col) {
         // the col location: [c * width * height + h_out, w_out]
-        int c_col = c * ksize * ksize + (h - h_col * stride) * ksize + (w - w_col * stride); 
+        int c_col = c * ksize * ksize + (h - h_col * stride) * ksize + (w - w_col * stride);
         val += data_col[(c_col * height_col + h_col) * width_col + w_col];
       }
     }
+    */
+    // equivalent implementation
+    int offset = (c * ksize * ksize + h * ksize + w) * height_col * width_col;
+    int coeff_h_col = (1 - stride * ksize * height_col) * width_col;
+    int coeff_w_col = (1 - stride * height_col * width_col);
+    for (int h_col = h_col_start; h_col < h_col_end; ++h_col) {
+      for (int w_col = w_col_start; w_col < w_col_end; ++w_col) {
+        val += data_col[offset + h_col * coeff_h_col + w_col * coeff_w_col];
+      }
+    }
     data_im[index] = val;
   }
 }
diff --git a/src/programs/imagenet.prototxt b/src/programs/imagenet.prototxt
new file mode 100644 (file)
index 0000000..65b7432
--- /dev/null
@@ -0,0 +1,332 @@
+name: "CaffeNet"
+layers {
+  layer {
+    name: "data"
+    type: "data"
+    source: "/home/jiayq/caffe-train-leveldb"
+    batchsize: 96
+    subtraction: 114
+    cropsize: 227
+    mirror: true
+  }
+  top: "data"
+  top: "label"
+}
+layers {
+  layer {
+    name: "conv1"
+    type: "conv"
+    num_output: 96
+    kernelsize: 11
+    stride: 4
+    weight_filler {
+      type: "gaussian"
+      std: 0.01
+    }
+    bias_filler {
+      type: "constant"
+      value: 0
+    }
+  }
+  bottom: "data"
+  top: "conv1"
+}
+layers {
+  layer {
+    name: "relu1"
+    type: "relu"
+  }
+  bottom: "conv1"
+  top: "relu1"
+}
+layers {
+  layer {
+    name: "pool1"
+    type: "pool"
+    pool: MAX
+    kernelsize: 3
+    stride: 2
+  }
+  bottom: "relu1"
+  top: "pool1"
+}
+layers {
+  layer {
+    name: "norm1"
+    type: "lrn"
+    local_size: 5
+    alpha: 0.0001
+    beta: 0.75
+  }
+  bottom: "pool1"
+  top: "norm1"
+}
+layers {
+  layer {
+    name: "pad2"
+    type: "padding"
+    pad: 2
+  }
+  bottom: "norm1"
+  top: "pad2"
+}
+layers {
+  layer {
+    name: "conv2"
+    type: "conv"
+    num_output: 256
+    group: 2
+    kernelsize: 5
+    weight_filler {
+      type: "gaussian"
+      std: 0.01
+    }
+    bias_filler {
+      type: "constant"
+      value: 1
+    }
+  }
+  bottom: "pad2"
+  top: "conv2"
+}
+layers {
+  layer {
+    name: "relu2"
+    type: "relu"
+  }
+  bottom: "conv2"
+  top: "relu2"
+}
+layers {
+  layer {
+    name: "pool2"
+    type: "pool"
+    pool: MAX
+    kernelsize: 3
+    stride: 2
+  }
+  bottom: "relu2"
+  top: "pool2"
+}
+layers {
+  layer {
+    name: "norm2"
+    type: "lrn"
+    local_size: 5
+    alpha: 0.0001
+    beta: 0.75
+  }
+  bottom: "pool2"
+  top: "norm2"
+}
+layers {
+  layer {
+    name: "pad3"
+    type: "padding"
+    pad: 1
+  }
+  bottom: "norm2"
+  top: "pad3"
+}
+layers {
+  layer {
+    name: "conv3"
+    type: "conv"
+    num_output: 384
+    kernelsize: 3
+    weight_filler {
+      type: "gaussian"
+      std: 0.01
+    }
+    bias_filler {
+      type: "constant"
+      value: 0
+    }
+  }
+  bottom: "pad3"
+  top: "conv3"
+}
+layers {
+  layer {
+    name: "relu3"
+    type: "relu"
+  }
+  bottom: "conv3"
+  top: "relu3"
+}
+layers {
+  layer {
+    name: "pad4"
+    type: "padding"
+    pad: 1
+  }
+  bottom: "relu3"
+  top: "pad4"
+}
+layers {
+  layer {
+    name: "conv4"
+    type: "conv"
+    num_output: 384
+    group: 2
+    kernelsize: 3
+    weight_filler {
+      type: "gaussian"
+      std: 0.01
+    }
+    bias_filler {
+      type: "constant"
+      value: 1
+    }
+  }
+  bottom: "pad4"
+  top: "conv4"
+}
+layers {
+  layer {
+    name: "relu4"
+    type: "relu"
+  }
+  bottom: "conv4"
+  top: "relu4"
+}
+layers {
+  layer {
+    name: "pad5"
+    type: "padding"
+    pad: 1
+  }
+  bottom: "relu4"
+  top: "pad5"
+}
+layers {
+  layer {
+    name: "conv5"
+    type: "conv"
+    num_output: 256
+    group: 2
+    kernelsize: 3
+    weight_filler {
+      type: "gaussian"
+      std: 0.01
+    }
+    bias_filler {
+      type: "constant"
+      value: 1
+    }
+  }
+  bottom: "pad5"
+  top: "conv5"
+}
+layers {
+  layer {
+    name: "relu5"
+    type: "relu"
+  }
+  bottom: "conv5"
+  top: "relu5"
+}
+layers {
+  layer {
+    name: "pool5"
+    type: "pool"
+    kernelsize: 3
+    pool: MAX
+    stride: 2
+  }
+  bottom: "relu5"
+  top: "pool5"
+}
+layers {
+  layer {
+    name: "fc6"
+    type: "innerproduct"
+    num_output: 4096
+    weight_filler {
+      type: "gaussian"
+      std: 0.005
+    }
+    bias_filler {
+      type: "constant"
+      value: 1
+    }
+  }
+  bottom: "pool5"
+  top: "fc6"
+}
+layers {
+  layer {
+    name: "relu6"
+    type: "relu"
+  }
+  bottom: "fc6"
+  top: "relu6"
+}
+layers {
+  layer {
+    name: "drop6"
+    type: "dropout"
+    dropout_ratio: 0.5
+  }
+  bottom: "relu6"
+  top: "drop6"
+}
+layers {
+  layer {
+    name: "fc7"
+    type: "innerproduct"
+    num_output: 4096
+    weight_filler {
+      type: "gaussian"
+      std: 0.005
+    }
+    bias_filler {
+      type: "constant"
+      value: 1
+    }
+  }
+  bottom: "drop6"
+  top: "fc7"
+}
+layers {
+  layer {
+    name: "relu7"
+    type: "relu"
+  }
+  bottom: "fc7"
+  top: "relu7"
+}
+layers {
+  layer {
+    name: "drop7"
+    type: "dropout"
+    dropout_ratio: 0.5
+  }
+  bottom: "relu7"
+  top: "drop7"
+}
+layers {
+  layer {
+    name: "fc8"
+    type: "innerproduct"
+    num_output: 1000
+    weight_filler {
+      type: "gaussian"
+      std: 0.01
+    }
+    bias_filler {
+      type: "constant"
+      value: 0
+    }
+  }
+  bottom: "drop7"
+  top: "fc8"
+}
+layers {
+  layer {
+    name: "loss"
+    type: "softmax_loss"
+  }
+  bottom: "fc8"
+  bottom: "label"
+}
\ No newline at end of file
diff --git a/src/programs/imagenet_solver.prototxt b/src/programs/imagenet_solver.prototxt
new file mode 100644 (file)
index 0000000..58b0dfe
--- /dev/null
@@ -0,0 +1,10 @@
+base_lr: 0.02
+lr_policy: "step"
+gamma: 0.1
+stepsize: 340000
+display: 100
+max_iter: 1200000
+momentum: 0.9
+weight_decay: 0.0005
+snapshot: 15000
+snapshot_prefix: "alexnet_train"
\ No newline at end of file
diff --git a/src/programs/net_speed_benchmark.cpp b/src/programs/net_speed_benchmark.cpp
new file mode 100644 (file)
index 0000000..560c5d8
--- /dev/null
@@ -0,0 +1,62 @@
+// Copyright 2013 Yangqing Jia
+
+#include <cuda_runtime.h>
+#include <fcntl.h>
+#include <google/protobuf/text_format.h>
+
+#include <cstring>
+#include <ctime>
+
+#include "caffe/blob.hpp"
+#include "caffe/common.hpp"
+#include "caffe/net.hpp"
+#include "caffe/filler.hpp"
+#include "caffe/proto/caffe.pb.h"
+#include "caffe/util/io.hpp"
+#include "caffe/optimization/solver.hpp"
+
+using namespace caffe;
+
+int main(int argc, char** argv) {
+  cudaSetDevice(1);
+  Caffe::set_mode(Caffe::GPU);
+  Caffe::set_phase(Caffe::TRAIN);
+  int repeat = 100;
+
+  NetParameter net_param;
+  ReadProtoFromTextFile(argv[1],
+      &net_param);
+  vector<Blob<float>*> bottom_vec;
+  Net<float> caffe_net(net_param, bottom_vec);
+
+  // Run the network without training.
+  LOG(ERROR) << "Performing Forward";
+  caffe_net.Forward(bottom_vec);
+  LOG(ERROR) << "Performing Backward";
+  LOG(ERROR) << "Initial loss: " << caffe_net.Backward();
+
+  const vector<shared_ptr<Layer<float> > >& layers = caffe_net.layers();
+  vector<vector<Blob<float>*> >& bottom_vecs = caffe_net.bottom_vecs();
+  vector<vector<Blob<float>*> >& top_vecs = caffe_net.top_vecs();
+  LOG(ERROR) << "*** Benchmark begins ***";
+  for (int i = 0; i < layers.size(); ++i) {
+    const string& layername = layers[i]->layer_param().name();
+    clock_t start = clock();
+    for (int j = 0; j < repeat; ++j) {
+      layers[i]->Forward(bottom_vecs[i], &top_vecs[i]);
+    }
+    LOG(ERROR) << layername << "\tforward: "
+        << float(clock() - start) / CLOCKS_PER_SEC << " seconds.";
+  }
+  for (int i = layers.size() - 1; i >= 0; --i) {
+    const string& layername = layers[i]->layer_param().name();
+    clock_t start = clock();
+    for (int j = 0; j < repeat; ++j) {
+      layers[i]->Backward(top_vecs[i], true, &bottom_vecs[i]);
+    }
+    LOG(ERROR) << layername << "\tbackward: "
+        << float(clock() - start) / CLOCKS_PER_SEC << " seconds.";
+  }
+  LOG(ERROR) << "*** Benchmark ends ***";
+  return 0;
+}
diff --git a/src/programs/train_alexnet.cpp b/src/programs/train_alexnet.cpp
deleted file mode 100644 (file)
index 99e37c3..0000000
+++ /dev/null
@@ -1,73 +0,0 @@
-// Copyright 2013 Yangqing Jia
-
-#include <cuda_runtime.h>
-#include <fcntl.h>
-#include <google/protobuf/text_format.h>
-
-#include <cstring>
-
-#include "caffe/blob.hpp"
-#include "caffe/common.hpp"
-#include "caffe/net.hpp"
-#include "caffe/filler.hpp"
-#include "caffe/proto/caffe.pb.h"
-#include "caffe/util/io.hpp"
-#include "caffe/optimization/solver.hpp"
-
-using namespace caffe;
-
-int main(int argc, char** argv) {
-  cudaSetDevice(0);
-  Caffe::set_mode(Caffe::GPU);
-  Caffe::set_phase(Caffe::TRAIN);
-
-  NetParameter net_param;
-  ReadProtoFromTextFile(argv[1],
-      &net_param);
-  vector<Blob<float>*> bottom_vec;
-  Net<float> caffe_net(net_param, bottom_vec);
-
-  // Run the network without training.
-  LOG(ERROR) << "Performing Forward";
-  caffe_net.Forward(bottom_vec);
-  LOG(ERROR) << "Performing Backward";
-  LOG(ERROR) << "Initial loss: " << caffe_net.Backward();
-
-  /*
-  // Now, let's dump all the layers
-  string output_prefix("alexnet_initial_dump_");
-  const vector<string>& blob_names = caffe_net.blob_names();
-  const vector<shared_ptr<Blob<float> > >& blobs = caffe_net.blobs();
-  for (int blobid = 0; blobid < caffe_net.blobs().size(); ++blobid) {
-    // Serialize blob
-    LOG(ERROR) << "Dumping " << blob_names[blobid];
-    BlobProto output_blob_proto;
-    blobs[blobid]->ToProto(&output_blob_proto);
-    WriteProtoToBinaryFile(output_blob_proto, output_prefix + blob_names[blobid]);
-  }
-  */
-
-  SolverParameter solver_param;
-  solver_param.set_base_lr(0.01);
-  solver_param.set_display(1);
-  solver_param.set_max_iter(60000);
-  solver_param.set_lr_policy("fixed");
-  solver_param.set_momentum(0.9);
-  solver_param.set_weight_decay(0.0005);
-  solver_param.set_snapshot(5000);
-  solver_param.set_snapshot_prefix("alexnet");
-
-  LOG(ERROR) << "Starting Optimization";
-  SGDSolver<float> solver(solver_param);
-  solver.Solve(&caffe_net);
-  LOG(ERROR) << "Optimization Done.";
-
-  // Run the network after training.
-  LOG(ERROR) << "Performing Forward";
-  caffe_net.Forward(bottom_vec);
-  LOG(ERROR) << "Performing Backward";
-  float loss = caffe_net.Backward();
-  LOG(ERROR) << "Final loss: " << loss;
-
-  return 0;
-}
diff --git a/src/programs/train_net.cpp b/src/programs/train_net.cpp
new file mode 100644 (file)
index 0000000..4111043
--- /dev/null
@@ -0,0 +1,39 @@
+// Copyright 2013 Yangqing Jia
+
+#include <cuda_runtime.h>
+#include <fcntl.h>
+#include <google/protobuf/text_format.h>
+
+#include <cstring>
+
+#include "caffe/blob.hpp"
+#include "caffe/common.hpp"
+#include "caffe/net.hpp"
+#include "caffe/filler.hpp"
+#include "caffe/proto/caffe.pb.h"
+#include "caffe/util/io.hpp"
+#include "caffe/optimization/solver.hpp"
+
+using namespace caffe;
+
+int main(int argc, char** argv) {
+  cudaSetDevice(0);
+  Caffe::set_mode(Caffe::GPU);
+  Caffe::set_phase(Caffe::TRAIN);
+
+  NetParameter net_param;
+  ReadProtoFromTextFile(argv[1],
+      &net_param);
+  vector<Blob<float>*> bottom_vec;
+  Net<float> caffe_net(net_param, bottom_vec);
+
+  SolverParameter solver_param;
+  ReadProtoFromTextFile(argv[2], &solver_param);
+
+  LOG(ERROR) << "Starting Optimization";
+  SGDSolver<float> solver(solver_param);
+  solver.Solve(&caffe_net);
+  LOG(ERROR) << "Optimization Done.";
+
+  return 0;
+}