Fix: made load_hd5 check blob dims by default.
[platform/upstream/caffeonacl.git] / src / caffe / util / hdf5.cpp
1 #include "caffe/util/hdf5.hpp"
2
3 #include <string>
4 #include <vector>
5
6 namespace caffe {
7
8 // Verifies format of data stored in HDF5 file and reshapes blob accordingly.
9 template <typename Dtype>
10 void hdf5_load_nd_dataset_helper(
11     hid_t file_id, const char* dataset_name_, int min_dim, int max_dim,
12     Blob<Dtype>* blob, bool reshape) {
13   // Verify that the dataset exists.
14   CHECK(H5LTfind_dataset(file_id, dataset_name_))
15       << "Failed to find HDF5 dataset " << dataset_name_;
16   // Verify that the number of dimensions is in the accepted range.
17   herr_t status;
18   int ndims;
19   status = H5LTget_dataset_ndims(file_id, dataset_name_, &ndims);
20   CHECK_GE(status, 0) << "Failed to get dataset ndims for " << dataset_name_;
21   CHECK_GE(ndims, min_dim);
22   CHECK_LE(ndims, max_dim);
23
24   // Verify that the data format is what we expect: float or double.
25   std::vector<hsize_t> dims(ndims);
26   H5T_class_t class_;
27   status = H5LTget_dataset_info(
28       file_id, dataset_name_, dims.data(), &class_, NULL);
29   CHECK_GE(status, 0) << "Failed to get dataset info for " << dataset_name_;
30   switch (class_) {
31   case H5T_FLOAT:
32     LOG_FIRST_N(INFO, 1) << "Datatype class: H5T_FLOAT";
33     break;
34   case H5T_INTEGER:
35     LOG_FIRST_N(INFO, 1) << "Datatype class: H5T_INTEGER";
36     break;
37   case H5T_TIME:
38     LOG(FATAL) << "Unsupported datatype class: H5T_TIME";
39   case H5T_STRING:
40     LOG(FATAL) << "Unsupported datatype class: H5T_STRING";
41   case H5T_BITFIELD:
42     LOG(FATAL) << "Unsupported datatype class: H5T_BITFIELD";
43   case H5T_OPAQUE:
44     LOG(FATAL) << "Unsupported datatype class: H5T_OPAQUE";
45   case H5T_COMPOUND:
46     LOG(FATAL) << "Unsupported datatype class: H5T_COMPOUND";
47   case H5T_REFERENCE:
48     LOG(FATAL) << "Unsupported datatype class: H5T_REFERENCE";
49   case H5T_ENUM:
50     LOG(FATAL) << "Unsupported datatype class: H5T_ENUM";
51   case H5T_VLEN:
52     LOG(FATAL) << "Unsupported datatype class: H5T_VLEN";
53   case H5T_ARRAY:
54     LOG(FATAL) << "Unsupported datatype class: H5T_ARRAY";
55   default:
56     LOG(FATAL) << "Datatype class unknown";
57   }
58
59
60   vector<int> blob_dims(dims.size());
61   for (int i = 0; i < dims.size(); ++i) {
62     blob_dims[i] = dims[i];
63   }
64
65   if (reshape) {
66     blob->Reshape(blob_dims);
67   } else {
68     if (blob_dims != blob->shape()) {
69       // create shape string for error message
70       ostringstream stream;
71       int count = 1;
72       for (int i = 0; i < blob_dims.size(); ++i) {
73         stream << blob_dims[i] << " ";
74         count = count * blob_dims[i];
75       }
76       stream << "(" << count << ")";
77       string source_shape_string = stream.str();
78
79       CHECK(blob_dims == blob->shape()) << "Cannot load blob from hdf5; shape "
80             << "mismatch. Source shape is " << source_shape_string
81             << " target shape is " << blob->shape_string();
82     }
83   }
84 }
85
86 template <>
87 void hdf5_load_nd_dataset<float>(hid_t file_id, const char* dataset_name_,
88         int min_dim, int max_dim, Blob<float>* blob, bool reshape) {
89   hdf5_load_nd_dataset_helper(file_id, dataset_name_, min_dim, max_dim, blob,
90                               reshape);
91   herr_t status = H5LTread_dataset_float(
92     file_id, dataset_name_, blob->mutable_cpu_data());
93   CHECK_GE(status, 0) << "Failed to read float dataset " << dataset_name_;
94 }
95
96 template <>
97 void hdf5_load_nd_dataset<double>(hid_t file_id, const char* dataset_name_,
98         int min_dim, int max_dim, Blob<double>* blob, bool reshape) {
99   hdf5_load_nd_dataset_helper(file_id, dataset_name_, min_dim, max_dim, blob,
100                               reshape);
101   herr_t status = H5LTread_dataset_double(
102     file_id, dataset_name_, blob->mutable_cpu_data());
103   CHECK_GE(status, 0) << "Failed to read double dataset " << dataset_name_;
104 }
105
106 template <>
107 void hdf5_save_nd_dataset<float>(
108     const hid_t file_id, const string& dataset_name, const Blob<float>& blob,
109     bool write_diff) {
110   int num_axes = blob.num_axes();
111   hsize_t *dims = new hsize_t[num_axes];
112   for (int i = 0; i < num_axes; ++i) {
113     dims[i] = blob.shape(i);
114   }
115   const float* data;
116   if (write_diff) {
117     data = blob.cpu_diff();
118   } else {
119     data = blob.cpu_data();
120   }
121   herr_t status = H5LTmake_dataset_float(
122       file_id, dataset_name.c_str(), num_axes, dims, data);
123   CHECK_GE(status, 0) << "Failed to make float dataset " << dataset_name;
124   delete[] dims;
125 }
126
127 template <>
128 void hdf5_save_nd_dataset<double>(
129     hid_t file_id, const string& dataset_name, const Blob<double>& blob,
130     bool write_diff) {
131   int num_axes = blob.num_axes();
132   hsize_t *dims = new hsize_t[num_axes];
133   for (int i = 0; i < num_axes; ++i) {
134     dims[i] = blob.shape(i);
135   }
136   const double* data;
137   if (write_diff) {
138     data = blob.cpu_diff();
139   } else {
140     data = blob.cpu_data();
141   }
142   herr_t status = H5LTmake_dataset_double(
143       file_id, dataset_name.c_str(), num_axes, dims, data);
144   CHECK_GE(status, 0) << "Failed to make double dataset " << dataset_name;
145   delete[] dims;
146 }
147
148 string hdf5_load_string(hid_t loc_id, const string& dataset_name) {
149   // Get size of dataset
150   size_t size;
151   H5T_class_t class_;
152   herr_t status = \
153     H5LTget_dataset_info(loc_id, dataset_name.c_str(), NULL, &class_, &size);
154   CHECK_GE(status, 0) << "Failed to get dataset info for " << dataset_name;
155   char *buf = new char[size];
156   status = H5LTread_dataset_string(loc_id, dataset_name.c_str(), buf);
157   CHECK_GE(status, 0)
158     << "Failed to load int dataset with name " << dataset_name;
159   string val(buf);
160   delete[] buf;
161   return val;
162 }
163
164 void hdf5_save_string(hid_t loc_id, const string& dataset_name,
165                       const string& s) {
166   herr_t status = \
167     H5LTmake_dataset_string(loc_id, dataset_name.c_str(), s.c_str());
168   CHECK_GE(status, 0)
169     << "Failed to save string dataset with name " << dataset_name;
170 }
171
172 int hdf5_load_int(hid_t loc_id, const string& dataset_name) {
173   int val;
174   herr_t status = H5LTread_dataset_int(loc_id, dataset_name.c_str(), &val);
175   CHECK_GE(status, 0)
176     << "Failed to load int dataset with name " << dataset_name;
177   return val;
178 }
179
180 void hdf5_save_int(hid_t loc_id, const string& dataset_name, int i) {
181   hsize_t one = 1;
182   herr_t status = \
183     H5LTmake_dataset_int(loc_id, dataset_name.c_str(), 1, &one, &i);
184   CHECK_GE(status, 0)
185     << "Failed to save int dataset with name " << dataset_name;
186 }
187
188 int hdf5_get_num_links(hid_t loc_id) {
189   H5G_info_t info;
190   herr_t status = H5Gget_info(loc_id, &info);
191   CHECK_GE(status, 0) << "Error while counting HDF5 links.";
192   return info.nlinks;
193 }
194
195 string hdf5_get_name_by_idx(hid_t loc_id, int idx) {
196   ssize_t str_size = H5Lget_name_by_idx(
197       loc_id, ".", H5_INDEX_NAME, H5_ITER_NATIVE, idx, NULL, 0, H5P_DEFAULT);
198   CHECK_GE(str_size, 0) << "Error retrieving HDF5 dataset at index " << idx;
199   char *c_str = new char[str_size+1];
200   ssize_t status = H5Lget_name_by_idx(
201       loc_id, ".", H5_INDEX_NAME, H5_ITER_NATIVE, idx, c_str, str_size+1,
202       H5P_DEFAULT);
203   CHECK_GE(status, 0) << "Error retrieving HDF5 dataset at index " << idx;
204   string result(c_str);
205   delete[] c_str;
206   return result;
207 }
208
209 }  // namespace caffe