Fix: made load_hd5 check blob dims by default.
authormax argus <argus.max@gmail.com>
Thu, 25 Aug 2016 09:20:24 +0000 (09:20 +0000)
committermax argus <argus.max@gmail.com>
Sat, 22 Oct 2016 00:41:19 +0000 (00:41 +0000)
Size checks are needed for loading parameters to avoid strange bugs
when loading data we continue to reshape.

include/caffe/util/hdf5.hpp
src/caffe/layers/hdf5_data_layer.cpp
src/caffe/test/test_hdf5_output_layer.cpp
src/caffe/test/test_hdf5data_layer.cpp
src/caffe/util/hdf5.cpp

index ce568c5..71549c1 100644 (file)
@@ -13,12 +13,12 @@ namespace caffe {
 template <typename Dtype>
 void hdf5_load_nd_dataset_helper(
     hid_t file_id, const char* dataset_name_, int min_dim, int max_dim,
-    Blob<Dtype>* blob);
+    Blob<Dtype>* blob, bool reshape);
 
 template <typename Dtype>
 void hdf5_load_nd_dataset(
     hid_t file_id, const char* dataset_name_, int min_dim, int max_dim,
-    Blob<Dtype>* blob);
+    Blob<Dtype>* blob, bool reshape = false);
 
 template <typename Dtype>
 void hdf5_save_nd_dataset(
index 2f13dc6..0099129 100644 (file)
@@ -39,8 +39,9 @@ void HDF5DataLayer<Dtype>::LoadHDF5FileData(const char* filename) {
 
   for (int i = 0; i < top_size; ++i) {
     hdf_blobs_[i] = shared_ptr<Blob<Dtype> >(new Blob<Dtype>());
+    // Allow reshape here, as we are loading data not params
     hdf5_load_nd_dataset(file_id, this->layer_param_.top(i).c_str(),
-        MIN_DATA_DIM, MAX_DATA_DIM, hdf_blobs_[i].get());
+        MIN_DATA_DIM, MAX_DATA_DIM, hdf_blobs_[i].get(), true);
   }
 
   herr_t status = H5Fclose(file_id);
index 3833ebf..2bc2de1 100644 (file)
@@ -77,10 +77,12 @@ TYPED_TEST(HDF5OutputLayerTest, TestForward) {
                           H5P_DEFAULT);
   ASSERT_GE(file_id, 0)<< "Failed to open HDF5 file" <<
       this->input_file_name_;
+  // Allow reshape here as we are loading data not params
+  bool reshape = true;
   hdf5_load_nd_dataset(file_id, HDF5_DATA_DATASET_NAME, 0, 4,
-                       this->blob_data_);
+                       this->blob_data_, reshape);
   hdf5_load_nd_dataset(file_id, HDF5_DATA_LABEL_NAME, 0, 4,
-                       this->blob_label_);
+                       this->blob_label_, reshape);
   herr_t status = H5Fclose(file_id);
   EXPECT_GE(status, 0)<< "Failed to close HDF5 file " <<
       this->input_file_name_;
@@ -105,12 +107,12 @@ TYPED_TEST(HDF5OutputLayerTest, TestForward) {
 
   Blob<Dtype>* blob_data = new Blob<Dtype>();
   hdf5_load_nd_dataset(file_id, HDF5_DATA_DATASET_NAME, 0, 4,
-                       blob_data);
+                       blob_data, reshape);
   this->CheckBlobEqual(*(this->blob_data_), *blob_data);
 
   Blob<Dtype>* blob_label = new Blob<Dtype>();
   hdf5_load_nd_dataset(file_id, HDF5_DATA_LABEL_NAME, 0, 4,
-                       blob_label);
+                       blob_label, reshape);
   this->CheckBlobEqual(*(this->blob_label_), *blob_label);
 
   status = H5Fclose(file_id);
index 8884ce9..e0fd621 100644 (file)
@@ -70,7 +70,7 @@ TYPED_TEST(HDF5DataLayerTest, TestRead) {
   int height = 6;
   int width = 5;
 
-  // Test that the layer setup got the correct parameters.
+  // Test that the layer setup gives correct parameters.
   HDF5DataLayer<Dtype> layer(param);
   layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
   EXPECT_EQ(this->blob_top_data_->num(), batch_size);
index 7730e76..0003f1b 100644 (file)
@@ -9,7 +9,7 @@ namespace caffe {
 template <typename Dtype>
 void hdf5_load_nd_dataset_helper(
     hid_t file_id, const char* dataset_name_, int min_dim, int max_dim,
-    Blob<Dtype>* blob) {
+    Blob<Dtype>* blob, bool reshape) {
   // Verify that the dataset exists.
   CHECK(H5LTfind_dataset(file_id, dataset_name_))
       << "Failed to find HDF5 dataset " << dataset_name_;
@@ -56,17 +56,38 @@ void hdf5_load_nd_dataset_helper(
     LOG(FATAL) << "Datatype class unknown";
   }
 
+
   vector<int> blob_dims(dims.size());
   for (int i = 0; i < dims.size(); ++i) {
     blob_dims[i] = dims[i];
   }
-  blob->Reshape(blob_dims);
+
+  if (reshape) {
+    blob->Reshape(blob_dims);
+  } else {
+    if (blob_dims != blob->shape()) {
+      // create shape string for error message
+      ostringstream stream;
+      int count = 1;
+      for (int i = 0; i < blob_dims.size(); ++i) {
+        stream << blob_dims[i] << " ";
+        count = count * blob_dims[i];
+      }
+      stream << "(" << count << ")";
+      string source_shape_string = stream.str();
+
+      CHECK(blob_dims == blob->shape()) << "Cannot load blob from hdf5; shape "
+            << "mismatch. Source shape is " << source_shape_string
+            << " target shape is " << blob->shape_string();
+    }
+  }
 }
 
 template <>
 void hdf5_load_nd_dataset<float>(hid_t file_id, const char* dataset_name_,
-        int min_dim, int max_dim, Blob<float>* blob) {
-  hdf5_load_nd_dataset_helper(file_id, dataset_name_, min_dim, max_dim, blob);
+        int min_dim, int max_dim, Blob<float>* blob, bool reshape) {
+  hdf5_load_nd_dataset_helper(file_id, dataset_name_, min_dim, max_dim, blob,
+                              reshape);
   herr_t status = H5LTread_dataset_float(
     file_id, dataset_name_, blob->mutable_cpu_data());
   CHECK_GE(status, 0) << "Failed to read float dataset " << dataset_name_;
@@ -74,8 +95,9 @@ void hdf5_load_nd_dataset<float>(hid_t file_id, const char* dataset_name_,
 
 template <>
 void hdf5_load_nd_dataset<double>(hid_t file_id, const char* dataset_name_,
-        int min_dim, int max_dim, Blob<double>* blob) {
-  hdf5_load_nd_dataset_helper(file_id, dataset_name_, min_dim, max_dim, blob);
+        int min_dim, int max_dim, Blob<double>* blob, bool reshape) {
+  hdf5_load_nd_dataset_helper(file_id, dataset_name_, min_dim, max_dim, blob,
+                              reshape);
   herr_t status = H5LTread_dataset_double(
     file_id, dataset_name_, blob->mutable_cpu_data());
   CHECK_GE(status, 0) << "Failed to read double dataset " << dataset_name_;