#include "caffe/layer.hpp"
#include "caffe/proto/caffe.pb.h"
+#define HDF5_DATA_DATASET_NAME "data"
+#define HDF5_DATA_LABEL_NAME "label"
+
namespace caffe {
template <typename Dtype>
+class HDF5OutputLayer : public Layer<Dtype> {
+ public:
+ explicit HDF5OutputLayer(const LayerParameter& param);
+ virtual ~HDF5OutputLayer();
+ virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
+ vector<Blob<Dtype>*>* top);
+ inline std::string file_name() const { return file_name_; }
+
+ protected:
+ virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+ vector<Blob<Dtype>*>* top);
+ virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+ vector<Blob<Dtype>*>* top);
+ virtual Dtype Backward_cpu(const vector<Blob<Dtype>*>& top,
+ const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+ virtual Dtype Backward_gpu(const vector<Blob<Dtype>*>& top,
+ const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+ virtual void SaveBlobs();
+
+ std::string file_name_;
+ hid_t file_id_;
+ Blob<Dtype> data_blob_;
+ Blob<Dtype> label_blob_;
+};
+
+
+template <typename Dtype>
class SoftmaxLayer : public Layer<Dtype> {
public:
explicit SoftmaxLayer(const LayerParameter& param)
--- /dev/null
+// Copyright 2014 BVLC and contributors.
+/*
+Contributors:
+- kloudkl@github, 2014.
+*/
+
+#include <vector>
+
+#include "hdf5.h"
+#include "hdf5_hl.h"
+
+#include "caffe/blob.hpp"
+#include "caffe/common.hpp"
+#include "caffe/layer.hpp"
+#include "caffe/util/io.hpp"
+#include "caffe/vision_layers.hpp"
+
+namespace caffe {
+using std::vector;
+
+template <typename Dtype>
+HDF5OutputLayer<Dtype>::HDF5OutputLayer(const LayerParameter& param)
+ : Layer<Dtype>(param),
+ file_name_(param.hdf5_output_param().file_name()) {
+ /* create a HDF5 file */
+ file_id_ = H5Fcreate(file_name_.c_str(), H5F_ACC_TRUNC, H5P_DEFAULT,
+ H5P_DEFAULT);
+ CHECK_GE(file_id_, 0) << "Failed to open HDF5 file" << file_name_;
+}
+
+template <typename Dtype>
+HDF5OutputLayer<Dtype>::~HDF5OutputLayer<Dtype>() {
+ herr_t status = H5Fclose(file_id_);
+ CHECK_GE(status, 0) << "Failed to close HDF5 file " << file_name_;
+}
+
+template <typename Dtype>
+void HDF5OutputLayer<Dtype>::SaveBlobs() {
+ // TODO: no limit on the number of blobs
+ LOG(INFO) << "Saving HDF5 file" << file_name_;
+ CHECK_EQ(data_blob_.num(), label_blob_.num()) <<
+ "data blob and label blob must have the same batch size";
+ hdf5_save_nd_dataset(file_id_, HDF5_DATA_DATASET_NAME, data_blob_);
+ hdf5_save_nd_dataset(file_id_, HDF5_DATA_LABEL_NAME, label_blob_);
+ LOG(INFO) << "Successfully saved " << data_blob_.num() << " rows";
+}
+
+template <typename Dtype>
+void HDF5OutputLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
+ vector<Blob<Dtype>*>* top) {
+ // TODO: no limit on the number of blobs
+ CHECK_EQ(bottom.size(), 2) << "HDF5OutputLayer takes two blobs as input.";
+ CHECK_EQ(top->size(), 0) << "HDF5OutputLayer takes no output blobs.";
+}
+
+template <typename Dtype>
+void HDF5OutputLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+ vector<Blob<Dtype>*>* top) {
+ CHECK_GE(bottom.size(), 2);
+ CHECK_EQ(bottom[0]->num(), bottom[1]->num());
+ data_blob_.Reshape(bottom[0]->num(), bottom[0]->channels(),
+ bottom[0]->height(), bottom[0]->width());
+ label_blob_.Reshape(bottom[1]->num(), bottom[1]->channels(),
+ bottom[1]->height(), bottom[1]->width());
+ const int data_datum_dim = bottom[0]->count() / bottom[0]->num();
+ const int label_datum_dim = bottom[1]->count() / bottom[1]->num();
+
+ for (int i = 0; i < bottom[0]->num(); ++i) {
+ memcpy(&data_blob_.mutable_cpu_data()[i * data_datum_dim],
+ &bottom[0]->cpu_data()[i * data_datum_dim],
+ sizeof(Dtype) * data_datum_dim);
+ memcpy(&label_blob_.mutable_cpu_data()[i * label_datum_dim],
+ &bottom[1]->cpu_data()[i * label_datum_dim],
+ sizeof(Dtype) * label_datum_dim);
+ }
+ SaveBlobs();
+}
+
+template <typename Dtype>
+void HDF5OutputLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+ vector<Blob<Dtype>*>* top) {
+ CHECK_GE(bottom.size(), 2);
+ CHECK_EQ(bottom[0]->num(), bottom[1]->num());
+ data_blob_.Reshape(bottom[0]->num(), bottom[0]->channels(),
+ bottom[0]->height(), bottom[0]->width());
+ label_blob_.Reshape(bottom[1]->num(), bottom[1]->channels(),
+ bottom[1]->height(), bottom[1]->width());
+ const int data_datum_dim = bottom[0]->count() / bottom[0]->num();
+ const int label_datum_dim = bottom[1]->count() / bottom[1]->num();
+
+ for (int i = 0; i < bottom[0]->num(); ++i) {
+ CUDA_CHECK(cudaMemcpy(&data_blob_.mutable_cpu_data()[i * data_datum_dim],
+ &bottom[0]->gpu_data()[i * data_datum_dim],
+ sizeof(Dtype) * data_datum_dim, cudaMemcpyDeviceToHost));
+ CUDA_CHECK(cudaMemcpy(&label_blob_.mutable_cpu_data()[i * label_datum_dim],
+ &bottom[1]->gpu_data()[i * label_datum_dim],
+ sizeof(Dtype) * label_datum_dim, cudaMemcpyDeviceToHost));
+ }
+ SaveBlobs();
+}
+
+template <typename Dtype>
+Dtype HDF5OutputLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
+ const bool propagate_down, vector<Blob<Dtype>*>* bottom) {
+ return Dtype(0.);
+}
+
+template <typename Dtype>
+Dtype HDF5OutputLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
+ const bool propagate_down, vector<Blob<Dtype>*>* bottom) {
+ return Dtype(0.);
+}
+
+INSTANTIATE_CLASS(HDF5OutputLayer);
+
+} // namespace caffe
// the other dimensions must be the same for all the bottom blobs.
// By default it will concatenate blobs along the channels dimension.
optional uint32 concat_dim = 65 [default = 1];
+
+ optional HDF5OutputParameter hdf5_output_param = 1001;
+}
+
+message HDF5OutputParameter {
+ optional string file_name = 1;
}
message LayerConnection {
--- /dev/null
+// Copyright 2014 kloudkl@github
+
+#include <cuda_runtime.h>
+#include <string>
+#include <vector>
+
+#include "gtest/gtest.h"
+#include "caffe/blob.hpp"
+#include "caffe/common.hpp"
+#include "caffe/util/io.hpp"
+#include "caffe/vision_layers.hpp"
+#include "caffe/proto/caffe.pb.h"
+#include "caffe/test/test_caffe_main.hpp"
+
+namespace caffe {
+using std::string;
+using std::vector;
+
+extern cudaDeviceProp CAFFE_TEST_CUDA_PROP;
+
+template <typename Dtype>
+class HDF5OutputLayerTest : public ::testing::Test {
+ protected:
+ HDF5OutputLayerTest()
+ : output_file_name_("/tmp/test_hdf5_output_layer-sample_data.hdf5"),
+ input_file_name_("src/caffe/test/test_data/sample_data.h5"),
+ blob_data_(new Blob<Dtype>()),
+ blob_label_(new Blob<Dtype>()),
+ num_(5),
+ channels_(8),
+ height_(5),
+ width_(5) {
+ }
+ virtual void SetUp() {
+ }
+
+ virtual ~HDF5OutputLayerTest() {
+ delete blob_data_;
+ delete blob_label_;
+ }
+
+ void CheckBlobEqual(const Blob<Dtype>& b1, const Blob<Dtype>& b2);
+
+ string output_file_name_;
+ string input_file_name_;
+ Blob<Dtype>* const blob_data_;
+ Blob<Dtype>* const blob_label_;
+ vector<Blob<Dtype>*> blob_bottom_vec_;
+ vector<Blob<Dtype>*> blob_top_vec_;
+ int num_;
+ int channels_;
+ int height_;
+ int width_;
+};
+
+template <typename Dtype>
+void HDF5OutputLayerTest<Dtype>::CheckBlobEqual(
+ const Blob<Dtype>& b1, const Blob<Dtype>& b2) {
+ EXPECT_EQ(b1.num(), b2.num());
+ EXPECT_EQ(b1.channels(), b2.channels());
+ EXPECT_EQ(b1.height(), b2.height());
+ EXPECT_EQ(b1.width(), b2.width());
+ for (int n = 0; n < b1.num(); ++n) {
+ for (int c = 0; c < b1.channels(); ++c) {
+ for (int h = 0; h < b1.height(); ++h) {
+ for (int w = 0; w < b1.width(); ++w) {
+ EXPECT_EQ(b1.data_at(n, c, h, w), b1.data_at(n, c, h, w));
+ }
+ }
+ }
+ }
+}
+
+typedef ::testing::Types<float, double> Dtypes;
+TYPED_TEST_CASE(HDF5OutputLayerTest, Dtypes);
+
+TYPED_TEST(HDF5OutputLayerTest, TestForward) {
+ LOG(INFO) << "Loading HDF5 file " << this->input_file_name_;
+ hid_t file_id = H5Fopen(this->input_file_name_.c_str(), H5F_ACC_RDONLY,
+ H5P_DEFAULT);
+ ASSERT_GE(file_id, 0) << "Failed to open HDF5 file" <<
+ this->input_file_name_;
+ hdf5_load_nd_dataset(file_id, HDF5_DATA_DATASET_NAME, 0, 4,
+ this->blob_data_);
+ hdf5_load_nd_dataset(file_id, HDF5_DATA_LABEL_NAME, 0, 4,
+ this->blob_label_);
+ herr_t status = H5Fclose(file_id);
+ EXPECT_GE(status, 0) << "Failed to close HDF5 file " <<
+ this->input_file_name_;
+ this->blob_bottom_vec_.push_back(this->blob_data_);
+ this->blob_bottom_vec_.push_back(this->blob_label_);
+
+ Caffe::Brew modes[] = { Caffe::CPU, Caffe::GPU };
+ for (int m = 0; m < 2; ++m) {
+ Caffe::set_mode(modes[m]);
+ LayerParameter param;
+ param.mutable_hdf5_output_param()->set_file_name(this->output_file_name_);
+ // This code block ensures that the layer is deconstructed and
+ // the output hdf5 file is closed.
+ {
+ HDF5OutputLayer<TypeParam> layer(param);
+ EXPECT_EQ(layer.file_name(), this->output_file_name_);
+ layer.SetUp(this->blob_bottom_vec_, &this->blob_top_vec_);
+ layer.Forward(this->blob_bottom_vec_, &this->blob_top_vec_);
+ }
+ hid_t file_id = H5Fopen(this->output_file_name_.c_str(), H5F_ACC_RDONLY,
+ H5P_DEFAULT);
+ ASSERT_GE(file_id, 0) << "Failed to open HDF5 file" <<
+ this->input_file_name_;
+
+ Blob<TypeParam>* blob_data = new Blob<TypeParam>();
+ hdf5_load_nd_dataset(file_id, HDF5_DATA_DATASET_NAME, 0, 4,
+ blob_data);
+ this->CheckBlobEqual(*(this->blob_data_), *blob_data);
+
+ Blob<TypeParam>* blob_label = new Blob<TypeParam>();
+ hdf5_load_nd_dataset(file_id, HDF5_DATA_LABEL_NAME, 0, 4,
+ blob_label);
+ this->CheckBlobEqual(*(this->blob_label_), *blob_label);
+
+ herr_t status = H5Fclose(file_id);
+ EXPECT_GE(status, 0) << "Failed to close HDF5 file " <<
+ this->output_file_name_;
+ }
+}
+
+} // namespace caffe