2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #ifndef __NNFW_API_TEST_CIRCLE_GEN_H__
18 #define __NNFW_API_TEST_CIRCLE_GEN_H__
20 #include <circle_schema_generated.h>
25 * @brief Class for storing flatbuffer buffer
27 * This is a simple wrapper for a finished FlatBufferBuilder. It owns the buffer and a user can
28 * get the buffer pointer and size.
33 CircleBuffer() = default;
34 explicit CircleBuffer(flatbuffers::FlatBufferBuilder &&fbb) : _fbb{std::move(fbb)}
36 _fbb.Finished(); // The build must have been finished, so check that here
39 uint8_t *buffer() const { return _fbb.GetBufferPointer(); }
40 size_t size() const { return _fbb.GetSize(); }
43 flatbuffers::FlatBufferBuilder _fbb;
47 * @brief Circle flatbuffer file generator
49 * This is a helper class for generating circle file.
57 std::vector<int32_t> shape;
58 circle::TensorType tensor_type = circle::TensorType::TensorType_FLOAT32;
65 std::vector<int32_t> inputs;
66 std::vector<int32_t> outputs;
70 struct SubgraphContext
72 std::vector<int> inputs;
73 std::vector<int> outputs;
74 std::vector<flatbuffers::Offset<circle::Tensor>> tensors;
75 std::vector<flatbuffers::Offset<circle::Operator>> operators;
81 template <typename T> uint32_t addBuffer(const std::vector<T> &buf_vec)
83 auto buf = reinterpret_cast<const uint8_t *>(buf_vec.data());
84 auto size = buf_vec.size() * sizeof(T);
85 return addBuffer(buf, size);
87 uint32_t addBuffer(const uint8_t *buf, size_t size);
88 uint32_t addTensor(const TensorParams ¶ms);
89 void setInputsAndOutputs(const std::vector<int> &inputs, const std::vector<int> &outputs);
90 uint32_t nextSubgraph();
91 CircleBuffer finish();
93 // ===== Add Operator methods begin =====
95 uint32_t addOperatorAdd(const OperatorParams ¶ms, circle::ActivationFunctionType actfn);
96 uint32_t addOperatorAveragePool2D(const OperatorParams ¶ms, circle::Padding padding,
97 int stride_w, int stride_h, int filter_w, int filter_h,
98 circle::ActivationFunctionType actfn);
99 uint32_t addOperatorConcatenation(const OperatorParams ¶ms, int axis,
100 circle::ActivationFunctionType actfn);
101 uint32_t addOperatorCos(const OperatorParams ¶ms);
102 uint32_t addOperatorL2Normalization(const OperatorParams ¶ms);
103 uint32_t addOperatorLeakyRelu(const OperatorParams ¶ms, float alpha);
104 uint32_t addOperatorLess(const OperatorParams ¶ms);
105 uint32_t addOperatorNeg(const OperatorParams ¶ms);
106 uint32_t addOperatorPad(const OperatorParams ¶ms);
107 uint32_t addOperatorPadV2(const OperatorParams ¶ms);
108 uint32_t addOperatorRank(const OperatorParams ¶ms);
109 uint32_t addOperatorResizeNearestNeighbor(const OperatorParams ¶ms);
110 uint32_t addOperatorWhile(const OperatorParams ¶ms, uint32_t cond_subg, uint32_t body_subg);
112 // NOTE Please add addOperator functions ABOVE this lie
113 // ===== Add Operator methods end =====
116 uint32_t addOperatorWithOptions(const OperatorParams ¶ms, circle::BuiltinOperator opcode,
117 circle::BuiltinOptions options_type,
118 flatbuffers::Offset<void> options);
119 uint32_t addOperatorCode(circle::BuiltinOperator opcode);
120 flatbuffers::Offset<circle::Buffer> buildBuffer(const uint8_t *buf, size_t size);
121 flatbuffers::Offset<circle::Tensor> buildTensor(const TensorParams ¶ms);
122 flatbuffers::Offset<circle::SubGraph> buildSubGraph(const SubgraphContext &ctx);
124 SubgraphContext &curSubgCtx() { return _subgraph_contexts.back(); }
127 flatbuffers::FlatBufferBuilder _fbb{1024};
128 std::vector<flatbuffers::Offset<circle::Buffer>> _buffers;
129 std::vector<flatbuffers::Offset<circle::OperatorCode>> _opcodes;
130 std::vector<SubgraphContext> _subgraph_contexts;
133 #endif // __NNFW_API_TEST_CIRCLE_GEN_H__