Scope macros inside switch
[platform/upstream/caffeonacl.git] / src / caffe / util / hdf5.cpp
1 #include "caffe/util/hdf5.hpp"
2
3 #include <string>
4 #include <vector>
5
6 namespace caffe {
7
8 // Verifies format of data stored in HDF5 file and reshapes blob accordingly.
9 template <typename Dtype>
10 void hdf5_load_nd_dataset_helper(
11     hid_t file_id, const char* dataset_name_, int min_dim, int max_dim,
12     Blob<Dtype>* blob) {
13   // Verify that the dataset exists.
14   CHECK(H5LTfind_dataset(file_id, dataset_name_))
15       << "Failed to find HDF5 dataset " << dataset_name_;
16   // Verify that the number of dimensions is in the accepted range.
17   herr_t status;
18   int ndims;
19   status = H5LTget_dataset_ndims(file_id, dataset_name_, &ndims);
20   CHECK_GE(status, 0) << "Failed to get dataset ndims for " << dataset_name_;
21   CHECK_GE(ndims, min_dim);
22   CHECK_LE(ndims, max_dim);
23
24   // Verify that the data format is what we expect: float or double.
25   std::vector<hsize_t> dims(ndims);
26   H5T_class_t class_;
27   status = H5LTget_dataset_info(
28       file_id, dataset_name_, dims.data(), &class_, NULL);
29   CHECK_GE(status, 0) << "Failed to get dataset info for " << dataset_name_;
30   switch (class_) {
31   case H5T_FLOAT:
32     { LOG_FIRST_N(INFO, 1) << "Datatype class: H5T_FLOAT"; }
33     break;
34   case H5T_INTEGER:
35     { LOG_FIRST_N(INFO, 1) << "Datatype class: H5T_INTEGER"; }
36     break;
37   case H5T_TIME:
38     LOG(FATAL) << "Unsupported datatype class: H5T_TIME";
39   case H5T_STRING:
40     LOG(FATAL) << "Unsupported datatype class: H5T_STRING";
41   case H5T_BITFIELD:
42     LOG(FATAL) << "Unsupported datatype class: H5T_BITFIELD";
43   case H5T_OPAQUE:
44     LOG(FATAL) << "Unsupported datatype class: H5T_OPAQUE";
45   case H5T_COMPOUND:
46     LOG(FATAL) << "Unsupported datatype class: H5T_COMPOUND";
47   case H5T_REFERENCE:
48     LOG(FATAL) << "Unsupported datatype class: H5T_REFERENCE";
49   case H5T_ENUM:
50     LOG(FATAL) << "Unsupported datatype class: H5T_ENUM";
51   case H5T_VLEN:
52     LOG(FATAL) << "Unsupported datatype class: H5T_VLEN";
53   case H5T_ARRAY:
54     LOG(FATAL) << "Unsupported datatype class: H5T_ARRAY";
55   default:
56     LOG(FATAL) << "Datatype class unknown";
57   }
58
59   vector<int> blob_dims(dims.size());
60   for (int i = 0; i < dims.size(); ++i) {
61     blob_dims[i] = dims[i];
62   }
63   blob->Reshape(blob_dims);
64 }
65
66 template <>
67 void hdf5_load_nd_dataset<float>(hid_t file_id, const char* dataset_name_,
68         int min_dim, int max_dim, Blob<float>* blob) {
69   hdf5_load_nd_dataset_helper(file_id, dataset_name_, min_dim, max_dim, blob);
70   herr_t status = H5LTread_dataset_float(
71     file_id, dataset_name_, blob->mutable_cpu_data());
72   CHECK_GE(status, 0) << "Failed to read float dataset " << dataset_name_;
73 }
74
75 template <>
76 void hdf5_load_nd_dataset<double>(hid_t file_id, const char* dataset_name_,
77         int min_dim, int max_dim, Blob<double>* blob) {
78   hdf5_load_nd_dataset_helper(file_id, dataset_name_, min_dim, max_dim, blob);
79   herr_t status = H5LTread_dataset_double(
80     file_id, dataset_name_, blob->mutable_cpu_data());
81   CHECK_GE(status, 0) << "Failed to read double dataset " << dataset_name_;
82 }
83
84 template <>
85 void hdf5_save_nd_dataset<float>(
86     const hid_t file_id, const string& dataset_name, const Blob<float>& blob,
87     bool write_diff) {
88   int num_axes = blob.num_axes();
89   hsize_t *dims = new hsize_t[num_axes];
90   for (int i = 0; i < num_axes; ++i) {
91     dims[i] = blob.shape(i);
92   }
93   const float* data;
94   if (write_diff) {
95     data = blob.cpu_diff();
96   } else {
97     data = blob.cpu_data();
98   }
99   herr_t status = H5LTmake_dataset_float(
100       file_id, dataset_name.c_str(), num_axes, dims, data);
101   CHECK_GE(status, 0) << "Failed to make float dataset " << dataset_name;
102   delete[] dims;
103 }
104
105 template <>
106 void hdf5_save_nd_dataset<double>(
107     hid_t file_id, const string& dataset_name, const Blob<double>& blob,
108     bool write_diff) {
109   int num_axes = blob.num_axes();
110   hsize_t *dims = new hsize_t[num_axes];
111   for (int i = 0; i < num_axes; ++i) {
112     dims[i] = blob.shape(i);
113   }
114   const double* data;
115   if (write_diff) {
116     data = blob.cpu_diff();
117   } else {
118     data = blob.cpu_data();
119   }
120   herr_t status = H5LTmake_dataset_double(
121       file_id, dataset_name.c_str(), num_axes, dims, data);
122   CHECK_GE(status, 0) << "Failed to make double dataset " << dataset_name;
123   delete[] dims;
124 }
125
126 string hdf5_load_string(hid_t loc_id, const string& dataset_name) {
127   // Get size of dataset
128   size_t size;
129   H5T_class_t class_;
130   herr_t status = \
131     H5LTget_dataset_info(loc_id, dataset_name.c_str(), NULL, &class_, &size);
132   CHECK_GE(status, 0) << "Failed to get dataset info for " << dataset_name;
133   char *buf = new char[size];
134   status = H5LTread_dataset_string(loc_id, dataset_name.c_str(), buf);
135   CHECK_GE(status, 0)
136     << "Failed to load int dataset with name " << dataset_name;
137   string val(buf);
138   delete[] buf;
139   return val;
140 }
141
142 void hdf5_save_string(hid_t loc_id, const string& dataset_name,
143                       const string& s) {
144   herr_t status = \
145     H5LTmake_dataset_string(loc_id, dataset_name.c_str(), s.c_str());
146   CHECK_GE(status, 0)
147     << "Failed to save string dataset with name " << dataset_name;
148 }
149
150 int hdf5_load_int(hid_t loc_id, const string& dataset_name) {
151   int val;
152   herr_t status = H5LTread_dataset_int(loc_id, dataset_name.c_str(), &val);
153   CHECK_GE(status, 0)
154     << "Failed to load int dataset with name " << dataset_name;
155   return val;
156 }
157
158 void hdf5_save_int(hid_t loc_id, const string& dataset_name, int i) {
159   hsize_t one = 1;
160   herr_t status = \
161     H5LTmake_dataset_int(loc_id, dataset_name.c_str(), 1, &one, &i);
162   CHECK_GE(status, 0)
163     << "Failed to save int dataset with name " << dataset_name;
164 }
165
166 int hdf5_get_num_links(hid_t loc_id) {
167   H5G_info_t info;
168   herr_t status = H5Gget_info(loc_id, &info);
169   CHECK_GE(status, 0) << "Error while counting HDF5 links.";
170   return info.nlinks;
171 }
172
173 string hdf5_get_name_by_idx(hid_t loc_id, int idx) {
174   ssize_t str_size = H5Lget_name_by_idx(
175       loc_id, ".", H5_INDEX_NAME, H5_ITER_NATIVE, idx, NULL, 0, H5P_DEFAULT);
176   CHECK_GE(str_size, 0) << "Error retrieving HDF5 dataset at index " << idx;
177   char *c_str = new char[str_size+1];
178   ssize_t status = H5Lget_name_by_idx(
179       loc_id, ".", H5_INDEX_NAME, H5_ITER_NATIVE, idx, c_str, str_size+1,
180       H5P_DEFAULT);
181   CHECK_GE(status, 0) << "Error retrieving HDF5 dataset at index " << idx;
182   string result(c_str);
183   delete[] c_str;
184   return result;
185 }
186
187 }  // namespace caffe