Merge pull request #2836 from erictzeng/hdf5_snapshot
[platform/upstream/caffeonacl.git] / src / caffe / util / hdf5.cpp
1 #include "caffe/util/hdf5.hpp"
2
3 #include <string>
4 #include <vector>
5
6 namespace caffe {
7
8 // Verifies format of data stored in HDF5 file and reshapes blob accordingly.
9 template <typename Dtype>
10 void hdf5_load_nd_dataset_helper(
11     hid_t file_id, const char* dataset_name_, int min_dim, int max_dim,
12     Blob<Dtype>* blob) {
13   // Verify that the dataset exists.
14   CHECK(H5LTfind_dataset(file_id, dataset_name_))
15       << "Failed to find HDF5 dataset " << dataset_name_;
16   // Verify that the number of dimensions is in the accepted range.
17   herr_t status;
18   int ndims;
19   status = H5LTget_dataset_ndims(file_id, dataset_name_, &ndims);
20   CHECK_GE(status, 0) << "Failed to get dataset ndims for " << dataset_name_;
21   CHECK_GE(ndims, min_dim);
22   CHECK_LE(ndims, max_dim);
23
24   // Verify that the data format is what we expect: float or double.
25   std::vector<hsize_t> dims(ndims);
26   H5T_class_t class_;
27   status = H5LTget_dataset_info(
28       file_id, dataset_name_, dims.data(), &class_, NULL);
29   CHECK_GE(status, 0) << "Failed to get dataset info for " << dataset_name_;
30   CHECK_EQ(class_, H5T_FLOAT) << "Expected float or double data";
31
32   vector<int> blob_dims(dims.size());
33   for (int i = 0; i < dims.size(); ++i) {
34     blob_dims[i] = dims[i];
35   }
36   blob->Reshape(blob_dims);
37 }
38
39 template <>
40 void hdf5_load_nd_dataset<float>(hid_t file_id, const char* dataset_name_,
41         int min_dim, int max_dim, Blob<float>* blob) {
42   hdf5_load_nd_dataset_helper(file_id, dataset_name_, min_dim, max_dim, blob);
43   herr_t status = H5LTread_dataset_float(
44     file_id, dataset_name_, blob->mutable_cpu_data());
45   CHECK_GE(status, 0) << "Failed to read float dataset " << dataset_name_;
46 }
47
48 template <>
49 void hdf5_load_nd_dataset<double>(hid_t file_id, const char* dataset_name_,
50         int min_dim, int max_dim, Blob<double>* blob) {
51   hdf5_load_nd_dataset_helper(file_id, dataset_name_, min_dim, max_dim, blob);
52   herr_t status = H5LTread_dataset_double(
53     file_id, dataset_name_, blob->mutable_cpu_data());
54   CHECK_GE(status, 0) << "Failed to read double dataset " << dataset_name_;
55 }
56
57 template <>
58 void hdf5_save_nd_dataset<float>(
59     const hid_t file_id, const string& dataset_name, const Blob<float>& blob,
60     bool write_diff) {
61   int num_axes = blob.num_axes();
62   hsize_t *dims = new hsize_t[num_axes];
63   for (int i = 0; i < num_axes; ++i) {
64     dims[i] = blob.shape(i);
65   }
66   const float* data;
67   if (write_diff) {
68     data = blob.cpu_diff();
69   } else {
70     data = blob.cpu_data();
71   }
72   herr_t status = H5LTmake_dataset_float(
73       file_id, dataset_name.c_str(), num_axes, dims, data);
74   CHECK_GE(status, 0) << "Failed to make float dataset " << dataset_name;
75   delete[] dims;
76 }
77
78 template <>
79 void hdf5_save_nd_dataset<double>(
80     hid_t file_id, const string& dataset_name, const Blob<double>& blob,
81     bool write_diff) {
82   int num_axes = blob.num_axes();
83   hsize_t *dims = new hsize_t[num_axes];
84   for (int i = 0; i < num_axes; ++i) {
85     dims[i] = blob.shape(i);
86   }
87   const double* data;
88   if (write_diff) {
89     data = blob.cpu_diff();
90   } else {
91     data = blob.cpu_data();
92   }
93   herr_t status = H5LTmake_dataset_double(
94       file_id, dataset_name.c_str(), num_axes, dims, data);
95   CHECK_GE(status, 0) << "Failed to make double dataset " << dataset_name;
96   delete[] dims;
97 }
98
99 string hdf5_load_string(hid_t loc_id, const string& dataset_name) {
100   // Get size of dataset
101   size_t size;
102   H5T_class_t class_;
103   herr_t status = \
104     H5LTget_dataset_info(loc_id, dataset_name.c_str(), NULL, &class_, &size);
105   CHECK_GE(status, 0) << "Failed to get dataset info for " << dataset_name;
106   char *buf = new char[size];
107   status = H5LTread_dataset_string(loc_id, dataset_name.c_str(), buf);
108   CHECK_GE(status, 0)
109     << "Failed to load int dataset with name " << dataset_name;
110   string val(buf);
111   delete[] buf;
112   return val;
113 }
114
115 void hdf5_save_string(hid_t loc_id, const string& dataset_name,
116                       const string& s) {
117   herr_t status = \
118     H5LTmake_dataset_string(loc_id, dataset_name.c_str(), s.c_str());
119   CHECK_GE(status, 0)
120     << "Failed to save string dataset with name " << dataset_name;
121 }
122
123 int hdf5_load_int(hid_t loc_id, const string& dataset_name) {
124   int val;
125   herr_t status = H5LTread_dataset_int(loc_id, dataset_name.c_str(), &val);
126   CHECK_GE(status, 0)
127     << "Failed to load int dataset with name " << dataset_name;
128   return val;
129 }
130
131 void hdf5_save_int(hid_t loc_id, const string& dataset_name, int i) {
132   hsize_t one = 1;
133   herr_t status = \
134     H5LTmake_dataset_int(loc_id, dataset_name.c_str(), 1, &one, &i);
135   CHECK_GE(status, 0)
136     << "Failed to save int dataset with name " << dataset_name;
137 }
138
139 int hdf5_get_num_links(hid_t loc_id) {
140   H5G_info_t info;
141   herr_t status = H5Gget_info(loc_id, &info);
142   CHECK_GE(status, 0) << "Error while counting HDF5 links.";
143   return info.nlinks;
144 }
145
146 string hdf5_get_name_by_idx(hid_t loc_id, int idx) {
147   ssize_t str_size = H5Lget_name_by_idx(
148       loc_id, ".", H5_INDEX_NAME, H5_ITER_NATIVE, idx, NULL, 0, H5P_DEFAULT);
149   CHECK_GE(str_size, 0) << "Error retrieving HDF5 dataset at index " << idx;
150   char *c_str = new char[str_size+1];
151   ssize_t status = H5Lget_name_by_idx(
152       loc_id, ".", H5_INDEX_NAME, H5_ITER_NATIVE, idx, c_str, str_size+1,
153       H5P_DEFAULT);
154   CHECK_GE(status, 0) << "Error retrieving HDF5 dataset at index " << idx;
155   string result(c_str);
156   delete[] c_str;
157   return result;
158 }
159
160 }  // namespace caffe