9ae556b774585a28552ecb67162bfab088db3f71
[platform/core/ml/nnfw.git] / compiler / dio-hdf5 / src / HDF5Importer.cpp
1 /*
2  * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "dio_hdf5/HDF5Importer.h"
18
19 #include <H5Cpp.h>
20
21 #include <string>
22 #include <vector>
23 #include <cassert>
24 #include <stdexcept>
25
26 using Shape = std::vector<loco::Dimension>;
27 using DataType = loco::DataType;
28
29 namespace
30 {
31
32 Shape toInternalShape(const H5::DataSpace &dataspace)
33 {
34   int rank = dataspace.getSimpleExtentNdims();
35
36   std::vector<hsize_t> dims;
37   dims.resize(rank, 0);
38   dataspace.getSimpleExtentDims(dims.data());
39
40   Shape res;
41   for (int axis = 0; axis < rank; ++axis)
42   {
43     res.emplace_back(dims[axis]);
44   }
45
46   return res;
47 }
48
49 DataType toInternalDtype(const H5::DataType &h5_type)
50 {
51   if (h5_type == H5::PredType::IEEE_F32BE || h5_type == H5::PredType::IEEE_F32LE)
52   {
53     return DataType::FLOAT32;
54   }
55   if (h5_type == H5::PredType::STD_I32BE || h5_type == H5::PredType::STD_I32LE)
56   {
57     return DataType::S32;
58   }
59   if (h5_type == H5::PredType::STD_I64BE || h5_type == H5::PredType::STD_I64LE)
60   {
61     return DataType::S64;
62   }
63   if (h5_type.getClass() == H5T_class_t::H5T_ENUM)
64   {
65     // We follow the numpy format
66     // In numpy 1.19.0, np.bool_ is saved as H5T_ENUM
67     // - (name, value) -> (FALSE, 0) and (TRUE, 1)
68     // - value dtype is H5T_STD_I8LE
69     // TODO Find a general way to recognize BOOL type
70     char name[10];
71     int8_t value[2] = {0, 1};
72     if (H5Tenum_nameof(h5_type.getId(), value, name, 10) < 0)
73       return DataType::Unknown;
74
75     if (std::string(name) != "FALSE")
76       return DataType::Unknown;
77
78     if (H5Tenum_nameof(h5_type.getId(), value + 1, name, 10) < 0)
79       return DataType::Unknown;
80
81     if (std::string(name) != "TRUE")
82       return DataType::Unknown;
83
84     return DataType::BOOL;
85   }
86   // TODO Support more datatypes
87   return DataType::Unknown;
88 }
89
90 void readTensorData(H5::DataSet &tensor, uint8_t *buffer)
91 {
92   tensor.read(buffer, H5::PredType::NATIVE_UINT8);
93 }
94
95 void readTensorData(H5::DataSet &tensor, float *buffer)
96 {
97   tensor.read(buffer, H5::PredType::NATIVE_FLOAT);
98 }
99
100 void readTensorData(H5::DataSet &tensor, int32_t *buffer)
101 {
102   tensor.read(buffer, H5::PredType::NATIVE_INT);
103 }
104
105 void readTensorData(H5::DataSet &tensor, int64_t *buffer)
106 {
107   tensor.read(buffer, H5::PredType::NATIVE_LONG);
108 }
109
110 } // namespace
111
112 namespace dio
113 {
114 namespace hdf5
115 {
116
117 HDF5Importer::HDF5Importer(const std::string &path)
118 {
119   if (_file.isHdf5(path) == false)
120     throw std::runtime_error("Given data file is not HDF5");
121
122   _file = H5::H5File(path, H5F_ACC_RDONLY);
123 }
124
125 int32_t HDF5Importer::numInputs(int32_t record_idx)
126 {
127   auto records = _group.openGroup(std::to_string(record_idx));
128   return records.getNumObjs();
129 }
130
131 void HDF5Importer::readTensor(int32_t record_idx, int32_t input_idx, void *buffer)
132 {
133   auto record = _group.openGroup(std::to_string(record_idx));
134   auto tensor = record.openDataSet(std::to_string(input_idx));
135
136   readTensorData(tensor, static_cast<uint8_t *>(buffer));
137 }
138
139 void HDF5Importer::readTensor(int32_t record_idx, int32_t input_idx, DataType *dtype, Shape *shape,
140                               void *buffer)
141 {
142   auto record = _group.openGroup(std::to_string(record_idx));
143   auto tensor = record.openDataSet(std::to_string(input_idx));
144
145   auto tensor_dtype = tensor.getDataType();
146   *dtype = toInternalDtype(tensor_dtype);
147
148   auto tensor_shape = tensor.getSpace();
149   *shape = toInternalShape(tensor_shape);
150
151   switch (*dtype)
152   {
153     case DataType::FLOAT32:
154       readTensorData(tensor, static_cast<float *>(buffer));
155       break;
156     case DataType::S32:
157       readTensorData(tensor, static_cast<int32_t *>(buffer));
158       break;
159     case DataType::S64:
160       readTensorData(tensor, static_cast<int64_t *>(buffer));
161       break;
162     case DataType::BOOL:
163       readTensorData(tensor, static_cast<uint8_t *>(buffer));
164       break;
165     default:
166       throw std::runtime_error{"Unsupported data type for input data (.h5)"};
167   }
168 }
169
170 } // namespace hdf5
171 } // namespace dio