7da245975bd0407aa07054d862cfb6d5da50d2c4
[platform/core/ml/nnfw.git] / tests / nnfw_api / src / CircleGen.h
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #ifndef __NNFW_API_TEST_CIRCLE_GEN_H__
18 #define __NNFW_API_TEST_CIRCLE_GEN_H__
19
20 #include <circle_schema_generated.h>
21
22 #include <vector>
23
24 /**
25  * @brief Class for storing flatbuffer buffer
26  *
27  * This is a simple wrapper for a finished FlatBufferBuilder. It owns the buffer and a user can
28  * get the buffer pointer and size.
29  */
30 class CircleBuffer
31 {
32 public:
33   CircleBuffer() = default;
34   explicit CircleBuffer(flatbuffers::FlatBufferBuilder &&fbb) : _fbb{std::move(fbb)}
35   {
36     _fbb.Finished(); // The build must have been finished, so check that here
37   }
38
39   uint8_t *buffer() const { return _fbb.GetBufferPointer(); }
40   size_t size() const { return _fbb.GetSize(); }
41
42 private:
43   flatbuffers::FlatBufferBuilder _fbb;
44 };
45
46 /**
47  * @brief Circle flatbuffer file generator
48  *
49  * This is a helper class for generating circle file.
50  *
51  */
52 class CircleGen
53 {
54 public:
55   using Shape = std::vector<int32_t>;
56
57   using SparseIndexVectorType = circle::SparseIndexVector;
58   using SparseDimensionType = circle::DimensionType;
59
60   struct SparseIndexVector
61   {
62     std::vector<uint16_t> u16;
63   };
64
65   struct DimMetaData
66   {
67     DimMetaData() = delete;
68     DimMetaData(SparseDimensionType format, std::vector<uint16_t> array_segments,
69                 std::vector<uint16_t> array_indices)
70         : _format{format},
71           _array_segments_type(SparseIndexVectorType::SparseIndexVector_Uint16Vector),
72           _array_indices_type(SparseIndexVectorType::SparseIndexVector_Uint16Vector)
73     {
74       _array_segments.u16 = array_segments;
75       _array_indices.u16 = array_indices;
76     }
77     DimMetaData(SparseDimensionType format, int32_t dense_size)
78         : _format{format}, _dense_size{dense_size}
79     {
80     }
81     SparseDimensionType _format{circle::DimensionType_DENSE};
82     int32_t _dense_size{0};
83     SparseIndexVectorType _array_segments_type{circle::SparseIndexVector_NONE};
84     SparseIndexVector _array_segments;
85     SparseIndexVectorType _array_indices_type{circle::SparseIndexVector_NONE};
86     SparseIndexVector _array_indices;
87   };
88
89   struct SparsityParams
90   {
91     std::vector<int32_t> traversal_order;
92     std::vector<int32_t> block_map;
93     std::vector<DimMetaData> dim_metadata;
94   };
95
96   struct TensorParams
97   {
98     std::vector<int32_t> shape;
99     circle::TensorType tensor_type = circle::TensorType::TensorType_FLOAT32;
100     uint32_t buffer = 0;
101     std::string name;
102   };
103
104   struct OperatorParams
105   {
106     std::vector<int32_t> inputs;
107     std::vector<int32_t> outputs;
108     int version = 1;
109   };
110
111   struct SubgraphContext
112   {
113     std::vector<int> inputs;
114     std::vector<int> outputs;
115     std::vector<flatbuffers::Offset<circle::Tensor>> tensors;
116     std::vector<flatbuffers::Offset<circle::Operator>> operators;
117   };
118
119 public:
120   CircleGen();
121
122   template <typename T> uint32_t addBuffer(const std::vector<T> &buf_vec)
123   {
124     auto buf = reinterpret_cast<const uint8_t *>(buf_vec.data());
125     auto size = buf_vec.size() * sizeof(T);
126     return addBuffer(buf, size);
127   }
128   uint32_t addBuffer(const uint8_t *buf, size_t size);
129   uint32_t addTensor(const TensorParams &params);
130   uint32_t addTensor(const TensorParams &params, float scale, int64_t zero_point);
131   uint32_t addTensor(const TensorParams &params, const SparsityParams &sp);
132   void setInputsAndOutputs(const std::vector<int> &inputs, const std::vector<int> &outputs);
133   uint32_t nextSubgraph();
134   CircleBuffer finish();
135
136   // ===== Add Operator methods begin (SORTED IN ALPHABETICAL ORDER) =====
137
138   uint32_t addOperatorAdd(const OperatorParams &params, circle::ActivationFunctionType actfn);
139   uint32_t addOperatorAddN(const OperatorParams &params);
140   uint32_t addOperatorArgMax(const OperatorParams &params,
141                              circle::TensorType output_type = circle::TensorType::TensorType_INT32);
142   uint32_t addOperatorAveragePool2D(const OperatorParams &params, circle::Padding padding,
143                                     int stride_w, int stride_h, int filter_w, int filter_h,
144                                     circle::ActivationFunctionType actfn);
145   uint32_t addOperatorCast(const OperatorParams &params, circle::TensorType input_type,
146                            circle::TensorType output_type);
147   uint32_t addOperatorConcatenation(const OperatorParams &params, int axis,
148                                     circle::ActivationFunctionType actfn);
149   uint32_t addOperatorCos(const OperatorParams &params);
150   uint32_t addOperatorDepthwiseConv2D(const OperatorParams &params, circle::Padding padding,
151                                       int stride_w, int stride_h, int depth_multiplier,
152                                       circle::ActivationFunctionType actfn, int dilation_w = 1,
153                                       int dilation_h = 1);
154   uint32_t addOperatorEqual(const OperatorParams &params);
155   uint32_t addOperatorFill(const OperatorParams &params);
156   uint32_t addOperatorFloor(const OperatorParams &params);
157   uint32_t addOperatorFullyConnected(const OperatorParams &params,
158                                      circle::FullyConnectedOptionsWeightsFormat weights_format =
159                                          circle::FullyConnectedOptionsWeightsFormat_DEFAULT);
160   uint32_t addOperatorIf(const OperatorParams &params, uint32_t then_subg, uint32_t else_subg);
161   uint32_t addOperatorInstanceNorm(const OperatorParams &params, float epsilon,
162                                    circle::ActivationFunctionType actfn);
163   uint32_t addOperatorL2Normalization(const OperatorParams &params);
164   uint32_t addOperatorLeakyRelu(const OperatorParams &params, float alpha);
165   uint32_t addOperatorLess(const OperatorParams &params);
166   uint32_t addOperatorLogSoftmax(const OperatorParams &params);
167   uint32_t addOperatorNeg(const OperatorParams &params);
168   uint32_t addOperatorOneHot(const OperatorParams &params, int32_t axis);
169   uint32_t addOperatorPad(const OperatorParams &params);
170   uint32_t addOperatorPadV2(const OperatorParams &params);
171   uint32_t addOperatorRank(const OperatorParams &params);
172   uint32_t addOperatorReduce(const OperatorParams &params, circle::BuiltinOperator reduce_op,
173                              bool keep_dims);
174   /**
175    * @brief Create circle Reshape op
176    *        the second param new_shape can be optional just like circle::CreateReshapeOptionsDirect
177    */
178   uint32_t addOperatorReshape(const OperatorParams &params, const Shape *new_shape = nullptr);
179   uint32_t addOperatorResizeBilinear(const OperatorParams &params, bool align_corners = false,
180                                      bool half_pixel_centers = false);
181   uint32_t addOperatorResizeNearestNeighbor(const OperatorParams &params);
182   uint32_t addOperatorReverseV2(const OperatorParams &params);
183   uint32_t addOperatorShape(const OperatorParams &params,
184                             circle::TensorType type = circle::TensorType::TensorType_INT32);
185   uint32_t addOperatorSelect(const OperatorParams &params);
186   uint32_t addOperatorSelectV2(const OperatorParams &params);
187   uint32_t addOperatorSplit(const OperatorParams &params, int32_t num_split);
188   uint32_t addOperatorStridedSlice(const OperatorParams &params, int32_t begin_mask = 0,
189                                    int32_t end_mask = 0, int32_t ellipsis_mask = 0,
190                                    int32_t new_axis_mask = 0, int32_t shrink_axis_mask = 0);
191   uint32_t addOperatorTile(const OperatorParams &params);
192   uint32_t addOperatorTranspose(const OperatorParams &params);
193   uint32_t addOperatorWhile(const OperatorParams &params, uint32_t cond_subg, uint32_t body_subg);
194
195   // NOTE Please add addOperator functions ABOVE this line in ALPHABETICAL ORDER
196   // ===== Add Operator methods end =====
197
198 private:
199   uint32_t addOperatorWithOptions(const OperatorParams &params, circle::BuiltinOperator opcode,
200                                   circle::BuiltinOptions options_type,
201                                   flatbuffers::Offset<void> options);
202   uint32_t addOperatorCode(circle::BuiltinOperator opcode);
203   flatbuffers::Offset<circle::Buffer> buildBuffer(const uint8_t *buf, size_t size);
204   flatbuffers::Offset<circle::Tensor> buildTensor(const TensorParams &params);
205   flatbuffers::Offset<circle::Tensor> buildTensor(const TensorParams &params, float scale,
206                                                   int64_t zero_point);
207   flatbuffers::Offset<circle::SparsityParameters> buildSparsityParameters(const SparsityParams &sp);
208   flatbuffers::Offset<circle::Tensor> buildTensor(const TensorParams &params,
209                                                   const SparsityParams &sp);
210   flatbuffers::Offset<circle::SubGraph> buildSubGraph(const SubgraphContext &ctx);
211
212   SubgraphContext &curSubgCtx() { return _subgraph_contexts.back(); }
213
214 private:
215   flatbuffers::FlatBufferBuilder _fbb{1024};
216   std::vector<flatbuffers::Offset<circle::Buffer>> _buffers;
217   std::vector<flatbuffers::Offset<circle::OperatorCode>> _opcodes;
218   std::vector<SubgraphContext> _subgraph_contexts;
219 };
220
221 #endif // __NNFW_API_TEST_CIRCLE_GEN_H__