2 * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "dio_hdf5/HDF5Importer.h"
26 using Shape = std::vector<loco::Dimension>;
27 using DataType = loco::DataType;
32 Shape toInternalShape(const H5::DataSpace &dataspace)
34 int rank = dataspace.getSimpleExtentNdims();
36 std::vector<hsize_t> dims;
38 dataspace.getSimpleExtentDims(dims.data());
41 for (int axis = 0; axis < rank; ++axis)
43 res.emplace_back(dims[axis]);
49 DataType toInternalDtype(const H5::DataType &h5_type)
51 if (h5_type == H5::PredType::IEEE_F32BE || h5_type == H5::PredType::IEEE_F32LE)
53 return DataType::FLOAT32;
55 if (h5_type == H5::PredType::STD_I32BE || h5_type == H5::PredType::STD_I32LE)
59 if (h5_type == H5::PredType::STD_I64BE || h5_type == H5::PredType::STD_I64LE)
63 if (h5_type.getClass() == H5T_class_t::H5T_ENUM)
65 // We follow the numpy format
66 // In numpy 1.19.0, np.bool_ is saved as H5T_ENUM
67 // - (name, value) -> (FALSE, 0) and (TRUE, 1)
68 // - value dtype is H5T_STD_I8LE
69 // TODO Find a general way to recognize BOOL type
71 int8_t value[2] = {0, 1};
72 if (H5Tenum_nameof(h5_type.getId(), value, name, 10) < 0)
73 return DataType::Unknown;
75 if (std::string(name) != "FALSE")
76 return DataType::Unknown;
78 if (H5Tenum_nameof(h5_type.getId(), value + 1, name, 10) < 0)
79 return DataType::Unknown;
81 if (std::string(name) != "TRUE")
82 return DataType::Unknown;
84 return DataType::BOOL;
86 // TODO Support more datatypes
87 return DataType::Unknown;
90 void readTensorData(H5::DataSet &tensor, uint8_t *buffer)
92 tensor.read(buffer, H5::PredType::NATIVE_UINT8);
95 void readTensorData(H5::DataSet &tensor, float *buffer)
97 tensor.read(buffer, H5::PredType::NATIVE_FLOAT);
100 void readTensorData(H5::DataSet &tensor, int32_t *buffer)
102 tensor.read(buffer, H5::PredType::NATIVE_INT);
105 void readTensorData(H5::DataSet &tensor, int64_t *buffer)
107 tensor.read(buffer, H5::PredType::NATIVE_LONG);
117 HDF5Importer::HDF5Importer(const std::string &path)
119 if (_file.isHdf5(path) == false)
120 throw std::runtime_error("Given data file is not HDF5");
122 _file = H5::H5File(path, H5F_ACC_RDONLY);
125 int32_t HDF5Importer::numInputs(int32_t record_idx)
127 auto records = _group.openGroup(std::to_string(record_idx));
128 return records.getNumObjs();
131 void HDF5Importer::readTensor(int32_t record_idx, int32_t input_idx, void *buffer)
133 auto record = _group.openGroup(std::to_string(record_idx));
134 auto tensor = record.openDataSet(std::to_string(input_idx));
136 readTensorData(tensor, static_cast<uint8_t *>(buffer));
139 void HDF5Importer::readTensor(int32_t record_idx, int32_t input_idx, DataType *dtype, Shape *shape,
142 auto record = _group.openGroup(std::to_string(record_idx));
143 auto tensor = record.openDataSet(std::to_string(input_idx));
145 auto tensor_dtype = tensor.getDataType();
146 *dtype = toInternalDtype(tensor_dtype);
148 auto tensor_shape = tensor.getSpace();
149 *shape = toInternalShape(tensor_shape);
153 case DataType::FLOAT32:
154 readTensorData(tensor, static_cast<float *>(buffer));
157 readTensorData(tensor, static_cast<int32_t *>(buffer));
160 readTensorData(tensor, static_cast<int64_t *>(buffer));
163 readTensorData(tensor, static_cast<uint8_t *>(buffer));
166 throw std::runtime_error{"Unsupported data type for input data (.h5)"};