2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #ifndef __NNFW_API_TEST_CIRCLE_GEN_H__
18 #define __NNFW_API_TEST_CIRCLE_GEN_H__
20 #include <circle_schema_generated.h>
25 * @brief Class for storing flatbuffer buffer
27 * This is a simple wrapper for a finished FlatBufferBuilder. It owns the buffer and a user can
28 * get the buffer pointer and size.
33 CircleBuffer() = default;
34 explicit CircleBuffer(flatbuffers::FlatBufferBuilder &&fbb) : _fbb{std::move(fbb)}
36 _fbb.Finished(); // The build must have been finished, so check that here
39 uint8_t *buffer() const { return _fbb.GetBufferPointer(); }
40 size_t size() const { return _fbb.GetSize(); }
43 flatbuffers::FlatBufferBuilder _fbb;
47 * @brief Circle flatbuffer file generator
49 * This is a helper class for generating circle file.
55 using Shape = std::vector<int32_t>;
57 using SparseIndexVectorType = circle::SparseIndexVector;
58 using SparseDimensionType = circle::DimensionType;
60 struct SparseIndexVector
62 std::vector<uint16_t> u16;
67 DimMetaData() = delete;
68 DimMetaData(SparseDimensionType format, std::vector<uint16_t> array_segments,
69 std::vector<uint16_t> array_indices)
71 _array_segments_type(SparseIndexVectorType::SparseIndexVector_Uint16Vector),
72 _array_indices_type(SparseIndexVectorType::SparseIndexVector_Uint16Vector)
74 _array_segments.u16 = array_segments;
75 _array_indices.u16 = array_indices;
77 DimMetaData(SparseDimensionType format, int32_t dense_size)
78 : _format{format}, _dense_size{dense_size}
81 SparseDimensionType _format{circle::DimensionType_DENSE};
82 int32_t _dense_size{0};
83 SparseIndexVectorType _array_segments_type{circle::SparseIndexVector_NONE};
84 SparseIndexVector _array_segments;
85 SparseIndexVectorType _array_indices_type{circle::SparseIndexVector_NONE};
86 SparseIndexVector _array_indices;
91 std::vector<int32_t> traversal_order;
92 std::vector<int32_t> block_map;
93 std::vector<DimMetaData> dim_metadata;
98 std::vector<int32_t> shape;
99 circle::TensorType tensor_type = circle::TensorType::TensorType_FLOAT32;
104 struct OperatorParams
106 std::vector<int32_t> inputs;
107 std::vector<int32_t> outputs;
111 struct SubgraphContext
113 std::vector<int> inputs;
114 std::vector<int> outputs;
115 std::vector<flatbuffers::Offset<circle::Tensor>> tensors;
116 std::vector<flatbuffers::Offset<circle::Operator>> operators;
122 template <typename T> uint32_t addBuffer(const std::vector<T> &buf_vec)
124 auto buf = reinterpret_cast<const uint8_t *>(buf_vec.data());
125 auto size = buf_vec.size() * sizeof(T);
126 return addBuffer(buf, size);
128 uint32_t addBuffer(const uint8_t *buf, size_t size);
129 uint32_t addTensor(const TensorParams ¶ms);
130 uint32_t addTensor(const TensorParams ¶ms, float scale, int64_t zero_point);
131 uint32_t addTensor(const TensorParams ¶ms, const SparsityParams &sp);
132 void setInputsAndOutputs(const std::vector<int> &inputs, const std::vector<int> &outputs);
133 uint32_t nextSubgraph();
134 CircleBuffer finish();
136 // ===== Add Operator methods begin (SORTED IN ALPHABETICAL ORDER) =====
138 uint32_t addOperatorAdd(const OperatorParams ¶ms, circle::ActivationFunctionType actfn);
139 uint32_t addOperatorAddN(const OperatorParams ¶ms);
140 uint32_t addOperatorArgMax(const OperatorParams ¶ms,
141 circle::TensorType output_type = circle::TensorType::TensorType_INT32);
142 uint32_t addOperatorAveragePool2D(const OperatorParams ¶ms, circle::Padding padding,
143 int stride_w, int stride_h, int filter_w, int filter_h,
144 circle::ActivationFunctionType actfn);
145 uint32_t addOperatorCast(const OperatorParams ¶ms, circle::TensorType input_type,
146 circle::TensorType output_type);
147 uint32_t addOperatorConcatenation(const OperatorParams ¶ms, int axis,
148 circle::ActivationFunctionType actfn);
149 uint32_t addOperatorCos(const OperatorParams ¶ms);
150 uint32_t addOperatorDepthwiseConv2D(const OperatorParams ¶ms, circle::Padding padding,
151 int stride_w, int stride_h, int depth_multiplier,
152 circle::ActivationFunctionType actfn, int dilation_w = 1,
154 uint32_t addOperatorEqual(const OperatorParams ¶ms);
155 uint32_t addOperatorFill(const OperatorParams ¶ms);
156 uint32_t addOperatorFloor(const OperatorParams ¶ms);
157 uint32_t addOperatorFullyConnected(const OperatorParams ¶ms,
158 circle::FullyConnectedOptionsWeightsFormat weights_format =
159 circle::FullyConnectedOptionsWeightsFormat_DEFAULT);
160 uint32_t addOperatorIf(const OperatorParams ¶ms, uint32_t then_subg, uint32_t else_subg);
161 uint32_t addOperatorInstanceNorm(const OperatorParams ¶ms, float epsilon,
162 circle::ActivationFunctionType actfn);
163 uint32_t addOperatorL2Normalization(const OperatorParams ¶ms);
164 uint32_t addOperatorLeakyRelu(const OperatorParams ¶ms, float alpha);
165 uint32_t addOperatorLess(const OperatorParams ¶ms);
166 uint32_t addOperatorLogSoftmax(const OperatorParams ¶ms);
167 uint32_t addOperatorNeg(const OperatorParams ¶ms);
168 uint32_t addOperatorOneHot(const OperatorParams ¶ms, int32_t axis);
169 uint32_t addOperatorPad(const OperatorParams ¶ms);
170 uint32_t addOperatorPadV2(const OperatorParams ¶ms);
171 uint32_t addOperatorRank(const OperatorParams ¶ms);
172 uint32_t addOperatorReduce(const OperatorParams ¶ms, circle::BuiltinOperator reduce_op,
175 * @brief Create circle Reshape op
176 * the second param new_shape can be optional just like circle::CreateReshapeOptionsDirect
178 uint32_t addOperatorReshape(const OperatorParams ¶ms, const Shape *new_shape = nullptr);
179 uint32_t addOperatorResizeBilinear(const OperatorParams ¶ms, bool align_corners = false,
180 bool half_pixel_centers = false);
181 uint32_t addOperatorResizeNearestNeighbor(const OperatorParams ¶ms);
182 uint32_t addOperatorReverseV2(const OperatorParams ¶ms);
183 uint32_t addOperatorShape(const OperatorParams ¶ms,
184 circle::TensorType type = circle::TensorType::TensorType_INT32);
185 uint32_t addOperatorSelect(const OperatorParams ¶ms);
186 uint32_t addOperatorSelectV2(const OperatorParams ¶ms);
187 uint32_t addOperatorSplit(const OperatorParams ¶ms, int32_t num_split);
188 uint32_t addOperatorStridedSlice(const OperatorParams ¶ms, int32_t begin_mask = 0,
189 int32_t end_mask = 0, int32_t ellipsis_mask = 0,
190 int32_t new_axis_mask = 0, int32_t shrink_axis_mask = 0);
191 uint32_t addOperatorTile(const OperatorParams ¶ms);
192 uint32_t addOperatorTranspose(const OperatorParams ¶ms);
193 uint32_t addOperatorWhile(const OperatorParams ¶ms, uint32_t cond_subg, uint32_t body_subg);
195 // NOTE Please add addOperator functions ABOVE this line in ALPHABETICAL ORDER
196 // ===== Add Operator methods end =====
199 uint32_t addOperatorWithOptions(const OperatorParams ¶ms, circle::BuiltinOperator opcode,
200 circle::BuiltinOptions options_type,
201 flatbuffers::Offset<void> options);
202 uint32_t addOperatorCode(circle::BuiltinOperator opcode);
203 flatbuffers::Offset<circle::Buffer> buildBuffer(const uint8_t *buf, size_t size);
204 flatbuffers::Offset<circle::Tensor> buildTensor(const TensorParams ¶ms);
205 flatbuffers::Offset<circle::Tensor> buildTensor(const TensorParams ¶ms, float scale,
207 flatbuffers::Offset<circle::SparsityParameters> buildSparsityParameters(const SparsityParams &sp);
208 flatbuffers::Offset<circle::Tensor> buildTensor(const TensorParams ¶ms,
209 const SparsityParams &sp);
210 flatbuffers::Offset<circle::SubGraph> buildSubGraph(const SubgraphContext &ctx);
212 SubgraphContext &curSubgCtx() { return _subgraph_contexts.back(); }
215 flatbuffers::FlatBufferBuilder _fbb{1024};
216 std::vector<flatbuffers::Offset<circle::Buffer>> _buffers;
217 std::vector<flatbuffers::Offset<circle::OperatorCode>> _opcodes;
218 std::vector<SubgraphContext> _subgraph_contexts;
221 #endif // __NNFW_API_TEST_CIRCLE_GEN_H__