1 #include "caffe/util/hdf5.hpp"
8 // Verifies format of data stored in HDF5 file and reshapes blob accordingly.
9 template <typename Dtype>
10 void hdf5_load_nd_dataset_helper(
11 hid_t file_id, const char* dataset_name_, int min_dim, int max_dim,
13 // Verify that the dataset exists.
14 CHECK(H5LTfind_dataset(file_id, dataset_name_))
15 << "Failed to find HDF5 dataset " << dataset_name_;
16 // Verify that the number of dimensions is in the accepted range.
19 status = H5LTget_dataset_ndims(file_id, dataset_name_, &ndims);
20 CHECK_GE(status, 0) << "Failed to get dataset ndims for " << dataset_name_;
21 CHECK_GE(ndims, min_dim);
22 CHECK_LE(ndims, max_dim);
24 // Verify that the data format is what we expect: float or double.
25 std::vector<hsize_t> dims(ndims);
27 status = H5LTget_dataset_info(
28 file_id, dataset_name_, dims.data(), &class_, NULL);
29 CHECK_GE(status, 0) << "Failed to get dataset info for " << dataset_name_;
30 CHECK_EQ(class_, H5T_FLOAT) << "Expected float or double data";
32 vector<int> blob_dims(dims.size());
33 for (int i = 0; i < dims.size(); ++i) {
34 blob_dims[i] = dims[i];
36 blob->Reshape(blob_dims);
40 void hdf5_load_nd_dataset<float>(hid_t file_id, const char* dataset_name_,
41 int min_dim, int max_dim, Blob<float>* blob) {
42 hdf5_load_nd_dataset_helper(file_id, dataset_name_, min_dim, max_dim, blob);
43 herr_t status = H5LTread_dataset_float(
44 file_id, dataset_name_, blob->mutable_cpu_data());
45 CHECK_GE(status, 0) << "Failed to read float dataset " << dataset_name_;
49 void hdf5_load_nd_dataset<double>(hid_t file_id, const char* dataset_name_,
50 int min_dim, int max_dim, Blob<double>* blob) {
51 hdf5_load_nd_dataset_helper(file_id, dataset_name_, min_dim, max_dim, blob);
52 herr_t status = H5LTread_dataset_double(
53 file_id, dataset_name_, blob->mutable_cpu_data());
54 CHECK_GE(status, 0) << "Failed to read double dataset " << dataset_name_;
58 void hdf5_save_nd_dataset<float>(
59 const hid_t file_id, const string& dataset_name, const Blob<float>& blob,
61 int num_axes = blob.num_axes();
62 hsize_t *dims = new hsize_t[num_axes];
63 for (int i = 0; i < num_axes; ++i) {
64 dims[i] = blob.shape(i);
68 data = blob.cpu_diff();
70 data = blob.cpu_data();
72 herr_t status = H5LTmake_dataset_float(
73 file_id, dataset_name.c_str(), num_axes, dims, data);
74 CHECK_GE(status, 0) << "Failed to make float dataset " << dataset_name;
79 void hdf5_save_nd_dataset<double>(
80 hid_t file_id, const string& dataset_name, const Blob<double>& blob,
82 int num_axes = blob.num_axes();
83 hsize_t *dims = new hsize_t[num_axes];
84 for (int i = 0; i < num_axes; ++i) {
85 dims[i] = blob.shape(i);
89 data = blob.cpu_diff();
91 data = blob.cpu_data();
93 herr_t status = H5LTmake_dataset_double(
94 file_id, dataset_name.c_str(), num_axes, dims, data);
95 CHECK_GE(status, 0) << "Failed to make double dataset " << dataset_name;
99 string hdf5_load_string(hid_t loc_id, const string& dataset_name) {
100 // Get size of dataset
104 H5LTget_dataset_info(loc_id, dataset_name.c_str(), NULL, &class_, &size);
105 CHECK_GE(status, 0) << "Failed to get dataset info for " << dataset_name;
106 char *buf = new char[size];
107 status = H5LTread_dataset_string(loc_id, dataset_name.c_str(), buf);
109 << "Failed to load int dataset with name " << dataset_name;
115 void hdf5_save_string(hid_t loc_id, const string& dataset_name,
118 H5LTmake_dataset_string(loc_id, dataset_name.c_str(), s.c_str());
120 << "Failed to save string dataset with name " << dataset_name;
123 int hdf5_load_int(hid_t loc_id, const string& dataset_name) {
125 herr_t status = H5LTread_dataset_int(loc_id, dataset_name.c_str(), &val);
127 << "Failed to load int dataset with name " << dataset_name;
131 void hdf5_save_int(hid_t loc_id, const string& dataset_name, int i) {
134 H5LTmake_dataset_int(loc_id, dataset_name.c_str(), 1, &one, &i);
136 << "Failed to save int dataset with name " << dataset_name;
139 int hdf5_get_num_links(hid_t loc_id) {
141 herr_t status = H5Gget_info(loc_id, &info);
142 CHECK_GE(status, 0) << "Error while counting HDF5 links.";
146 string hdf5_get_name_by_idx(hid_t loc_id, int idx) {
147 ssize_t str_size = H5Lget_name_by_idx(
148 loc_id, ".", H5_INDEX_NAME, H5_ITER_NATIVE, idx, NULL, 0, H5P_DEFAULT);
149 CHECK_GE(str_size, 0) << "Error retrieving HDF5 dataset at index " << idx;
150 char *c_str = new char[str_size+1];
151 ssize_t status = H5Lget_name_by_idx(
152 loc_id, ".", H5_INDEX_NAME, H5_ITER_NATIVE, idx, c_str, str_size+1,
154 CHECK_GE(status, 0) << "Error retrieving HDF5 dataset at index " << idx;
155 string result(c_str);