3833ebff78e89a42595b681ea1be08f6319ece7c
[platform/upstream/caffeonacl.git] / src / caffe / test / test_hdf5_output_layer.cpp
1 #include <string>
2 #include <vector>
3
4 #include "gtest/gtest.h"
5
6 #include "caffe/blob.hpp"
7 #include "caffe/common.hpp"
8 #include "caffe/layers/hdf5_output_layer.hpp"
9 #include "caffe/proto/caffe.pb.h"
10 #include "caffe/util/hdf5.hpp"
11 #include "caffe/util/io.hpp"
12
13 #include "caffe/test/test_caffe_main.hpp"
14
15 namespace caffe {
16
17 template<typename TypeParam>
18 class HDF5OutputLayerTest : public MultiDeviceTest<TypeParam> {
19   typedef typename TypeParam::Dtype Dtype;
20
21  protected:
22   HDF5OutputLayerTest()
23       : input_file_name_(
24         CMAKE_SOURCE_DIR "caffe/test/test_data/sample_data.h5"),
25         blob_data_(new Blob<Dtype>()),
26         blob_label_(new Blob<Dtype>()),
27         num_(5),
28         channels_(8),
29         height_(5),
30         width_(5) {
31     MakeTempFilename(&output_file_name_);
32   }
33
34   virtual ~HDF5OutputLayerTest() {
35     delete blob_data_;
36     delete blob_label_;
37   }
38
39   void CheckBlobEqual(const Blob<Dtype>& b1, const Blob<Dtype>& b2);
40
41   string output_file_name_;
42   string input_file_name_;
43   Blob<Dtype>* const blob_data_;
44   Blob<Dtype>* const blob_label_;
45   vector<Blob<Dtype>*> blob_bottom_vec_;
46   vector<Blob<Dtype>*> blob_top_vec_;
47   int num_;
48   int channels_;
49   int height_;
50   int width_;
51 };
52
53 template<typename TypeParam>
54 void HDF5OutputLayerTest<TypeParam>::CheckBlobEqual(const Blob<Dtype>& b1,
55                                                     const Blob<Dtype>& b2) {
56   EXPECT_EQ(b1.num(), b2.num());
57   EXPECT_EQ(b1.channels(), b2.channels());
58   EXPECT_EQ(b1.height(), b2.height());
59   EXPECT_EQ(b1.width(), b2.width());
60   for (int n = 0; n < b1.num(); ++n) {
61     for (int c = 0; c < b1.channels(); ++c) {
62       for (int h = 0; h < b1.height(); ++h) {
63         for (int w = 0; w < b1.width(); ++w) {
64           EXPECT_EQ(b1.data_at(n, c, h, w), b2.data_at(n, c, h, w));
65         }
66       }
67     }
68   }
69 }
70
71 TYPED_TEST_CASE(HDF5OutputLayerTest, TestDtypesAndDevices);
72
73 TYPED_TEST(HDF5OutputLayerTest, TestForward) {
74   typedef typename TypeParam::Dtype Dtype;
75   LOG(INFO) << "Loading HDF5 file " << this->input_file_name_;
76   hid_t file_id = H5Fopen(this->input_file_name_.c_str(), H5F_ACC_RDONLY,
77                           H5P_DEFAULT);
78   ASSERT_GE(file_id, 0)<< "Failed to open HDF5 file" <<
79       this->input_file_name_;
80   hdf5_load_nd_dataset(file_id, HDF5_DATA_DATASET_NAME, 0, 4,
81                        this->blob_data_);
82   hdf5_load_nd_dataset(file_id, HDF5_DATA_LABEL_NAME, 0, 4,
83                        this->blob_label_);
84   herr_t status = H5Fclose(file_id);
85   EXPECT_GE(status, 0)<< "Failed to close HDF5 file " <<
86       this->input_file_name_;
87   this->blob_bottom_vec_.push_back(this->blob_data_);
88   this->blob_bottom_vec_.push_back(this->blob_label_);
89
90   LayerParameter param;
91   param.mutable_hdf5_output_param()->set_file_name(this->output_file_name_);
92   // This code block ensures that the layer is deconstructed and
93   //   the output hdf5 file is closed.
94   {
95     HDF5OutputLayer<Dtype> layer(param);
96     layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
97     EXPECT_EQ(layer.file_name(), this->output_file_name_);
98     layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_);
99   }
100   file_id = H5Fopen(this->output_file_name_.c_str(), H5F_ACC_RDONLY,
101                           H5P_DEFAULT);
102   ASSERT_GE(
103     file_id, 0)<< "Failed to open HDF5 file" <<
104           this->input_file_name_;
105
106   Blob<Dtype>* blob_data = new Blob<Dtype>();
107   hdf5_load_nd_dataset(file_id, HDF5_DATA_DATASET_NAME, 0, 4,
108                        blob_data);
109   this->CheckBlobEqual(*(this->blob_data_), *blob_data);
110
111   Blob<Dtype>* blob_label = new Blob<Dtype>();
112   hdf5_load_nd_dataset(file_id, HDF5_DATA_LABEL_NAME, 0, 4,
113                        blob_label);
114   this->CheckBlobEqual(*(this->blob_label_), *blob_label);
115
116   status = H5Fclose(file_id);
117   EXPECT_GE(status, 0) << "Failed to close HDF5 file " <<
118       this->output_file_name_;
119 }
120
121 }  // namespace caffe