Merge pull request #4630 from BlGene/load_hdf5_fix
[platform/upstream/caffeonacl.git] / src / caffe / util / hdf5.cpp
index d255877..ed73742 100644 (file)
@@ -9,7 +9,7 @@ namespace caffe {
 template <typename Dtype>
 void hdf5_load_nd_dataset_helper(
     hid_t file_id, const char* dataset_name_, int min_dim, int max_dim,
-    Blob<Dtype>* blob) {
+    Blob<Dtype>* blob, bool reshape) {
   // Verify that the dataset exists.
   CHECK(H5LTfind_dataset(file_id, dataset_name_))
       << "Failed to find HDF5 dataset " << dataset_name_;
@@ -56,17 +56,38 @@ void hdf5_load_nd_dataset_helper(
     LOG(FATAL) << "Datatype class unknown";
   }
 
+
   vector<int> blob_dims(dims.size());
   for (int i = 0; i < dims.size(); ++i) {
     blob_dims[i] = dims[i];
   }
-  blob->Reshape(blob_dims);
+
+  if (reshape) {
+    blob->Reshape(blob_dims);
+  } else {
+    if (blob_dims != blob->shape()) {
+      // create shape string for error message
+      ostringstream stream;
+      int count = 1;
+      for (int i = 0; i < blob_dims.size(); ++i) {
+        stream << blob_dims[i] << " ";
+        count = count * blob_dims[i];
+      }
+      stream << "(" << count << ")";
+      string source_shape_string = stream.str();
+
+      CHECK(blob_dims == blob->shape()) << "Cannot load blob from hdf5; shape "
+            << "mismatch. Source shape is " << source_shape_string
+            << " target shape is " << blob->shape_string();
+    }
+  }
 }
 
 template <>
 void hdf5_load_nd_dataset<float>(hid_t file_id, const char* dataset_name_,
-        int min_dim, int max_dim, Blob<float>* blob) {
-  hdf5_load_nd_dataset_helper(file_id, dataset_name_, min_dim, max_dim, blob);
+        int min_dim, int max_dim, Blob<float>* blob, bool reshape) {
+  hdf5_load_nd_dataset_helper(file_id, dataset_name_, min_dim, max_dim, blob,
+                              reshape);
   herr_t status = H5LTread_dataset_float(
     file_id, dataset_name_, blob->mutable_cpu_data());
   CHECK_GE(status, 0) << "Failed to read float dataset " << dataset_name_;
@@ -74,8 +95,9 @@ void hdf5_load_nd_dataset<float>(hid_t file_id, const char* dataset_name_,
 
 template <>
 void hdf5_load_nd_dataset<double>(hid_t file_id, const char* dataset_name_,
-        int min_dim, int max_dim, Blob<double>* blob) {
-  hdf5_load_nd_dataset_helper(file_id, dataset_name_, min_dim, max_dim, blob);
+        int min_dim, int max_dim, Blob<double>* blob, bool reshape) {
+  hdf5_load_nd_dataset_helper(file_id, dataset_name_, min_dim, max_dim, blob,
+                              reshape);
   herr_t status = H5LTread_dataset_double(
     file_id, dataset_name_, blob->mutable_cpu_data());
   CHECK_GE(status, 0) << "Failed to read double dataset " << dataset_name_;