CANONICAL_NODE(TensorReduce, TensorReduce)
CANONICAL_NODE(TensorSoftmax, Softmax<Domain::Tensor>)
CANONICAL_NODE(TransposedConv2D, TransposedConv2D)
+CANONICAL_NODE(MatrixEncode, MatrixEncode)
+CANONICAL_NODE(MatrixDecode, MatrixDecode)
+CANONICAL_NODE(MatMul, MatMul)
#include "loco/IR/FeatureCodec.h"
#include "loco/IR/FilterCodec.h"
#include "loco/IR/DepthwiseFilterCodec.h"
+#include "loco/IR/MatrixCodec.h"
#include "loco/IR/NodeMixins.h"
#include "loco/IR/CanonicalNodeDecl.h"
#include "loco/IR/GraphInputIndex.h"
Mapping _mapping;
};
+/**
+ * @brief Create Matrix from Tensor
+ *
+ * MatrixEncode currently requires a rank-2 Tensor as its input.
+ */
+class MatrixEncode final
+ : public CanonicalNodeDef<CanonicalOpcode::MatrixEncode, FixedArity<1>::Mixin>
+{
+public:
+ MatrixEncode() = default;
+
+public:
+ Node *input(void) const { return at(0)->node(); }
+ void input(Node *node) { at(0)->node(node); }
+
+public:
+ MatrixEncoder *encoder(void) const { return _enc.get(); }
+ void encoder(std::unique_ptr<MatrixEncoder> &&enc) { _enc = std::move(enc); }
+
+private:
+ /// @note "encoder" is mandatory
+ std::unique_ptr<MatrixEncoder> _enc{nullptr};
+};
+
+/**
+ * @brief Create Tensor from Matrix
+ *
+ * MatrixDecode currently requires a Matrix as its input.
+ */
+class MatrixDecode final
+ : public CanonicalNodeDef<CanonicalOpcode::MatrixDecode, FixedArity<1>::Mixin>
+{
+public:
+ MatrixDecode() = default;
+
+public:
+ Node *input(void) const { return at(0)->node(); }
+ void input(Node *node) { at(0)->node(node); }
+
+public:
+ MatrixDecoder *decoder(void) const { return _dec.get(); }
+ void decoder(std::unique_ptr<MatrixDecoder> &&dec) { _dec = std::move(dec); }
+
+private:
+ /// @note "decoder" is mandatory
+ std::unique_ptr<MatrixDecoder> _dec{nullptr};
+};
+
+/**
+ * @brief Matrix Multiplication lhs and rhs
+ *
+ * LHS and RHS must be on Matrix domain
+ */
+class MatMul final : public CanonicalNodeDef<CanonicalOpcode::MatMul, FixedArity<2>::Mixin>
+{
+public:
+ MatMul() = default;
+
+public:
+ Node *lhs(void) const { return at(0)->node(); }
+ void lhs(Node *node) { return at(0)->node(node); }
+
+ Node *rhs(void) const { return at(1)->node(); }
+ void rhs(Node *node) { return at(1)->node(node); }
+};
+
} // namespace loco
#endif // __LOCO_IR_NODES_H__
ASSERT_EQ(tensor_broadcast_node.mapping()->defined(0), true);
ASSERT_EQ(tensor_broadcast_node.mapping()->dim(0), 3);
}
+
+TEST(MatrixEncodeTest, constructor)
+{
+ loco::MatrixEncode matrix_encode;
+
+ ASSERT_EQ(matrix_encode.dialect(), loco::CanonicalDialect::get());
+ ASSERT_EQ(matrix_encode.opcode(), loco::CanonicalOpcode::MatrixEncode);
+
+ ASSERT_EQ(matrix_encode.input(), nullptr);
+}
+
+TEST(MatrixDecodeTest, constructor)
+{
+ loco::MatrixDecode matrix_decode;
+
+ ASSERT_EQ(matrix_decode.dialect(), loco::CanonicalDialect::get());
+ ASSERT_EQ(matrix_decode.opcode(), loco::CanonicalOpcode::MatrixDecode);
+
+ ASSERT_EQ(matrix_decode.input(), nullptr);
+}
+
+TEST(MatMulTest, constructor)
+{
+ loco::MatMul mat_mul;
+
+ ASSERT_EQ(mat_mul.dialect(), loco::CanonicalDialect::get());
+ ASSERT_EQ(mat_mul.opcode(), loco::CanonicalOpcode::MatMul);
+
+ ASSERT_EQ(mat_mul.lhs(), nullptr);
+ ASSERT_EQ(mat_mul.rhs(), nullptr);
+}
return loco::NodeShape{tensor_shape};
}
+ // CASE: MatMul
+ loco::NodeShape visit(const loco::MatMul *node) final
+ {
+ assert(shape_known(node->lhs()));
+ assert(shape_known(node->rhs()));
+ auto const lhs_shape = node_shape(node->lhs()).as<loco::MatrixShape>();
+ auto const rhs_shape = node_shape(node->rhs()).as<loco::MatrixShape>();
+
+ loco::MatrixShape out_shape;
+
+ // Checking shape capability for multiplication
+ assert(lhs_shape.width() == rhs_shape.height());
+
+ out_shape.height() = lhs_shape.height();
+ out_shape.width() = rhs_shape.width();
+
+ return out_shape;
+ }
+
+ // CASE: MatrixDecode
+ loco::NodeShape visit(const loco::MatrixDecode *node) final
+ {
+ auto input_node_shape = node_shape(node->input());
+ return loco::NodeShape{node->decoder()->shape(input_node_shape.as<loco::MatrixShape>())};
+ }
+
+ // CASE: MatrixEncode
+ loco::NodeShape visit(const loco::MatrixEncode *node) final
+ {
+ auto input_node_shape = node_shape(node->input());
+ return loco::NodeShape{node->encoder()->shape(input_node_shape.as<loco::TensorShape>())};
+ }
+
// CASE: MaxPool2D
loco::NodeShape visit(const loco::MaxPool2D *node) final
{
loco::DataType visit(const loco::FeatureEncode *node) { return loco::dtype_get(node->input()); }
loco::DataType visit(const loco::FilterEncode *node) { return loco::dtype_get(node->input()); }
loco::DataType visit(const loco::FixedReshape *node) { return loco::dtype_get(node->input()); }
+ loco::DataType visit(const loco::MatrixDecode *node) { return loco::dtype_get(node->input()); }
+ loco::DataType visit(const loco::MatrixEncode *node) { return loco::dtype_get(node->input()); }
+ loco::DataType visit(const loco::MatMul *node) { return loco::dtype_get(node->lhs()); }
loco::DataType visit(const loco::MaxPool2D *node) { return loco::dtype_get(node->ifm()); }
loco::DataType visit(const loco::Push *node) { return loco::dtype_get(node->from()); }
loco::DataType visit(const loco::Pull *node) { return node->dtype(); }