[loco] Introduce MatrixEncode, MatrixDecode and MatrixMul operations (#7406)
authorПавел Ильютченко/AI Tools Lab /SRR/Engineer/삼성전자 <p.iliutchenk@samsung.com>
Wed, 18 Sep 2019 07:34:44 +0000 (10:34 +0300)
committer박종현/On-Device Lab(SR)/Staff Engineer/삼성전자 <jh1302.park@samsung.com>
Wed, 18 Sep 2019 07:34:44 +0000 (16:34 +0900)
* Added MatrixEncode, MatrixDecode and MatrixMul nodes
* Added tests for them
* Defined CanonicalShapeInferenceRule and TypeInference

Signed-off-by: Pavel Iliutchenko <p.iliutchenk@samsung.com>
compiler/loco/include/loco/IR/CanonicalNodes.lst
compiler/loco/include/loco/IR/Nodes.h
compiler/loco/src/IR/Nodes.test.cpp
compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp
compiler/loco/src/Service/TypeInference.cpp

index 71fc8d0..540de70 100644 (file)
@@ -39,3 +39,6 @@ CANONICAL_NODE(TensorBroadcast, TensorBroadcast)
 CANONICAL_NODE(TensorReduce, TensorReduce)
 CANONICAL_NODE(TensorSoftmax, Softmax<Domain::Tensor>)
 CANONICAL_NODE(TransposedConv2D, TransposedConv2D)
+CANONICAL_NODE(MatrixEncode, MatrixEncode)
+CANONICAL_NODE(MatrixDecode, MatrixDecode)
+CANONICAL_NODE(MatMul, MatMul)
index 6ffad19..a617386 100644 (file)
@@ -31,6 +31,7 @@
 #include "loco/IR/FeatureCodec.h"
 #include "loco/IR/FilterCodec.h"
 #include "loco/IR/DepthwiseFilterCodec.h"
+#include "loco/IR/MatrixCodec.h"
 #include "loco/IR/NodeMixins.h"
 #include "loco/IR/CanonicalNodeDecl.h"
 #include "loco/IR/GraphInputIndex.h"
@@ -917,6 +918,72 @@ private:
   Mapping _mapping;
 };
 
+/**
+ * @brief Create Matrix from Tensor
+ *
+ * MatrixEncode currently requires a rank-2 Tensor as its input.
+ */
+class MatrixEncode final
+    : public CanonicalNodeDef<CanonicalOpcode::MatrixEncode, FixedArity<1>::Mixin>
+{
+public:
+  MatrixEncode() = default;
+
+public:
+  Node *input(void) const { return at(0)->node(); }
+  void input(Node *node) { at(0)->node(node); }
+
+public:
+  MatrixEncoder *encoder(void) const { return _enc.get(); }
+  void encoder(std::unique_ptr<MatrixEncoder> &&enc) { _enc = std::move(enc); }
+
+private:
+  /// @note "encoder" is mandatory
+  std::unique_ptr<MatrixEncoder> _enc{nullptr};
+};
+
+/**
+ * @brief Create Tensor from Matrix
+ *
+ * MatrixDecode currently requires a Matrix as its input.
+ */
+class MatrixDecode final
+    : public CanonicalNodeDef<CanonicalOpcode::MatrixDecode, FixedArity<1>::Mixin>
+{
+public:
+  MatrixDecode() = default;
+
+public:
+  Node *input(void) const { return at(0)->node(); }
+  void input(Node *node) { at(0)->node(node); }
+
+public:
+  MatrixDecoder *decoder(void) const { return _dec.get(); }
+  void decoder(std::unique_ptr<MatrixDecoder> &&dec) { _dec = std::move(dec); }
+
+private:
+  /// @note "decoder" is mandatory
+  std::unique_ptr<MatrixDecoder> _dec{nullptr};
+};
+
+/**
+ * @brief Matrix Multiplication lhs and rhs
+ *
+ * LHS and RHS must be on Matrix domain
+ */
+class MatMul final : public CanonicalNodeDef<CanonicalOpcode::MatMul, FixedArity<2>::Mixin>
+{
+public:
+  MatMul() = default;
+
+public:
+  Node *lhs(void) const { return at(0)->node(); }
+  void lhs(Node *node) { return at(0)->node(node); }
+
+  Node *rhs(void) const { return at(1)->node(); }
+  void rhs(Node *node) { return at(1)->node(node); }
+};
+
 } // namespace loco
 
 #endif // __LOCO_IR_NODES_H__
index 2f3c8cb..35846d5 100644 (file)
@@ -501,3 +501,34 @@ TEST(TensorBroadcastTest, mapping)
   ASSERT_EQ(tensor_broadcast_node.mapping()->defined(0), true);
   ASSERT_EQ(tensor_broadcast_node.mapping()->dim(0), 3);
 }
+
+TEST(MatrixEncodeTest, constructor)
+{
+  loco::MatrixEncode matrix_encode;
+
+  ASSERT_EQ(matrix_encode.dialect(), loco::CanonicalDialect::get());
+  ASSERT_EQ(matrix_encode.opcode(), loco::CanonicalOpcode::MatrixEncode);
+
+  ASSERT_EQ(matrix_encode.input(), nullptr);
+}
+
+TEST(MatrixDecodeTest, constructor)
+{
+  loco::MatrixDecode matrix_decode;
+
+  ASSERT_EQ(matrix_decode.dialect(), loco::CanonicalDialect::get());
+  ASSERT_EQ(matrix_decode.opcode(), loco::CanonicalOpcode::MatrixDecode);
+
+  ASSERT_EQ(matrix_decode.input(), nullptr);
+}
+
+TEST(MatMulTest, constructor)
+{
+  loco::MatMul mat_mul;
+
+  ASSERT_EQ(mat_mul.dialect(), loco::CanonicalDialect::get());
+  ASSERT_EQ(mat_mul.opcode(), loco::CanonicalOpcode::MatMul);
+
+  ASSERT_EQ(mat_mul.lhs(), nullptr);
+  ASSERT_EQ(mat_mul.rhs(), nullptr);
+}
index f0e4ec6..331150f 100644 (file)
@@ -439,6 +439,39 @@ public:
     return loco::NodeShape{tensor_shape};
   }
 
+  // CASE: MatMul
+  loco::NodeShape visit(const loco::MatMul *node) final
+  {
+    assert(shape_known(node->lhs()));
+    assert(shape_known(node->rhs()));
+    auto const lhs_shape = node_shape(node->lhs()).as<loco::MatrixShape>();
+    auto const rhs_shape = node_shape(node->rhs()).as<loco::MatrixShape>();
+
+    loco::MatrixShape out_shape;
+
+    // Checking shape capability for multiplication
+    assert(lhs_shape.width() == rhs_shape.height());
+
+    out_shape.height() = lhs_shape.height();
+    out_shape.width() = rhs_shape.width();
+
+    return out_shape;
+  }
+
+  // CASE: MatrixDecode
+  loco::NodeShape visit(const loco::MatrixDecode *node) final
+  {
+    auto input_node_shape = node_shape(node->input());
+    return loco::NodeShape{node->decoder()->shape(input_node_shape.as<loco::MatrixShape>())};
+  }
+
+  // CASE: MatrixEncode
+  loco::NodeShape visit(const loco::MatrixEncode *node) final
+  {
+    auto input_node_shape = node_shape(node->input());
+    return loco::NodeShape{node->encoder()->shape(input_node_shape.as<loco::TensorShape>())};
+  }
+
   // CASE: MaxPool2D
   loco::NodeShape visit(const loco::MaxPool2D *node) final
   {
index 53915a9..8827113 100644 (file)
@@ -137,6 +137,9 @@ struct CanonicalTypeForwardAlgorithm final : public loco::CanonicalNodeVisitor<l
   loco::DataType visit(const loco::FeatureEncode *node) { return loco::dtype_get(node->input()); }
   loco::DataType visit(const loco::FilterEncode *node) { return loco::dtype_get(node->input()); }
   loco::DataType visit(const loco::FixedReshape *node) { return loco::dtype_get(node->input()); }
+  loco::DataType visit(const loco::MatrixDecode *node) { return loco::dtype_get(node->input()); }
+  loco::DataType visit(const loco::MatrixEncode *node) { return loco::dtype_get(node->input()); }
+  loco::DataType visit(const loco::MatMul *node) { return loco::dtype_get(node->lhs()); }
   loco::DataType visit(const loco::MaxPool2D *node) { return loco::dtype_get(node->ifm()); }
   loco::DataType visit(const loco::Push *node) { return loco::dtype_get(node->from()); }
   loco::DataType visit(const loco::Pull *node) { return node->dtype(); }