From: Павел Ильютченко/AI Tools Lab /SRR/Engineer/삼성전자
Date: Wed, 18 Sep 2019 07:34:44 +0000 (+0300)
Subject: [loco] Introduce MatrixEncode, MatrixDecode and MatrixMul operations (#7406)
X-Git-Tag: submit/tizen/20191205.083104~1192
X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=b593efd991bcf650e832f41f3dd7b9db481e0c55;p=platform%2Fcore%2Fml%2Fnnfw.git
[loco] Introduce MatrixEncode, MatrixDecode and MatrixMul operations (#7406)
* Added MatrixEncode, MatrixDecode and MatrixMul nodes
* Added tests for them
* Defined CanonicalShapeInferenceRule and TypeInference
Signed-off-by: Pavel Iliutchenko
---
diff --git a/compiler/loco/include/loco/IR/CanonicalNodes.lst b/compiler/loco/include/loco/IR/CanonicalNodes.lst
index 71fc8d0..540de70 100644
--- a/compiler/loco/include/loco/IR/CanonicalNodes.lst
+++ b/compiler/loco/include/loco/IR/CanonicalNodes.lst
@@ -39,3 +39,6 @@ CANONICAL_NODE(TensorBroadcast, TensorBroadcast)
CANONICAL_NODE(TensorReduce, TensorReduce)
CANONICAL_NODE(TensorSoftmax, Softmax)
CANONICAL_NODE(TransposedConv2D, TransposedConv2D)
+CANONICAL_NODE(MatrixEncode, MatrixEncode)
+CANONICAL_NODE(MatrixDecode, MatrixDecode)
+CANONICAL_NODE(MatMul, MatMul)
diff --git a/compiler/loco/include/loco/IR/Nodes.h b/compiler/loco/include/loco/IR/Nodes.h
index 6ffad19..a617386 100644
--- a/compiler/loco/include/loco/IR/Nodes.h
+++ b/compiler/loco/include/loco/IR/Nodes.h
@@ -31,6 +31,7 @@
#include "loco/IR/FeatureCodec.h"
#include "loco/IR/FilterCodec.h"
#include "loco/IR/DepthwiseFilterCodec.h"
+#include "loco/IR/MatrixCodec.h"
#include "loco/IR/NodeMixins.h"
#include "loco/IR/CanonicalNodeDecl.h"
#include "loco/IR/GraphInputIndex.h"
@@ -917,6 +918,72 @@ private:
Mapping _mapping;
};
+/**
+ * @brief Create Matrix from Tensor
+ *
+ * MatrixEncode currently requires a rank-2 Tensor as its input.
+ */
+class MatrixEncode final
+ : public CanonicalNodeDef::Mixin>
+{
+public:
+ MatrixEncode() = default;
+
+public:
+ Node *input(void) const { return at(0)->node(); }
+ void input(Node *node) { at(0)->node(node); }
+
+public:
+ MatrixEncoder *encoder(void) const { return _enc.get(); }
+ void encoder(std::unique_ptr &&enc) { _enc = std::move(enc); }
+
+private:
+ /// @note "encoder" is mandatory
+ std::unique_ptr _enc{nullptr};
+};
+
+/**
+ * @brief Create Tensor from Matrix
+ *
+ * MatrixDecode currently requires a Matrix as its input.
+ */
+class MatrixDecode final
+ : public CanonicalNodeDef::Mixin>
+{
+public:
+ MatrixDecode() = default;
+
+public:
+ Node *input(void) const { return at(0)->node(); }
+ void input(Node *node) { at(0)->node(node); }
+
+public:
+ MatrixDecoder *decoder(void) const { return _dec.get(); }
+ void decoder(std::unique_ptr &&dec) { _dec = std::move(dec); }
+
+private:
+ /// @note "decoder" is mandatory
+ std::unique_ptr _dec{nullptr};
+};
+
+/**
+ * @brief Matrix Multiplication lhs and rhs
+ *
+ * LHS and RHS must be on Matrix domain
+ */
+class MatMul final : public CanonicalNodeDef::Mixin>
+{
+public:
+ MatMul() = default;
+
+public:
+ Node *lhs(void) const { return at(0)->node(); }
+ void lhs(Node *node) { return at(0)->node(node); }
+
+ Node *rhs(void) const { return at(1)->node(); }
+ void rhs(Node *node) { return at(1)->node(node); }
+};
+
} // namespace loco
#endif // __LOCO_IR_NODES_H__
diff --git a/compiler/loco/src/IR/Nodes.test.cpp b/compiler/loco/src/IR/Nodes.test.cpp
index 2f3c8cb..35846d5 100644
--- a/compiler/loco/src/IR/Nodes.test.cpp
+++ b/compiler/loco/src/IR/Nodes.test.cpp
@@ -501,3 +501,34 @@ TEST(TensorBroadcastTest, mapping)
ASSERT_EQ(tensor_broadcast_node.mapping()->defined(0), true);
ASSERT_EQ(tensor_broadcast_node.mapping()->dim(0), 3);
}
+
+TEST(MatrixEncodeTest, constructor)
+{
+ loco::MatrixEncode matrix_encode;
+
+ ASSERT_EQ(matrix_encode.dialect(), loco::CanonicalDialect::get());
+ ASSERT_EQ(matrix_encode.opcode(), loco::CanonicalOpcode::MatrixEncode);
+
+ ASSERT_EQ(matrix_encode.input(), nullptr);
+}
+
+TEST(MatrixDecodeTest, constructor)
+{
+ loco::MatrixDecode matrix_decode;
+
+ ASSERT_EQ(matrix_decode.dialect(), loco::CanonicalDialect::get());
+ ASSERT_EQ(matrix_decode.opcode(), loco::CanonicalOpcode::MatrixDecode);
+
+ ASSERT_EQ(matrix_decode.input(), nullptr);
+}
+
+TEST(MatMulTest, constructor)
+{
+ loco::MatMul mat_mul;
+
+ ASSERT_EQ(mat_mul.dialect(), loco::CanonicalDialect::get());
+ ASSERT_EQ(mat_mul.opcode(), loco::CanonicalOpcode::MatMul);
+
+ ASSERT_EQ(mat_mul.lhs(), nullptr);
+ ASSERT_EQ(mat_mul.rhs(), nullptr);
+}
diff --git a/compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp b/compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp
index f0e4ec6..331150f 100644
--- a/compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp
+++ b/compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp
@@ -439,6 +439,39 @@ public:
return loco::NodeShape{tensor_shape};
}
+ // CASE: MatMul
+ loco::NodeShape visit(const loco::MatMul *node) final
+ {
+ assert(shape_known(node->lhs()));
+ assert(shape_known(node->rhs()));
+ auto const lhs_shape = node_shape(node->lhs()).as();
+ auto const rhs_shape = node_shape(node->rhs()).as();
+
+ loco::MatrixShape out_shape;
+
+ // Checking shape capability for multiplication
+ assert(lhs_shape.width() == rhs_shape.height());
+
+ out_shape.height() = lhs_shape.height();
+ out_shape.width() = rhs_shape.width();
+
+ return out_shape;
+ }
+
+ // CASE: MatrixDecode
+ loco::NodeShape visit(const loco::MatrixDecode *node) final
+ {
+ auto input_node_shape = node_shape(node->input());
+ return loco::NodeShape{node->decoder()->shape(input_node_shape.as())};
+ }
+
+ // CASE: MatrixEncode
+ loco::NodeShape visit(const loco::MatrixEncode *node) final
+ {
+ auto input_node_shape = node_shape(node->input());
+ return loco::NodeShape{node->encoder()->shape(input_node_shape.as())};
+ }
+
// CASE: MaxPool2D
loco::NodeShape visit(const loco::MaxPool2D *node) final
{
diff --git a/compiler/loco/src/Service/TypeInference.cpp b/compiler/loco/src/Service/TypeInference.cpp
index 53915a9..8827113 100644
--- a/compiler/loco/src/Service/TypeInference.cpp
+++ b/compiler/loco/src/Service/TypeInference.cpp
@@ -137,6 +137,9 @@ struct CanonicalTypeForwardAlgorithm final : public loco::CanonicalNodeVisitorinput()); }
loco::DataType visit(const loco::FilterEncode *node) { return loco::dtype_get(node->input()); }
loco::DataType visit(const loco::FixedReshape *node) { return loco::dtype_get(node->input()); }
+ loco::DataType visit(const loco::MatrixDecode *node) { return loco::dtype_get(node->input()); }
+ loco::DataType visit(const loco::MatrixEncode *node) { return loco::dtype_get(node->input()); }
+ loco::DataType visit(const loco::MatMul *node) { return loco::dtype_get(node->lhs()); }
loco::DataType visit(const loco::MaxPool2D *node) { return loco::dtype_get(node->ifm()); }
loco::DataType visit(const loco::Push *node) { return loco::dtype_get(node->from()); }
loco::DataType visit(const loco::Pull *node) { return node->dtype(); }