// RUN: mlir-opt %s -test-vector-contraction-lowering | FileCheck %s
// RUN: mlir-opt %s -test-vector-contraction-lowering=vector-lower-matrix-intrinsics=1 | FileCheck %s --check-prefix=MATRIX
// RUN: mlir-opt %s -test-vector-contraction-lowering=vector-outerproduct=1 | FileCheck %s --check-prefix=OUTERPRODUCT
-// RUN: mlir-opt %s -test-vector-contraction-lowering=vector-filter-outerproduct=1 | FileCheck %s --check-prefix=FILTEROUTERPRODUCT
// RUN: mlir-opt %s -test-vector-contraction-lowering=vector-parallel-arith=1 | FileCheck %s --check-prefix=PARALLEL
#dotp_accesses = [
return %0 : vector<3x2xf32>
}
-// FILTEROUTERPRODUCT-LABEL: func @matmul_4_filtered
-// FILTEROUTERPRODUCT-SAME: %[[A:[a-zA-Z0-9]*]]: vector<4x4xf32>,
-// FILTEROUTERPRODUCT-SAME: %[[B:[a-zA-Z0-9]*]]: vector<4x4xf32>,
-// FILTEROUTERPRODUCT-SAME: %[[C:[a-zA-Z0-9]*]]: vector<4x4xf32>
-// FILTEROUTERPRODUCT: %[[c0:.*]] = vector.contract {{{.*}}} %[[A]], %[[B]], %[[C]]
-func.func @matmul_4_filtered(%arg0: vector<4x4xf32>, %arg1: vector<4x4xf32>, %arg2: vector<4x4xf32>)
--> vector<4x4xf32>
-{
- %0 = vector.contract #matmat_trait_0 %arg0, %arg1, %arg2
- : vector<4x4xf32>, vector<4x4xf32> into vector<4x4xf32>
- return %0 : vector<4x4xf32>
-}
-
-// FILTEROUTERPRODUCT-LABEL: func @matmul_4_not_filtered
-// FILTEROUTERPRODUCT-SAME: %[[A:[a-zA-Z0-9]*]]: vector<3x4xf32>,
-// FILTEROUTERPRODUCT-SAME: %[[B:[a-zA-Z0-9]*]]: vector<4x4xf32>,
-// FILTEROUTERPRODUCT-SAME: %[[C:[a-zA-Z0-9]*]]: vector<3x4xf32>
-// FILTEROUTERPRODUCT: %[[c0:.*]] = vector.contract {{{.*}}} %[[A]], %[[B]], %[[C]]
-func.func @matmul_4_not_filtered(%arg0: vector<3x4xf32>, %arg1: vector<4x4xf32>, %arg2: vector<3x4xf32>)
--> vector<3x4xf32>
-{
- %0 = vector.contract #matmat_trait_0 %arg0, %arg1, %arg2
- : vector<3x4xf32>, vector<4x4xf32> into vector<3x4xf32>
- return %0 : vector<3x4xf32>
-}
-
// PARALLEL-LABEL: func @parrallel_contract_lowering
// PARALLEL: %[[E0:.*]] = vector.extract %{{.*}}[0, 0] : vector<1x1x4xf32>
// PARALLEL: %[[E1:.*]] = vector.extract %{{.*}}[0, 0] : vector<1x1x4xf32>
//
//===----------------------------------------------------------------------===//
-#include <type_traits>
#include <optional>
+#include <type_traits>
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
*this, "vector-outerproduct",
llvm::cl::desc("Lower vector.contract to vector.outerproduct"),
llvm::cl::init(false)};
- Option<bool> lowerToFilterOuterProduct{
- *this, "vector-filter-outerproduct",
- llvm::cl::desc("Lower vector.contract to vector.outerproduct but not for "
- "vectors of size 4."),
- llvm::cl::init(false)};
Option<bool> lowerToParallelArith{
*this, "vector-parallel-arith",
llvm::cl::desc("Lower vector.contract to elementwise vector ops."),
return;
}
- // Test on one pattern in isolation.
- if (lowerToFilterOuterProduct) {
- VectorContractLowering lowering = VectorContractLowering::OuterProduct;
- VectorTransformsOptions options{lowering};
- patterns.add<ContractionOpToOuterProductOpLowering>(
- options, &getContext(), /*benefit=*/1, [](vector::ContractionOp op) {
- // Only lowers vector.contract where the lhs as a type vector<MxNx?>
- // where M is not 4.
- if (op.getRhsType().getShape()[0] == 4)
- return failure();
- return success();
- });
- (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
- return;
- }
-
if (lowerToParallelArith) {
vector::populateVectorContractLoweringPatterns(
patterns,