return failure();
}
-
-LogicalResult nvgpu::PrepareContractToGPUMMASync::matchAndRewrite(
- vector::ContractionOp op, PatternRewriter &rewriter) const {
- Location loc = op.getLoc();
- Value lhs = op.getLhs();
- Value rhs = op.getRhs();
- Value res = op.getAcc();
-
- // Set up the parallel/reduction structure in right form.
- using MapList = ArrayRef<ArrayRef<AffineExpr>>;
- auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
- AffineExpr m;
- AffineExpr n;
- AffineExpr k;
- bindDims(rewriter.getContext(), m, n, k);
- static constexpr std::array<int64_t, 2> perm = {1, 0};
- auto iteratorTypes = op.getIteratorTypes().getValue();
- SmallVector<AffineMap, 4> maps = op.getIndexingMapsArray();
- if (iteratorTypes.size() != 3)
- return failure();
- if (!(vector::isParallelIterator(iteratorTypes[0]) &&
- vector::isParallelIterator(iteratorTypes[1]) &&
- vector::isReductionIterator(iteratorTypes[2])))
- return failure();
-
- // The canonical form is "TNT" = A row-major, B col-major, C row-major.
- const auto canonicalForm = infer({{m, k}, {n, k}, {m, n}});
- if (maps == canonicalForm) {
- return failure();
- }
- if (maps == infer({{m, k}, {k, n}, {m, n}})) {
- rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
- } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
- lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
- } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
- rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
- lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
- } else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
- std::swap(rhs, lhs);
- rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
- lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
- } else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
- std::swap(rhs, lhs);
- rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
- } else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
- std::swap(lhs, rhs);
- lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
- } else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
- std::swap(lhs, rhs);
- } else {
- return failure();
- }
- rewriter.replaceOpWithNewOp<vector::ContractionOp>(
- op, lhs, rhs, res, rewriter.getAffineMapArrayAttr(canonicalForm),
- op.getIteratorTypes());
- return success();
-}
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
+#include <functional>
#include <optional>
#include <type_traits>
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h"
+#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/IR/BuiltinAttributeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/VectorInterfaces.h"
+#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/MapVector.h"
}
};
+/// Canonicalization of a `vector.contraction %a, %b, %c` with row-major matmul
+/// semantics to a contraction suitable for MMT (matrix matrix multiplication
+/// with the RHS transposed) lowering.
+struct CanonicalizeContractMatmulToMMT final
+ : OpRewritePattern<vector::ContractionOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ using FilterConstraintType =
+ std::function<LogicalResult(vector::ContractionOp op)>;
+
+ CanonicalizeContractMatmulToMMT(MLIRContext *context, PatternBenefit benefit,
+ FilterConstraintType constraint)
+ : OpRewritePattern<vector::ContractionOp>(context, benefit),
+ filter(std::move(constraint)) {}
+
+ LogicalResult matchAndRewrite(vector::ContractionOp op,
+ PatternRewriter &rewriter) const override {
+ // TODO: Remove native masks from contraction op?
+ if (!op.getMasks().empty())
+ return failure();
+
+ if (failed(filter(op)))
+ return failure();
+
+ Location loc = op.getLoc();
+ Value lhs = op.getLhs();
+ Value rhs = op.getRhs();
+ Value res = op.getAcc();
+
+ // Set up the parallel/reduction structure in right form.
+ using MapList = ArrayRef<ArrayRef<AffineExpr>>;
+ auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
+ AffineExpr m;
+ AffineExpr n;
+ AffineExpr k;
+ bindDims(rewriter.getContext(), m, n, k);
+ static constexpr std::array<int64_t, 2> perm = {1, 0};
+ auto iteratorTypes = op.getIteratorTypes().getValue();
+ SmallVector<AffineMap, 4> maps = op.getIndexingMapsArray();
+ if (iteratorTypes.size() != 3 ||
+ !vector::isParallelIterator(iteratorTypes[0]) ||
+ !vector::isParallelIterator(iteratorTypes[1]) ||
+ !vector::isReductionIterator(iteratorTypes[2]))
+ return rewriter.notifyMatchFailure(op, "contraction is not a gemm");
+
+ // The canonical form is "TNT" = A row-major, B col-major, C row-major.
+ const auto canonicalForm = infer({{m, k}, {n, k}, {m, n}});
+ if (maps == canonicalForm)
+ return rewriter.notifyMatchFailure(op, "already in the canonical form");
+
+ // Create a vector transpose making sure to emit zero/sign-extend at the
+ // end.
+ auto createTranspose = [&rewriter, loc](Value mat) -> Value {
+ if (auto sext = mat.getDefiningOp<arith::ExtSIOp>()) {
+ Value trans =
+ rewriter.create<vector::TransposeOp>(loc, sext.getIn(), perm);
+ return rewriter.create<arith::ExtSIOp>(loc, mat.getType(), trans);
+ }
+ if (auto zext = mat.getDefiningOp<arith::ExtUIOp>()) {
+ Value trans =
+ rewriter.create<vector::TransposeOp>(loc, zext.getIn(), perm);
+ return rewriter.create<arith::ExtUIOp>(loc, mat.getType(), trans);
+ }
+ return rewriter.create<vector::TransposeOp>(loc, mat, perm);
+ };
+
+ if (maps == infer({{m, k}, {k, n}, {m, n}})) {
+ rhs = createTranspose(rhs);
+ } else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
+ lhs = createTranspose(lhs);
+ } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
+ rhs = createTranspose(rhs);
+ lhs = createTranspose(lhs);
+ } else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
+ std::swap(rhs, lhs);
+ rhs = createTranspose(rhs);
+ lhs = createTranspose(lhs);
+ } else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
+ std::swap(rhs, lhs);
+ rhs = createTranspose(rhs);
+ } else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
+ std::swap(lhs, rhs);
+ lhs = createTranspose(lhs);
+ } else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
+ std::swap(lhs, rhs);
+ } else {
+ return rewriter.notifyMatchFailure(op, "unhandled contraction form");
+ }
+ rewriter.replaceOpWithNewOp<vector::ContractionOp>(
+ op, lhs, rhs, res, rewriter.getAffineMapArrayAttr(canonicalForm),
+ op.getIteratorTypes());
+ return success();
+ };
+
+private:
+ FilterConstraintType filter;
+};
+
} // namespace
void mlir::vector::populateVectorMaskMaterializationPatterns(
options, patterns.getContext(), benefit);
}
+void mlir::vector::populateVectorContractCanonicalizeMatmulToMMT(
+ RewritePatternSet &patterns,
+ std::function<LogicalResult(vector::ContractionOp)> constraint,
+ PatternBenefit benefit) {
+ patterns.add<CanonicalizeContractMatmulToMMT>(patterns.getContext(), benefit,
+ std::move(constraint));
+}
+
void mlir::vector::populateVectorTransposeLoweringPatterns(
RewritePatternSet &patterns, VectorTransformsOptions options,
PatternBenefit benefit) {
--- /dev/null
+// RUN: mlir-opt %s -test-vector-contraction-prepare-for-mmt-lowering | FileCheck %s
+
+// CHECK-LABEL: func.func @not_matmul
+// CHECK-SAME: ([[ARG0:%.+]]: vector<4xf32>, [[ARG1:%.+]]: vector<4xf32>, [[ARG2:%.+]]: f32)
+// CHECK-NEXT: vector.contract
+// CHECK-NEXT: return
+func.func @not_matmul(%arg0: vector<4xf32>, %arg1: vector<4xf32>, %arg2: f32) -> f32 {
+ %0 = vector.contract {indexing_maps = [affine_map<(d0) -> (d0)>,
+ affine_map<(d0) -> (d0)>,
+ affine_map<(d0) -> ()>],
+ iterator_types = ["reduction"],
+ kind = #vector.kind<add>} %arg0, %arg1, %arg2 :
+ vector<4xf32>, vector<4xf32> into f32
+ return %0 : f32
+}
+
+// This contraction is already in the canonical form.
+// CHECK-LABEL: func.func @matmul_mk_nk_mn_4x4xi32
+// CHECK-SAME: ([[ARG0:%.+]]: vector<4x4xi32>, [[ARG1:%.+]]: vector<4x4xi32>, [[ARG2:%.+]]: vector<4x4xi32>)
+// CHECK-NEXT: [[RES:%.+]] = vector.contract {{.+}} [[ARG0]], [[ARG1]], [[ARG2]]
+// CHECK-NEXT: return [[RES]]
+func.func @matmul_mk_nk_mn_4x4xi32(%arg0: vector<4x4xi32>, %arg1: vector<4x4xi32>, %arg2: vector<4x4xi32>) -> vector<4x4xi32> {
+ %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>} %arg0, %arg1, %arg2 : vector<4x4xi32>, vector<4x4xi32> into vector<4x4xi32>
+ return %res : vector<4x4xi32>
+}
+
+// CHECK-LABEL: func.func @matmul_mk_kn_mn_4x4xi32
+// CHECK-SAME: ([[ARG0:%.+]]: vector<4x4xi32>, [[ARG1:%.+]]: vector<4x4xi32>, [[ARG2:%.+]]: vector<4x4xi32>)
+// CHECK-NEXT: [[TRANS:%.+]] = vector.transpose [[ARG1]], [1, 0] : vector<4x4xi32> to vector<4x4xi32>
+// CHECK-NEXT: [[RES:%.+]] = vector.contract {{.+}} [[ARG0]], [[TRANS]], [[ARG2]]
+// CHECK-NEXT: return [[RES]]
+func.func @matmul_mk_kn_mn_4x4xi32(%arg0: vector<4x4xi32>, %arg1: vector<4x4xi32>, %arg2: vector<4x4xi32>) -> vector<4x4xi32> {
+ %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d2, d1)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>} %arg0, %arg1, %arg2 : vector<4x4xi32>, vector<4x4xi32> into vector<4x4xi32>
+ return %res : vector<4x4xi32>
+}
+
+// CHECK-LABEL: func.func @matmul_mk_kn_mn_4x4xi8_extsi_i32
+// CHECK-SAME: ([[ARG0:%.+]]: vector<4x4xi8>, [[ARG1:%.+]]: vector<4x4xi8>, [[ARG2:%.+]]: vector<4x4xi32>)
+// CHECK-NEXT: [[LHS:%.+]] = arith.extsi [[ARG0]] : vector<4x4xi8> to vector<4x4xi32>
+// CHECK-NEXT: [[TRANS:%.+]] = vector.transpose [[ARG1]], [1, 0] : vector<4x4xi8> to vector<4x4xi8>
+// CHECK-NEXT: [[RHS:%.+]] = arith.extsi [[TRANS]] : vector<4x4xi8> to vector<4x4xi32>
+// CHECK-NEXT: [[RES:%.+]] = vector.contract {{.+}} [[LHS]], [[RHS]], [[ARG2]]
+// CHECK-NEXT: return [[RES]]
+func.func @matmul_mk_kn_mn_4x4xi8_extsi_i32(%arg0: vector<4x4xi8>, %arg1: vector<4x4xi8>, %arg2: vector<4x4xi32>) -> vector<4x4xi32> {
+ %lhs = arith.extsi %arg0: vector<4x4xi8> to vector<4x4xi32>
+ %rhs = arith.extsi %arg1: vector<4x4xi8> to vector<4x4xi32>
+ %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d2, d1)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>} %lhs, %rhs, %arg2 : vector<4x4xi32>, vector<4x4xi32> into vector<4x4xi32>
+ return %res : vector<4x4xi32>
+}
+
+// Check that non-square shapes are also handled.
+// CHECK-LABEL: func.func @matmul_mk_kn_mn_4x16xi32
+// CHECK-SAME: ([[ARG0:%.+]]: vector<4x16xi32>, [[ARG1:%.+]]: vector<16x4xi32>, [[ARG2:%.+]]: vector<4x4xi32>)
+// CHECK-NEXT: [[TRANS:%.+]] = vector.transpose [[ARG1]], [1, 0] : vector<16x4xi32> to vector<4x16xi32>
+// CHECK-NEXT: [[RES:%.+]] = vector.contract {{.+}} [[ARG0]], [[TRANS]], [[ARG2]]
+// CHECK-NEXT: return [[RES]]
+func.func @matmul_mk_kn_mn_4x16xi32(%arg0: vector<4x16xi32>, %arg1: vector<16x4xi32>, %arg2: vector<4x4xi32>) -> vector<4x4xi32> {
+ %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d2, d1)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>} %arg0, %arg1, %arg2 : vector<4x16xi32>, vector<16x4xi32> into vector<4x4xi32>
+ return %res : vector<4x4xi32>
+}
+
+// CHECK-LABEL: func.func @matmul_mk_kn_mn_4x4xi8_extui_i32
+// CHECK-SAME: ([[ARG0:%.+]]: vector<4x4xi8>, [[ARG1:%.+]]: vector<4x4xi8>, [[ARG2:%.+]]: vector<4x4xi32>)
+// CHECK-NEXT: [[LHS:%.+]] = arith.extui [[ARG0]] : vector<4x4xi8> to vector<4x4xi32>
+// CHECK-NEXT: [[TRANS:%.+]] = vector.transpose [[ARG1]], [1, 0] : vector<4x4xi8> to vector<4x4xi8>
+// CHECK-NEXT: [[RHS:%.+]] = arith.extui [[TRANS]] : vector<4x4xi8> to vector<4x4xi32>
+// CHECK-NEXT: [[RES:%.+]] = vector.contract {{.+}} [[LHS]], [[RHS]], [[ARG2]]
+// CHECK-NEXT: return [[RES]]
+func.func @matmul_mk_kn_mn_4x4xi8_extui_i32(%arg0: vector<4x4xi8>, %arg1: vector<4x4xi8>, %arg2: vector<4x4xi32>) -> vector<4x4xi32> {
+ %lhs = arith.extui %arg0: vector<4x4xi8> to vector<4x4xi32>
+ %rhs = arith.extui %arg1: vector<4x4xi8> to vector<4x4xi32>
+ %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d2, d1)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>} %lhs, %rhs, %arg2 : vector<4x4xi32>, vector<4x4xi32> into vector<4x4xi32>
+ return %res : vector<4x4xi32>
+}
+
+// CHECK-LABEL: func.func @matmul_km_nk_mn_4x4xi32
+// CHECK-SAME: ([[ARG0:%.+]]: vector<4x4xi32>, [[ARG1:%.+]]: vector<4x4xi32>, [[ARG2:%.+]]: vector<4x4xi32>)
+// CHECK-NEXT: [[TRANS:%.+]] = vector.transpose [[ARG0]], [1, 0] : vector<4x4xi32> to vector<4x4xi32>
+// CHECK-NEXT: [[RES:%.+]] = vector.contract {{.+}} [[TRANS]], [[ARG1]], [[ARG2]]
+// CHECK-NEXT: return [[RES]]
+func.func @matmul_km_nk_mn_4x4xi32(%arg0: vector<4x4xi32>, %arg1: vector<4x4xi32>, %arg2: vector<4x4xi32>) -> vector<4x4xi32> {
+ %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d0)>,
+ affine_map<(d0, d1, d2) -> (d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>} %arg0, %arg1, %arg2 : vector<4x4xi32>, vector<4x4xi32> into vector<4x4xi32>
+ return %res : vector<4x4xi32>
+}
+
+// CHECK-LABEL: func.func @matmul_km_kn_mn_4x4xi32
+// CHECK-SAME: ([[ARG0:%.+]]: vector<4x4xi32>, [[ARG1:%.+]]: vector<4x4xi32>, [[ARG2:%.+]]: vector<4x4xi32>)
+// CHECK-DAG: [[LHS:%.+]] = vector.transpose [[ARG0]], [1, 0] : vector<4x4xi32> to vector<4x4xi32>
+// CHECK-DAG: [[RHS:%.+]] = vector.transpose [[ARG1]], [1, 0] : vector<4x4xi32> to vector<4x4xi32>
+// CHECK-NEXT: [[RES:%.+]] = vector.contract {{.+}} [[LHS]], [[RHS]], [[ARG2]]
+// CHECK-NEXT: return [[RES]]
+func.func @matmul_km_kn_mn_4x4xi32(%arg0: vector<4x4xi32>, %arg1: vector<4x4xi32>, %arg2: vector<4x4xi32>) -> vector<4x4xi32> {
+ %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d0)>,
+ affine_map<(d0, d1, d2) -> (d2, d1)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>} %arg0, %arg1, %arg2 : vector<4x4xi32>, vector<4x4xi32> into vector<4x4xi32>
+ return %res : vector<4x4xi32>
+}
+
+// CHECK-LABEL: func.func @matmul_km_kn_mn_4x4xi8_mixed_ext_i32
+// CHECK-SAME: ([[ARG0:%.+]]: vector<4x4xi8>, [[ARG1:%.+]]: vector<4x4xi8>, [[ARG2:%.+]]: vector<4x4xi32>)
+// CHECK-DAG: [[LHST:%.+]] = vector.transpose [[ARG0]], [1, 0] : vector<4x4xi8> to vector<4x4xi8>
+// CHECK-DAG: [[LHS:%.+]] = arith.extsi [[LHST]] : vector<4x4xi8> to vector<4x4xi32>
+// CHECK-DAG: [[RHST:%.+]] = vector.transpose [[ARG1]], [1, 0] : vector<4x4xi8> to vector<4x4xi8>
+// CHECK-DAG: [[RHS:%.+]] = arith.extui [[RHST]] : vector<4x4xi8> to vector<4x4xi32>
+// CHECK-NEXT: [[RES:%.+]] = vector.contract {{.+}} [[LHS]], [[RHS]], [[ARG2]]
+// CHECK-NEXT: return [[RES]]
+func.func @matmul_km_kn_mn_4x4xi8_mixed_ext_i32(%arg0: vector<4x4xi8>, %arg1: vector<4x4xi8>, %arg2: vector<4x4xi32>) -> vector<4x4xi32> {
+ %lhs = arith.extsi %arg0 : vector<4x4xi8> to vector<4x4xi32>
+ %rhs = arith.extui %arg1 : vector<4x4xi8> to vector<4x4xi32>
+ %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d0)>,
+ affine_map<(d0, d1, d2) -> (d2, d1)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>} %lhs, %rhs, %arg2 : vector<4x4xi32>, vector<4x4xi32> into vector<4x4xi32>
+ return %res : vector<4x4xi32>
+}
+
+// CHECK-LABEL: func.func @matmul_mk_nk_nm_4x4xi32
+// CHECK-SAME: ([[ARG0:%.+]]: vector<4x4xi32>, [[ARG1:%.+]]: vector<4x4xi32>, [[ARG2:%.+]]: vector<4x4xi32>)
+// CHECK-NEXT: [[RES:%.+]] = vector.contract {{.+}} [[ARG1]], [[ARG0]], [[ARG2]]
+// CHECK-NEXT: return [[RES]]
+func.func @matmul_mk_nk_nm_4x4xi32(%arg0: vector<4x4xi32>, %arg1: vector<4x4xi32>, %arg2: vector<4x4xi32>) -> vector<4x4xi32> {
+ %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d1, d0)>],
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>} %arg0, %arg1, %arg2 : vector<4x4xi32>, vector<4x4xi32> into vector<4x4xi32>
+ return %res : vector<4x4xi32>
+}
+
+// CHECK-LABEL: func.func @matmul_km_kn_nm_4x4xi32
+// CHECK-SAME: ([[ARG0:%.+]]: vector<4x4xi32>, [[ARG1:%.+]]: vector<4x4xi32>, [[ARG2:%.+]]: vector<4x4xi32>)
+// CHECK-DAG: [[LHS:%.+]] = vector.transpose [[ARG0]], [1, 0] : vector<4x4xi32> to vector<4x4xi32>
+// CHECK-DAG: [[RHS:%.+]] = vector.transpose [[ARG1]], [1, 0] : vector<4x4xi32> to vector<4x4xi32>
+// CHECK-NEXT: [[RES:%.+]] = vector.contract {{.+}} [[RHS]], [[LHS]], [[ARG2]]
+// CHECK-NEXT: return [[RES]]
+func.func @matmul_km_kn_nm_4x4xi32(%arg0: vector<4x4xi32>, %arg1: vector<4x4xi32>, %arg2: vector<4x4xi32>) -> vector<4x4xi32> {
+ %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d0)>,
+ affine_map<(d0, d1, d2) -> (d2, d1)>,
+ affine_map<(d0, d1, d2) -> (d1, d0)>],
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>} %arg0, %arg1, %arg2 : vector<4x4xi32>, vector<4x4xi32> into vector<4x4xi32>
+ return %res : vector<4x4xi32>
+}
+
+// CHECK-LABEL: func.func @matmul_mk_kn_nm_4x4xi32
+// CHECK-SAME: ([[ARG0:%.+]]: vector<4x4xi32>, [[ARG1:%.+]]: vector<4x4xi32>, [[ARG2:%.+]]: vector<4x4xi32>)
+// CHECK-DAG: [[RHS:%.+]] = vector.transpose [[ARG1]], [1, 0] : vector<4x4xi32> to vector<4x4xi32>
+// CHECK-NEXT: [[RES:%.+]] = vector.contract {{.+}} [[RHS]], [[ARG0]], [[ARG2]]
+// CHECK-NEXT: return [[RES]]
+func.func @matmul_mk_kn_nm_4x4xi32(%arg0: vector<4x4xi32>, %arg1: vector<4x4xi32>, %arg2: vector<4x4xi32>) -> vector<4x4xi32> {
+ %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d2, d1)>,
+ affine_map<(d0, d1, d2) -> (d1, d0)>],
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>} %arg0, %arg1, %arg2 : vector<4x4xi32>, vector<4x4xi32> into vector<4x4xi32>
+ return %res : vector<4x4xi32>
+}
+
+// CHECK-LABEL: func.func @matmul_km_nk_nm_4x4xi32
+// CHECK-SAME: ([[ARG0:%.+]]: vector<4x4xi32>, [[ARG1:%.+]]: vector<4x4xi32>, [[ARG2:%.+]]: vector<4x4xi32>)
+// CHECK-DAG: [[LHS:%.+]] = vector.transpose [[ARG0]], [1, 0] : vector<4x4xi32> to vector<4x4xi32>
+// CHECK-NEXT: [[RES:%.+]] = vector.contract {{.+}} [[ARG1]], [[LHS]], [[ARG2]]
+// CHECK-NEXT: return [[RES]]
+func.func @matmul_km_nk_nm_4x4xi32(%arg0: vector<4x4xi32>, %arg1: vector<4x4xi32>, %arg2: vector<4x4xi32>) -> vector<4x4xi32> {
+ %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d0)>,
+ affine_map<(d0, d1, d2) -> (d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d1, d0)>],
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>} %arg0, %arg1, %arg2 : vector<4x4xi32>, vector<4x4xi32> into vector<4x4xi32>
+ return %res : vector<4x4xi32>
+}