2 * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #ifndef __LOCO_IR_NODES_H__
18 #define __LOCO_IR_NODES_H__
20 #include "loco/IR/Node.h"
21 #include "loco/IR/Use.h"
22 #include "loco/IR/Domain.h"
23 #include "loco/IR/DataType.h"
24 #include "loco/IR/DataTypeTraits.h"
25 #include "loco/IR/Dimension.h"
26 #include "loco/IR/Window.h"
27 #include "loco/IR/Stride.h"
28 #include "loco/IR/Padding2D.h"
29 #include "loco/IR/PaddingND.h"
30 #include "loco/IR/TensorAxis.h"
31 #include "loco/IR/TensorAxisSet.h"
32 #include "loco/IR/FeatureCodec.h"
33 #include "loco/IR/FilterCodec.h"
34 #include "loco/IR/DepthwiseFilterCodec.h"
35 #include "loco/IR/MatrixCodec.h"
36 #include "loco/IR/NodeMixins.h"
37 #include "loco/IR/CanonicalNodeDecl.h"
38 #include "loco/IR/GraphInputIndex.h"
39 #include "loco/IR/GraphOutputIndex.h"
49 * @brief Make a value visible to user
51 class Push /* to user */ final
52 : public CanonicalNodeDef<CanonicalOpcode::Push, FixedArity<1>::Mixin>
58 Node *from(void) const { return at(0)->node(); }
59 void from(Node *node) { at(0)->node(node); }
62 void index(const GraphOutputIndex &index);
65 * @brief Get associated output index
67 * The behavior of this method is undefined when "index" is not set before.
69 * NOTE This method intentionally returns "GraphOutputIndex" instead of "const GraphOutputIndex &"
70 * not to expose the internal implementation details.
72 GraphOutputIndex index(void) const;
75 * @brief Check whether index is initialized
77 * NOTE "indexed" method does not validate whether index is in a valid range
79 bool indexed(void) const { return _index != -1; }
82 int64_t _index = -1; // Uninitialized
85 void link(GraphOutput *, Push *push);
87 /// @brief Find a Push node with a given output index
88 Push *push_node(Graph *g, const GraphOutputIndex &index);
91 * @brief Create a value from user data
93 class Pull /* from user */ final
94 : public CanonicalNodeDef<CanonicalOpcode::Pull, FixedArity<0>::Mixin,
95 With<NodeTrait::TensorShape>::Mixin>
101 void index(const GraphInputIndex &index);
104 * @brief Get associated input index
106 * The behavior of this method is undefined when "index" is not set before.
108 * NOTE This method intentionally returns "GraphInputIndex" instead of "const GraphInputIndex &"
109 * not to expose the internal implementation details.
111 GraphInputIndex index(void) const;
114 * @brief Check whether index is initialized
116 * NOTE "indexed" method does not validate whether index is in a valid range
118 bool indexed(void) const { return _index != -1; }
121 void dtype(const DataType &d);
122 DataType dtype(void) const;
125 int64_t _index = -1; // Uninitialized
128 * @brief Locally cached data type attribute
130 * TODO Remove this cache once all the clients are updated
132 DataType _dtype = DataType::Unknown;
135 void link(GraphInput *, Pull *pull);
137 /// @brief Find a Pull node with a given input index
138 Pull *pull_node(Graph *g, const GraphInputIndex &index);
141 * @brief Create a new value identical to its input
143 * This node may encode memory transfer (such as CPU -> GPU or GPU -> CPU)
145 class Forward final : public CanonicalNodeDef<CanonicalOpcode::Forward, FixedArity<1>::Mixin>
151 Node *input(void) const { return at(0)->node(); }
152 void input(Node *node) { at(0)->node(node); }
156 * @brief Create a new value that rectifies its input
158 class ReLU final : public CanonicalNodeDef<CanonicalOpcode::ReLU, FixedArity<1>::Mixin>
164 Node *input(void) const { return at(0)->node(); }
165 void input(Node *node) { at(0)->node(node); }
169 * @brief Create a new value that rectifies its input capping the units at 6.
171 class ReLU6 final : public CanonicalNodeDef<CanonicalOpcode::ReLU6, FixedArity<1>::Mixin>
177 Node *input(void) const { return at(0)->node(); }
178 void input(Node *node) { at(0)->node(node); }
182 * @brief Create a new value that rectifies its input by tanh
184 class Tanh final : public CanonicalNodeDef<CanonicalOpcode::Tanh, FixedArity<1>::Mixin>
190 Node *input(void) const { return at(0)->node(); }
191 void input(Node *node) { at(0)->node(node); }
195 * @brief Create a value from constant byte array
197 * @note ConstGen assumes "lexical memory layout".
199 * Let us assume that a 'ConstGen' generates a constant tensor of shape "S".
200 * for each valid index I, the corresponding value comes from offset(S, I)
201 * where the implementation of "offset" is given as follows:
203 * uint32_t stride(TensorShape shape, uint32_t axis) {
205 * for (uint32_t n = rank(shape) - 1; n > axis; --n) { res *= shape.dim(n); }
209 * uint32_t offset(TensorShape shape, TensorIndex index) {
211 * for (uint32_t n = 0; n < rank(shape); ++n) { res += index.at(n) * stride(shape, n); }
216 : public CanonicalNodeDef<CanonicalOpcode::ConstGen, FixedArity<0>::Mixin,
217 With<NodeTrait::DataType>::Mixin, With<NodeTrait::TensorShape>::Mixin>
220 ConstGen() = default;
224 * @brief Return the number of reserved elements
225 * @note This method returns the number of ELEMENT (not BYTE).
227 template <DataType DT> uint32_t size(void) const;
230 * @brief Adjust the number of reserved elements
232 template <DataType DT> void size(uint32_t size);
235 * @brief Get the element at a given position
236 * @require at(n) is valid only when n < size()
238 template <DataType DT> const typename DataTypeImpl<DT>::Type &at(uint32_t n) const;
241 * @brief Update the element at a given position
242 * @require at(n) is valid only when n < size()
244 template <DataType DT> typename DataTypeImpl<DT>::Type &at(uint32_t n);
248 std::vector<uint8_t> _data;
252 * @brief 2D Max Pooling
254 * MaxPool2D takes as input a feature map, and produces another feature map
257 * Any valid MaxPool2D nodes SHOULD satisfy the following conditions.
259 * Let us define several helper functions that takes a MaxPool2D nodes first:
260 * - IFM_DOMAIN returns the domain of its input
261 * - IFM_H returns the height of its input.
262 * - IFM_W returns the width of its input.
263 * - PAD_T returns the top padding required over its input
264 * - PAD_B returns the bottom padding required over its input
265 * - PAD_L returns the left padding required over its input
266 * - PAD_R returns the right padding required over its input
267 * - WIN_H returns the height of its receptive field.
268 * - WIN_W returns the width of its receptive field.
269 * - STRIDE_H returns the vertical(= on height) stride.
270 * - STRIDE_W returns the horizontal(= on width) stride.
275 * A valid MaxPool2D node M SHOULD satisfy the following condition:
276 * - IFM_DOMAIN(M) == Feature
280 * There are many possible ways to encode a feature map as a tensor.
281 * - e.g. NCHW/NHWC/...
283 * In order to give some freedom on memory layout to backend, loco requires a feature map
284 * value to be explicitly encoded via FeatureEncode.
289 * A valid MaxPool2D node M SHOULD satisfy the following conditions:
290 * - (IFM_H(M) + PAD_T(M) + PAD_B(M) - WIN_H(M)) % STRIDE_H(M) == 0
291 * - (IFM_W(M) + PAD_L(M) + PAD_R(M) - WIN_W(M)) % STRIDE_W(M) == 0
295 * The output shape may differ for each NN framework when these conditions do not hold.
297 * In order to mitigate such a difference among NN frameworks, loco requires these conditions
298 * for MaxPool2D nodes.
300 * This means that each frontend implementation SHOULD insert appropriate padding/trimming node
301 * before/after MaxPool2D node according to the semantics of the corresponding NN framework.
304 class MaxPool2D final : public CanonicalNodeDef<CanonicalOpcode::MaxPool2D, FixedArity<1>::Mixin>
307 Node *ifm(void) const { return at(0)->node(); }
308 void ifm(Node *node) { at(0)->node(node); }
311 const Padding2D *pad(void) const { return &_pad; }
312 Padding2D *pad(void) { return &_pad; }
315 const Window<2> *window(void) const { return &_window; }
316 Window<2> *window(void) { return &_window; }
319 const Stride<2> *stride(void) const { return &_stride; }
320 Stride<2> *stride(void) { return &_stride; }
332 * @brief 2D Average Pooling
334 * @note Follows MaxPool2D (TODO: describe difference)
336 class AvgPool2D final : public CanonicalNodeDef<CanonicalOpcode::AvgPool2D, FixedArity<1>::Mixin>
339 enum class Convention
342 // Use the number of elements in each receptive field as a divisor
344 // Use the number of valid (non-padding) elements in each receptive field as a divisor
349 Node *ifm(void) const { return at(0)->node(); }
350 void ifm(Node *node) { at(0)->node(node); }
353 Convention convention(void) const { return _convention; }
354 void convention(const Convention &convention) { _convention = convention; }
357 const Padding2D *pad(void) const { return &_pad; }
358 Padding2D *pad(void) { return &_pad; }
361 const Window<2> *window(void) const { return &_window; }
362 Window<2> *window(void) { return &_window; }
365 const Stride<2> *stride(void) const { return &_stride; }
366 Stride<2> *stride(void) { return &_stride; }
369 Convention _convention = Convention::Unknown;
376 * @brief Create a feature map from a tensor
378 class FeatureEncode final
379 : public CanonicalNodeDef<CanonicalOpcode::FeatureEncode, FixedArity<1>::Mixin>
382 Node *input(void) const { return at(0)->node(); }
383 void input(Node *node) { at(0)->node(node); }
386 FeatureEncoder *encoder(void) const { return _enc.get(); }
387 void encoder(std::unique_ptr<FeatureEncoder> &&enc) { _enc = std::move(enc); }
390 /// @note "encoder" is mandatory
391 std::unique_ptr<FeatureEncoder> _enc{nullptr};
395 * @brief Create a tensor from a feature map
397 class FeatureDecode final
398 : public CanonicalNodeDef<CanonicalOpcode::FeatureDecode, FixedArity<1>::Mixin>
401 Node *input(void) const { return at(0)->node(); }
402 void input(Node *node) { at(0)->node(node); }
405 FeatureDecoder *decoder(void) const { return _dec.get(); }
406 void decoder(std::unique_ptr<FeatureDecoder> &&dec) { _dec = std::move(dec); }
409 /// @NOTE "decoder" is mandatory
410 std::unique_ptr<FeatureDecoder> _dec{nullptr};
414 * @brief Create a filter from a tensor
416 class FilterEncode final
417 : public CanonicalNodeDef<CanonicalOpcode::FilterEncode, FixedArity<1>::Mixin>
420 Node *input(void) const { return at(0)->node(); }
421 void input(Node *node) { at(0)->node(node); }
424 FilterEncoder *encoder(void) const { return _enc.get(); }
425 void encoder(std::unique_ptr<FilterEncoder> &&enc) { _enc = std::move(enc); }
428 /// @note "encoder" is mandatory
429 std::unique_ptr<FilterEncoder> _enc{nullptr};
433 * @brief Create a tensor from a filter
435 class FilterDecode final
436 : public CanonicalNodeDef<CanonicalOpcode::FilterDecode, FixedArity<1>::Mixin>
439 Node *input(void) const { return at(0)->node(); }
440 void input(Node *node) { at(0)->node(node); }
443 FilterDecoder *decoder(void) const { return _dec.get(); }
444 void decoder(std::unique_ptr<FilterDecoder> &&dec) { _dec = std::move(dec); }
447 /// @note "decoder" is mandatory
448 std::unique_ptr<FilterDecoder> _dec{nullptr};
452 * @brief Create a depthwise filter from a tensor
454 class DepthwiseFilterEncode final
455 : public CanonicalNodeDef<CanonicalOpcode::DepthwiseFilterEncode, FixedArity<1>::Mixin>
458 Node *input(void) const { return at(0)->node(); }
459 void input(Node *node) { at(0)->node(node); }
462 DepthwiseFilterEncoder *encoder(void) const { return _enc.get(); }
463 void encoder(std::unique_ptr<DepthwiseFilterEncoder> &&enc) { _enc = std::move(enc); }
466 /// @note "encoder" is mandatory
467 std::unique_ptr<DepthwiseFilterEncoder> _enc{nullptr};
471 * @brief Create a tensor from a depthwise filter
473 class DepthwiseFilterDecode final
474 : public CanonicalNodeDef<CanonicalOpcode::DepthwiseFilterDecode, FixedArity<1>::Mixin>
477 Node *input(void) const { return at(0)->node(); }
478 void input(Node *node) { at(0)->node(node); }
481 DepthwiseFilterDecoder *decoder(void) const { return _dec.get(); }
482 void decoder(std::unique_ptr<DepthwiseFilterDecoder> &&dec) { _dec = std::move(dec); }
485 /// @note "decoder" is mandatory
486 std::unique_ptr<DepthwiseFilterDecoder> _dec{nullptr};
489 enum class ReshapeType
491 Fixed, // shape is known at compile time
492 // Add another type for a case when shape is not known at compile time
495 template <ReshapeType RT> class Reshape;
498 * @brief Reshape a tensor to another tensor whose shape is known at compile time
500 * @note This class reshapes the shape of an input tensor to _shape.
501 * Each dimension of _shape should be known at compile time.
502 * Any dimension of _shape should be greater than 0.
504 * Interpreter or runtime should lexicographically copy an input tensor into an output tensor.
505 * For example, values of an input tesor of shape [2, 2, 2, 2] will be copied into an output
506 * tensor of new shape [4, 4] like the following:
507 * input[0, 0, 0, 0] => output [0, 0]
508 * input[0, 0, 0, 1] => output [0, 1]
509 * input[0, 0, 1, 0] => output [0, 2]
511 * input[1, 1, 1, 1] => output [3, 3]
514 class Reshape<ReshapeType::Fixed> final
515 : public CanonicalNodeDef<CanonicalOpcode::FixedReshape, FixedArity<1>::Mixin,
516 With<NodeTrait::TensorShape>::Mixin>
519 Node *input(void) const { return at(0)->node(); }
520 void input(Node *node) { at(0)->node(node); }
523 using FixedReshape = Reshape<ReshapeType::Fixed>;
526 * @brief Concatenate two tensors
528 * Given an axis, TensorConcat takes as input two tensors and produces a tensor
529 * concatenated along the given axis.
531 class TensorConcat final
532 : public CanonicalNodeDef<CanonicalOpcode::TensorConcat, FixedArity<2>::Mixin>
535 Node *lhs(void) const { return at(0)->node(); }
536 void lhs(Node *node) { at(0)->node(node); }
538 Node *rhs(void) const { return at(1)->node(); }
539 void rhs(Node *node) { at(1)->node(node); }
542 uint32_t axis(void) const { return _axis; }
543 void axis(uint32_t val) { _axis = val; }
551 * @brief 2D Spatial Convolution
553 class Conv2D final : public CanonicalNodeDef<CanonicalOpcode::Conv2D, FixedArity<2>::Mixin>
556 Node *ifm(void) const { return at(0)->node(); }
557 void ifm(Node *node) { at(0)->node(node); }
559 Node *ker(void) const { return at(1)->node(); }
560 void ker(Node *node) { at(1)->node(node); }
563 const Padding2D *pad(void) const { return &_pad; }
564 Padding2D *pad(void) { return &_pad; }
567 const Stride<2> *stride(void) const { return &_stride; }
568 Stride<2> *stride(void) { return &_stride; }
574 // TODO Support "Dilation"
578 * @brief Depthwise 2D Convolution
580 class DepthwiseConv2D final
581 : public CanonicalNodeDef<CanonicalOpcode::DepthwiseConv2D, FixedArity<2>::Mixin>
584 Node *ifm(void) const { return at(0)->node(); }
585 void ifm(Node *node) { at(0)->node(node); }
587 Node *ker(void) const { return at(1)->node(); }
588 void ker(Node *node) { at(1)->node(node); }
591 const Padding2D *pad(void) const { return &_pad; }
592 Padding2D *pad(void) { return &_pad; }
595 const Stride<2> *stride(void) const { return &_stride; }
596 Stride<2> *stride(void) { return &_stride; }
602 // TODO Support "Dilation"
606 * @brief Reduce type functions
608 enum class ReduceFunc
611 // TODO Support other reduce operations
615 * @brief Computes ReduceFunc operations for Tensor domain
616 * @note All the reduce functions always keep dimensions
618 class TensorReduce final
619 : public CanonicalNodeDef<CanonicalOpcode::TensorReduce, FixedArity<1>::Mixin>
622 Node *input(void) const { return at(0)->node(); }
623 void input(Node *node) { at(0)->node(node); }
626 const TensorAxisSet *axes(void) const { return &_axes; }
627 TensorAxisSet *axes(void) { return &_axes; }
630 ReduceFunc func(void) const { return _func; }
631 void func(ReduceFunc func) { _func = func; }
635 ReduceFunc _func{ReduceFunc::Mean};
639 * @brief 2D Transposed Convolution
641 * @note TransposedConv2D have a few important conventions that IR users should
642 * understand and follow, so please check below notice carefully.
645 * 1. What is 'input' and 'output'
647 * For loco canonical TransposedConv2D, 'input' and 'output' mean actual input
648 * and output node of TransposedConv2D node. Be careful that some other
649 * frameworks may use opposite sense, especially TensorFlow which is inspired by
650 * backpropagation of convolution.
651 * For example, loco::TransposedConv2D::ifm() means actual input feature map
652 * node that is sourced into TransposedConv2D.
654 * 2. How to read kernel representation
656 * TransposedConv2D::ker() should be a node of Filter domain. Following is what
657 * each FilterAxis means as a kernel of TransposedConv2D:
658 * - FilterAxis::Height : kernel's height
659 * - FilterAxis::Width : kernel's width
660 * - FilterAxis::Depth : IFM's channel depth
661 * - FilterAxis::Count : OFM's channel depth
662 * TODO We may refactor FilterAxis as follow to reduce ambiguity:
663 * - FilterAxis::Height -> FilterAxis::H
664 * - FilterAxis::Width -> FilterAxis::W
665 * - FilterAxis::Depth -> FilterAxis::I
666 * - FilterAxis::Count -> FilterAxis::O
671 * TransposedConv2D have no information about its output shape. Instead, it
672 * always satisfy following 'tight fit' rule for horizontal and vertical
675 * O = S * ( I - 1 ) + F - P
681 * F: effective kernal(filter) size
682 * P: whole pad size (= front + rear pad)
684 * With this, output shape is uniquely determined by all inputs and attributes.
686 class TransposedConv2D final
687 : public CanonicalNodeDef<CanonicalOpcode::TransposedConv2D, FixedArity<2>::Mixin>
690 Node *ifm(void) const { return at(0)->node(); }
691 void ifm(Node *node) { at(0)->node(node); }
693 Node *ker(void) const { return at(1)->node(); }
694 void ker(Node *node) { at(1)->node(node); }
697 const Padding2D *pad(void) const { return &_pad; }
698 Padding2D *pad(void) { return &_pad; }
701 const Stride<2> *stride(void) const { return &_stride; }
702 Stride<2> *stride(void) { return &_stride; }
708 // TODO Support "Dilation"
712 * @brief Computes softmax activations
714 template <Domain D> class Softmax;
717 * @brief Computes softmax activations for Tensor domain
720 class Softmax<Domain::Tensor> final
721 : public CanonicalNodeDef<CanonicalOpcode::TensorSoftmax, FixedArity<1>::Mixin>
727 Node *input(void) const { return at(0)->node(); }
728 void input(Node *node) { return at(0)->node(node); }
730 uint32_t axis(void) const { return _axis; }
731 void axis(uint32_t axis) { _axis = axis; }
737 using TensorSoftmax = Softmax<Domain::Tensor>;
740 * @brief Create a "Tensor" from a "Bias"
742 class BiasDecode final : public CanonicalNodeDef<CanonicalOpcode::BiasDecode, FixedArity<1>::Mixin>
745 BiasDecode() = default;
748 Node *input(void) const { return at(0)->node(); }
749 void input(Node *node) { at(0)->node(node); }
753 * @brief Create a "Bias" from a "Tensor"
755 * BiasEncode currently requires a rank-1 tensor as its input.
757 class BiasEncode final : public CanonicalNodeDef<CanonicalOpcode::BiasEncode, FixedArity<1>::Mixin>
760 BiasEncode() = default;
763 Node *input(void) const { return at(0)->node(); }
764 void input(Node *node) { at(0)->node(node); }
768 * @brief Produce a value of domain D from an input value (of domain D) and a bias
770 template <Domain D> class BiasAdd;
773 * @brief Add Tensor and Bias
775 * for each valid tensor index I
776 * out(I) = value(I) + bias(I.at(axis))
779 class BiasAdd<Domain::Tensor> final
780 : public CanonicalNodeDef<CanonicalOpcode::TensorBiasAdd, FixedArity<2>::Mixin>
786 Node *value(void) const { return at(0)->node(); }
787 void value(Node *node) { return at(0)->node(node); }
789 Node *bias(void) const { return at(1)->node(); }
790 void bias(Node *node) { return at(1)->node(node); }
792 uint32_t axis(void) const { return _axis; }
793 void axis(uint32_t axis) { _axis = axis; }
800 // Alias for external users
802 // loco::TensorBiasAdd
804 // loco::BiasAdd<loco::Domain::Tensor>
806 using TensorBiasAdd = BiasAdd<Domain::Tensor>;
809 * @brief Add Feature and Bias along "depth" axis
811 * for each valid feature index (b, ch, row, col)
812 * out(b, ch, row, col) = value(b, ch, row, col) + bias(ch)
815 class BiasAdd<Domain::Feature> final
816 : public CanonicalNodeDef<CanonicalOpcode::FeatureBiasAdd, FixedArity<2>::Mixin>
822 Node *value(void) const { return at(0)->node(); }
823 void value(Node *node) { return at(0)->node(node); }
825 Node *bias(void) const { return at(1)->node(); }
826 void bias(Node *node) { return at(1)->node(node); }
829 using FeatureBiasAdd = BiasAdd<Domain::Feature>;
832 * @brief Pads a tensor with constant value
834 * Pads a input tensor according to the padding with constant value.
836 * The dimension of each axis n of the output is
837 * output.dim(n) = padding.front(n) + input.dim(n) + padding.back(n)
839 * For example, input tensor of shape [1, 2] with
841 * padding.front(0) = 1;
842 * padding.back(0) = 2;
844 * padding.front(1) = 3;
845 * padding.back(1) = 4;
847 * will be a output tensor of shape
848 * [padding.front(0) + 1 + padding.back(0), padding.front(1) + 2 + padding.back(1)] = [4,9].
850 class TensorConstantPad final
851 : public CanonicalNodeDef<CanonicalOpcode::TensorConstantPad, FixedArity<2>::Mixin>
854 Node *input(void) const { return at(0)->node(); }
855 void input(Node *node) { at(0)->node(node); }
857 Node *constant(void) const { return at(1)->node(); }
858 void constant(Node *node) { at(1)->node(node); }
861 const PaddingND *padding(void) const { return &_padding; }
862 PaddingND *padding(void) { return &_padding; }
869 * @brief Elementwise Add lhs and rhs
871 class EltwiseAdd final : public CanonicalNodeDef<CanonicalOpcode::EltwiseAdd, FixedArity<2>::Mixin>
874 EltwiseAdd() = default;
877 Node *lhs(void) const { return at(0)->node(); }
878 void lhs(Node *node) { return at(0)->node(node); }
880 Node *rhs(void) const { return at(1)->node(); }
881 void rhs(Node *node) { return at(1)->node(node); }
885 * @brief Elementwise Maximum of lhs and rhs
887 * o = (l > r) ? l : r (element-wise)
889 class EltwiseMax final : public CanonicalNodeDef<CanonicalOpcode::EltwiseMax, FixedArity<2>::Mixin>
892 EltwiseMax() = default;
895 Node *lhs(void) const { return at(0)->node(); }
896 void lhs(Node *node) { return at(0)->node(node); }
898 Node *rhs(void) const { return at(1)->node(); }
899 void rhs(Node *node) { return at(1)->node(node); }
903 * @brief Elementwise Mul lhs and rhs
905 class EltwiseMul final : public CanonicalNodeDef<CanonicalOpcode::EltwiseMul, FixedArity<2>::Mixin>
908 EltwiseMul() = default;
911 Node *lhs(void) const { return at(0)->node(); }
912 void lhs(Node *node) { return at(0)->node(node); }
914 Node *rhs(void) const { return at(1)->node(); }
915 void rhs(Node *node) { return at(1)->node(node); }
919 * @brief Elementwise Sub lhs and rhs
921 class EltwiseSub final : public CanonicalNodeDef<CanonicalOpcode::EltwiseSub, FixedArity<2>::Mixin>
924 EltwiseSub() = default;
927 Node *lhs(void) const { return at(0)->node(); }
928 void lhs(Node *node) { return at(0)->node(node); }
930 Node *rhs(void) const { return at(1)->node(); }
931 void rhs(Node *node) { return at(1)->node(node); }
935 * @brief Elementwise Div lhs and rhs
937 class EltwiseDiv final : public CanonicalNodeDef<CanonicalOpcode::EltwiseDiv, FixedArity<2>::Mixin>
940 EltwiseDiv() = default;
943 Node *lhs(void) const { return at(0)->node(); }
944 void lhs(Node *node) { return at(0)->node(node); }
946 Node *rhs(void) const { return at(1)->node(); }
947 void rhs(Node *node) { return at(1)->node(node); }
951 * @brief Elementwise Sqrt of input
953 class EltwiseSqrt final
954 : public CanonicalNodeDef<CanonicalOpcode::EltwiseSqrt, FixedArity<1>::Mixin>
957 EltwiseSqrt() = default;
960 Node *input(void) const { return at(0)->node(); }
961 void input(Node *node) { at(0)->node(node); }
965 * @brief Duplicate elements along specified axes
967 * TensorBroadcast takes a tensor and produces another tensor with the same rank but HIGHER
970 * To create such a tensor. TensorBroadcast duplicates the element along the specified axes.
972 * It is possible to control the degree of duplication with a partial map from TensorAxis to
975 * TODO Explain the constraints (The dimension of inputs for specified axes SHOULD BE 1).
976 * TODO Explain the operation semantics
978 class TensorBroadcast final
979 : public CanonicalNodeDef<CanonicalOpcode::TensorBroadcast, FixedArity<1>::Mixin>
982 TensorBroadcast() = default;
985 Node *input(void) const { return at(0)->node(); }
986 void input(Node *node) { at(0)->node(node); }
995 bool defined(const TensorAxis &axis) const;
997 const Dimension &dim(const TensorAxis &axis) const;
998 Dimension &dim(const TensorAxis &axis);
1001 std::map<TensorAxis, Dimension> _content;
1004 Mapping *mapping(void) { return &_mapping; }
1005 const Mapping *mapping(void) const { return &_mapping; }
1012 * @brief Create Matrix from Tensor
1014 * MatrixEncode currently requires a rank-2 Tensor as its input.
1016 class MatrixEncode final
1017 : public CanonicalNodeDef<CanonicalOpcode::MatrixEncode, FixedArity<1>::Mixin>
1020 MatrixEncode() = default;
1023 Node *input(void) const { return at(0)->node(); }
1024 void input(Node *node) { at(0)->node(node); }
1027 MatrixEncoder *encoder(void) const { return _enc.get(); }
1028 void encoder(std::unique_ptr<MatrixEncoder> &&enc) { _enc = std::move(enc); }
1031 /// @note "encoder" is mandatory
1032 std::unique_ptr<MatrixEncoder> _enc{nullptr};
1036 * @brief Create Tensor from Matrix
1038 * MatrixDecode currently requires a Matrix as its input.
1040 class MatrixDecode final
1041 : public CanonicalNodeDef<CanonicalOpcode::MatrixDecode, FixedArity<1>::Mixin>
1044 MatrixDecode() = default;
1047 Node *input(void) const { return at(0)->node(); }
1048 void input(Node *node) { at(0)->node(node); }
1051 MatrixDecoder *decoder(void) const { return _dec.get(); }
1052 void decoder(std::unique_ptr<MatrixDecoder> &&dec) { _dec = std::move(dec); }
1055 /// @note "decoder" is mandatory
1056 std::unique_ptr<MatrixDecoder> _dec{nullptr};
1060 * @brief Matrix Multiplication lhs and rhs
1062 * LHS and RHS must be on Matrix domain
1064 class MatMul final : public CanonicalNodeDef<CanonicalOpcode::MatMul, FixedArity<2>::Mixin>
1070 Node *lhs(void) const { return at(0)->node(); }
1071 void lhs(Node *node) { return at(0)->node(node); }
1073 Node *rhs(void) const { return at(1)->node(); }
1074 void rhs(Node *node) { return at(1)->node(node); }
1078 * @brief Permute an input
1080 * In the following case,
1082 * output = loco::TensorTranspose(input)
1084 * perm()->axis(output's axis) = input's axis
1086 * Input and output belong to tensor domain.
1088 class TensorTranspose final
1089 : public CanonicalNodeDef<CanonicalOpcode::TensorTranspose, FixedArity<1>::Mixin>
1092 TensorTranspose() = default;
1095 Node *input(void) const { return at(0)->node(); }
1096 void input(Node *node) { return at(0)->node(node); }
1104 uint32_t size() const { return _vals.size(); }
1105 void size(uint32_t size) { _vals.resize(size); }
1107 const TensorAxis &axis(TensorAxis n) const { return _vals[n]; }
1108 TensorAxis &axis(TensorAxis n) { return _vals[n]; }
1111 std::vector<TensorAxis> _vals;
1114 Perm *perm(void) { return &_perm; }
1115 const Perm *perm(void) const { return &_perm; }
1123 #endif // __LOCO_IR_NODES_H__