[mlir][vector][nvgpu] Move MMA contraction preparation to VectorUtils
authorJakub Kuderski <kubak@google.com>
Thu, 9 Mar 2023 19:56:20 +0000 (14:56 -0500)
committerJakub Kuderski <kubak@google.com>
Thu, 9 Mar 2023 19:56:21 +0000 (14:56 -0500)
This pattern is not specific to nvgpu; I intend to use in SPIR-V codegen. `VectorTransforms` seems like a more generally useful place.

In addition:
-  Fix a bug in the second condition (the dimensions were swapped for RHS).
-  Add tests.
-  Add support for externally provided filter functions, similar to other vector transforms.
-  Prefer to transpose before zero/sign-extending inputs.

Reviewed By: ThomasRaoux

Differential Revision: https://reviews.llvm.org/D145638

mlir/include/mlir/Dialect/NVGPU/Utils/MMAUtils.h
mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp
mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
mlir/test/Dialect/Vector/vector-contract-matmul-transforms.mlir [new file with mode: 0644]
mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

index 5880b09..003a160 100644 (file)
@@ -93,18 +93,6 @@ FailureOr<AffineMap>
 getLaneIdToLdMatrixMatrixCoord(OpBuilder &builder, Location loc,
                                const LdMatrixParams &params);
 
-/// Transform `vector.contract` into (m,k)x(n,k)x(m,n) form so that it can be
-/// converted to `nvgpu.mma.sync`. This specific form is meant to indicate that
-/// the vector operands are organized such that the reduction dimension is
-/// contiguous.
-struct PrepareContractToGPUMMASync
-    : public OpRewritePattern<vector::ContractionOp> {
-  using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(vector::ContractionOp op,
-                                PatternRewriter &rewriter) const override;
-};
-
 } // namespace nvgpu
 } // namespace mlir
 
index 775bfbf..1d57243 100644 (file)
@@ -17,6 +17,7 @@
 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/PatternMatch.h"
+#include "mlir/Support/LogicalResult.h"
 
 namespace mlir {
 class RewritePatternSet;
@@ -147,6 +148,27 @@ void populateVectorContractLoweringPatterns(
     VectorTransformsOptions options = VectorTransformsOptions(),
     PatternBenefit benefit = 1);
 
+/// Canonicalization of a `vector.contraction %a, %b, %c` with row-major matmul
+/// semantics to a contraction with MMT semantics (matrix matrix multiplication
+/// with the RHS transposed). This specific form is meant to have the vector
+/// operands are organized such that the reduction dimension is contiguous.
+/// Example:
+/// ```
+/// vector.contract {indexing_maps = [affine_map<(m, n, k) -> (m, k)>,
+///                                   affine_map<(m, n, k) -> (n, k)>,
+///                                   affine_map<(m, n, k) -> (m, n)>],
+///                  iterator_types = ["parallel", "parallel", "reduction"],
+///                  kind = #vector.kind<add>} %a, %b, %c : ...
+/// ```
+///
+///  The `constraint` predicate is used to decide which `vector.contraction` ops
+///  to filter out.
+void populateVectorContractCanonicalizeMatmulToMMT(
+    RewritePatternSet &patterns,
+    std::function<LogicalResult(vector::ContractionOp)> constraint =
+        [](vector::ContractionOp) { return success(); },
+    PatternBenefit = 1);
+
 /// Collect patterns to convert reduction op to vector.contract and fold
 /// transpose/broadcast ops into the contract.
 void populateVectorReductionToContractPatterns(RewritePatternSet &patterns,
index d9533b4..cc98136 100644 (file)
@@ -24,6 +24,7 @@
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinOps.h"
@@ -1173,9 +1174,8 @@ void mlir::populatePrepareVectorToMMAPatterns(RewritePatternSet &patterns,
         patterns.getContext());
     return;
   }
-  patterns
-      .add<nvgpu::PrepareContractToGPUMMASync, CombineTransferReadOpTranspose>(
-          patterns.getContext());
+  vector::populateVectorContractCanonicalizeMatmulToMMT(patterns);
+  patterns.add<CombineTransferReadOpTranspose>(patterns.getContext());
 }
 
 LogicalResult mlir::convertVectorToMMAOps(RewriterBase &rewriter,
index 44f9b6d..7525f9f 100644 (file)
@@ -272,60 +272,3 @@ nvgpu::getLaneIdToLdMatrixMatrixCoord(OpBuilder &builder, Location loc,
 
   return failure();
 }
-
-LogicalResult nvgpu::PrepareContractToGPUMMASync::matchAndRewrite(
-    vector::ContractionOp op, PatternRewriter &rewriter) const {
-  Location loc = op.getLoc();
-  Value lhs = op.getLhs();
-  Value rhs = op.getRhs();
-  Value res = op.getAcc();
-
-  // Set up the parallel/reduction structure in right form.
-  using MapList = ArrayRef<ArrayRef<AffineExpr>>;
-  auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
-  AffineExpr m;
-  AffineExpr n;
-  AffineExpr k;
-  bindDims(rewriter.getContext(), m, n, k);
-  static constexpr std::array<int64_t, 2> perm = {1, 0};
-  auto iteratorTypes = op.getIteratorTypes().getValue();
-  SmallVector<AffineMap, 4> maps = op.getIndexingMapsArray();
-  if (iteratorTypes.size() != 3)
-    return failure();
-  if (!(vector::isParallelIterator(iteratorTypes[0]) &&
-        vector::isParallelIterator(iteratorTypes[1]) &&
-        vector::isReductionIterator(iteratorTypes[2])))
-    return failure();
-
-  // The canonical form is "TNT" = A row-major, B col-major, C row-major.
-  const auto canonicalForm = infer({{m, k}, {n, k}, {m, n}});
-  if (maps == canonicalForm) {
-    return failure();
-  }
-  if (maps == infer({{m, k}, {k, n}, {m, n}})) {
-    rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
-  } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
-    lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
-  } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
-    rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
-    lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
-  } else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
-    std::swap(rhs, lhs);
-    rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
-    lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
-  } else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
-    std::swap(rhs, lhs);
-    rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
-  } else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
-    std::swap(lhs, rhs);
-    lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
-  } else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
-    std::swap(lhs, rhs);
-  } else {
-    return failure();
-  }
-  rewriter.replaceOpWithNewOp<vector::ContractionOp>(
-      op, lhs, rhs, res, rewriter.getAffineMapArrayAttr(canonicalForm),
-      op.getIteratorTypes());
-  return success();
-}
index 0844fda..9e9e999 100644 (file)
@@ -12,6 +12,7 @@
 
 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
 
+#include <functional>
 #include <optional>
 #include <type_traits>
 
@@ -24,6 +25,8 @@
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h"
+#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
 #include "mlir/IR/BuiltinAttributeInterfaces.h"
 #include "mlir/IR/BuiltinTypes.h"
@@ -31,6 +34,7 @@
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Interfaces/VectorInterfaces.h"
+#include "mlir/Support/LogicalResult.h"
 
 #include "llvm/ADT/DenseSet.h"
 #include "llvm/ADT/MapVector.h"
@@ -3053,6 +3057,104 @@ struct ScanToArithOps : public OpRewritePattern<vector::ScanOp> {
   }
 };
 
+/// Canonicalization of a `vector.contraction %a, %b, %c` with row-major matmul
+/// semantics to a contraction suitable for MMT (matrix matrix multiplication
+/// with the RHS transposed) lowering.
+struct CanonicalizeContractMatmulToMMT final
+    : OpRewritePattern<vector::ContractionOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  using FilterConstraintType =
+      std::function<LogicalResult(vector::ContractionOp op)>;
+
+  CanonicalizeContractMatmulToMMT(MLIRContext *context, PatternBenefit benefit,
+                                  FilterConstraintType constraint)
+      : OpRewritePattern<vector::ContractionOp>(context, benefit),
+        filter(std::move(constraint)) {}
+
+  LogicalResult matchAndRewrite(vector::ContractionOp op,
+                                PatternRewriter &rewriter) const override {
+    // TODO: Remove native masks from contraction op?
+    if (!op.getMasks().empty())
+      return failure();
+
+    if (failed(filter(op)))
+      return failure();
+
+    Location loc = op.getLoc();
+    Value lhs = op.getLhs();
+    Value rhs = op.getRhs();
+    Value res = op.getAcc();
+
+    // Set up the parallel/reduction structure in right form.
+    using MapList = ArrayRef<ArrayRef<AffineExpr>>;
+    auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
+    AffineExpr m;
+    AffineExpr n;
+    AffineExpr k;
+    bindDims(rewriter.getContext(), m, n, k);
+    static constexpr std::array<int64_t, 2> perm = {1, 0};
+    auto iteratorTypes = op.getIteratorTypes().getValue();
+    SmallVector<AffineMap, 4> maps = op.getIndexingMapsArray();
+    if (iteratorTypes.size() != 3 ||
+        !vector::isParallelIterator(iteratorTypes[0]) ||
+        !vector::isParallelIterator(iteratorTypes[1]) ||
+        !vector::isReductionIterator(iteratorTypes[2]))
+      return rewriter.notifyMatchFailure(op, "contraction is not a gemm");
+
+    // The canonical form is "TNT" = A row-major, B col-major, C row-major.
+    const auto canonicalForm = infer({{m, k}, {n, k}, {m, n}});
+    if (maps == canonicalForm)
+      return rewriter.notifyMatchFailure(op, "already in the canonical form");
+
+    // Create a vector transpose making sure to emit zero/sign-extend at the
+    // end.
+    auto createTranspose = [&rewriter, loc](Value mat) -> Value {
+      if (auto sext = mat.getDefiningOp<arith::ExtSIOp>()) {
+        Value trans =
+            rewriter.create<vector::TransposeOp>(loc, sext.getIn(), perm);
+        return rewriter.create<arith::ExtSIOp>(loc, mat.getType(), trans);
+      }
+      if (auto zext = mat.getDefiningOp<arith::ExtUIOp>()) {
+        Value trans =
+            rewriter.create<vector::TransposeOp>(loc, zext.getIn(), perm);
+        return rewriter.create<arith::ExtUIOp>(loc, mat.getType(), trans);
+      }
+      return rewriter.create<vector::TransposeOp>(loc, mat, perm);
+    };
+
+    if (maps == infer({{m, k}, {k, n}, {m, n}})) {
+      rhs = createTranspose(rhs);
+    } else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
+      lhs = createTranspose(lhs);
+    } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
+      rhs = createTranspose(rhs);
+      lhs = createTranspose(lhs);
+    } else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
+      std::swap(rhs, lhs);
+      rhs = createTranspose(rhs);
+      lhs = createTranspose(lhs);
+    } else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
+      std::swap(rhs, lhs);
+      rhs = createTranspose(rhs);
+    } else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
+      std::swap(lhs, rhs);
+      lhs = createTranspose(lhs);
+    } else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
+      std::swap(lhs, rhs);
+    } else {
+      return rewriter.notifyMatchFailure(op, "unhandled contraction form");
+    }
+    rewriter.replaceOpWithNewOp<vector::ContractionOp>(
+        op, lhs, rhs, res, rewriter.getAffineMapArrayAttr(canonicalForm),
+        op.getIteratorTypes());
+    return success();
+  };
+
+private:
+  FilterConstraintType filter;
+};
+
 } // namespace
 
 void mlir::vector::populateVectorMaskMaterializationPatterns(
@@ -3104,6 +3206,14 @@ void mlir::vector::populateVectorContractLoweringPatterns(
       options, patterns.getContext(), benefit);
 }
 
+void mlir::vector::populateVectorContractCanonicalizeMatmulToMMT(
+    RewritePatternSet &patterns,
+    std::function<LogicalResult(vector::ContractionOp)> constraint,
+    PatternBenefit benefit) {
+  patterns.add<CanonicalizeContractMatmulToMMT>(patterns.getContext(), benefit,
+                                                std::move(constraint));
+}
+
 void mlir::vector::populateVectorTransposeLoweringPatterns(
     RewritePatternSet &patterns, VectorTransformsOptions options,
     PatternBenefit benefit) {
diff --git a/mlir/test/Dialect/Vector/vector-contract-matmul-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-matmul-transforms.mlir
new file mode 100644 (file)
index 0000000..d0be123
--- /dev/null
@@ -0,0 +1,198 @@
+// RUN: mlir-opt %s -test-vector-contraction-prepare-for-mmt-lowering | FileCheck %s
+
+// CHECK-LABEL: func.func @not_matmul
+// CHECK-SAME:    ([[ARG0:%.+]]: vector<4xf32>, [[ARG1:%.+]]: vector<4xf32>, [[ARG2:%.+]]: f32)
+// CHECK-NEXT:    vector.contract
+// CHECK-NEXT:    return
+func.func @not_matmul(%arg0: vector<4xf32>, %arg1: vector<4xf32>, %arg2: f32) -> f32 {
+  %0 = vector.contract {indexing_maps = [affine_map<(d0) -> (d0)>,
+                                         affine_map<(d0) -> (d0)>,
+                                         affine_map<(d0) -> ()>],
+                        iterator_types = ["reduction"],
+                        kind = #vector.kind<add>} %arg0, %arg1, %arg2 :
+         vector<4xf32>, vector<4xf32> into f32
+  return %0 : f32
+}
+
+// This contraction is already in the canonical form.
+// CHECK-LABEL: func.func @matmul_mk_nk_mn_4x4xi32
+// CHECK-SAME:    ([[ARG0:%.+]]: vector<4x4xi32>, [[ARG1:%.+]]: vector<4x4xi32>, [[ARG2:%.+]]: vector<4x4xi32>)
+// CHECK-NEXT:    [[RES:%.+]]   = vector.contract {{.+}} [[ARG0]], [[ARG1]], [[ARG2]]
+// CHECK-NEXT:    return [[RES]]
+func.func @matmul_mk_nk_mn_4x4xi32(%arg0: vector<4x4xi32>, %arg1: vector<4x4xi32>, %arg2: vector<4x4xi32>) -> vector<4x4xi32> {
+  %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
+                                           affine_map<(d0, d1, d2) -> (d1, d2)>,
+                                           affine_map<(d0, d1, d2) -> (d0, d1)>],
+                          iterator_types = ["parallel", "parallel", "reduction"],
+                          kind = #vector.kind<add>} %arg0, %arg1, %arg2 : vector<4x4xi32>, vector<4x4xi32> into vector<4x4xi32>
+  return %res : vector<4x4xi32>
+}
+
+// CHECK-LABEL: func.func @matmul_mk_kn_mn_4x4xi32
+// CHECK-SAME:    ([[ARG0:%.+]]: vector<4x4xi32>, [[ARG1:%.+]]: vector<4x4xi32>, [[ARG2:%.+]]: vector<4x4xi32>)
+// CHECK-NEXT:    [[TRANS:%.+]] = vector.transpose [[ARG1]], [1, 0] : vector<4x4xi32> to vector<4x4xi32>
+// CHECK-NEXT:    [[RES:%.+]]   = vector.contract {{.+}} [[ARG0]], [[TRANS]], [[ARG2]]
+// CHECK-NEXT:    return [[RES]]
+func.func @matmul_mk_kn_mn_4x4xi32(%arg0: vector<4x4xi32>, %arg1: vector<4x4xi32>, %arg2: vector<4x4xi32>) -> vector<4x4xi32> {
+  %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
+                                           affine_map<(d0, d1, d2) -> (d2, d1)>,
+                                           affine_map<(d0, d1, d2) -> (d0, d1)>],
+                          iterator_types = ["parallel", "parallel", "reduction"],
+                          kind = #vector.kind<add>} %arg0, %arg1, %arg2 : vector<4x4xi32>, vector<4x4xi32> into vector<4x4xi32>
+  return %res : vector<4x4xi32>
+}
+
+// CHECK-LABEL: func.func @matmul_mk_kn_mn_4x4xi8_extsi_i32
+// CHECK-SAME:    ([[ARG0:%.+]]: vector<4x4xi8>, [[ARG1:%.+]]: vector<4x4xi8>, [[ARG2:%.+]]: vector<4x4xi32>)
+// CHECK-NEXT:    [[LHS:%.+]]   = arith.extsi [[ARG0]] : vector<4x4xi8> to vector<4x4xi32>
+// CHECK-NEXT:    [[TRANS:%.+]] = vector.transpose [[ARG1]], [1, 0] : vector<4x4xi8> to vector<4x4xi8>
+// CHECK-NEXT:    [[RHS:%.+]]   = arith.extsi [[TRANS]] : vector<4x4xi8> to vector<4x4xi32>
+// CHECK-NEXT:    [[RES:%.+]]   = vector.contract {{.+}} [[LHS]], [[RHS]], [[ARG2]]
+// CHECK-NEXT:    return [[RES]]
+func.func @matmul_mk_kn_mn_4x4xi8_extsi_i32(%arg0: vector<4x4xi8>, %arg1: vector<4x4xi8>, %arg2: vector<4x4xi32>) -> vector<4x4xi32> {
+  %lhs = arith.extsi %arg0: vector<4x4xi8> to vector<4x4xi32>
+  %rhs = arith.extsi %arg1: vector<4x4xi8> to vector<4x4xi32>
+  %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
+                                           affine_map<(d0, d1, d2) -> (d2, d1)>,
+                                           affine_map<(d0, d1, d2) -> (d0, d1)>],
+                          iterator_types = ["parallel", "parallel", "reduction"],
+                          kind = #vector.kind<add>} %lhs, %rhs, %arg2 : vector<4x4xi32>, vector<4x4xi32> into vector<4x4xi32>
+  return %res : vector<4x4xi32>
+}
+
+// Check that non-square shapes are also handled.
+// CHECK-LABEL: func.func @matmul_mk_kn_mn_4x16xi32
+// CHECK-SAME:    ([[ARG0:%.+]]: vector<4x16xi32>, [[ARG1:%.+]]: vector<16x4xi32>, [[ARG2:%.+]]: vector<4x4xi32>)
+// CHECK-NEXT:    [[TRANS:%.+]] = vector.transpose [[ARG1]], [1, 0] : vector<16x4xi32> to vector<4x16xi32>
+// CHECK-NEXT:    [[RES:%.+]]   = vector.contract {{.+}} [[ARG0]], [[TRANS]], [[ARG2]]
+// CHECK-NEXT:    return [[RES]]
+func.func @matmul_mk_kn_mn_4x16xi32(%arg0: vector<4x16xi32>, %arg1: vector<16x4xi32>, %arg2: vector<4x4xi32>) -> vector<4x4xi32> {
+  %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
+                                           affine_map<(d0, d1, d2) -> (d2, d1)>,
+                                           affine_map<(d0, d1, d2) -> (d0, d1)>],
+                          iterator_types = ["parallel", "parallel", "reduction"],
+                          kind = #vector.kind<add>} %arg0, %arg1, %arg2 : vector<4x16xi32>, vector<16x4xi32> into vector<4x4xi32>
+  return %res : vector<4x4xi32>
+}
+
+// CHECK-LABEL: func.func @matmul_mk_kn_mn_4x4xi8_extui_i32
+// CHECK-SAME:    ([[ARG0:%.+]]: vector<4x4xi8>, [[ARG1:%.+]]: vector<4x4xi8>, [[ARG2:%.+]]: vector<4x4xi32>)
+// CHECK-NEXT:    [[LHS:%.+]]   = arith.extui [[ARG0]] : vector<4x4xi8> to vector<4x4xi32>
+// CHECK-NEXT:    [[TRANS:%.+]] = vector.transpose [[ARG1]], [1, 0] : vector<4x4xi8> to vector<4x4xi8>
+// CHECK-NEXT:    [[RHS:%.+]]   = arith.extui [[TRANS]] : vector<4x4xi8> to vector<4x4xi32>
+// CHECK-NEXT:    [[RES:%.+]]   = vector.contract {{.+}} [[LHS]], [[RHS]], [[ARG2]]
+// CHECK-NEXT:    return [[RES]]
+func.func @matmul_mk_kn_mn_4x4xi8_extui_i32(%arg0: vector<4x4xi8>, %arg1: vector<4x4xi8>, %arg2: vector<4x4xi32>) -> vector<4x4xi32> {
+  %lhs = arith.extui %arg0: vector<4x4xi8> to vector<4x4xi32>
+  %rhs = arith.extui %arg1: vector<4x4xi8> to vector<4x4xi32>
+  %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
+                                           affine_map<(d0, d1, d2) -> (d2, d1)>,
+                                           affine_map<(d0, d1, d2) -> (d0, d1)>],
+                          iterator_types = ["parallel", "parallel", "reduction"],
+                          kind = #vector.kind<add>} %lhs, %rhs, %arg2 : vector<4x4xi32>, vector<4x4xi32> into vector<4x4xi32>
+  return %res : vector<4x4xi32>
+}
+
+// CHECK-LABEL: func.func @matmul_km_nk_mn_4x4xi32
+// CHECK-SAME:    ([[ARG0:%.+]]: vector<4x4xi32>, [[ARG1:%.+]]: vector<4x4xi32>, [[ARG2:%.+]]: vector<4x4xi32>)
+// CHECK-NEXT:    [[TRANS:%.+]] = vector.transpose [[ARG0]], [1, 0] : vector<4x4xi32> to vector<4x4xi32>
+// CHECK-NEXT:    [[RES:%.+]]   = vector.contract {{.+}} [[TRANS]], [[ARG1]], [[ARG2]]
+// CHECK-NEXT:    return [[RES]]
+func.func @matmul_km_nk_mn_4x4xi32(%arg0: vector<4x4xi32>, %arg1: vector<4x4xi32>, %arg2: vector<4x4xi32>) -> vector<4x4xi32> {
+  %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d0)>,
+                                           affine_map<(d0, d1, d2) -> (d1, d2)>,
+                                           affine_map<(d0, d1, d2) -> (d0, d1)>],
+                          iterator_types = ["parallel", "parallel", "reduction"],
+                          kind = #vector.kind<add>} %arg0, %arg1, %arg2 : vector<4x4xi32>, vector<4x4xi32> into vector<4x4xi32>
+  return %res : vector<4x4xi32>
+}
+
+// CHECK-LABEL: func.func @matmul_km_kn_mn_4x4xi32
+// CHECK-SAME:    ([[ARG0:%.+]]: vector<4x4xi32>, [[ARG1:%.+]]: vector<4x4xi32>, [[ARG2:%.+]]: vector<4x4xi32>)
+// CHECK-DAG:     [[LHS:%.+]] = vector.transpose [[ARG0]], [1, 0] : vector<4x4xi32> to vector<4x4xi32>
+// CHECK-DAG:     [[RHS:%.+]] = vector.transpose [[ARG1]], [1, 0] : vector<4x4xi32> to vector<4x4xi32>
+// CHECK-NEXT:    [[RES:%.+]] = vector.contract {{.+}} [[LHS]], [[RHS]], [[ARG2]]
+// CHECK-NEXT:    return [[RES]]
+func.func @matmul_km_kn_mn_4x4xi32(%arg0: vector<4x4xi32>, %arg1: vector<4x4xi32>, %arg2: vector<4x4xi32>) -> vector<4x4xi32> {
+  %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d0)>,
+                                           affine_map<(d0, d1, d2) -> (d2, d1)>,
+                                           affine_map<(d0, d1, d2) -> (d0, d1)>],
+                          iterator_types = ["parallel", "parallel", "reduction"],
+                          kind = #vector.kind<add>} %arg0, %arg1, %arg2 : vector<4x4xi32>, vector<4x4xi32> into vector<4x4xi32>
+  return %res : vector<4x4xi32>
+}
+
+// CHECK-LABEL: func.func @matmul_km_kn_mn_4x4xi8_mixed_ext_i32
+// CHECK-SAME:    ([[ARG0:%.+]]: vector<4x4xi8>, [[ARG1:%.+]]: vector<4x4xi8>, [[ARG2:%.+]]: vector<4x4xi32>)
+// CHECK-DAG:     [[LHST:%.+]] = vector.transpose [[ARG0]], [1, 0] : vector<4x4xi8> to vector<4x4xi8>
+// CHECK-DAG:     [[LHS:%.+]]  = arith.extsi [[LHST]] : vector<4x4xi8> to vector<4x4xi32>
+// CHECK-DAG:     [[RHST:%.+]] = vector.transpose [[ARG1]], [1, 0] : vector<4x4xi8> to vector<4x4xi8>
+// CHECK-DAG:     [[RHS:%.+]]  = arith.extui [[RHST]] : vector<4x4xi8> to vector<4x4xi32>
+// CHECK-NEXT:    [[RES:%.+]]  = vector.contract {{.+}} [[LHS]], [[RHS]], [[ARG2]]
+// CHECK-NEXT:    return [[RES]]
+func.func @matmul_km_kn_mn_4x4xi8_mixed_ext_i32(%arg0: vector<4x4xi8>, %arg1: vector<4x4xi8>, %arg2: vector<4x4xi32>) -> vector<4x4xi32> {
+  %lhs = arith.extsi %arg0 : vector<4x4xi8> to vector<4x4xi32>
+  %rhs = arith.extui %arg1 : vector<4x4xi8> to vector<4x4xi32>
+  %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d0)>,
+                                           affine_map<(d0, d1, d2) -> (d2, d1)>,
+                                           affine_map<(d0, d1, d2) -> (d0, d1)>],
+                          iterator_types = ["parallel", "parallel", "reduction"],
+                          kind = #vector.kind<add>} %lhs, %rhs, %arg2 : vector<4x4xi32>, vector<4x4xi32> into vector<4x4xi32>
+  return %res : vector<4x4xi32>
+}
+
+// CHECK-LABEL: func.func @matmul_mk_nk_nm_4x4xi32
+// CHECK-SAME:    ([[ARG0:%.+]]: vector<4x4xi32>, [[ARG1:%.+]]: vector<4x4xi32>, [[ARG2:%.+]]: vector<4x4xi32>)
+// CHECK-NEXT:    [[RES:%.+]]   = vector.contract {{.+}} [[ARG1]], [[ARG0]], [[ARG2]]
+// CHECK-NEXT:    return [[RES]]
+func.func @matmul_mk_nk_nm_4x4xi32(%arg0: vector<4x4xi32>, %arg1: vector<4x4xi32>, %arg2: vector<4x4xi32>) -> vector<4x4xi32> {
+  %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
+                                           affine_map<(d0, d1, d2) -> (d1, d2)>,
+                                           affine_map<(d0, d1, d2) -> (d1, d0)>],
+                          iterator_types = ["parallel", "parallel", "reduction"],
+                          kind = #vector.kind<add>} %arg0, %arg1, %arg2 : vector<4x4xi32>, vector<4x4xi32> into vector<4x4xi32>
+  return %res : vector<4x4xi32>
+}
+
+// CHECK-LABEL: func.func @matmul_km_kn_nm_4x4xi32
+// CHECK-SAME:    ([[ARG0:%.+]]: vector<4x4xi32>, [[ARG1:%.+]]: vector<4x4xi32>, [[ARG2:%.+]]: vector<4x4xi32>)
+// CHECK-DAG:     [[LHS:%.+]] = vector.transpose [[ARG0]], [1, 0] : vector<4x4xi32> to vector<4x4xi32>
+// CHECK-DAG:     [[RHS:%.+]] = vector.transpose [[ARG1]], [1, 0] : vector<4x4xi32> to vector<4x4xi32>
+// CHECK-NEXT:    [[RES:%.+]] = vector.contract {{.+}} [[RHS]], [[LHS]], [[ARG2]]
+// CHECK-NEXT:    return [[RES]]
+func.func @matmul_km_kn_nm_4x4xi32(%arg0: vector<4x4xi32>, %arg1: vector<4x4xi32>, %arg2: vector<4x4xi32>) -> vector<4x4xi32> {
+  %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d0)>,
+                                           affine_map<(d0, d1, d2) -> (d2, d1)>,
+                                           affine_map<(d0, d1, d2) -> (d1, d0)>],
+                          iterator_types = ["parallel", "parallel", "reduction"],
+                          kind = #vector.kind<add>} %arg0, %arg1, %arg2 : vector<4x4xi32>, vector<4x4xi32> into vector<4x4xi32>
+  return %res : vector<4x4xi32>
+}
+
+// CHECK-LABEL: func.func @matmul_mk_kn_nm_4x4xi32
+// CHECK-SAME:    ([[ARG0:%.+]]: vector<4x4xi32>, [[ARG1:%.+]]: vector<4x4xi32>, [[ARG2:%.+]]: vector<4x4xi32>)
+// CHECK-DAG:     [[RHS:%.+]] = vector.transpose [[ARG1]], [1, 0] : vector<4x4xi32> to vector<4x4xi32>
+// CHECK-NEXT:    [[RES:%.+]] = vector.contract {{.+}} [[RHS]], [[ARG0]], [[ARG2]]
+// CHECK-NEXT:    return [[RES]]
+func.func @matmul_mk_kn_nm_4x4xi32(%arg0: vector<4x4xi32>, %arg1: vector<4x4xi32>, %arg2: vector<4x4xi32>) -> vector<4x4xi32> {
+  %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
+                                           affine_map<(d0, d1, d2) -> (d2, d1)>,
+                                           affine_map<(d0, d1, d2) -> (d1, d0)>],
+                          iterator_types = ["parallel", "parallel", "reduction"],
+                          kind = #vector.kind<add>} %arg0, %arg1, %arg2 : vector<4x4xi32>, vector<4x4xi32> into vector<4x4xi32>
+  return %res : vector<4x4xi32>
+}
+
+// CHECK-LABEL: func.func @matmul_km_nk_nm_4x4xi32
+// CHECK-SAME:    ([[ARG0:%.+]]: vector<4x4xi32>, [[ARG1:%.+]]: vector<4x4xi32>, [[ARG2:%.+]]: vector<4x4xi32>)
+// CHECK-DAG:     [[LHS:%.+]] = vector.transpose [[ARG0]], [1, 0] : vector<4x4xi32> to vector<4x4xi32>
+// CHECK-NEXT:    [[RES:%.+]] = vector.contract {{.+}} [[ARG1]], [[LHS]], [[ARG2]]
+// CHECK-NEXT:    return [[RES]]
+func.func @matmul_km_nk_nm_4x4xi32(%arg0: vector<4x4xi32>, %arg1: vector<4x4xi32>, %arg2: vector<4x4xi32>) -> vector<4x4xi32> {
+  %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d0)>,
+                                           affine_map<(d0, d1, d2) -> (d1, d2)>,
+                                           affine_map<(d0, d1, d2) -> (d1, d0)>],
+                          iterator_types = ["parallel", "parallel", "reduction"],
+                          kind = #vector.kind<add>} %arg0, %arg1, %arg2 : vector<4x4xi32>, vector<4x4xi32> into vector<4x4xi32>
+  return %res : vector<4x4xi32>
+}
index 68edea5..93736da 100644 (file)
@@ -11,6 +11,7 @@
 
 #include "mlir/Analysis/SliceAnalysis.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
@@ -199,6 +200,33 @@ struct TestVectorContractionLowering
   }
 };
 
+struct TestVectorContractionPrepareForMMTLowering
+    : public PassWrapper<TestVectorContractionPrepareForMMTLowering,
+                         OperationPass<func::FuncOp>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
+      TestVectorContractionPrepareForMMTLowering)
+
+  StringRef getArgument() const final {
+    return "test-vector-contraction-prepare-for-mmt-lowering";
+  }
+  StringRef getDescription() const final {
+    return "Test vector.contraction matmul canonicalization for MMT lowering.";
+  }
+  TestVectorContractionPrepareForMMTLowering() = default;
+
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry
+        .insert<AffineDialect, arith::ArithDialect, vector::VectorDialect>();
+  }
+
+  void runOnOperation() override {
+    MLIRContext *ctx = &getContext();
+    RewritePatternSet patterns(ctx);
+    vector::populateVectorContractCanonicalizeMatmulToMMT(patterns);
+    (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+  }
+};
+
 struct TestVectorTransposeLowering
     : public PassWrapper<TestVectorTransposeLowering,
                          OperationPass<func::FuncOp>> {
@@ -892,6 +920,8 @@ void registerTestVectorLowerings() {
 
   PassRegistration<TestVectorContractionLowering>();
 
+  PassRegistration<TestVectorContractionPrepareForMMTLowering>();
+
   PassRegistration<TestVectorTransposeLowering>();
 
   PassRegistration<TestVectorUnrollingPatterns>();