[nGraph] MatMul - Remove fused op and align output shape inference (#2866)
authorKatarzyna Mitrus <katarzyna.mitrus@intel.com>
Fri, 13 Nov 2020 12:15:22 +0000 (13:15 +0100)
committerGitHub <noreply@github.com>
Fri, 13 Nov 2020 12:15:22 +0000 (15:15 +0300)
To follow MatMul spec update for 1D tensors this PR removes FusedOp decomposition for MatMul without changing current MatMul output shape inference logic (numpy/onnx aligned).
Based on previous PR #2212 that follows rather current spec logic.

ngraph/core/include/ngraph/op/matmul.hpp
ngraph/core/src/op/matmul.cpp
ngraph/test/backend/matmul.in.cpp
ngraph/test/type_prop/matmul.cpp

index a920fab..e8a60f9 100644 (file)
@@ -18,9 +18,6 @@
 
 #include "ngraph/node.hpp"
 #include "ngraph/op/op.hpp"
-#include "ngraph/op/util/fused_op.hpp"
-
-NGRAPH_SUPPRESS_DEPRECATED_START
 
 namespace ngraph
 {
@@ -29,7 +26,7 @@ namespace ngraph
         namespace v0
         {
             /// \brief Operator performing Matrix Multiplication.
-            class NGRAPH_API MatMul : public ngraph::op::util::FusedOp
+            class NGRAPH_API MatMul : public Op
             {
             public:
                 NGRAPH_RTTI_DECLARATION;
@@ -46,9 +43,7 @@ namespace ngraph
                        const bool& transpose_b = 0);
 
                 bool visit_attributes(AttributeVisitor& visitor) override;
-                virtual void pre_validate_and_infer_types() override;
-
-                virtual OutputVector decompose_op() const override;
+                void validate_and_infer_types() override;
 
                 virtual std::shared_ptr<Node>
                     clone_with_new_inputs(const OutputVector& new_args) const override;
@@ -66,5 +61,3 @@ namespace ngraph
         using v0::MatMul;
     } // namespace op
 } // namespace ngraph
-
-NGRAPH_SUPPRESS_DEPRECATED_END
index 9c09cd3..eefda07 100644 (file)
 #include "itt.hpp"
 #include "matmul.hpp"
 #include "ngraph/attribute_visitor.hpp"
-#include "ngraph/builder/matmul_factory.hpp"
-#include "ngraph/builder/reshape.hpp"
 #include "ngraph/op/reshape.hpp"
 #include "ngraph/runtime/reference/matmul.hpp"
 
 using namespace std;
 using namespace ngraph;
 
-NGRAPH_SUPPRESS_DEPRECATED_START
-
 NGRAPH_RTTI_DEFINITION(op::MatMul, "MatMul", 0);
 
 op::MatMul::MatMul(const Output<Node>& A,
                    const Output<Node>& B,
                    const bool& transpose_a,
                    const bool& transpose_b)
-    : FusedOp(OutputVector{A, B})
+    : Op(OutputVector{A, B})
     , m_transpose_a{transpose_a}
     , m_transpose_b{transpose_b}
 {
@@ -49,62 +45,6 @@ bool ngraph::op::v0::MatMul::visit_attributes(AttributeVisitor& visitor)
     return true;
 }
 
-void op::MatMul::pre_validate_and_infer_types()
-{
-    element::Type result_et;
-    NODE_VALIDATION_CHECK(
-        this,
-        element::Type::merge(result_et, get_input_element_type(0), get_input_element_type(1)),
-        "Arguments do not have the same element type (arg0 element type: ",
-        get_input_element_type(0),
-        ", arg1 element type: ",
-        get_input_element_type(1),
-        ").");
-
-    const Rank& A_rank = get_input_partial_shape(0).rank();
-    const Rank& B_rank = get_input_partial_shape(1).rank();
-
-    if (A_rank.is_static() && B_rank.is_static())
-    {
-        Rank max_rank = A_rank.get_length() > B_rank.get_length() ? A_rank : B_rank;
-        set_output_type(0, result_et, PartialShape::dynamic(max_rank));
-    }
-    else
-    {
-        set_output_type(0, result_et, PartialShape::dynamic());
-    }
-}
-
-OutputVector op::MatMul::decompose_op() const
-{
-    auto A = input_value(0);
-    auto B = input_value(1);
-
-    const auto a_rank = A.get_shape().size();
-    const auto b_rank = B.get_shape().size();
-
-    if (m_transpose_a && a_rank >= 2)
-    {
-        vector<size_t> axes_order(a_rank);
-        // generate default axes_order.
-        iota(axes_order.begin(), axes_order.end(), 0);
-        // transpose the last 2 spatial dims
-        swap(axes_order[a_rank - 1], axes_order[a_rank - 2]);
-        A = builder::opset1::reorder_axes(A, axes_order);
-    }
-
-    if (m_transpose_b && b_rank >= 2)
-    {
-        vector<size_t> axes_order(b_rank);
-        iota(axes_order.begin(), axes_order.end(), 0);
-        swap(axes_order[b_rank - 1], axes_order[b_rank - 2]);
-        B = builder::opset1::reorder_axes(B, axes_order);
-    }
-
-    builder::MatmulFactory factory({A, B});
-    return factory.make_matmul_op();
-}
-
 shared_ptr<Node> op::MatMul::clone_with_new_inputs(const OutputVector& new_args) const
 {
     check_new_args_count(this, new_args);
@@ -113,68 +53,157 @@ shared_ptr<Node> op::MatMul::clone_with_new_inputs(const OutputVector& new_args)
 
 namespace matmul
 {
-    Shape evaluate_matmul_output_shape(const Shape& arg0_shape,
-                                       const Shape& arg1_shape,
-                                       bool transpose_a,
-                                       bool transpose_b)
+    PartialShape validate_matmul_output_shape(const PartialShape& arg0_shape,
+                                              const PartialShape& arg1_shape,
+                                              bool transpose_a,
+                                              bool transpose_b)
     {
-        Shape output_shape;
-        Shape arg0_shape_update = arg0_shape;
-        Shape arg1_shape_update = arg1_shape;
+        auto arg0_rank = arg0_shape.rank().get_length();
+        auto arg1_rank = arg1_shape.rank().get_length();
 
-        size_t arg0_rank = arg0_shape.size();
-        size_t arg1_rank = arg1_shape.size();
+        NGRAPH_CHECK((arg0_rank != 0 && arg1_rank != 0),
+                     "Scalars are not supported as MatMul inputs.");
 
+        // Temporary Dimension vectors to calculate output shape
+        std::vector<Dimension> arg0_shape_tmp(arg0_shape);
+        std::vector<Dimension> arg1_shape_tmp(arg1_shape);
+
+        // 1. Applying transpositions specified by optional `transpose_a` and `transpose_b`
+        // Only two right-most dimensions are swapped, other dimensions remain the same.
+        // Transpose attributes are ignored for 1D tensors.
         if (transpose_a && arg0_rank > 1)
         {
-            swap(arg0_shape_update[arg0_rank - 2], arg0_shape_update[arg0_rank - 1]);
+            swap(arg0_shape_tmp[arg0_rank - 2], arg0_shape_tmp[arg0_rank - 1]);
         }
         if (transpose_b && arg1_rank > 1)
         {
-            swap(arg1_shape_update[arg1_rank - 2], arg1_shape_update[arg1_rank - 1]);
+            swap(arg1_shape_tmp[arg1_rank - 2], arg1_shape_tmp[arg1_rank - 1]);
         }
 
-        if (arg0_rank == 1 && arg1_rank == 1)
+        // 2. One-dimensional tensors unsqueezing is applied to each input independently.
+        if (arg0_rank == 1)
         {
-            NGRAPH_CHECK(arg0_shape_update == arg1_shape_update, "Incompatible arg shapes");
-            output_shape = Shape{};
+            // If the first input is 1D tensor, it is unsqueezed to 2D tensor (row vector)
+            // by adding axes with size 1 at ROW_INDEX_DIM, to the left of the shape.
+            // For example {S} will be reshaped to {1, S}.
+            arg0_shape_tmp.insert(arg0_shape_tmp.begin(), 1);
+            arg0_rank = arg0_shape_tmp.size();
         }
-        else if (arg0_rank == 1)
+        if (arg1_rank == 1)
         {
-            // i.e., arg0 shape {3}, arg1 shape{2, 3, 2}, output shape {2, 2}
-            NGRAPH_CHECK(arg0_shape_update[0] == arg1_shape_update[arg1_rank - 2],
-                         "Incompatible arg shapes");
-            arg1_shape_update.erase(arg1_shape_update.begin() + arg1_rank - 2);
-            output_shape = arg1_shape_update;
+            // If the second input is 1D tensor, it is unsqueezed to 2D tensor (column vector)
+            // by adding axes with size 1 at COL_INDEX_DIM, to the right of the shape.
+            // For example {S} will be reshaped to {S, 1}.
+            arg1_shape_tmp.insert(arg1_shape_tmp.end(), 1);
+            arg1_rank = arg1_shape_tmp.size();
         }
-        else if (arg1_rank == 1)
+
+        // Check matrices dimensions compatibility,
+        // COL_INDEX_DIM of the first matrix has to match ROW_INDEX_DIM of the second matrix.
+        // Error is not thrown for dynamic dimensions bounds without intersection
+        // to ensure MatMul backward compatibility.
+        auto merged_dimension = Dimension::dynamic();
+        auto arg0_col_dim = arg0_shape_tmp[arg0_rank - 1];
+        auto arg1_row_dim = arg1_shape_tmp[arg1_rank - 2];
+        NGRAPH_CHECK(Dimension::merge(merged_dimension, arg0_col_dim, arg1_row_dim) ||
+                         arg0_col_dim.is_dynamic() || arg1_row_dim.is_dynamic(),
+                     "Incompatible MatMul matrix dimension. ",
+                     "First input dimension=",
+                     arg0_col_dim,
+                     " at COL_INDEX_DIM=",
+                     (arg0_rank - 1),
+                     " doesn't match the second input dimension=",
+                     arg1_row_dim,
+                     " at ROW_INDEX_DIM=",
+                     (arg1_rank - 2));
+
+        // 3. If ranks of input arguments are different after steps 1 and 2,
+        // the smaller tensor is unsqueezed from the left side of the shape
+        // by necessary number of axes to make both shapes of the same rank.
+        if (arg0_rank < arg1_rank)
+            arg0_shape_tmp.insert(arg0_shape_tmp.begin(), arg1_rank - arg0_rank, 1);
+        else if (arg0_rank > arg1_rank)
+            arg1_shape_tmp.insert(arg1_shape_tmp.begin(), arg0_rank - arg1_rank, 1);
+        // Both arg0_shape_tmp and arg1_shape_tmp have identical size now
+        auto max_rank = arg0_shape_tmp.size();
+        std::vector<Dimension> output_shape(max_rank);
+
+        // 4. Usual rules of the broadcasting are applied for batch dimensions.
+        // Broadcast all batches (last two dimensions represent matrix),
+        // expand dim with value 1 to bigger dim if dimensions are not equal.
+        for (auto i = 0; i < max_rank - 2; i++)
         {
-            // i.e., arg0 shape {2, 2, 3}, arg1 shape{3}, output shape {2, 2}
-            NGRAPH_CHECK(arg1_shape_update[0] == arg0_shape_update[arg0_rank - 1],
-                         "Incompatible arg shapes");
-            arg0_shape_update.erase(arg0_shape_update.begin() + arg0_rank - 1);
-            output_shape = arg0_shape_update;
+            auto min_dim_val =
+                std::min(arg0_shape_tmp[i].get_min_length(), arg1_shape_tmp[i].get_min_length());
+
+            // If both dimensions don't have 1 in range, usual merge is enough.
+            if (min_dim_val > 1)
+            {
+                // Error is not thrown for dynamic dimensions bounds without intersection
+                // to ensure MatMul backward compatibility.
+                // Instead fully dynamic dimension is set as default for such a case.
+                auto merged_dimension = Dimension::dynamic();
+                NGRAPH_CHECK(
+                    Dimension::merge(merged_dimension, arg0_shape_tmp[i], arg1_shape_tmp[i]) ||
+                        arg0_shape_tmp[i].is_dynamic() || arg1_shape_tmp[i].is_dynamic(),
+                    "Incompatible MatMul batch dimension. ",
+                    "Can't merge first input dimension=",
+                    arg0_shape_tmp[i],
+                    " with second input dimension=",
+                    arg1_shape_tmp[i],
+                    " at index=",
+                    i);
+
+                output_shape[i] = merged_dimension;
+            }
+            else
+            {
+                // Dimension with value 1 can be expanded to any bigger.
+                Dimension::value_type lower_bound; // The lowest possible value of output dimension
+                Dimension::value_type upper_bound; // The highest possible value of output dimension
+
+                // Output dimension lower_bound is a maximum of
+                // corresponding input dimensions lower bounds.
+                lower_bound = std::max(arg0_shape_tmp[i].get_min_length(),
+                                       arg1_shape_tmp[i].get_min_length());
+                if (lower_bound <= 1)
+                {
+                    // If both of the dimensions have 1 in range, output dimension upper_bound
+                    // is a maximum of corresponding input dimensions upper bounds.
+                    upper_bound = std::max(arg0_shape_tmp[i].get_interval().get_max_val(),
+                                           arg1_shape_tmp[i].get_interval().get_max_val());
+                }
+                else
+                {
+                    // Otherwise output dimension upper_bound is same as upper bound of
+                    // the dimension without 1 in range.
+                    upper_bound = arg0_shape_tmp[i].get_min_length() <= 1
+                                      ? arg1_shape_tmp[i].get_max_length()
+                                      : arg0_shape_tmp[i].get_max_length();
+                }
+                output_shape[i] = Dimension(lower_bound, upper_bound);
+            }
         }
-        else if (arg0_rank == 2 && arg1_rank == 2)
+
+        // In output_shape replace 2 last axes with ROW_INDEX_DIM from arg0 matrix
+        // and COL_INDEX_DIM from arg1 matrix.
+        output_shape.at(output_shape.size() - 2) = arg0_shape_tmp.at(arg0_shape_tmp.size() - 2);
+        output_shape.at(output_shape.size() - 1) = arg1_shape_tmp.at(arg1_shape_tmp.size() - 1);
+
+        // 5. Removing the temporary axes from originally 1D tensors.
+        // Output shape of two 1D tensors multiplication will be a 0D tensor (scalar).
+        if (arg0_shape.rank().get_length() == 1)
         {
-            NGRAPH_CHECK(arg0_shape_update[1] == arg1_shape_update[0], "Incompatible arg shapes");
-            output_shape = Shape{arg0_shape_update[0], arg1_shape_update[1]};
+            // arg0 input temporary axis inserted at ROW_INDEX_DIM is removed
+            output_shape.erase(output_shape.begin() + output_shape.size() - 2);
         }
-        else
+        if (arg1_shape.rank().get_length() == 1)
         {
-            NGRAPH_CHECK(arg0_shape_update[arg0_rank - 1] == arg1_shape_update[arg1_rank - 2],
-                         "Incompatible arg shapes");
-
-            const auto& broadcast_shapes = builder::get_numpy_broadcast_shapes(
-                {Shape{begin(arg0_shape_update), next(end(arg0_shape_update), -2)},
-                 Shape{begin(arg1_shape_update), next(end(arg1_shape_update), -2)}});
-
-            output_shape = broadcast_shapes.first;
-            output_shape.insert(output_shape.end(), arg0_shape_update[arg0_rank - 2]);
-            output_shape.insert(output_shape.end(), arg1_shape_update[arg1_rank - 1]);
+            // arg1 input temporary axis inserted at COL_INDEX_DIM is removed
+            output_shape.erase(output_shape.begin() + output_shape.size() - 1);
         }
 
-        return output_shape;
+        return PartialShape(output_shape);
     }
 
     template <element::Type_t ET>
@@ -189,8 +218,9 @@ namespace matmul
         Shape arg0_shape = arg0->get_shape();
         Shape arg1_shape = arg1->get_shape();
 
-        Shape output_shape =
-            evaluate_matmul_output_shape(arg0_shape, arg1_shape, transpose_a, transpose_b);
+        PartialShape output_partial_shape = validate_matmul_output_shape(
+            PartialShape(arg0_shape), PartialShape(arg1_shape), transpose_a, transpose_b);
+        Shape output_shape = output_partial_shape.to_shape();
         output->set_element_type(arg0->get_element_type());
         output->set_shape(output_shape);
 
@@ -231,7 +261,7 @@ namespace matmul
         }
         return rc;
     }
-}
+} // namespace
 
 bool op::MatMul::evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const
 {
@@ -239,3 +269,37 @@ bool op::MatMul::evaluate(const HostTensorVector& outputs, const HostTensorVecto
     return matmul::evaluate_matmul(
         inputs[0], inputs[1], outputs[0], get_transpose_a(), get_transpose_b());
 }
+
+void ngraph::op::v0::MatMul::validate_and_infer_types()
+{
+    element::Type result_et;
+
+    NODE_VALIDATION_CHECK(
+        this,
+        element::Type::merge(result_et, get_input_element_type(0), get_input_element_type(1)),
+        "Arguments do not have the same element type (arg0 element type: ",
+        get_input_element_type(0),
+        ", arg1 element type: ",
+        get_input_element_type(1),
+        ").");
+
+    const auto& A_partial_shape = get_input_partial_shape(0);
+    const auto& B_partial_shape = get_input_partial_shape(1);
+
+    if (A_partial_shape.rank().is_static() && B_partial_shape.rank().is_static())
+    {
+        PartialShape output_shape;
+
+        const bool transpose_a = get_transpose_a();
+        const bool transpose_b = get_transpose_b();
+
+        output_shape = matmul::validate_matmul_output_shape(
+            A_partial_shape, B_partial_shape, transpose_a, transpose_b);
+
+        set_output_type(0, result_et, output_shape);
+    }
+    else
+    {
+        set_output_type(0, result_et, PartialShape::dynamic());
+    }
+}
index cfb67bf..a134de1 100644 (file)
@@ -28,7 +28,9 @@
 #include "runtime/backend.hpp"
 #include "util/all_close.hpp"
 #include "util/all_close_f.hpp"
+#include "util/engine/test_engines.hpp"
 #include "util/ndarray.hpp"
+#include "util/test_case.hpp"
 #include "util/test_control.hpp"
 #include "util/test_tools.hpp"
 
@@ -36,6 +38,7 @@ using namespace std;
 using namespace ngraph;
 
 static string s_manifest = "${MANIFEST}";
+using TestEngine = test::ENGINE_CLASS_NAME(${BACKEND_NAME});
 
 NGRAPH_TEST(${BACKEND_NAME}, matmul_2x0_0x2)
 {
@@ -437,3 +440,109 @@ NGRAPH_TEST(${BACKEND_NAME}, matmul_1x2x3_1x4x3x2)
                                                 244.f,
                                                 256.f}));
 }
+
+// 2D x 1D
+NGRAPH_TEST(${BACKEND_NAME}, matmul_1_3_x_3_false_false_param)
+{
+    Shape shape_in1{1, 3};
+    Shape shape_in2{3};
+    Shape shape_out{1};
+
+    bool transpose_a = false;
+    bool transpose_b = false;
+
+    std::vector<float> inputs_a{1, 2, 3};
+    std::vector<float> inputs_b{1, 2, 3};
+    std::vector<float> expected_result{14.};
+
+    auto A = make_shared<op::Parameter>(element::f32, shape_in1);
+    auto B = make_shared<op::Parameter>(element::f32, shape_in2);
+    auto matmul = make_shared<op::MatMul>(A, B, transpose_a, transpose_b);
+    auto f = make_shared<Function>(matmul, ParameterVector{A, B});
+
+    auto test_case = test::TestCase<TestEngine>(f);
+    test_case.add_input<float>(inputs_a);
+    test_case.add_input<float>(inputs_b);
+
+    test_case.add_expected_output<float>(shape_out, expected_result);
+    test_case.run();
+}
+
+NGRAPH_TEST(${BACKEND_NAME}, matmul_3_1_x_3_true_false_param)
+{
+    Shape shape_in1{3, 1};
+    Shape shape_in2{3};
+    Shape shape_out{1};
+
+    bool transpose_a = true;
+    bool transpose_b = false;
+
+    std::vector<float> inputs_a{1, 2, 3};
+    std::vector<float> inputs_b{1, 2, 3};
+    std::vector<float> expected_result{14.};
+
+    auto A = make_shared<op::Parameter>(element::f32, shape_in1);
+    auto B = make_shared<op::Parameter>(element::f32, shape_in2);
+    auto matmul = make_shared<op::MatMul>(A, B, transpose_a, transpose_b);
+    auto f = make_shared<Function>(matmul, ParameterVector{A, B});
+
+    auto test_case = test::TestCase<TestEngine>(f);
+    test_case.add_input<float>(inputs_a);
+    test_case.add_input<float>(inputs_b);
+
+    test_case.add_expected_output<float>(shape_out, expected_result);
+    test_case.run();
+}
+
+// 1D x 2D
+NGRAPH_TEST(${BACKEND_NAME}, matmul_3_x_3_1_false_false_param)
+{
+    Shape shape_in1{3};
+    Shape shape_in2{3, 1};
+    Shape shape_out{1};
+
+    bool transpose_a = false;
+    bool transpose_b = false;
+
+    std::vector<float> inputs_a{1, 2, 3};
+    std::vector<float> inputs_b{1, 2, 3};
+    std::vector<float> expected_result{14.};
+
+    auto A = make_shared<op::Parameter>(element::f32, shape_in1);
+    auto B = make_shared<op::Parameter>(element::f32, shape_in2);
+    auto matmul = make_shared<op::MatMul>(A, B, transpose_a, transpose_b);
+    auto f = make_shared<Function>(matmul, ParameterVector{A, B});
+
+    auto test_case = test::TestCase<TestEngine>(f);
+    test_case.add_input<float>(inputs_a);
+    test_case.add_input<float>(inputs_b);
+
+    test_case.add_expected_output<float>(shape_out, expected_result);
+    test_case.run();
+}
+
+NGRAPH_TEST(${BACKEND_NAME}, matmul_3_x_1_3_false_true_param)
+{
+    Shape shape_in1{3};
+    Shape shape_in2{1, 3};
+    Shape shape_out{1};
+
+    bool transpose_a = false;
+    bool transpose_b = true;
+
+    std::vector<float> inputs_a{1, 2, 3};
+    std::vector<float> inputs_b{1, 2, 3};
+    std::vector<float> expected_result{14.};
+
+    auto A = make_shared<op::Parameter>(element::f32, shape_in1);
+    auto B = make_shared<op::Parameter>(element::f32, shape_in2);
+    auto matmul = make_shared<op::MatMul>(A, B, transpose_a, transpose_b);
+    auto f = make_shared<Function>(matmul, ParameterVector{A, B});
+
+    auto test_case = test::TestCase<TestEngine>(f);
+    test_case.add_input<float>(inputs_a);
+    test_case.add_input<float>(inputs_b);
+
+    test_case.add_expected_output<float>(shape_out, expected_result);
+    test_case.run();
+}
index 24a403b..eb452d7 100644 (file)
@@ -65,6 +65,17 @@ TEST(type_prop, matmul_4D)
     ASSERT_EQ(matmul->get_shape(), (Shape{2, 2, 3, 4}));
 }
 
+TEST(type_prop, matmul_5D_x_3D_transpose_a_transpose_b)
+{
+    auto A = make_shared<op::Parameter>(element::f32, Shape{2, 1, 6, 3});
+    auto B = make_shared<op::Parameter>(element::f32, Shape{7, 1, 5, 4, 6});
+
+    auto matmul = make_shared<op::MatMul>(A, B, true, true);
+
+    ASSERT_EQ(matmul->get_element_type(), element::f32);
+    ASSERT_EQ(matmul->get_shape(), (Shape{7, 2, 5, 3, 4}));
+}
+
 TEST(type_prop, matmul_2D_transpose_a)
 {
     auto A = make_shared<op::Parameter>(element::f32, Shape{6, 3});
@@ -108,3 +119,440 @@ TEST(type_prop, matmul_4D_transpose_b)
     ASSERT_EQ(matmul->get_element_type(), element::f32);
     ASSERT_EQ(matmul->get_shape(), (Shape{2, 2, 3, 4}));
 }
+
+TEST(type_prop, matmul_dynamic_5D_transpose_b)
+{
+    Dimension dynamic = Dimension::dynamic();
+    auto A =
+        make_shared<op::Parameter>(element::f32, PartialShape{dynamic, 4, dynamic, dynamic, 6});
+    auto B = make_shared<op::Parameter>(element::f32, PartialShape{1, dynamic, dynamic, 4, 6});
+
+    auto matmul = make_shared<op::MatMul>(A, B, 0, 1);
+
+    ASSERT_EQ(matmul->get_element_type(), element::f32);
+    ASSERT_EQ(matmul->get_output_partial_shape(0),
+              (PartialShape{Dimension(1, -1), 4, dynamic, dynamic, 4}));
+}
+
+TEST(type_prop, matmul_dynamic_2D_transpose_a)
+{
+    Dimension dynamic = Dimension::dynamic();
+    auto A = make_shared<op::Parameter>(element::f32, PartialShape{dynamic, 3});
+    auto B = make_shared<op::Parameter>(element::f32, PartialShape{4, dynamic});
+
+    auto matmul = make_shared<op::MatMul>(A, B, 1, 0);
+
+    ASSERT_EQ(matmul->get_element_type(), element::f32);
+    ASSERT_EQ(matmul->get_output_partial_shape(0), (PartialShape{3, dynamic}));
+}
+
+TEST(type_prop, matmul_dynamic_1D_3D)
+{
+    Dimension dynamic = Dimension::dynamic();
+    auto A = make_shared<op::Parameter>(element::f32, PartialShape{dynamic});
+    auto B = make_shared<op::Parameter>(element::f32, PartialShape{2, 4, dynamic});
+
+    auto matmul = make_shared<op::MatMul>(A, B);
+
+    ASSERT_EQ(matmul->get_element_type(), element::f32);
+    ASSERT_EQ(matmul->get_output_partial_shape(0), (PartialShape{2, dynamic}));
+}
+
+// Transpose attributes are ignored for 1D
+// 1D x 1D
+TEST(type_prop, matmul_1D_x_1D_false_false)
+{
+    auto A = make_shared<op::Parameter>(element::f32, Shape{1});
+    auto B = make_shared<op::Parameter>(element::f32, Shape{1});
+
+    auto matmul = make_shared<op::MatMul>(A, B, false, false);
+
+    ASSERT_EQ(matmul->get_element_type(), element::f32);
+    ASSERT_EQ(matmul->get_shape(), (Shape{}));
+}
+
+TEST(type_prop, matmul_1D_x_1D_false_true)
+{
+    auto A = make_shared<op::Parameter>(element::f32, Shape{1});
+    auto B = make_shared<op::Parameter>(element::f32, Shape{1});
+
+    auto matmul = make_shared<op::MatMul>(A, B, false, true);
+
+    ASSERT_EQ(matmul->get_element_type(), element::f32);
+    ASSERT_EQ(matmul->get_shape(), (Shape{}));
+}
+
+TEST(type_prop, matmul_1D_x_1D_true_false)
+{
+    auto A = make_shared<op::Parameter>(element::f32, Shape{1});
+    auto B = make_shared<op::Parameter>(element::f32, Shape{1});
+
+    auto matmul = make_shared<op::MatMul>(A, B, true, false);
+
+    ASSERT_EQ(matmul->get_element_type(), element::f32);
+    ASSERT_EQ(matmul->get_shape(), (Shape{}));
+}
+
+TEST(type_prop, matmul_1D_x_1D_true_true)
+{
+    auto A = make_shared<op::Parameter>(element::f32, Shape{1});
+    auto B = make_shared<op::Parameter>(element::f32, Shape{1});
+
+    auto matmul = make_shared<op::MatMul>(A, B, true, true);
+
+    ASSERT_EQ(matmul->get_element_type(), element::f32);
+    ASSERT_EQ(matmul->get_shape(), (Shape{}));
+}
+
+TEST(type_prop, matmul_1D_x_1D_incompatible)
+{
+    auto A = make_shared<op::Parameter>(element::f32, Shape{3});
+    auto B = make_shared<op::Parameter>(element::f32, Shape{4});
+
+    try
+    {
+        auto matmul = make_shared<op::MatMul>(A, B);
+        // Should have thrown, so fail if it didn't
+        FAIL() << "Incompatible matrix dimensions not detected. ";
+    }
+    catch (const ngraph_error& error)
+    {
+        EXPECT_HAS_SUBSTRING(error.what(), std::string("Incompatible MatMul matrix dimension"));
+    }
+    catch (...)
+    {
+        FAIL() << "MatMul shape validation failed for unexpected reason";
+    }
+}
+
+// 2D x 1D
+TEST(type_prop, matmul_2D_x_1D_false_false)
+{
+    auto A = make_shared<op::Parameter>(element::f32, Shape{1, 2});
+    auto B = make_shared<op::Parameter>(element::f32, Shape{2});
+
+    auto matmul = make_shared<op::MatMul>(A, B, false, false);
+
+    ASSERT_EQ(matmul->get_element_type(), element::f32);
+    ASSERT_EQ(matmul->get_shape(), (Shape{1}));
+}
+
+TEST(type_prop, matmul_2D_x_1D_false_true)
+{
+    auto A = make_shared<op::Parameter>(element::f32, Shape{1, 2});
+    auto B = make_shared<op::Parameter>(element::f32, Shape{2});
+
+    auto matmul = make_shared<op::MatMul>(A, B, false, true);
+
+    ASSERT_EQ(matmul->get_element_type(), element::f32);
+    ASSERT_EQ(matmul->get_shape(), (Shape{1}));
+}
+
+TEST(type_prop, matmul_2D_x_1D_true_false)
+{
+    auto A = make_shared<op::Parameter>(element::f32, Shape{1, 2});
+    auto B = make_shared<op::Parameter>(element::f32, Shape{2});
+
+    try
+    {
+        auto matmul = make_shared<op::MatMul>(A, B, true, false);
+        // Should have thrown, so fail if it didn't
+        FAIL() << "Incompatible matrix dimensions not detected. ";
+    }
+    catch (const ngraph_error& error)
+    {
+        EXPECT_HAS_SUBSTRING(error.what(), std::string("Incompatible MatMul matrix dimension"));
+    }
+    catch (...)
+    {
+        FAIL() << "MatMul shape validation failed for unexpected reason";
+    }
+}
+
+TEST(type_prop, matmul_2D_x_1D_true_true)
+{
+    auto A = make_shared<op::Parameter>(element::f32, Shape{1, 2});
+    auto B = make_shared<op::Parameter>(element::f32, Shape{2});
+
+    try
+    {
+        auto matmul = make_shared<op::MatMul>(A, B, true, true);
+        // Should have thrown, so fail if it didn't
+        FAIL() << "Incompatible matrix dimensions not detected. ";
+    }
+    catch (const ngraph_error& error)
+    {
+        EXPECT_HAS_SUBSTRING(error.what(), std::string("Incompatible MatMul matrix dimension"));
+    }
+    catch (...)
+    {
+        FAIL() << "MatMul shape validation failed for unexpected reason";
+    }
+}
+
+// 1D x 2D
+TEST(type_prop, matmul_1D_x_2D_false_false)
+{
+    auto A = make_shared<op::Parameter>(element::f32, Shape{2});
+    auto B = make_shared<op::Parameter>(element::f32, Shape{2, 1});
+
+    auto matmul = make_shared<op::MatMul>(A, B, false, false);
+
+    ASSERT_EQ(matmul->get_element_type(), element::f32);
+    ASSERT_EQ(matmul->get_shape(), (Shape{1}));
+}
+
+TEST(type_prop, matmul_1D_x_2D_false_true)
+{
+    auto A = make_shared<op::Parameter>(element::f32, Shape{2});
+    auto B = make_shared<op::Parameter>(element::f32, Shape{2, 1});
+
+    try
+    {
+        auto matmul = make_shared<op::MatMul>(A, B, false, true);
+        // Should have thrown, so fail if it didn't
+        FAIL() << "Incompatible matrix dimensions not detected. ";
+    }
+    catch (const ngraph_error& error)
+    {
+        EXPECT_HAS_SUBSTRING(error.what(), std::string("Incompatible MatMul matrix dimension"));
+    }
+    catch (...)
+    {
+        FAIL() << "MatMul shape validation failed for unexpected reason";
+    }
+}
+
+TEST(type_prop, matmul_1D_x_2D_true_false)
+{
+    auto A = make_shared<op::Parameter>(element::f32, Shape{2});
+    auto B = make_shared<op::Parameter>(element::f32, Shape{2, 1});
+    auto matmul = make_shared<op::MatMul>(A, B, true, false);
+
+    ASSERT_EQ(matmul->get_element_type(), element::f32);
+    ASSERT_EQ(matmul->get_shape(), (Shape{1}));
+}
+
+TEST(type_prop, matmul_1D_x_2D_true_true)
+{
+    auto A = make_shared<op::Parameter>(element::f32, Shape{2});
+    auto B = make_shared<op::Parameter>(element::f32, Shape{2, 1});
+
+    try
+    {
+        auto matmul = make_shared<op::MatMul>(A, B, true, true);
+        // Should have thrown, so fail if it didn't
+        FAIL() << "Incompatible matrix dimensions not detected. ";
+    }
+    catch (const ngraph_error& error)
+    {
+        EXPECT_HAS_SUBSTRING(error.what(), std::string("Incompatible MatMul matrix dimension"));
+    }
+    catch (...)
+    {
+        FAIL() << "MatMul shape validation failed for unexpected reason";
+    }
+}
+
+// 1D x 4D
+TEST(type_prop, matmul_1D_x_4D_false_false)
+{
+    auto A = make_shared<op::Parameter>(element::f32, Shape{3});
+    auto B = make_shared<op::Parameter>(element::f32, Shape{1, 2, 3, 4});
+
+    auto matmul = make_shared<op::MatMul>(A, B, false, false);
+
+    ASSERT_EQ(matmul->get_element_type(), element::f32);
+    ASSERT_EQ(matmul->get_shape(), (Shape{1, 2, 4}));
+}
+
+// 4D x 1D
+TEST(type_prop, matmul_4D_x_1D_false_false)
+{
+    auto A = make_shared<op::Parameter>(element::f32, Shape{1, 2, 3, 4});
+    auto B = make_shared<op::Parameter>(element::f32, Shape{4});
+
+    auto matmul = make_shared<op::MatMul>(A, B, false, false);
+
+    ASSERT_EQ(matmul->get_element_type(), element::f32);
+    ASSERT_EQ(matmul->get_shape(), (Shape{1, 2, 3}));
+}
+
+// Batch broadcast
+TEST(type_prop, matmul_batch_broadcast)
+{
+    auto A = make_shared<op::Parameter>(element::f32, Shape{5, 1, 1, 4, 3});
+    auto B = make_shared<op::Parameter>(element::f32, Shape{1, 1, 6, 3, 2});
+
+    auto matmul = make_shared<op::MatMul>(A, B, false, false);
+
+    ASSERT_EQ(matmul->get_element_type(), element::f32);
+    ASSERT_EQ(matmul->get_shape(), (Shape{5, 1, 6, 4, 2}));
+}
+
+TEST(type_prop, matmul_batch_broadcast_expand_to_A)
+{
+    auto A = make_shared<op::Parameter>(element::f32, Shape{1, 4, 3});
+    auto B = make_shared<op::Parameter>(element::f32, Shape{7, 8, 5, 3, 2});
+
+    auto matmul = make_shared<op::MatMul>(A, B, false, false);
+
+    ASSERT_EQ(matmul->get_element_type(), element::f32);
+    ASSERT_EQ(matmul->get_shape(), (Shape{7, 8, 5, 4, 2}));
+}
+
+TEST(type_prop, matmul_batch_broadcast_expand_to_B)
+{
+    auto A = make_shared<op::Parameter>(element::f32, Shape{8, 7, 6, 1, 4, 3});
+    auto B = make_shared<op::Parameter>(element::f32, Shape{1, 5, 3, 2});
+
+    auto matmul = make_shared<op::MatMul>(A, B, false, false);
+
+    ASSERT_EQ(matmul->get_element_type(), element::f32);
+    ASSERT_EQ(matmul->get_shape(), (Shape{8, 7, 6, 5, 4, 2}));
+}
+
+TEST(type_prop, matmul_incompatible_batch_dims)
+{
+    auto A = make_shared<op::Parameter>(element::f32, Shape{7, 4, 3});
+    auto B = make_shared<op::Parameter>(element::f32, Shape{6, 3, 2});
+
+    try
+    {
+        auto matmul = make_shared<op::MatMul>(A, B);
+        // Should have thrown, so fail if it didn't
+        FAIL() << "Incompatible batch dimensions not detected. ";
+    }
+    catch (const ngraph_error& error)
+    {
+        EXPECT_HAS_SUBSTRING(error.what(), std::string("Incompatible MatMul batch dimension"));
+    }
+    catch (...)
+    {
+        FAIL() << "MatMul shape validation failed for unexpected reason";
+    }
+}
+
+TEST(type_prop, matmul_matrix_dynamic_bounds)
+{
+    auto A =
+        make_shared<op::Parameter>(element::f32, PartialShape{Dimension(2, 5), Dimension(6, 10)});
+    auto B =
+        make_shared<op::Parameter>(element::f32, PartialShape{Dimension(7, 8), Dimension(15, 20)});
+
+    auto matmul = make_shared<op::MatMul>(A, B, false, false);
+
+    ASSERT_EQ(matmul->get_element_type(), element::f32);
+    ASSERT_EQ(matmul->get_output_partial_shape(0),
+              (PartialShape{Dimension(2, 5), Dimension(15, 20)}));
+}
+
+TEST(type_prop, matmul_batch_dynamic_bounds)
+{
+    // Input A and input B dim bounds => output dim bound
+    // Dimension 1 can be expanded to any bigger
+
+    Dimension dynamic = Dimension::dynamic();
+
+    auto A_shape = PartialShape{dynamic,          // 0
+                                Dimension(1, 5),  // 1
+                                Dimension(2, 10), // 2
+                                Dimension(5, 7),  // 3
+                                Dimension(4, 7),  // 4
+                                Dimension(5, 10), // 5
+                                Dimension(1, 4),  // 6
+                                Dimension(0, 1),  // 7
+                                Dimension(0, 3),  // 8
+                                1,                // 9
+                                Dimension(1, -1), // 10
+                                Dimension(1, 10), // 11
+                                Dimension(2, -1), // 12
+                                Dimension(1, -1), // 13
+                                Dimension(2, -1), // 14
+                                Dimension(1, -1), // 15
+                                1,                // 16
+                                1,                // 17
+                                5,                // 18
+                                6};               // 19
+
+    auto B_shape = PartialShape{dynamic,           // 0
+                                Dimension(10, 20), // 1
+                                Dimension(10, 20), // 2
+                                Dimension(4, 10),  // 3
+                                Dimension(5, 10),  // 4
+                                Dimension(4, 7),   // 5
+                                dynamic,           // 6
+                                Dimension(0, 1),   // 7
+                                Dimension(2, 5),   // 8
+                                Dimension(5, 10),  // 9
+                                Dimension(1, 5),   // 10
+                                Dimension(1, 5),   // 11
+                                Dimension(1, 5),   // 12
+                                Dimension(2, -1),  // 13
+                                Dimension(2, -1),  // 14
+                                Dimension(1, -1),  // 15
+                                dynamic,           // 16
+                                3,                 // 17
+                                6,                 // 18
+                                4};                // 19
+
+    auto expected_output_shape = PartialShape{dynamic,           // 0
+                                              Dimension(10, 20), // 1
+                                              10,                // 2
+                                              Dimension(5, 7),   // 3
+                                              Dimension(5, 7),   // 4
+                                              Dimension(5, 7),   // 5
+                                              Dimension(1, -1),  // 6
+                                              Dimension(0, 1),   // 7
+                                              Dimension(2, 5),   // 8
+                                              Dimension(5, 10),  // 9
+                                              Dimension(1, -1),  // 10
+                                              Dimension(1, 10),  // 11
+                                              Dimension(2, -1),  // 12
+                                              Dimension(2, -1),  // 13
+                                              Dimension(2, -1),  // 14
+                                              Dimension(1, -1),  // 15
+                                              Dimension(1, -1),  // 16
+                                              3,                 // 17
+                                              5,                 // 18
+                                              4};                // 19
+
+    auto A = make_shared<op::Parameter>(element::f32, A_shape);
+    auto B = make_shared<op::Parameter>(element::f32, B_shape);
+
+    auto matmul = make_shared<op::MatMul>(A, B);
+
+    ASSERT_EQ(matmul->get_element_type(), element::f32);
+    ASSERT_EQ(matmul->get_output_partial_shape(0), expected_output_shape);
+}
+
+TEST(type_prop, matmul_incompatible_matrix_dim_bounds)
+{
+    auto A =
+        make_shared<op::Parameter>(element::f32, PartialShape{Dimension(2, 5), Dimension(3, 4)});
+    auto B =
+        make_shared<op::Parameter>(element::f32, PartialShape{Dimension(1, 2), Dimension(15, 20)});
+
+    auto expected_output_shape = PartialShape{Dimension(2, 5), Dimension(15, 20)};
+
+    // No error for backward compatibility
+    auto matmul = make_shared<op::MatMul>(A, B, false, false);
+
+    ASSERT_EQ(matmul->get_element_type(), element::f32);
+    ASSERT_EQ(matmul->get_output_partial_shape(0), expected_output_shape);
+}
+
+TEST(type_prop, matmul_incompatible_batch_dim_bounds)
+{
+    auto A = make_shared<op::Parameter>(element::f32, PartialShape{Dimension(2, 5), 4, 3});
+    auto B = make_shared<op::Parameter>(element::f32, PartialShape{Dimension(6, 10), 3, 2});
+
+    Dimension dynamic = Dimension::dynamic();
+    auto expected_output_shape = PartialShape{dynamic, 4, 2};
+
+    // No error for backward compatibility
+    auto matmul = make_shared<op::MatMul>(A, B, false, false);
+
+    ASSERT_EQ(matmul->get_element_type(), element::f32);
+    ASSERT_EQ(matmul->get_output_partial_shape(0), expected_output_shape);
+}