4 #include "gtest/gtest.h"
6 #include "caffe/blob.hpp"
7 #include "caffe/common.hpp"
8 #include "caffe/layers/hdf5_output_layer.hpp"
9 #include "caffe/proto/caffe.pb.h"
10 #include "caffe/util/hdf5.hpp"
11 #include "caffe/util/io.hpp"
13 #include "caffe/test/test_caffe_main.hpp"
17 template<typename TypeParam>
18 class HDF5OutputLayerTest : public MultiDeviceTest<TypeParam> {
19 typedef typename TypeParam::Dtype Dtype;
24 CMAKE_SOURCE_DIR "caffe/test/test_data/sample_data.h5"),
25 blob_data_(new Blob<Dtype>()),
26 blob_label_(new Blob<Dtype>()),
31 MakeTempFilename(&output_file_name_);
34 virtual ~HDF5OutputLayerTest() {
39 void CheckBlobEqual(const Blob<Dtype>& b1, const Blob<Dtype>& b2);
41 string output_file_name_;
42 string input_file_name_;
43 Blob<Dtype>* const blob_data_;
44 Blob<Dtype>* const blob_label_;
45 vector<Blob<Dtype>*> blob_bottom_vec_;
46 vector<Blob<Dtype>*> blob_top_vec_;
53 template<typename TypeParam>
54 void HDF5OutputLayerTest<TypeParam>::CheckBlobEqual(const Blob<Dtype>& b1,
55 const Blob<Dtype>& b2) {
56 EXPECT_EQ(b1.num(), b2.num());
57 EXPECT_EQ(b1.channels(), b2.channels());
58 EXPECT_EQ(b1.height(), b2.height());
59 EXPECT_EQ(b1.width(), b2.width());
60 for (int n = 0; n < b1.num(); ++n) {
61 for (int c = 0; c < b1.channels(); ++c) {
62 for (int h = 0; h < b1.height(); ++h) {
63 for (int w = 0; w < b1.width(); ++w) {
64 EXPECT_EQ(b1.data_at(n, c, h, w), b2.data_at(n, c, h, w));
71 TYPED_TEST_CASE(HDF5OutputLayerTest, TestDtypesAndDevices);
73 TYPED_TEST(HDF5OutputLayerTest, TestForward) {
74 typedef typename TypeParam::Dtype Dtype;
75 LOG(INFO) << "Loading HDF5 file " << this->input_file_name_;
76 hid_t file_id = H5Fopen(this->input_file_name_.c_str(), H5F_ACC_RDONLY,
78 ASSERT_GE(file_id, 0)<< "Failed to open HDF5 file" <<
79 this->input_file_name_;
80 hdf5_load_nd_dataset(file_id, HDF5_DATA_DATASET_NAME, 0, 4,
82 hdf5_load_nd_dataset(file_id, HDF5_DATA_LABEL_NAME, 0, 4,
84 herr_t status = H5Fclose(file_id);
85 EXPECT_GE(status, 0)<< "Failed to close HDF5 file " <<
86 this->input_file_name_;
87 this->blob_bottom_vec_.push_back(this->blob_data_);
88 this->blob_bottom_vec_.push_back(this->blob_label_);
91 param.mutable_hdf5_output_param()->set_file_name(this->output_file_name_);
92 // This code block ensures that the layer is deconstructed and
93 // the output hdf5 file is closed.
95 HDF5OutputLayer<Dtype> layer(param);
96 layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
97 EXPECT_EQ(layer.file_name(), this->output_file_name_);
98 layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_);
100 file_id = H5Fopen(this->output_file_name_.c_str(), H5F_ACC_RDONLY,
103 file_id, 0)<< "Failed to open HDF5 file" <<
104 this->input_file_name_;
106 Blob<Dtype>* blob_data = new Blob<Dtype>();
107 hdf5_load_nd_dataset(file_id, HDF5_DATA_DATASET_NAME, 0, 4,
109 this->CheckBlobEqual(*(this->blob_data_), *blob_data);
111 Blob<Dtype>* blob_label = new Blob<Dtype>();
112 hdf5_load_nd_dataset(file_id, HDF5_DATA_LABEL_NAME, 0, 4,
114 this->CheckBlobEqual(*(this->blob_label_), *blob_label);
116 status = H5Fclose(file_id);
117 EXPECT_GE(status, 0) << "Failed to close HDF5 file " <<
118 this->output_file_name_;