1 #include "caffe/util/hdf5.hpp"
8 // Verifies format of data stored in HDF5 file and reshapes blob accordingly.
9 template <typename Dtype>
10 void hdf5_load_nd_dataset_helper(
11 hid_t file_id, const char* dataset_name_, int min_dim, int max_dim,
12 Blob<Dtype>* blob, bool reshape) {
13 // Verify that the dataset exists.
14 CHECK(H5LTfind_dataset(file_id, dataset_name_))
15 << "Failed to find HDF5 dataset " << dataset_name_;
16 // Verify that the number of dimensions is in the accepted range.
19 status = H5LTget_dataset_ndims(file_id, dataset_name_, &ndims);
20 CHECK_GE(status, 0) << "Failed to get dataset ndims for " << dataset_name_;
21 CHECK_GE(ndims, min_dim);
22 CHECK_LE(ndims, max_dim);
24 // Verify that the data format is what we expect: float or double.
25 std::vector<hsize_t> dims(ndims);
27 status = H5LTget_dataset_info(
28 file_id, dataset_name_, dims.data(), &class_, NULL);
29 CHECK_GE(status, 0) << "Failed to get dataset info for " << dataset_name_;
32 LOG_FIRST_N(INFO, 1) << "Datatype class: H5T_FLOAT";
35 LOG_FIRST_N(INFO, 1) << "Datatype class: H5T_INTEGER";
38 LOG(FATAL) << "Unsupported datatype class: H5T_TIME";
40 LOG(FATAL) << "Unsupported datatype class: H5T_STRING";
42 LOG(FATAL) << "Unsupported datatype class: H5T_BITFIELD";
44 LOG(FATAL) << "Unsupported datatype class: H5T_OPAQUE";
46 LOG(FATAL) << "Unsupported datatype class: H5T_COMPOUND";
48 LOG(FATAL) << "Unsupported datatype class: H5T_REFERENCE";
50 LOG(FATAL) << "Unsupported datatype class: H5T_ENUM";
52 LOG(FATAL) << "Unsupported datatype class: H5T_VLEN";
54 LOG(FATAL) << "Unsupported datatype class: H5T_ARRAY";
56 LOG(FATAL) << "Datatype class unknown";
60 vector<int> blob_dims(dims.size());
61 for (int i = 0; i < dims.size(); ++i) {
62 blob_dims[i] = dims[i];
66 blob->Reshape(blob_dims);
68 if (blob_dims != blob->shape()) {
69 // create shape string for error message
72 for (int i = 0; i < blob_dims.size(); ++i) {
73 stream << blob_dims[i] << " ";
74 count = count * blob_dims[i];
76 stream << "(" << count << ")";
77 string source_shape_string = stream.str();
79 CHECK(blob_dims == blob->shape()) << "Cannot load blob from hdf5; shape "
80 << "mismatch. Source shape is " << source_shape_string
81 << " target shape is " << blob->shape_string();
87 void hdf5_load_nd_dataset<float>(hid_t file_id, const char* dataset_name_,
88 int min_dim, int max_dim, Blob<float>* blob, bool reshape) {
89 hdf5_load_nd_dataset_helper(file_id, dataset_name_, min_dim, max_dim, blob,
91 herr_t status = H5LTread_dataset_float(
92 file_id, dataset_name_, blob->mutable_cpu_data());
93 CHECK_GE(status, 0) << "Failed to read float dataset " << dataset_name_;
97 void hdf5_load_nd_dataset<double>(hid_t file_id, const char* dataset_name_,
98 int min_dim, int max_dim, Blob<double>* blob, bool reshape) {
99 hdf5_load_nd_dataset_helper(file_id, dataset_name_, min_dim, max_dim, blob,
101 herr_t status = H5LTread_dataset_double(
102 file_id, dataset_name_, blob->mutable_cpu_data());
103 CHECK_GE(status, 0) << "Failed to read double dataset " << dataset_name_;
107 void hdf5_save_nd_dataset<float>(
108 const hid_t file_id, const string& dataset_name, const Blob<float>& blob,
110 int num_axes = blob.num_axes();
111 hsize_t *dims = new hsize_t[num_axes];
112 for (int i = 0; i < num_axes; ++i) {
113 dims[i] = blob.shape(i);
117 data = blob.cpu_diff();
119 data = blob.cpu_data();
121 herr_t status = H5LTmake_dataset_float(
122 file_id, dataset_name.c_str(), num_axes, dims, data);
123 CHECK_GE(status, 0) << "Failed to make float dataset " << dataset_name;
128 void hdf5_save_nd_dataset<double>(
129 hid_t file_id, const string& dataset_name, const Blob<double>& blob,
131 int num_axes = blob.num_axes();
132 hsize_t *dims = new hsize_t[num_axes];
133 for (int i = 0; i < num_axes; ++i) {
134 dims[i] = blob.shape(i);
138 data = blob.cpu_diff();
140 data = blob.cpu_data();
142 herr_t status = H5LTmake_dataset_double(
143 file_id, dataset_name.c_str(), num_axes, dims, data);
144 CHECK_GE(status, 0) << "Failed to make double dataset " << dataset_name;
148 string hdf5_load_string(hid_t loc_id, const string& dataset_name) {
149 // Get size of dataset
153 H5LTget_dataset_info(loc_id, dataset_name.c_str(), NULL, &class_, &size);
154 CHECK_GE(status, 0) << "Failed to get dataset info for " << dataset_name;
155 char *buf = new char[size];
156 status = H5LTread_dataset_string(loc_id, dataset_name.c_str(), buf);
158 << "Failed to load int dataset with name " << dataset_name;
164 void hdf5_save_string(hid_t loc_id, const string& dataset_name,
167 H5LTmake_dataset_string(loc_id, dataset_name.c_str(), s.c_str());
169 << "Failed to save string dataset with name " << dataset_name;
172 int hdf5_load_int(hid_t loc_id, const string& dataset_name) {
174 herr_t status = H5LTread_dataset_int(loc_id, dataset_name.c_str(), &val);
176 << "Failed to load int dataset with name " << dataset_name;
180 void hdf5_save_int(hid_t loc_id, const string& dataset_name, int i) {
183 H5LTmake_dataset_int(loc_id, dataset_name.c_str(), 1, &one, &i);
185 << "Failed to save int dataset with name " << dataset_name;
188 int hdf5_get_num_links(hid_t loc_id) {
190 herr_t status = H5Gget_info(loc_id, &info);
191 CHECK_GE(status, 0) << "Error while counting HDF5 links.";
195 string hdf5_get_name_by_idx(hid_t loc_id, int idx) {
196 ssize_t str_size = H5Lget_name_by_idx(
197 loc_id, ".", H5_INDEX_NAME, H5_ITER_NATIVE, idx, NULL, 0, H5P_DEFAULT);
198 CHECK_GE(str_size, 0) << "Error retrieving HDF5 dataset at index " << idx;
199 char *c_str = new char[str_size+1];
200 ssize_t status = H5Lget_name_by_idx(
201 loc_id, ".", H5_INDEX_NAME, H5_ITER_NATIVE, idx, c_str, str_size+1,
203 CHECK_GE(status, 0) << "Error retrieving HDF5 dataset at index " << idx;
204 string result(c_str);