Implement and test HDF5OutputLayer
authorKai Li <kaili_kloud@163.com>
Sun, 23 Mar 2014 11:03:21 +0000 (19:03 +0800)
committerKai Li <kaili_kloud@163.com>
Sun, 23 Mar 2014 12:12:54 +0000 (20:12 +0800)
include/caffe/vision_layers.hpp
src/caffe/layers/hdf5_output_layer.cpp [new file with mode: 0644]
src/caffe/proto/caffe.proto
src/caffe/test/test_hdf5_output_layer.cpp [new file with mode: 0644]

index 91a2324..fb0c0dd 100644 (file)
@@ -15,6 +15,9 @@
 #include "caffe/layer.hpp"
 #include "caffe/proto/caffe.pb.h"
 
+#define HDF5_DATA_DATASET_NAME "data"
+#define HDF5_DATA_LABEL_NAME "label"
+
 namespace caffe {
 
 
@@ -478,6 +481,33 @@ class HDF5DataLayer : public Layer<Dtype> {
 
 
 template <typename Dtype>
+class HDF5OutputLayer : public Layer<Dtype> {
+ public:
+  explicit HDF5OutputLayer(const LayerParameter& param);
+  virtual ~HDF5OutputLayer();
+  virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top);
+  inline std::string file_name() const { return file_name_; }
+
+ protected:
+  virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top);
+  virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top);
+  virtual Dtype Backward_cpu(const vector<Blob<Dtype>*>& top,
+      const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+  virtual Dtype Backward_gpu(const vector<Blob<Dtype>*>& top,
+      const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+  virtual void SaveBlobs();
+
+  std::string file_name_;
+  hid_t file_id_;
+  Blob<Dtype> data_blob_;
+  Blob<Dtype> label_blob_;
+};
+
+
+template <typename Dtype>
 class SoftmaxLayer : public Layer<Dtype> {
  public:
   explicit SoftmaxLayer(const LayerParameter& param)
diff --git a/src/caffe/layers/hdf5_output_layer.cpp b/src/caffe/layers/hdf5_output_layer.cpp
new file mode 100644 (file)
index 0000000..3bf8dc2
--- /dev/null
@@ -0,0 +1,116 @@
+// Copyright 2014 BVLC and contributors.
+/*
+Contributors:
+- kloudkl@github, 2014.
+*/
+
+#include <vector>
+
+#include "hdf5.h"
+#include "hdf5_hl.h"
+
+#include "caffe/blob.hpp"
+#include "caffe/common.hpp"
+#include "caffe/layer.hpp"
+#include "caffe/util/io.hpp"
+#include "caffe/vision_layers.hpp"
+
+namespace caffe {
+using std::vector;
+
+template <typename Dtype>
+HDF5OutputLayer<Dtype>::HDF5OutputLayer(const LayerParameter& param)
+    : Layer<Dtype>(param),
+      file_name_(param.hdf5_output_param().file_name()) {
+  /* create a HDF5 file */
+  file_id_ = H5Fcreate(file_name_.c_str(), H5F_ACC_TRUNC, H5P_DEFAULT,
+                       H5P_DEFAULT);
+  CHECK_GE(file_id_, 0) << "Failed to open HDF5 file" << file_name_;
+}
+
+template <typename Dtype>
+HDF5OutputLayer<Dtype>::~HDF5OutputLayer<Dtype>() {
+  herr_t status = H5Fclose(file_id_);
+  CHECK_GE(status, 0) << "Failed to close HDF5 file " << file_name_;
+}
+
+template <typename Dtype>
+void HDF5OutputLayer<Dtype>::SaveBlobs() {
+  // TODO: no limit on the number of blobs
+  LOG(INFO) << "Saving HDF5 file" << file_name_;
+  CHECK_EQ(data_blob_.num(), label_blob_.num()) <<
+      "data blob and label blob must have the same batch size";
+  hdf5_save_nd_dataset(file_id_, HDF5_DATA_DATASET_NAME, data_blob_);
+  hdf5_save_nd_dataset(file_id_, HDF5_DATA_LABEL_NAME, label_blob_);
+  LOG(INFO) << "Successfully saved " << data_blob_.num() << " rows";
+}
+
+template <typename Dtype>
+void HDF5OutputLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top) {
+  // TODO: no limit on the number of blobs
+  CHECK_EQ(bottom.size(), 2) << "HDF5OutputLayer takes two blobs as input.";
+  CHECK_EQ(top->size(), 0) << "HDF5OutputLayer takes no output blobs.";
+}
+
+template <typename Dtype>
+void HDF5OutputLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top) {
+  CHECK_GE(bottom.size(), 2);
+  CHECK_EQ(bottom[0]->num(), bottom[1]->num());
+  data_blob_.Reshape(bottom[0]->num(), bottom[0]->channels(),
+                     bottom[0]->height(), bottom[0]->width());
+  label_blob_.Reshape(bottom[1]->num(), bottom[1]->channels(),
+                     bottom[1]->height(), bottom[1]->width());
+  const int data_datum_dim = bottom[0]->count() / bottom[0]->num();
+  const int label_datum_dim = bottom[1]->count() / bottom[1]->num();
+
+  for (int i = 0; i < bottom[0]->num(); ++i) {
+    memcpy(&data_blob_.mutable_cpu_data()[i * data_datum_dim],
+           &bottom[0]->cpu_data()[i * data_datum_dim],
+           sizeof(Dtype) * data_datum_dim);
+    memcpy(&label_blob_.mutable_cpu_data()[i * label_datum_dim],
+           &bottom[1]->cpu_data()[i * label_datum_dim],
+           sizeof(Dtype) * label_datum_dim);
+  }
+  SaveBlobs();
+}
+
+template <typename Dtype>
+void HDF5OutputLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top) {
+  CHECK_GE(bottom.size(), 2);
+  CHECK_EQ(bottom[0]->num(), bottom[1]->num());
+  data_blob_.Reshape(bottom[0]->num(), bottom[0]->channels(),
+                     bottom[0]->height(), bottom[0]->width());
+  label_blob_.Reshape(bottom[1]->num(), bottom[1]->channels(),
+                     bottom[1]->height(), bottom[1]->width());
+  const int data_datum_dim = bottom[0]->count() / bottom[0]->num();
+  const int label_datum_dim = bottom[1]->count() / bottom[1]->num();
+
+  for (int i = 0; i < bottom[0]->num(); ++i) {
+    CUDA_CHECK(cudaMemcpy(&data_blob_.mutable_cpu_data()[i * data_datum_dim],
+           &bottom[0]->gpu_data()[i * data_datum_dim],
+           sizeof(Dtype) * data_datum_dim, cudaMemcpyDeviceToHost));
+    CUDA_CHECK(cudaMemcpy(&label_blob_.mutable_cpu_data()[i * label_datum_dim],
+           &bottom[1]->gpu_data()[i * label_datum_dim],
+           sizeof(Dtype) * label_datum_dim, cudaMemcpyDeviceToHost));
+  }
+  SaveBlobs();
+}
+
+template <typename Dtype>
+Dtype HDF5OutputLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
+      const bool propagate_down, vector<Blob<Dtype>*>* bottom) {
+  return Dtype(0.);
+}
+
+template <typename Dtype>
+Dtype HDF5OutputLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
+      const bool propagate_down, vector<Blob<Dtype>*>* bottom) {
+  return Dtype(0.);
+}
+
+INSTANTIATE_CLASS(HDF5OutputLayer);
+
+}  // namespace caffe
index 5a73a44..362764a 100644 (file)
@@ -125,6 +125,12 @@ message LayerParameter {
   // the other dimensions must be the same for all the bottom blobs.
   // By default it will concatenate blobs along the channels dimension.
   optional uint32 concat_dim = 65 [default = 1];
+  
+  optional HDF5OutputParameter hdf5_output_param = 1001;
+}
+
+message HDF5OutputParameter {
+  optional string file_name = 1;
 }
 
 message LayerConnection {
diff --git a/src/caffe/test/test_hdf5_output_layer.cpp b/src/caffe/test/test_hdf5_output_layer.cpp
new file mode 100644 (file)
index 0000000..3cbfb3f
--- /dev/null
@@ -0,0 +1,127 @@
+// Copyright 2014 kloudkl@github
+
+#include <cuda_runtime.h>
+#include <string>
+#include <vector>
+
+#include "gtest/gtest.h"
+#include "caffe/blob.hpp"
+#include "caffe/common.hpp"
+#include "caffe/util/io.hpp"
+#include "caffe/vision_layers.hpp"
+#include "caffe/proto/caffe.pb.h"
+#include "caffe/test/test_caffe_main.hpp"
+
+namespace caffe {
+using std::string;
+using std::vector;
+
+extern cudaDeviceProp CAFFE_TEST_CUDA_PROP;
+
+template <typename Dtype>
+class HDF5OutputLayerTest : public ::testing::Test {
+ protected:
+  HDF5OutputLayerTest()
+      : output_file_name_("/tmp/test_hdf5_output_layer-sample_data.hdf5"),
+        input_file_name_("src/caffe/test/test_data/sample_data.h5"),
+        blob_data_(new Blob<Dtype>()),
+        blob_label_(new Blob<Dtype>()),
+        num_(5),
+        channels_(8),
+        height_(5),
+        width_(5) {
+  }
+  virtual void SetUp() {
+  }
+
+  virtual ~HDF5OutputLayerTest() {
+    delete blob_data_;
+    delete blob_label_;
+  }
+
+  void CheckBlobEqual(const Blob<Dtype>& b1, const Blob<Dtype>& b2);
+
+  string output_file_name_;
+  string input_file_name_;
+  Blob<Dtype>* const blob_data_;
+  Blob<Dtype>* const blob_label_;
+  vector<Blob<Dtype>*> blob_bottom_vec_;
+  vector<Blob<Dtype>*> blob_top_vec_;
+  int num_;
+  int channels_;
+  int height_;
+  int width_;
+};
+
+template <typename Dtype>
+void HDF5OutputLayerTest<Dtype>::CheckBlobEqual(
+    const Blob<Dtype>& b1, const Blob<Dtype>& b2) {
+  EXPECT_EQ(b1.num(), b2.num());
+  EXPECT_EQ(b1.channels(), b2.channels());
+  EXPECT_EQ(b1.height(), b2.height());
+  EXPECT_EQ(b1.width(), b2.width());
+  for (int n = 0; n < b1.num(); ++n) {
+    for (int c = 0; c < b1.channels(); ++c) {
+      for (int h = 0; h < b1.height(); ++h) {
+        for (int w = 0; w < b1.width(); ++w) {
+          EXPECT_EQ(b1.data_at(n, c, h, w), b1.data_at(n, c, h, w));
+        }
+      }
+    }
+  }
+}
+
+typedef ::testing::Types<float, double> Dtypes;
+TYPED_TEST_CASE(HDF5OutputLayerTest, Dtypes);
+
+TYPED_TEST(HDF5OutputLayerTest, TestForward) {
+  LOG(INFO) << "Loading HDF5 file " << this->input_file_name_;
+  hid_t file_id = H5Fopen(this->input_file_name_.c_str(), H5F_ACC_RDONLY,
+                          H5P_DEFAULT);
+  ASSERT_GE(file_id, 0) << "Failed to open HDF5 file" <<
+      this->input_file_name_;
+  hdf5_load_nd_dataset(file_id, HDF5_DATA_DATASET_NAME, 0, 4,
+                       this->blob_data_);
+  hdf5_load_nd_dataset(file_id, HDF5_DATA_LABEL_NAME, 0, 4,
+                       this->blob_label_);
+  herr_t status = H5Fclose(file_id);
+  EXPECT_GE(status, 0) << "Failed to close HDF5 file " <<
+      this->input_file_name_;
+  this->blob_bottom_vec_.push_back(this->blob_data_);
+  this->blob_bottom_vec_.push_back(this->blob_label_);
+
+  Caffe::Brew modes[] = { Caffe::CPU, Caffe::GPU };
+  for (int m = 0; m < 2; ++m) {
+    Caffe::set_mode(modes[m]);
+    LayerParameter param;
+    param.mutable_hdf5_output_param()->set_file_name(this->output_file_name_);
+    // This code block ensures that the layer is deconstructed and
+    //   the output hdf5 file is closed.
+    {
+      HDF5OutputLayer<TypeParam> layer(param);
+      EXPECT_EQ(layer.file_name(), this->output_file_name_);
+      layer.SetUp(this->blob_bottom_vec_, &this->blob_top_vec_);
+      layer.Forward(this->blob_bottom_vec_, &this->blob_top_vec_);
+    }
+    hid_t file_id = H5Fopen(this->output_file_name_.c_str(), H5F_ACC_RDONLY,
+                            H5P_DEFAULT);
+    ASSERT_GE(file_id, 0) << "Failed to open HDF5 file" <<
+        this->input_file_name_;
+
+    Blob<TypeParam>* blob_data = new Blob<TypeParam>();
+    hdf5_load_nd_dataset(file_id, HDF5_DATA_DATASET_NAME, 0, 4,
+                         blob_data);
+    this->CheckBlobEqual(*(this->blob_data_), *blob_data);
+
+    Blob<TypeParam>* blob_label = new Blob<TypeParam>();
+    hdf5_load_nd_dataset(file_id, HDF5_DATA_LABEL_NAME, 0, 4,
+                         blob_label);
+    this->CheckBlobEqual(*(this->blob_label_), *blob_label);
+
+    herr_t status = H5Fclose(file_id);
+    EXPECT_GE(status, 0) << "Failed to close HDF5 file " <<
+        this->output_file_name_;
+  }
+}
+
+}  // namespace caffe