template <typename Dtype>
void hdf5_load_nd_dataset_helper(
hid_t file_id, const char* dataset_name_, int min_dim, int max_dim,
- Blob<Dtype>* blob);
+ Blob<Dtype>* blob, bool reshape);
template <typename Dtype>
void hdf5_load_nd_dataset(
hid_t file_id, const char* dataset_name_, int min_dim, int max_dim,
- Blob<Dtype>* blob);
+ Blob<Dtype>* blob, bool reshape = false);
template <typename Dtype>
void hdf5_save_nd_dataset(
H5P_DEFAULT);
ASSERT_GE(file_id, 0)<< "Failed to open HDF5 file" <<
this->input_file_name_;
+ // Allow reshape here as we are loading data not params
+ bool reshape = true;
hdf5_load_nd_dataset(file_id, HDF5_DATA_DATASET_NAME, 0, 4,
- this->blob_data_);
+ this->blob_data_, reshape);
hdf5_load_nd_dataset(file_id, HDF5_DATA_LABEL_NAME, 0, 4,
- this->blob_label_);
+ this->blob_label_, reshape);
herr_t status = H5Fclose(file_id);
EXPECT_GE(status, 0)<< "Failed to close HDF5 file " <<
this->input_file_name_;
Blob<Dtype>* blob_data = new Blob<Dtype>();
hdf5_load_nd_dataset(file_id, HDF5_DATA_DATASET_NAME, 0, 4,
- blob_data);
+ blob_data, reshape);
this->CheckBlobEqual(*(this->blob_data_), *blob_data);
Blob<Dtype>* blob_label = new Blob<Dtype>();
hdf5_load_nd_dataset(file_id, HDF5_DATA_LABEL_NAME, 0, 4,
- blob_label);
+ blob_label, reshape);
this->CheckBlobEqual(*(this->blob_label_), *blob_label);
status = H5Fclose(file_id);
template <typename Dtype>
void hdf5_load_nd_dataset_helper(
hid_t file_id, const char* dataset_name_, int min_dim, int max_dim,
- Blob<Dtype>* blob) {
+ Blob<Dtype>* blob, bool reshape) {
// Verify that the dataset exists.
CHECK(H5LTfind_dataset(file_id, dataset_name_))
<< "Failed to find HDF5 dataset " << dataset_name_;
LOG(FATAL) << "Datatype class unknown";
}
+
vector<int> blob_dims(dims.size());
for (int i = 0; i < dims.size(); ++i) {
blob_dims[i] = dims[i];
}
- blob->Reshape(blob_dims);
+
+ if (reshape) {
+ blob->Reshape(blob_dims);
+ } else {
+ if (blob_dims != blob->shape()) {
+ // create shape string for error message
+ ostringstream stream;
+ int count = 1;
+ for (int i = 0; i < blob_dims.size(); ++i) {
+ stream << blob_dims[i] << " ";
+ count = count * blob_dims[i];
+ }
+ stream << "(" << count << ")";
+ string source_shape_string = stream.str();
+
+ CHECK(blob_dims == blob->shape()) << "Cannot load blob from hdf5; shape "
+ << "mismatch. Source shape is " << source_shape_string
+ << " target shape is " << blob->shape_string();
+ }
+ }
}
template <>
void hdf5_load_nd_dataset<float>(hid_t file_id, const char* dataset_name_,
- int min_dim, int max_dim, Blob<float>* blob) {
- hdf5_load_nd_dataset_helper(file_id, dataset_name_, min_dim, max_dim, blob);
+ int min_dim, int max_dim, Blob<float>* blob, bool reshape) {
+ hdf5_load_nd_dataset_helper(file_id, dataset_name_, min_dim, max_dim, blob,
+ reshape);
herr_t status = H5LTread_dataset_float(
file_id, dataset_name_, blob->mutable_cpu_data());
CHECK_GE(status, 0) << "Failed to read float dataset " << dataset_name_;
template <>
void hdf5_load_nd_dataset<double>(hid_t file_id, const char* dataset_name_,
- int min_dim, int max_dim, Blob<double>* blob) {
- hdf5_load_nd_dataset_helper(file_id, dataset_name_, min_dim, max_dim, blob);
+ int min_dim, int max_dim, Blob<double>* blob, bool reshape) {
+ hdf5_load_nd_dataset_helper(file_id, dataset_name_, min_dim, max_dim, blob,
+ reshape);
herr_t status = H5LTread_dataset_double(
file_id, dataset_name_, blob->mutable_cpu_data());
CHECK_GE(status, 0) << "Failed to read double dataset " << dataset_name_;