2 * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #ifndef __CIRCLE_EVAL_DIFF_INPUT_DATA_LOADER_H__
18 #define __CIRCLE_EVAL_DIFF_INPUT_DATA_LOADER_H__
20 #include <dio_hdf5/HDF5Importer.h>
21 #include <loco/IR/Node.h>
22 #include <luci/IR/CircleNodes.h>
29 namespace circle_eval_diff
32 void verifyTypeShape(const luci::CircleInput *input_node, const loco::DataType &dtype,
33 const std::vector<loco::Dimension> &shape);
35 } // namespace circle_eval_diff
37 namespace circle_eval_diff
40 enum class InputFormat
42 Undefined, // For debugging
45 // TODO Implement Random, Directory
51 using Data = std::vector<Tensor>;
54 virtual ~InputDataLoader() = default;
57 virtual uint32_t size(void) const = 0;
60 virtual Data get(uint32_t data_idx) const = 0;
63 class HDF5Loader final : public InputDataLoader
66 HDF5Loader(const std::string &file_path, const std::vector<loco::Node *> &input_nodes);
69 uint32_t size(void) const final;
70 Data get(uint32_t data_idx) const final;
73 const std::vector<loco::Node *> _input_nodes;
74 std::unique_ptr<dio::hdf5::HDF5Importer> _hdf5;
77 // This class loads the directory that has raw data binary files.
78 class DirectoryLoader final : public InputDataLoader
81 DirectoryLoader(const std::string &dir_path, const std::vector<loco::Node *> &input_nodes);
84 uint32_t size(void) const final;
85 Data get(uint32_t data_idx) const final;
88 const std::vector<loco::Node *> _input_nodes;
89 std::vector<std::string> _data_paths;
92 std::unique_ptr<InputDataLoader> makeDataLoader(const std::string &file_path,
93 const InputFormat &format,
94 const std::vector<loco::Node *> &input_nodes);
96 } // namespace circle_eval_diff
98 #endif // __CIRCLE_EVAL_DIFF_INPUT_DATA_LOADER_H__