Imported Upstream version 1.7.0
[platform/core/ml/nnfw.git] / compiler / record-minmax / src / HDF5Importer.h
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #ifndef __RECORD_MINMAX_HDF5IMPORTER_H__
18 #define __RECORD_MINMAX_HDF5IMPORTER_H__
19
20 #include <luci_interpreter/core/Tensor.h>
21
22 #include <H5Cpp.h>
23
24 using Shape = luci_interpreter::Shape;
25 using DataType = luci_interpreter::DataType;
26
27 namespace record_minmax
28 {
29
30 // HDF5Importer reads an input data saved in the hdf5 file in the given path
31 // The hierarchy of the hdf5 file is as follows.
32 // Group "/"
33 //  > Group "value"
34 //    > Group <record_idx>
35 //      > Dataset <input_idx>
36 // record_idx : index of the record (dataset file can contain multiple records)
37 // input_idx : index of the input (DNN model can have multiple inputs)
38 // Ex: the j'th input of the i'th record can be accessed by "/value/i/j"
39 class HDF5Importer
40 {
41 public:
42   explicit HDF5Importer(const std::string &path) : _file{path, H5F_ACC_RDONLY}
43   {
44     // Do nothing
45   }
46
47 public:
48   /**
49    * @brief importGroup has to be called before readTensor is called
50    *        Otherwise, readTensor will throw an exception
51    */
52   void importGroup() { _value_grp = _file.openGroup("value"); }
53
54   /**
55    * @brief Read tensor data from file and store it into buffer
56    * @details A tensor in the file can be retrieved with (record_idx, input_idx)
57    * @param record_idx : index of the record
58    * @param input_idx : index of the input
59    * @param dtype : pointer to write the tensor's data type
60    * @param shape : pointer to write the tensor's shape
61    * @param buffer : pointer to write the tensor's data
62    */
63   void readTensor(int32_t record_idx, int32_t input_idx, DataType *dtype, Shape *shape,
64                   void *buffer);
65
66   // Read a raw tensor (no type/shape is specified)
67   void readTensor(int32_t record_idx, int32_t input_idx, void *buffer);
68
69   bool isRawData() { return _value_grp.attrExists("rawData"); }
70
71   int32_t numRecords() { return _value_grp.getNumObjs(); }
72
73   int32_t numInputs(int32_t record_idx);
74
75 private:
76   H5::H5File _file;
77   H5::Group _value_grp;
78 };
79
80 } // namespace record_minmax
81
82 #endif // __RECORD_MINMAX_HDF5IMPORTER_H__