Imported Upstream version 1.21.0
[platform/core/ml/nnfw.git] / compiler / circle-eval-diff / src / InputDataLoader.h
1 /*
2  * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #ifndef __CIRCLE_EVAL_DIFF_INPUT_DATA_LOADER_H__
18 #define __CIRCLE_EVAL_DIFF_INPUT_DATA_LOADER_H__
19
20 #include <dio_hdf5/HDF5Importer.h>
21 #include <loco/IR/Node.h>
22 #include <luci/IR/CircleNodes.h>
23
24 #include "Tensor.h"
25
26 #include <memory>
27 #include <string>
28
29 namespace circle_eval_diff
30 {
31
32 void verifyTypeShape(const luci::CircleInput *input_node, const loco::DataType &dtype,
33                      const std::vector<loco::Dimension> &shape);
34
35 } // namespace circle_eval_diff
36
37 namespace circle_eval_diff
38 {
39
40 enum class InputFormat
41 {
42   Undefined, // For debugging
43   H5,
44   DIR, // directory
45   // TODO Implement Random, Directory
46 };
47
48 class InputDataLoader
49 {
50 public:
51   using Data = std::vector<Tensor>;
52
53 public:
54   virtual ~InputDataLoader() = default;
55
56 public:
57   virtual uint32_t size(void) const = 0;
58
59 public:
60   virtual Data get(uint32_t data_idx) const = 0;
61 };
62
63 class HDF5Loader final : public InputDataLoader
64 {
65 public:
66   HDF5Loader(const std::string &file_path, const std::vector<loco::Node *> &input_nodes);
67
68 public:
69   uint32_t size(void) const final;
70   Data get(uint32_t data_idx) const final;
71
72 private:
73   const std::vector<loco::Node *> _input_nodes;
74   std::unique_ptr<dio::hdf5::HDF5Importer> _hdf5;
75 };
76
77 // This class loads the directory that has raw data binary files.
78 class DirectoryLoader final : public InputDataLoader
79 {
80 public:
81   DirectoryLoader(const std::string &dir_path, const std::vector<loco::Node *> &input_nodes);
82
83 public:
84   uint32_t size(void) const final;
85   Data get(uint32_t data_idx) const final;
86
87 private:
88   const std::vector<loco::Node *> _input_nodes;
89   std::vector<std::string> _data_paths;
90 };
91
92 std::unique_ptr<InputDataLoader> makeDataLoader(const std::string &file_path,
93                                                 const InputFormat &format,
94                                                 const std::vector<loco::Node *> &input_nodes);
95
96 } // namespace circle_eval_diff
97
98 #endif // __CIRCLE_EVAL_DIFF_INPUT_DATA_LOADER_H__