2 * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #ifndef __LOCOEX_IR_TFLNODES_H__
18 #define __LOCOEX_IR_TFLNODES_H__
20 #include "TFLNodeDecl.h"
21 #include "TFLOpcode.h"
23 #include "FusedActFunc.h"
24 #include "NodeMixins.h"
26 #include <loco/IR/Node.h>
27 #include <loco/IR/NodeMixins.h>
28 #include <loco/IR/DataTypeTraits.h>
30 #include <locoex/VariadicArityNode.h>
39 UNDEFINED, // This is not defined by TFLite. This was added to prevent programming error.
47 Filter() : _w(1), _h(1) {}
49 int32_t w() const { return _w; }
50 void w(int32_t w) { _w = w; }
52 int32_t h() const { return _h; }
53 void h(int32_t h) { _h = h; }
63 Stride() : _w(1), _h(1) {}
65 int32_t w() const { return _w; }
66 void w(int32_t w) { _w = w; }
68 int32_t h() const { return _h; }
69 void h(int32_t h) { _h = h; }
76 /// @brief enumeration of mixin class
77 enum class TFLNodeTrait
83 template <TFLNodeTrait T> class TFLNodeMixin;
85 template <> class TFLNodeMixin<TFLNodeTrait::FusedActFunc>
88 TFLNodeMixin() = default;
91 FusedActFunc fusedActivationFunction() const { return _fused_act_fun; }
92 void fusedActivationFunction(FusedActFunc fused_act_fun) { _fused_act_fun = fused_act_fun; }
95 FusedActFunc _fused_act_fun = FusedActFunc::UNDEFINED;
99 * @brief Mixin class for nodes that has a bias input
101 template <> class TFLNodeMixin<TFLNodeTrait::Bias>
104 TFLNodeMixin() = default;
107 virtual loco::Node *bias(void) const = 0; /// @brief get the input for bias.
108 virtual void bias(loco::Node *node) = 0; /// @brief set the input for bias.
112 * @brief ADD in TensorFlow Lite
114 class TFLAdd final : public FixedArityNode<2, TFLNodeImpl<TFLOpcode::ADD>>,
115 public TFLNodeMixin<TFLNodeTrait::FusedActFunc>
118 loco::Node *x(void) const { return at(0)->node(); }
119 void x(loco::Node *node) { at(0)->node(node); }
121 loco::Node *y(void) const { return at(1)->node(); }
122 void y(loco::Node *node) { at(1)->node(node); }
126 * @brief AVERAGE_POOL_2D in TensorFlow Lite
128 class TFLAveragePool2D final : public FixedArityNode<1, TFLNodeImpl<TFLOpcode::AVERAGE_POOL_2D>>,
129 public TFLNodeMixin<TFLNodeTrait::FusedActFunc>
132 TFLAveragePool2D() : _padding(Padding::UNDEFINED) { /* empty */}
135 loco::Node *value(void) const { return at(0)->node(); }
136 void value(loco::Node *node) { at(0)->node(node); }
138 Padding padding() const { return _padding; }
139 void padding(Padding padding) { _padding = padding; }
141 const Filter *filter(void) const { return &_filter; }
142 Filter *filter(void) { return &_filter; }
144 const Stride *stride(void) const { return &_stride; }
145 Stride *stride(void) { return &_stride; }
154 * @brief CONCATENATION in TensorFlow Lite
156 class TFLConcatenation final : public VariadicArityNode<TFLNodeImpl<TFLOpcode::CONCATENATION>>,
157 public TFLNodeMixin<TFLNodeTrait::FusedActFunc>
160 TFLConcatenation(uint32_t arity) : VariadicArityNode<TFLNodeImpl<TFLOpcode::CONCATENATION>>(arity)
162 // TODO Support when arity is 0
167 uint32_t numValues(void) const { return arity(); }
170 Node *values(uint32_t index) const
172 assert(index < numValues());
173 return at(index)->node();
175 void values(uint32_t index, Node *node)
177 assert(index < numValues());
178 at(index)->node(node);
182 uint32_t axis(void) const { return _axis; }
183 void axis(uint32_t axis) { _axis = axis; }
190 * @brief Class to build tensor data
191 * @note This will not be exported as a specific op
193 class TFLConst final : public FixedArityNode<0, TFLNodeImpl<TFLOpcode::CONST>>,
194 public loco::NodeMixin<loco::NodeTrait::DataType>,
195 public loco::NodeMixin<loco::NodeTrait::TensorShape>
198 TFLConst() = default;
201 template <loco::DataType DT> uint32_t size(void) const;
202 template <loco::DataType DT> void size(uint32_t size);
203 template <loco::DataType DT> const typename loco::DataTypeImpl<DT>::Type &at(uint32_t n) const;
204 template <loco::DataType DT> typename loco::DataTypeImpl<DT>::Type &at(uint32_t n);
207 std::vector<uint8_t> _data;
211 * @brief CONV_2D in TensorFlow Lite
213 class TFLConv2D final : public FixedArityNode<3, TFLNodeImpl<TFLOpcode::CONV_2D>>,
214 public TFLNodeMixin<TFLNodeTrait::FusedActFunc>,
215 public TFLNodeMixin<TFLNodeTrait::Bias>
218 loco::Node *input(void) const { return at(0)->node(); }
219 void input(loco::Node *node) { at(0)->node(node); }
221 loco::Node *filter(void) const { return at(1)->node(); }
222 void filter(loco::Node *node) { at(1)->node(node); }
224 loco::Node *bias(void) const override { return at(2)->node(); }
225 void bias(loco::Node *node) override { at(2)->node(node); }
228 Padding padding() const { return _padding; }
229 void padding(Padding padding) { _padding = padding; }
231 const Stride *stride(void) const { return &_stride; }
232 Stride *stride(void) { return &_stride; }
235 Padding _padding = Padding::UNDEFINED;
240 * @brief DEPTHWISE_CONV_2D in TensorFlow Lite
242 class TFLDepthwiseConv2D final
243 : public FixedArityNode<3, TFLNodeImpl<TFLOpcode::DEPTHWISE_CONV_2D>>,
244 public TFLNodeMixin<TFLNodeTrait::FusedActFunc>,
245 public TFLNodeMixin<TFLNodeTrait::Bias>
248 loco::Node *input(void) const { return at(0)->node(); }
249 void input(loco::Node *node) { at(0)->node(node); }
251 loco::Node *filter(void) const { return at(1)->node(); }
252 void filter(loco::Node *node) { at(1)->node(node); }
254 loco::Node *bias(void) const override { return at(2)->node(); }
255 void bias(loco::Node *node) override { at(2)->node(node); }
258 Padding padding() const { return _padding; }
259 void padding(Padding padding) { _padding = padding; }
261 const Stride *stride(void) const { return &_stride; }
262 Stride *stride(void) { return &_stride; }
264 int32_t depthMultiplier(void) const { return _depth_multiplier; }
265 void depthMultiplier(int32_t arg) { _depth_multiplier = arg; }
268 Padding _padding = Padding::UNDEFINED;
270 int32_t _depth_multiplier = 0;
274 * @brief DIV in TensorFlow Lite
276 class TFLDiv final : public FixedArityNode<2, TFLNodeImpl<TFLOpcode::DIV>>,
277 public TFLNodeMixin<TFLNodeTrait::FusedActFunc>
283 loco::Node *x(void) const { return at(0)->node(); }
284 void x(loco::Node *node) { at(0)->node(node); }
286 loco::Node *y(void) const { return at(1)->node(); }
287 void y(loco::Node *node) { at(1)->node(node); }
291 * @brief FULLY_CONNECTED in TensorFlow Lite
293 class TFLFullyConnected final : public FixedArityNode<3, TFLNodeImpl<TFLOpcode::FULLY_CONNECTED>>,
294 public TFLNodeMixin<TFLNodeTrait::FusedActFunc>,
295 public TFLNodeMixin<TFLNodeTrait::Bias>
298 loco::Node *input(void) const { return at(0)->node(); }
299 void input(loco::Node *node) { at(0)->node(node); }
301 loco::Node *weights(void) const { return at(1)->node(); }
302 void weights(loco::Node *node) { at(1)->node(node); }
304 loco::Node *bias(void) const override { return at(2)->node(); }
305 void bias(loco::Node *node) override { at(2)->node(node); }
309 * @brief MAXIMUM in TensorFlow Lite
311 class TFLMaximum final : public FixedArityNode<2, TFLNodeImpl<TFLOpcode::MAXIMUM>>
314 loco::Node *x(void) const { return at(0)->node(); }
315 void x(loco::Node *node) { at(0)->node(node); }
317 loco::Node *y(void) const { return at(1)->node(); }
318 void y(loco::Node *node) { at(1)->node(node); }
322 * @brief MAX_POOL_2D in TensorFlow Lite
324 class TFLMaxPool2D final : public FixedArityNode<1, TFLNodeImpl<TFLOpcode::MAX_POOL_2D>>,
325 public TFLNodeMixin<TFLNodeTrait::FusedActFunc>
328 TFLMaxPool2D() : _padding(Padding::UNDEFINED) { /* empty */}
331 loco::Node *value(void) const { return at(0)->node(); }
332 void value(loco::Node *node) { at(0)->node(node); }
334 Padding padding() const { return _padding; }
335 void padding(Padding padding) { _padding = padding; }
337 const Filter *filter(void) const { return &_filter; }
338 Filter *filter(void) { return &_filter; }
340 const Stride *stride(void) const { return &_stride; }
341 Stride *stride(void) { return &_stride; }
349 class TFLMean final : public FixedArityNode<2, TFLNodeImpl<TFLOpcode::MEAN>>
352 loco::Node *input(void) const { return at(0)->node(); }
353 void input(loco::Node *node) { at(0)->node(node); }
355 loco::Node *reduction_indices(void) const { return at(1)->node(); }
356 void reduction_indices(loco::Node *node) { at(1)->node(node); }
359 bool keep_dims(void) const { return _keep_dims; }
360 void keep_dims(bool keep_dims) { _keep_dims = keep_dims; }
363 bool _keep_dims = false;
367 * @brief MUL in TensorFlow Lite
369 class TFLMul final : public FixedArityNode<2, TFLNodeImpl<TFLOpcode::MUL>>,
370 public TFLNodeMixin<TFLNodeTrait::FusedActFunc>
373 loco::Node *x(void) const { return at(0)->node(); }
374 void x(loco::Node *node) { at(0)->node(node); }
376 loco::Node *y(void) const { return at(1)->node(); }
377 void y(loco::Node *node) { at(1)->node(node); }
380 class TFLRelu final : public FixedArityNode<1, TFLNodeImpl<TFLOpcode::RELU>>
386 loco::Node *features(void) const { return at(0)->node(); }
387 void features(loco::Node *node) { at(0)->node(node); }
390 class TFLRelu6 final : public FixedArityNode<1, TFLNodeImpl<TFLOpcode::RELU6>>
393 TFLRelu6() = default;
396 loco::Node *features(void) const { return at(0)->node(); }
397 void features(loco::Node *node) { at(0)->node(node); }
400 class TFLReshape final : public FixedArityNode<2, TFLNodeImpl<TFLOpcode::RESHAPE>>
403 TFLReshape() = default;
406 loco::Node *tensor(void) const { return at(0)->node(); }
407 void tensor(loco::Node *node) { at(0)->node(node); }
409 // TODO Make this input optional. That is, loco system does not emit error
410 // with this input being null
411 loco::Node *shape(void) const { return at(1)->node(); }
412 void shape(loco::Node *node) { at(1)->node(node); }
418 uint32_t rank(void) const { return _shape.size(); }
419 void rank(uint32_t rank) { _shape.resize(rank); }
421 int32_t dim(uint32_t n) const { return _shape.at(n); }
422 int32_t &dim(uint32_t n) { return _shape.at(n); }
425 std::vector<int32_t> _shape;
428 const Shape *newShape(void) const { return &_new_shape; }
429 Shape *newShape(void) { return &_new_shape; }
436 * @brief Set both TFLReshape's 2nd input as TFLConst, and newShape attribute
438 * @note Shape inference for TFLReshape forces them to be same
439 * TODO find better place for this helper
441 void set_new_shape(locoex::TFLReshape *node, int32_t *base, uint32_t size);
443 class TFLRsqrt final : public FixedArityNode<1, TFLNodeImpl<TFLOpcode::RSQRT>>
446 TFLRsqrt() = default;
449 loco::Node *x(void) const { return at(0)->node(); }
450 void x(loco::Node *node) { at(0)->node(node); }
455 class TFLSqrt final : public FixedArityNode<1, TFLNodeImpl<TFLOpcode::SQRT>>
461 loco::Node *x(void) const { return at(0)->node(); }
462 void x(loco::Node *node) { at(0)->node(node); }
465 class TFLSquaredDifference final
466 : public FixedArityNode<2, TFLNodeImpl<TFLOpcode::SQUARED_DIFFERENCE>>
469 TFLSquaredDifference() = default;
472 loco::Node *x(void) const { return at(0)->node(); }
473 void x(loco::Node *node) { at(0)->node(node); }
475 loco::Node *y(void) const { return at(1)->node(); }
476 void y(loco::Node *node) { at(1)->node(node); }
480 * @brief SUB in TensorFlow Lite
482 class TFLSub final : public FixedArityNode<2, TFLNodeImpl<TFLOpcode::SUB>>,
483 public TFLNodeMixin<TFLNodeTrait::FusedActFunc>
489 loco::Node *x(void) const { return at(0)->node(); }
490 void x(loco::Node *node) { at(0)->node(node); }
492 loco::Node *y(void) const { return at(1)->node(); }
493 void y(loco::Node *node) { at(1)->node(node); }
499 * @brief TRANSPOSE in TensorFlow Lite
501 class TFLTranspose final : public FixedArityNode<2, TFLNodeImpl<TFLOpcode::TRANSPOSE>>
504 TFLTranspose() = default;
507 /// @brief Get the input node to transpose
508 loco::Node *a(void) const { return at(0)->node(); }
510 /// @brief Set the input node to transpose
511 void a(loco::Node *node) { at(0)->node(node); }
513 loco::Node *perm(void) const { return at(1)->node(); }
514 void perm(loco::Node *node) { at(1)->node(node); }
518 * @brief TRANSPOSE_CONV in TensorFlow Lite
520 * @note Argument node function names are from TensorFlow. So refering 'in' and
521 * 'out' acutally means 'out' and 'in' of the this node.
523 class TFLTransposeConv final : public FixedArityNode<3, TFLNodeImpl<TFLOpcode::TRANSPOSE_CONV>>
526 loco::Node *inputSizes(void) const { return at(0)->node(); }
527 void inputSizes(Node *node) { at(0)->node(node); }
529 loco::Node *filter(void) const { return at(1)->node(); }
530 void filter(Node *node) { at(1)->node(node); }
532 loco::Node *outBackprop(void) const { return at(2)->node(); }
533 void outBackprop(Node *node) { at(2)->node(node); }
536 const Padding &padding(void) const { return _padding; }
537 void padding(const Padding &padding) { _padding = padding; }
539 const Stride *stride(void) const { return &_stride; }
540 Stride *stride(void) { return &_stride; }
543 Padding _padding{Padding::UNDEFINED};
547 // TODO define more children of TFLNode
549 } // namespace locoex
551 #endif // __LOCOEX_IR_TFLNODES_H__