From: Павел Ильютченко/AI Tools Lab /SRR/Engineer/삼성전자 Date: Wed, 18 Sep 2019 07:34:44 +0000 (+0300) Subject: [loco] Introduce MatrixEncode, MatrixDecode and MatrixMul operations (#7406) X-Git-Tag: submit/tizen/20191205.083104~1192 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=b593efd991bcf650e832f41f3dd7b9db481e0c55;p=platform%2Fcore%2Fml%2Fnnfw.git [loco] Introduce MatrixEncode, MatrixDecode and MatrixMul operations (#7406) * Added MatrixEncode, MatrixDecode and MatrixMul nodes * Added tests for them * Defined CanonicalShapeInferenceRule and TypeInference Signed-off-by: Pavel Iliutchenko --- diff --git a/compiler/loco/include/loco/IR/CanonicalNodes.lst b/compiler/loco/include/loco/IR/CanonicalNodes.lst index 71fc8d0..540de70 100644 --- a/compiler/loco/include/loco/IR/CanonicalNodes.lst +++ b/compiler/loco/include/loco/IR/CanonicalNodes.lst @@ -39,3 +39,6 @@ CANONICAL_NODE(TensorBroadcast, TensorBroadcast) CANONICAL_NODE(TensorReduce, TensorReduce) CANONICAL_NODE(TensorSoftmax, Softmax) CANONICAL_NODE(TransposedConv2D, TransposedConv2D) +CANONICAL_NODE(MatrixEncode, MatrixEncode) +CANONICAL_NODE(MatrixDecode, MatrixDecode) +CANONICAL_NODE(MatMul, MatMul) diff --git a/compiler/loco/include/loco/IR/Nodes.h b/compiler/loco/include/loco/IR/Nodes.h index 6ffad19..a617386 100644 --- a/compiler/loco/include/loco/IR/Nodes.h +++ b/compiler/loco/include/loco/IR/Nodes.h @@ -31,6 +31,7 @@ #include "loco/IR/FeatureCodec.h" #include "loco/IR/FilterCodec.h" #include "loco/IR/DepthwiseFilterCodec.h" +#include "loco/IR/MatrixCodec.h" #include "loco/IR/NodeMixins.h" #include "loco/IR/CanonicalNodeDecl.h" #include "loco/IR/GraphInputIndex.h" @@ -917,6 +918,72 @@ private: Mapping _mapping; }; +/** + * @brief Create Matrix from Tensor + * + * MatrixEncode currently requires a rank-2 Tensor as its input. + */ +class MatrixEncode final + : public CanonicalNodeDef::Mixin> +{ +public: + MatrixEncode() = default; + +public: + Node *input(void) const { return at(0)->node(); } + void input(Node *node) { at(0)->node(node); } + +public: + MatrixEncoder *encoder(void) const { return _enc.get(); } + void encoder(std::unique_ptr &&enc) { _enc = std::move(enc); } + +private: + /// @note "encoder" is mandatory + std::unique_ptr _enc{nullptr}; +}; + +/** + * @brief Create Tensor from Matrix + * + * MatrixDecode currently requires a Matrix as its input. + */ +class MatrixDecode final + : public CanonicalNodeDef::Mixin> +{ +public: + MatrixDecode() = default; + +public: + Node *input(void) const { return at(0)->node(); } + void input(Node *node) { at(0)->node(node); } + +public: + MatrixDecoder *decoder(void) const { return _dec.get(); } + void decoder(std::unique_ptr &&dec) { _dec = std::move(dec); } + +private: + /// @note "decoder" is mandatory + std::unique_ptr _dec{nullptr}; +}; + +/** + * @brief Matrix Multiplication lhs and rhs + * + * LHS and RHS must be on Matrix domain + */ +class MatMul final : public CanonicalNodeDef::Mixin> +{ +public: + MatMul() = default; + +public: + Node *lhs(void) const { return at(0)->node(); } + void lhs(Node *node) { return at(0)->node(node); } + + Node *rhs(void) const { return at(1)->node(); } + void rhs(Node *node) { return at(1)->node(node); } +}; + } // namespace loco #endif // __LOCO_IR_NODES_H__ diff --git a/compiler/loco/src/IR/Nodes.test.cpp b/compiler/loco/src/IR/Nodes.test.cpp index 2f3c8cb..35846d5 100644 --- a/compiler/loco/src/IR/Nodes.test.cpp +++ b/compiler/loco/src/IR/Nodes.test.cpp @@ -501,3 +501,34 @@ TEST(TensorBroadcastTest, mapping) ASSERT_EQ(tensor_broadcast_node.mapping()->defined(0), true); ASSERT_EQ(tensor_broadcast_node.mapping()->dim(0), 3); } + +TEST(MatrixEncodeTest, constructor) +{ + loco::MatrixEncode matrix_encode; + + ASSERT_EQ(matrix_encode.dialect(), loco::CanonicalDialect::get()); + ASSERT_EQ(matrix_encode.opcode(), loco::CanonicalOpcode::MatrixEncode); + + ASSERT_EQ(matrix_encode.input(), nullptr); +} + +TEST(MatrixDecodeTest, constructor) +{ + loco::MatrixDecode matrix_decode; + + ASSERT_EQ(matrix_decode.dialect(), loco::CanonicalDialect::get()); + ASSERT_EQ(matrix_decode.opcode(), loco::CanonicalOpcode::MatrixDecode); + + ASSERT_EQ(matrix_decode.input(), nullptr); +} + +TEST(MatMulTest, constructor) +{ + loco::MatMul mat_mul; + + ASSERT_EQ(mat_mul.dialect(), loco::CanonicalDialect::get()); + ASSERT_EQ(mat_mul.opcode(), loco::CanonicalOpcode::MatMul); + + ASSERT_EQ(mat_mul.lhs(), nullptr); + ASSERT_EQ(mat_mul.rhs(), nullptr); +} diff --git a/compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp b/compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp index f0e4ec6..331150f 100644 --- a/compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp +++ b/compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp @@ -439,6 +439,39 @@ public: return loco::NodeShape{tensor_shape}; } + // CASE: MatMul + loco::NodeShape visit(const loco::MatMul *node) final + { + assert(shape_known(node->lhs())); + assert(shape_known(node->rhs())); + auto const lhs_shape = node_shape(node->lhs()).as(); + auto const rhs_shape = node_shape(node->rhs()).as(); + + loco::MatrixShape out_shape; + + // Checking shape capability for multiplication + assert(lhs_shape.width() == rhs_shape.height()); + + out_shape.height() = lhs_shape.height(); + out_shape.width() = rhs_shape.width(); + + return out_shape; + } + + // CASE: MatrixDecode + loco::NodeShape visit(const loco::MatrixDecode *node) final + { + auto input_node_shape = node_shape(node->input()); + return loco::NodeShape{node->decoder()->shape(input_node_shape.as())}; + } + + // CASE: MatrixEncode + loco::NodeShape visit(const loco::MatrixEncode *node) final + { + auto input_node_shape = node_shape(node->input()); + return loco::NodeShape{node->encoder()->shape(input_node_shape.as())}; + } + // CASE: MaxPool2D loco::NodeShape visit(const loco::MaxPool2D *node) final { diff --git a/compiler/loco/src/Service/TypeInference.cpp b/compiler/loco/src/Service/TypeInference.cpp index 53915a9..8827113 100644 --- a/compiler/loco/src/Service/TypeInference.cpp +++ b/compiler/loco/src/Service/TypeInference.cpp @@ -137,6 +137,9 @@ struct CanonicalTypeForwardAlgorithm final : public loco::CanonicalNodeVisitorinput()); } loco::DataType visit(const loco::FilterEncode *node) { return loco::dtype_get(node->input()); } loco::DataType visit(const loco::FixedReshape *node) { return loco::dtype_get(node->input()); } + loco::DataType visit(const loco::MatrixDecode *node) { return loco::dtype_get(node->input()); } + loco::DataType visit(const loco::MatrixEncode *node) { return loco::dtype_get(node->input()); } + loco::DataType visit(const loco::MatMul *node) { return loco::dtype_get(node->lhs()); } loco::DataType visit(const loco::MaxPool2D *node) { return loco::dtype_get(node->ifm()); } loco::DataType visit(const loco::Push *node) { return loco::dtype_get(node->from()); } loco::DataType visit(const loco::Pull *node) { return node->dtype(); }