Imported Upstream version 1.7.0
[platform/core/ml/nnfw.git] / compiler / loco / include / loco / IR / Nodes.h
1 /*
2  * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #ifndef __LOCO_IR_NODES_H__
18 #define __LOCO_IR_NODES_H__
19
20 #include "loco/IR/Node.h"
21 #include "loco/IR/Use.h"
22 #include "loco/IR/Domain.h"
23 #include "loco/IR/DataType.h"
24 #include "loco/IR/DataTypeTraits.h"
25 #include "loco/IR/Dimension.h"
26 #include "loco/IR/Window.h"
27 #include "loco/IR/Stride.h"
28 #include "loco/IR/Padding2D.h"
29 #include "loco/IR/PaddingND.h"
30 #include "loco/IR/TensorAxis.h"
31 #include "loco/IR/TensorAxisSet.h"
32 #include "loco/IR/FeatureCodec.h"
33 #include "loco/IR/FilterCodec.h"
34 #include "loco/IR/DepthwiseFilterCodec.h"
35 #include "loco/IR/MatrixCodec.h"
36 #include "loco/IR/NodeMixins.h"
37 #include "loco/IR/CanonicalNodeDecl.h"
38 #include "loco/IR/GraphInputIndex.h"
39 #include "loco/IR/GraphOutputIndex.h"
40
41 namespace loco
42 {
43
44 class Graph;
45 class GraphInput;
46 class GraphOutput;
47
48 /**
49  * @brief Make a value visible to user
50  */
51 class Push /* to user */ final
52     : public CanonicalNodeDef<CanonicalOpcode::Push, FixedArity<1>::Mixin>
53 {
54 public:
55   Push() = default;
56
57 public:
58   Node *from(void) const { return at(0)->node(); }
59   void from(Node *node) { at(0)->node(node); }
60
61 public:
62   void index(const GraphOutputIndex &index);
63
64   /**
65    * @brief Get associated output index
66    *
67    * The behavior of this method is undefined when "index" is not set before.
68    *
69    * NOTE This method intentionally returns "GraphOutputIndex" instead of "const GraphOutputIndex &"
70    *      not to expose the internal implementation details.
71    */
72   GraphOutputIndex index(void) const;
73
74   /**
75    * @brief Check whether index is initialized
76    *
77    * NOTE "indexed" method does not validate whether index is in a valid range
78    */
79   bool indexed(void) const { return _index != -1; }
80
81 private:
82   int64_t _index = -1; // Uninitialized
83 };
84
85 void link(GraphOutput *, Push *push);
86
87 /// @brief Find a Push node with a given output index
88 Push *push_node(Graph *g, const GraphOutputIndex &index);
89
90 /**
91  * @brief Create a value from user data
92  */
93 class Pull /* from user */ final
94     : public CanonicalNodeDef<CanonicalOpcode::Pull, FixedArity<0>::Mixin,
95                               With<NodeTrait::TensorShape>::Mixin>
96 {
97 public:
98   Pull() = default;
99
100 public:
101   void index(const GraphInputIndex &index);
102
103   /**
104    * @brief Get associated input index
105    *
106    * The behavior of this method is undefined when "index" is not set before.
107    *
108    * NOTE This method intentionally returns "GraphInputIndex" instead of "const GraphInputIndex &"
109    *      not to expose the internal implementation details.
110    */
111   GraphInputIndex index(void) const;
112
113   /**
114    * @brief Check whether index is initialized
115    *
116    * NOTE "indexed" method does not validate whether index is in a valid range
117    */
118   bool indexed(void) const { return _index != -1; }
119
120 public:
121   void dtype(const DataType &d);
122   DataType dtype(void) const;
123
124 private:
125   int64_t _index = -1; // Uninitialized
126
127   /**
128    * @brief Locally cached data type attribute
129    *
130    * TODO Remove this cache once all the clients are updated
131    */
132   DataType _dtype = DataType::Unknown;
133 };
134
135 void link(GraphInput *, Pull *pull);
136
137 /// @brief Find a Pull node with a given input index
138 Pull *pull_node(Graph *g, const GraphInputIndex &index);
139
140 /**
141  * @brief Create a new value identical to its input
142  *
143  * This node may encode memory transfer (such as CPU -> GPU or GPU -> CPU)
144  */
145 class Forward final : public CanonicalNodeDef<CanonicalOpcode::Forward, FixedArity<1>::Mixin>
146 {
147 public:
148   Forward() = default;
149
150 public:
151   Node *input(void) const { return at(0)->node(); }
152   void input(Node *node) { at(0)->node(node); }
153 };
154
155 /**
156  * @brief Create a new value that rectifies its input
157  */
158 class ReLU final : public CanonicalNodeDef<CanonicalOpcode::ReLU, FixedArity<1>::Mixin>
159 {
160 public:
161   ReLU() = default;
162
163 public:
164   Node *input(void) const { return at(0)->node(); }
165   void input(Node *node) { at(0)->node(node); }
166 };
167
168 /**
169  * @brief Create a new value that rectifies its input capping the units at 6.
170  */
171 class ReLU6 final : public CanonicalNodeDef<CanonicalOpcode::ReLU6, FixedArity<1>::Mixin>
172 {
173 public:
174   ReLU6() = default;
175
176 public:
177   Node *input(void) const { return at(0)->node(); }
178   void input(Node *node) { at(0)->node(node); }
179 };
180
181 /**
182  * @brief Create a new value that rectifies its input by tanh
183  */
184 class Tanh final : public CanonicalNodeDef<CanonicalOpcode::Tanh, FixedArity<1>::Mixin>
185 {
186 public:
187   Tanh() = default;
188
189 public:
190   Node *input(void) const { return at(0)->node(); }
191   void input(Node *node) { at(0)->node(node); }
192 };
193
194 /**
195  * @brief Create a value from constant byte array
196  *
197  * @note ConstGen assumes "lexical memory layout".
198  *
199  * Let us assume that a 'ConstGen' generates a constant tensor of shape "S".
200  * for each valid index I, the corresponding value comes from offset(S, I)
201  * where the implementation of "offset" is given as follows:
202  *
203  * uint32_t stride(TensorShape shape, uint32_t axis) {
204  *   uint32_t res = 1;
205  *   for (uint32_t n = rank(shape) - 1; n > axis; --n) { res *= shape.dim(n); }
206  *   return res;
207  * }
208  *
209  * uint32_t offset(TensorShape shape, TensorIndex index) {
210  *   uint32_t res = 0;
211  *   for (uint32_t n = 0; n < rank(shape); ++n) { res += index.at(n) * stride(shape, n); }
212  *   return res;
213  * }
214  */
215 class ConstGen final
216     : public CanonicalNodeDef<CanonicalOpcode::ConstGen, FixedArity<0>::Mixin,
217                               With<NodeTrait::DataType>::Mixin, With<NodeTrait::TensorShape>::Mixin>
218 {
219 public:
220   ConstGen() = default;
221
222 public:
223   /**
224    * @brief Return the number of reserved elements
225    * @note This method returns the number of ELEMENT (not BYTE).
226    */
227   template <DataType DT> uint32_t size(void) const;
228
229   /**
230    * @brief Adjust the number of reserved elements
231    */
232   template <DataType DT> void size(uint32_t size);
233
234   /**
235    * @brief Get the element at a given position
236    * @require at(n) is valid only when n < size()
237    */
238   template <DataType DT> const typename DataTypeImpl<DT>::Type &at(uint32_t n) const;
239
240   /**
241    * @brief Update the element at a given position
242    * @require at(n) is valid only when n < size()
243    */
244   template <DataType DT> typename DataTypeImpl<DT>::Type &at(uint32_t n);
245
246 private:
247   /// @brief Data
248   std::vector<uint8_t> _data;
249 };
250
251 /**
252  * @brief 2D Max Pooling
253  *
254  * MaxPool2D takes as input a feature map, and produces another feature map
255  *
256  * ---
257  * Any valid MaxPool2D nodes SHOULD satisfy the following conditions.
258  *
259  * Let us define several helper functions that takes a MaxPool2D nodes first:
260  * - IFM_DOMAIN returns the domain of its input
261  * - IFM_H returns the height of its input.
262  * - IFM_W returns the width of its input.
263  * - PAD_T returns the top padding required over its input
264  * - PAD_B returns the bottom padding required over its input
265  * - PAD_L returns the left padding required over its input
266  * - PAD_R returns the right padding required over its input
267  * - WIN_H returns the height of its receptive field.
268  * - WIN_W returns the width of its receptive field.
269  * - STRIDE_H returns the vertical(= on height) stride.
270  * - STRIDE_W returns the horizontal(= on width) stride.
271  *
272  * Condition 1
273  *   Statement
274  *
275  *   A valid MaxPool2D node M SHOULD satisfy the following condition:
276  *   - IFM_DOMAIN(M) == Feature
277  *
278  *   Motivation
279  *
280  *   There are many possible ways to encode a feature map as a tensor.
281  *   - e.g. NCHW/NHWC/...
282  *
283  *   In order to give some freedom on memory layout to backend, loco requires a feature map
284  *   value to be explicitly encoded via FeatureEncode.
285  *
286  * Condition 2:
287  *   Statement
288  *
289  *   A valid MaxPool2D node M SHOULD satisfy the following conditions:
290  *   - (IFM_H(M) + PAD_T(M) + PAD_B(M) - WIN_H(M)) % STRIDE_H(M) == 0
291  *   - (IFM_W(M) + PAD_L(M) + PAD_R(M) - WIN_W(M)) % STRIDE_W(M) == 0
292  *
293  *   Motivation
294  *
295  *   The output shape may differ for each NN framework when these conditions do not hold.
296  *
297  *   In order to mitigate such a difference among NN frameworks, loco requires these conditions
298  *   for MaxPool2D nodes.
299  *
300  *   This means that each frontend implementation SHOULD insert appropriate padding/trimming node
301  *   before/after MaxPool2D node according to the semantics of the corresponding NN framework.
302  * ---
303  */
304 class MaxPool2D final : public CanonicalNodeDef<CanonicalOpcode::MaxPool2D, FixedArity<1>::Mixin>
305 {
306 public:
307   Node *ifm(void) const { return at(0)->node(); }
308   void ifm(Node *node) { at(0)->node(node); }
309
310 public:
311   const Padding2D *pad(void) const { return &_pad; }
312   Padding2D *pad(void) { return &_pad; }
313
314 public:
315   const Window<2> *window(void) const { return &_window; }
316   Window<2> *window(void) { return &_window; }
317
318 public:
319   const Stride<2> *stride(void) const { return &_stride; }
320   Stride<2> *stride(void) { return &_stride; }
321
322 private:
323   // Pad
324   Padding2D _pad;
325   // Window
326   Window<2> _window;
327   // Stride
328   Stride<2> _stride;
329 };
330
331 /**
332  * @brief 2D Average Pooling
333  *
334  * @note Follows MaxPool2D (TODO: describe difference)
335  */
336 class AvgPool2D final : public CanonicalNodeDef<CanonicalOpcode::AvgPool2D, FixedArity<1>::Mixin>
337 {
338 public:
339   enum class Convention
340   {
341     Unknown,
342     // Use the number of elements in each receptive field as a divisor
343     Full,
344     // Use the number of valid (non-padding) elements in each receptive field as a divisor
345     Valid
346   };
347
348 public:
349   Node *ifm(void) const { return at(0)->node(); }
350   void ifm(Node *node) { at(0)->node(node); }
351
352 public:
353   Convention convention(void) const { return _convention; }
354   void convention(const Convention &convention) { _convention = convention; }
355
356 public:
357   const Padding2D *pad(void) const { return &_pad; }
358   Padding2D *pad(void) { return &_pad; }
359
360 public:
361   const Window<2> *window(void) const { return &_window; }
362   Window<2> *window(void) { return &_window; }
363
364 public:
365   const Stride<2> *stride(void) const { return &_stride; }
366   Stride<2> *stride(void) { return &_stride; }
367
368 private:
369   Convention _convention = Convention::Unknown;
370   Padding2D _pad;
371   Window<2> _window;
372   Stride<2> _stride;
373 };
374
375 /**
376  * @brief Create a feature map from a tensor
377  */
378 class FeatureEncode final
379     : public CanonicalNodeDef<CanonicalOpcode::FeatureEncode, FixedArity<1>::Mixin>
380 {
381 public:
382   Node *input(void) const { return at(0)->node(); }
383   void input(Node *node) { at(0)->node(node); }
384
385 public:
386   FeatureEncoder *encoder(void) const { return _enc.get(); }
387   void encoder(std::unique_ptr<FeatureEncoder> &&enc) { _enc = std::move(enc); }
388
389 private:
390   /// @note "encoder" is mandatory
391   std::unique_ptr<FeatureEncoder> _enc{nullptr};
392 };
393
394 /**
395  * @brief Create a tensor from a feature map
396  */
397 class FeatureDecode final
398     : public CanonicalNodeDef<CanonicalOpcode::FeatureDecode, FixedArity<1>::Mixin>
399 {
400 public:
401   Node *input(void) const { return at(0)->node(); }
402   void input(Node *node) { at(0)->node(node); }
403
404 public:
405   FeatureDecoder *decoder(void) const { return _dec.get(); }
406   void decoder(std::unique_ptr<FeatureDecoder> &&dec) { _dec = std::move(dec); }
407
408 private:
409   /// @NOTE "decoder" is mandatory
410   std::unique_ptr<FeatureDecoder> _dec{nullptr};
411 };
412
413 /**
414  * @brief Create a filter from a tensor
415  */
416 class FilterEncode final
417     : public CanonicalNodeDef<CanonicalOpcode::FilterEncode, FixedArity<1>::Mixin>
418 {
419 public:
420   Node *input(void) const { return at(0)->node(); }
421   void input(Node *node) { at(0)->node(node); }
422
423 public:
424   FilterEncoder *encoder(void) const { return _enc.get(); }
425   void encoder(std::unique_ptr<FilterEncoder> &&enc) { _enc = std::move(enc); }
426
427 private:
428   /// @note "encoder" is mandatory
429   std::unique_ptr<FilterEncoder> _enc{nullptr};
430 };
431
432 /**
433  * @brief Create a tensor from a filter
434  */
435 class FilterDecode final
436     : public CanonicalNodeDef<CanonicalOpcode::FilterDecode, FixedArity<1>::Mixin>
437 {
438 public:
439   Node *input(void) const { return at(0)->node(); }
440   void input(Node *node) { at(0)->node(node); }
441
442 public:
443   FilterDecoder *decoder(void) const { return _dec.get(); }
444   void decoder(std::unique_ptr<FilterDecoder> &&dec) { _dec = std::move(dec); }
445
446 private:
447   /// @note "decoder" is mandatory
448   std::unique_ptr<FilterDecoder> _dec{nullptr};
449 };
450
451 /**
452  * @brief Create a depthwise filter from a tensor
453  */
454 class DepthwiseFilterEncode final
455     : public CanonicalNodeDef<CanonicalOpcode::DepthwiseFilterEncode, FixedArity<1>::Mixin>
456 {
457 public:
458   Node *input(void) const { return at(0)->node(); }
459   void input(Node *node) { at(0)->node(node); }
460
461 public:
462   DepthwiseFilterEncoder *encoder(void) const { return _enc.get(); }
463   void encoder(std::unique_ptr<DepthwiseFilterEncoder> &&enc) { _enc = std::move(enc); }
464
465 private:
466   /// @note "encoder" is mandatory
467   std::unique_ptr<DepthwiseFilterEncoder> _enc{nullptr};
468 };
469
470 /**
471  * @brief Create a tensor from a depthwise filter
472  */
473 class DepthwiseFilterDecode final
474     : public CanonicalNodeDef<CanonicalOpcode::DepthwiseFilterDecode, FixedArity<1>::Mixin>
475 {
476 public:
477   Node *input(void) const { return at(0)->node(); }
478   void input(Node *node) { at(0)->node(node); }
479
480 public:
481   DepthwiseFilterDecoder *decoder(void) const { return _dec.get(); }
482   void decoder(std::unique_ptr<DepthwiseFilterDecoder> &&dec) { _dec = std::move(dec); }
483
484 private:
485   /// @note "decoder" is mandatory
486   std::unique_ptr<DepthwiseFilterDecoder> _dec{nullptr};
487 };
488
489 enum class ReshapeType
490 {
491   Fixed, // shape is known at compile time
492   // Add another type for a case when shape is not known at compile time
493 };
494
495 template <ReshapeType RT> class Reshape;
496
497 /**
498  * @brief Reshape a tensor to another tensor whose shape is known at compile time
499  *
500  * @note This class reshapes the shape of an input tensor to _shape.
501  *       Each dimension of _shape should be known at compile time.
502  *       Any dimension of _shape should be greater than 0.
503  *
504  *       Interpreter or runtime should lexicographically copy an input tensor into an output tensor.
505  *       For example, values of an input tesor of shape [2, 2, 2, 2] will be copied into an output
506  *       tensor of new shape [4, 4] like the following:
507  *         input[0, 0, 0, 0] => output [0, 0]
508  *         input[0, 0, 0, 1] => output [0, 1]
509  *         input[0, 0, 1, 0] => output [0, 2]
510  *         ...
511  *         input[1, 1, 1, 1] => output [3, 3]
512  */
513 template <>
514 class Reshape<ReshapeType::Fixed> final
515     : public CanonicalNodeDef<CanonicalOpcode::FixedReshape, FixedArity<1>::Mixin,
516                               With<NodeTrait::TensorShape>::Mixin>
517 {
518 public:
519   Node *input(void) const { return at(0)->node(); }
520   void input(Node *node) { at(0)->node(node); }
521 };
522
523 using FixedReshape = Reshape<ReshapeType::Fixed>;
524
525 /**
526  * @brief Concatenate two tensors
527  *
528  * Given an axis, TensorConcat takes as input two tensors and produces a tensor
529  * concatenated along the given axis.
530  */
531 class TensorConcat final
532     : public CanonicalNodeDef<CanonicalOpcode::TensorConcat, FixedArity<2>::Mixin>
533 {
534 public:
535   Node *lhs(void) const { return at(0)->node(); }
536   void lhs(Node *node) { at(0)->node(node); }
537
538   Node *rhs(void) const { return at(1)->node(); }
539   void rhs(Node *node) { at(1)->node(node); }
540
541 public:
542   uint32_t axis(void) const { return _axis; }
543   void axis(uint32_t val) { _axis = val; }
544
545 private:
546   // Axis
547   uint32_t _axis{0};
548 };
549
550 /**
551  * @brief 2D Spatial Convolution
552  */
553 class Conv2D final : public CanonicalNodeDef<CanonicalOpcode::Conv2D, FixedArity<2>::Mixin>
554 {
555 public:
556   Node *ifm(void) const { return at(0)->node(); }
557   void ifm(Node *node) { at(0)->node(node); }
558
559   Node *ker(void) const { return at(1)->node(); }
560   void ker(Node *node) { at(1)->node(node); }
561
562 public:
563   const Padding2D *pad(void) const { return &_pad; }
564   Padding2D *pad(void) { return &_pad; }
565
566 public:
567   const Stride<2> *stride(void) const { return &_stride; }
568   Stride<2> *stride(void) { return &_stride; }
569
570 private:
571   Padding2D _pad;
572   Stride<2> _stride;
573
574   // TODO Support "Dilation"
575 };
576
577 /**
578  * @brief Depthwise 2D Convolution
579  */
580 class DepthwiseConv2D final
581     : public CanonicalNodeDef<CanonicalOpcode::DepthwiseConv2D, FixedArity<2>::Mixin>
582 {
583 public:
584   Node *ifm(void) const { return at(0)->node(); }
585   void ifm(Node *node) { at(0)->node(node); }
586
587   Node *ker(void) const { return at(1)->node(); }
588   void ker(Node *node) { at(1)->node(node); }
589
590 public:
591   const Padding2D *pad(void) const { return &_pad; }
592   Padding2D *pad(void) { return &_pad; }
593
594 public:
595   const Stride<2> *stride(void) const { return &_stride; }
596   Stride<2> *stride(void) { return &_stride; }
597
598 private:
599   Padding2D _pad;
600   Stride<2> _stride;
601
602   // TODO Support "Dilation"
603 };
604
605 /**
606  * @brief Reduce type functions
607  */
608 enum class ReduceFunc
609 {
610   Mean, // ReduceMean
611   // TODO Support other reduce operations
612 };
613
614 /**
615  * @brief Computes ReduceFunc operations for Tensor domain
616  * @note  All the reduce functions always keep dimensions
617  */
618 class TensorReduce final
619     : public CanonicalNodeDef<CanonicalOpcode::TensorReduce, FixedArity<1>::Mixin>
620 {
621 public:
622   Node *input(void) const { return at(0)->node(); }
623   void input(Node *node) { at(0)->node(node); }
624
625 public:
626   const TensorAxisSet *axes(void) const { return &_axes; }
627   TensorAxisSet *axes(void) { return &_axes; }
628
629 public:
630   ReduceFunc func(void) const { return _func; }
631   void func(ReduceFunc func) { _func = func; }
632
633 private:
634   TensorAxisSet _axes;
635   ReduceFunc _func{ReduceFunc::Mean};
636 };
637
638 /**
639  * @brief 2D Transposed Convolution
640  *
641  * @note  TransposedConv2D have a few important conventions that IR users should
642  *        understand and follow, so please check below notice carefully.
643  *
644  *
645  * 1. What is 'input' and 'output'
646  *
647  * For loco canonical TransposedConv2D, 'input' and 'output' mean actual input
648  * and output node of TransposedConv2D node. Be careful that some other
649  * frameworks may use opposite sense, especially TensorFlow which is inspired by
650  * backpropagation of convolution.
651  * For example, loco::TransposedConv2D::ifm() means actual input feature map
652  * node that is sourced into TransposedConv2D.
653  *
654  * 2. How to read kernel representation
655  *
656  * TransposedConv2D::ker() should be a node of Filter domain. Following is what
657  * each FilterAxis means as a kernel of TransposedConv2D:
658  *   - FilterAxis::Height : kernel's height
659  *   - FilterAxis::Width  : kernel's width
660  *   - FilterAxis::Depth  : IFM's channel depth
661  *   - FilterAxis::Count  : OFM's channel depth
662  * TODO We may refactor FilterAxis as follow to reduce ambiguity:
663  *   - FilterAxis::Height -> FilterAxis::H
664  *   - FilterAxis::Width  -> FilterAxis::W
665  *   - FilterAxis::Depth  -> FilterAxis::I
666  *   - FilterAxis::Count  -> FilterAxis::O
667  *
668  *
669  * 3. Tight fit rule
670  *
671  * TransposedConv2D have no information about its output shape. Instead, it
672  * always satisfy following 'tight fit' rule for horizontal and vertical
673  * dimension:
674  *
675  *   O = S * ( I - 1 ) + F - P
676  *
677  *   where
678  *     O: output size
679  *     S: stride
680  *     I: input size
681  *     F: effective kernal(filter) size
682  *     P: whole pad size (= front + rear pad)
683  *
684  * With this, output shape is uniquely determined by all inputs and attributes.
685  */
686 class TransposedConv2D final
687     : public CanonicalNodeDef<CanonicalOpcode::TransposedConv2D, FixedArity<2>::Mixin>
688 {
689 public:
690   Node *ifm(void) const { return at(0)->node(); }
691   void ifm(Node *node) { at(0)->node(node); }
692
693   Node *ker(void) const { return at(1)->node(); }
694   void ker(Node *node) { at(1)->node(node); }
695
696 public:
697   const Padding2D *pad(void) const { return &_pad; }
698   Padding2D *pad(void) { return &_pad; }
699
700 public:
701   const Stride<2> *stride(void) const { return &_stride; }
702   Stride<2> *stride(void) { return &_stride; }
703
704 private:
705   Padding2D _pad;
706   Stride<2> _stride;
707
708   // TODO Support "Dilation"
709 };
710
711 /**
712  * @brief Computes softmax activations
713  */
714 template <Domain D> class Softmax;
715
716 /**
717 * @brief Computes softmax activations for Tensor domain
718 */
719 template <>
720 class Softmax<Domain::Tensor> final
721     : public CanonicalNodeDef<CanonicalOpcode::TensorSoftmax, FixedArity<1>::Mixin>
722 {
723 public:
724   Softmax() = default;
725
726 public:
727   Node *input(void) const { return at(0)->node(); }
728   void input(Node *node) { return at(0)->node(node); }
729
730   uint32_t axis(void) const { return _axis; }
731   void axis(uint32_t axis) { _axis = axis; }
732
733 private:
734   uint32_t _axis = 0;
735 };
736
737 using TensorSoftmax = Softmax<Domain::Tensor>;
738
739 /**
740  * @brief Create a "Tensor" from a "Bias"
741  */
742 class BiasDecode final : public CanonicalNodeDef<CanonicalOpcode::BiasDecode, FixedArity<1>::Mixin>
743 {
744 public:
745   BiasDecode() = default;
746
747 public:
748   Node *input(void) const { return at(0)->node(); }
749   void input(Node *node) { at(0)->node(node); }
750 };
751
752 /**
753  * @brief Create a "Bias" from a "Tensor"
754  *
755  * BiasEncode currently requires a rank-1 tensor as its input.
756  */
757 class BiasEncode final : public CanonicalNodeDef<CanonicalOpcode::BiasEncode, FixedArity<1>::Mixin>
758 {
759 public:
760   BiasEncode() = default;
761
762 public:
763   Node *input(void) const { return at(0)->node(); }
764   void input(Node *node) { at(0)->node(node); }
765 };
766
767 /**
768  * @brief Produce a value of domain D from an input value (of domain D) and a bias
769  */
770 template <Domain D> class BiasAdd;
771
772 /**
773  * @brief Add Tensor and Bias
774  *
775  * for each valid tensor index I
776  *   out(I) = value(I) + bias(I.at(axis))
777  */
778 template <>
779 class BiasAdd<Domain::Tensor> final
780     : public CanonicalNodeDef<CanonicalOpcode::TensorBiasAdd, FixedArity<2>::Mixin>
781 {
782 public:
783   BiasAdd() = default;
784
785 public:
786   Node *value(void) const { return at(0)->node(); }
787   void value(Node *node) { return at(0)->node(node); }
788
789   Node *bias(void) const { return at(1)->node(); }
790   void bias(Node *node) { return at(1)->node(node); }
791
792   uint32_t axis(void) const { return _axis; }
793   void axis(uint32_t axis) { _axis = axis; }
794
795 private:
796   uint32_t _axis = 0;
797 };
798
799 //
800 // Alias for external users
801 //
802 // loco::TensorBiasAdd
803 //        vs.
804 // loco::BiasAdd<loco::Domain::Tensor>
805 //
806 using TensorBiasAdd = BiasAdd<Domain::Tensor>;
807
808 /**
809  * @brief Add Feature and Bias along "depth" axis
810  *
811  * for each valid feature index (b, ch, row, col)
812  *   out(b, ch, row, col) = value(b, ch, row, col) + bias(ch)
813  */
814 template <>
815 class BiasAdd<Domain::Feature> final
816     : public CanonicalNodeDef<CanonicalOpcode::FeatureBiasAdd, FixedArity<2>::Mixin>
817 {
818 public:
819   BiasAdd() = default;
820
821 public:
822   Node *value(void) const { return at(0)->node(); }
823   void value(Node *node) { return at(0)->node(node); }
824
825   Node *bias(void) const { return at(1)->node(); }
826   void bias(Node *node) { return at(1)->node(node); }
827 };
828
829 using FeatureBiasAdd = BiasAdd<Domain::Feature>;
830
831 /**
832  * @brief Pads a tensor with constant value
833  *
834  * Pads a input tensor according to the padding with constant value.
835  *
836  * The dimension of each axis n of the output is
837  * output.dim(n) = padding.front(n) + input.dim(n) + padding.back(n)
838  *
839  * For example, input tensor of shape [1, 2] with
840  *
841  * padding.front(0) = 1;
842  * padding.back(0) = 2;
843  *
844  * padding.front(1) = 3;
845  * padding.back(1) = 4;
846  *
847  * will be a output tensor of shape
848  * [padding.front(0) + 1 + padding.back(0), padding.front(1) + 2 + padding.back(1)] = [4,9].
849  */
850 class TensorConstantPad final
851     : public CanonicalNodeDef<CanonicalOpcode::TensorConstantPad, FixedArity<2>::Mixin>
852 {
853 public:
854   Node *input(void) const { return at(0)->node(); }
855   void input(Node *node) { at(0)->node(node); }
856
857   Node *constant(void) const { return at(1)->node(); }
858   void constant(Node *node) { at(1)->node(node); }
859
860 public:
861   const PaddingND *padding(void) const { return &_padding; }
862   PaddingND *padding(void) { return &_padding; }
863
864 private:
865   PaddingND _padding;
866 };
867
868 /**
869  * @brief Elementwise Add lhs and rhs
870  */
871 class EltwiseAdd final : public CanonicalNodeDef<CanonicalOpcode::EltwiseAdd, FixedArity<2>::Mixin>
872 {
873 public:
874   EltwiseAdd() = default;
875
876 public:
877   Node *lhs(void) const { return at(0)->node(); }
878   void lhs(Node *node) { return at(0)->node(node); }
879
880   Node *rhs(void) const { return at(1)->node(); }
881   void rhs(Node *node) { return at(1)->node(node); }
882 };
883
884 /**
885  * @brief Elementwise Maximum of lhs and rhs
886  *
887  * o = (l > r) ? l : r (element-wise)
888  */
889 class EltwiseMax final : public CanonicalNodeDef<CanonicalOpcode::EltwiseMax, FixedArity<2>::Mixin>
890 {
891 public:
892   EltwiseMax() = default;
893
894 public:
895   Node *lhs(void) const { return at(0)->node(); }
896   void lhs(Node *node) { return at(0)->node(node); }
897
898   Node *rhs(void) const { return at(1)->node(); }
899   void rhs(Node *node) { return at(1)->node(node); }
900 };
901
902 /**
903  * @brief Elementwise Mul lhs and rhs
904  */
905 class EltwiseMul final : public CanonicalNodeDef<CanonicalOpcode::EltwiseMul, FixedArity<2>::Mixin>
906 {
907 public:
908   EltwiseMul() = default;
909
910 public:
911   Node *lhs(void) const { return at(0)->node(); }
912   void lhs(Node *node) { return at(0)->node(node); }
913
914   Node *rhs(void) const { return at(1)->node(); }
915   void rhs(Node *node) { return at(1)->node(node); }
916 };
917
918 /**
919  * @brief Elementwise Sub lhs and rhs
920  */
921 class EltwiseSub final : public CanonicalNodeDef<CanonicalOpcode::EltwiseSub, FixedArity<2>::Mixin>
922 {
923 public:
924   EltwiseSub() = default;
925
926 public:
927   Node *lhs(void) const { return at(0)->node(); }
928   void lhs(Node *node) { return at(0)->node(node); }
929
930   Node *rhs(void) const { return at(1)->node(); }
931   void rhs(Node *node) { return at(1)->node(node); }
932 };
933
934 /**
935  * @brief Elementwise Div lhs and rhs
936  */
937 class EltwiseDiv final : public CanonicalNodeDef<CanonicalOpcode::EltwiseDiv, FixedArity<2>::Mixin>
938 {
939 public:
940   EltwiseDiv() = default;
941
942 public:
943   Node *lhs(void) const { return at(0)->node(); }
944   void lhs(Node *node) { return at(0)->node(node); }
945
946   Node *rhs(void) const { return at(1)->node(); }
947   void rhs(Node *node) { return at(1)->node(node); }
948 };
949
950 /**
951  * @brief Elementwise Sqrt of input
952  */
953 class EltwiseSqrt final
954     : public CanonicalNodeDef<CanonicalOpcode::EltwiseSqrt, FixedArity<1>::Mixin>
955 {
956 public:
957   EltwiseSqrt() = default;
958
959 public:
960   Node *input(void) const { return at(0)->node(); }
961   void input(Node *node) { at(0)->node(node); }
962 };
963
964 /**
965  * @brief Duplicate elements along specified axes
966  *
967  * TensorBroadcast takes a tensor and produces another tensor with the same rank but HIGHER
968  * dimensionality.
969  *
970  * To create such a tensor. TensorBroadcast duplicates the element along the specified axes.
971  *
972  * It is possible to control the degree of duplication with a partial map from TensorAxis to
973  * Dimension.
974  *
975  * TODO Explain the constraints (The dimension of inputs for specified axes SHOULD BE 1).
976  * TODO Explain the operation semantics
977  */
978 class TensorBroadcast final
979     : public CanonicalNodeDef<CanonicalOpcode::TensorBroadcast, FixedArity<1>::Mixin>
980 {
981 public:
982   TensorBroadcast() = default;
983
984 public:
985   Node *input(void) const { return at(0)->node(); }
986   void input(Node *node) { at(0)->node(node); }
987
988 public:
989   class Mapping final
990   {
991   public:
992     Mapping() = default;
993
994   public:
995     bool defined(const TensorAxis &axis) const;
996
997     const Dimension &dim(const TensorAxis &axis) const;
998     Dimension &dim(const TensorAxis &axis);
999
1000   private:
1001     std::map<TensorAxis, Dimension> _content;
1002   };
1003
1004   Mapping *mapping(void) { return &_mapping; }
1005   const Mapping *mapping(void) const { return &_mapping; }
1006
1007 private:
1008   Mapping _mapping;
1009 };
1010
1011 /**
1012  * @brief Create Matrix from Tensor
1013  *
1014  * MatrixEncode currently requires a rank-2 Tensor as its input.
1015  */
1016 class MatrixEncode final
1017     : public CanonicalNodeDef<CanonicalOpcode::MatrixEncode, FixedArity<1>::Mixin>
1018 {
1019 public:
1020   MatrixEncode() = default;
1021
1022 public:
1023   Node *input(void) const { return at(0)->node(); }
1024   void input(Node *node) { at(0)->node(node); }
1025
1026 public:
1027   MatrixEncoder *encoder(void) const { return _enc.get(); }
1028   void encoder(std::unique_ptr<MatrixEncoder> &&enc) { _enc = std::move(enc); }
1029
1030 private:
1031   /// @note "encoder" is mandatory
1032   std::unique_ptr<MatrixEncoder> _enc{nullptr};
1033 };
1034
1035 /**
1036  * @brief Create Tensor from Matrix
1037  *
1038  * MatrixDecode currently requires a Matrix as its input.
1039  */
1040 class MatrixDecode final
1041     : public CanonicalNodeDef<CanonicalOpcode::MatrixDecode, FixedArity<1>::Mixin>
1042 {
1043 public:
1044   MatrixDecode() = default;
1045
1046 public:
1047   Node *input(void) const { return at(0)->node(); }
1048   void input(Node *node) { at(0)->node(node); }
1049
1050 public:
1051   MatrixDecoder *decoder(void) const { return _dec.get(); }
1052   void decoder(std::unique_ptr<MatrixDecoder> &&dec) { _dec = std::move(dec); }
1053
1054 private:
1055   /// @note "decoder" is mandatory
1056   std::unique_ptr<MatrixDecoder> _dec{nullptr};
1057 };
1058
1059 /**
1060  * @brief Matrix Multiplication lhs and rhs
1061  *
1062  * LHS and RHS must be on Matrix domain
1063  */
1064 class MatMul final : public CanonicalNodeDef<CanonicalOpcode::MatMul, FixedArity<2>::Mixin>
1065 {
1066 public:
1067   MatMul() = default;
1068
1069 public:
1070   Node *lhs(void) const { return at(0)->node(); }
1071   void lhs(Node *node) { return at(0)->node(node); }
1072
1073   Node *rhs(void) const { return at(1)->node(); }
1074   void rhs(Node *node) { return at(1)->node(node); }
1075 };
1076
1077 /**
1078  * @brief Permute an input
1079  *
1080  * In the following case,
1081  *
1082  *    output = loco::TensorTranspose(input)
1083  *
1084  * perm()->axis(output's axis) = input's axis
1085  *
1086  * Input and output belong to tensor domain.
1087  */
1088 class TensorTranspose final
1089     : public CanonicalNodeDef<CanonicalOpcode::TensorTranspose, FixedArity<1>::Mixin>
1090 {
1091 public:
1092   TensorTranspose() = default;
1093
1094 public:
1095   Node *input(void) const { return at(0)->node(); }
1096   void input(Node *node) { return at(0)->node(node); }
1097
1098   class Perm final
1099   {
1100   public:
1101     Perm() = default;
1102
1103   public:
1104     uint32_t size() const { return _vals.size(); }
1105     void size(uint32_t size) { _vals.resize(size); }
1106
1107     const TensorAxis &axis(TensorAxis n) const { return _vals[n]; }
1108     TensorAxis &axis(TensorAxis n) { return _vals[n]; }
1109
1110   private:
1111     std::vector<TensorAxis> _vals;
1112   };
1113
1114   Perm *perm(void) { return &_perm; }
1115   const Perm *perm(void) const { return &_perm; }
1116
1117 private:
1118   Perm _perm;
1119 };
1120
1121 } // namespace loco
1122
1123 #endif // __LOCO_IR_NODES_H__