From 670455c77d4b2ee3bcf90fb454f62ae69ec47239 Mon Sep 17 00:00:00 2001 From: Thomas Raoux Date: Fri, 19 Jun 2020 17:33:15 -0700 Subject: [PATCH] [mlir][spirv] Legalize subviewop when used with vector transfer Subview operations are not natively supported downstream in the spirv path. This change allows removing subview when used by vector transfer the same way we already do it when they are used by LoadOp/StoreOp Differential Revision: https://reviews.llvm.org/D82106 --- .../StandardToSPIRV/LegalizeStandardForSPIRV.cpp | 88 +++++++++++++++++----- .../Conversion/StandardToSPIRV/legalization.mlir | 34 +++++++++ 2 files changed, 103 insertions(+), 19 deletions(-) diff --git a/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp index 3acd595..0d949f7 100644 --- a/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp @@ -15,28 +15,41 @@ #include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h" #include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/StandardTypes.h" using namespace mlir; namespace { -/// Merges subview operation with load operation. -class LoadOpOfSubViewFolder final : public OpRewritePattern { +/// Merges subview operation with load/transferRead operation. +template +class LoadOpOfSubViewFolder final : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(LoadOp loadOp, + LogicalResult matchAndRewrite(OpTy loadOp, PatternRewriter &rewriter) const override; + +private: + void replaceOp(OpTy loadOp, SubViewOp subViewOp, + ArrayRef sourceIndices, + PatternRewriter &rewriter) const; }; -/// Merges subview operation with store operation. -class StoreOpOfSubViewFolder final : public OpRewritePattern { +/// Merges subview operation with store/transferWriteOp operation. +template +class StoreOpOfSubViewFolder final : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(StoreOp storeOp, + LogicalResult matchAndRewrite(OpTy storeOp, PatternRewriter &rewriter) const override; + +private: + void replaceOp(OpTy StoreOp, SubViewOp subViewOp, + ArrayRef sourceIndices, + PatternRewriter &rewriter) const; }; } // namespace @@ -85,13 +98,14 @@ resolveSourceIndices(Location loc, PatternRewriter &rewriter, } //===----------------------------------------------------------------------===// -// Folding SubViewOp and LoadOp. +// Folding SubViewOp and LoadOp/TransferReadOp. //===----------------------------------------------------------------------===// +template LogicalResult -LoadOpOfSubViewFolder::matchAndRewrite(LoadOp loadOp, - PatternRewriter &rewriter) const { - auto subViewOp = loadOp.memref().getDefiningOp(); +LoadOpOfSubViewFolder::matchAndRewrite(OpTy loadOp, + PatternRewriter &rewriter) const { + auto subViewOp = loadOp.memref().template getDefiningOp(); if (!subViewOp) { return failure(); } @@ -100,19 +114,36 @@ LoadOpOfSubViewFolder::matchAndRewrite(LoadOp loadOp, loadOp.indices(), sourceIndices))) return failure(); + replaceOp(loadOp, subViewOp, sourceIndices, rewriter); + return success(); +} + +template <> +void LoadOpOfSubViewFolder::replaceOp(LoadOp loadOp, + SubViewOp subViewOp, + ArrayRef sourceIndices, + PatternRewriter &rewriter) const { rewriter.replaceOpWithNewOp(loadOp, subViewOp.source(), sourceIndices); - return success(); +} + +template <> +void LoadOpOfSubViewFolder::replaceOp( + vector::TransferReadOp loadOp, SubViewOp subViewOp, + ArrayRef sourceIndices, PatternRewriter &rewriter) const { + rewriter.replaceOpWithNewOp( + loadOp, loadOp.getVectorType(), subViewOp.source(), sourceIndices); } //===----------------------------------------------------------------------===// -// Folding SubViewOp and StoreOp. +// Folding SubViewOp and StoreOp/TransferWriteOp. //===----------------------------------------------------------------------===// +template LogicalResult -StoreOpOfSubViewFolder::matchAndRewrite(StoreOp storeOp, - PatternRewriter &rewriter) const { - auto subViewOp = storeOp.memref().getDefiningOp(); +StoreOpOfSubViewFolder::matchAndRewrite(OpTy storeOp, + PatternRewriter &rewriter) const { + auto subViewOp = storeOp.memref().template getDefiningOp(); if (!subViewOp) { return failure(); } @@ -121,9 +152,25 @@ StoreOpOfSubViewFolder::matchAndRewrite(StoreOp storeOp, storeOp.indices(), sourceIndices))) return failure(); + replaceOp(storeOp, subViewOp, sourceIndices, rewriter); + return success(); +} + +template <> +void StoreOpOfSubViewFolder::replaceOp( + StoreOp storeOp, SubViewOp subViewOp, ArrayRef sourceIndices, + PatternRewriter &rewriter) const { rewriter.replaceOpWithNewOp(storeOp, storeOp.value(), subViewOp.source(), sourceIndices); - return success(); +} + +template <> +void StoreOpOfSubViewFolder::replaceOp( + vector::TransferWriteOp tranferWriteOp, SubViewOp subViewOp, + ArrayRef sourceIndices, PatternRewriter &rewriter) const { + rewriter.replaceOpWithNewOp( + tranferWriteOp, tranferWriteOp.vector(), subViewOp.source(), + sourceIndices); } //===----------------------------------------------------------------------===// @@ -132,7 +179,10 @@ StoreOpOfSubViewFolder::matchAndRewrite(StoreOp storeOp, void mlir::populateStdLegalizationPatternsForSPIRVLowering( MLIRContext *context, OwningRewritePatternList &patterns) { - patterns.insert(context); + patterns.insert, + LoadOpOfSubViewFolder, + StoreOpOfSubViewFolder, + StoreOpOfSubViewFolder>(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/StandardToSPIRV/legalization.mlir b/mlir/test/Conversion/StandardToSPIRV/legalization.mlir index d3b339e..acbda35 100644 --- a/mlir/test/Conversion/StandardToSPIRV/legalization.mlir +++ b/mlir/test/Conversion/StandardToSPIRV/legalization.mlir @@ -62,3 +62,37 @@ func @fold_dynamic_stride_subview_with_store(%arg0 : memref<12x32xf32>, %arg1 : store %arg7, %0[%arg3, %arg4] : memref<4x4xf32, offset:?, strides: [?, ?]> return } + +// CHECK-LABEL: @fold_static_stride_subview_with_transfer_read +// CHECK-SAME: [[ARG0:%.*]]: memref<12x32xf32>, [[ARG1:%.*]]: index, [[ARG2:%.*]]: index, [[ARG3:%.*]]: index, [[ARG4:%.*]]: index +func @fold_static_stride_subview_with_transfer_read(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index) -> vector<4xf32> { + // CHECK-NOT: subview + // CHECK: [[C2:%.*]] = constant 2 : index + // CHECK: [[C3:%.*]] = constant 3 : index + // CHECK: [[STRIDE1:%.*]] = muli [[ARG3]], [[C2]] : index + // CHECK: [[INDEX1:%.*]] = addi [[ARG1]], [[STRIDE1]] : index + // CHECK: [[STRIDE2:%.*]] = muli [[ARG4]], [[C3]] : index + // CHECK: [[INDEX2:%.*]] = addi [[ARG2]], [[STRIDE2]] : index + // CHECK: vector.transfer_read [[ARG0]]{{\[}}[[INDEX1]], [[INDEX2]]{{\]}} + %f0 = constant 0.0 : f32 + %0 = subview %arg0[%arg1, %arg2][4, 4][2, 3] : memref<12x32xf32> to memref<4x4xf32, offset:?, strides: [64, 3]> + %1 = vector.transfer_read %0[%arg3, %arg4], %f0 : memref<4x4xf32, offset:?, strides: [64, 3]>, vector<4xf32> + return %1 : vector<4xf32> +} + +// CHECK-LABEL: @fold_static_stride_subview_with_transfer_write +// CHECK-SAME: [[ARG0:%.*]]: memref<12x32xf32>, [[ARG1:%.*]]: index, [[ARG2:%.*]]: index, [[ARG3:%.*]]: index, [[ARG4:%.*]]: index, [[ARG5:%.*]]: vector<4xf32> +func @fold_static_stride_subview_with_transfer_write(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index, %arg5 : vector<4xf32>) { + // CHECK-NOT: subview + // CHECK: [[C2:%.*]] = constant 2 : index + // CHECK: [[C3:%.*]] = constant 3 : index + // CHECK: [[STRIDE1:%.*]] = muli [[ARG3]], [[C2]] : index + // CHECK: [[INDEX1:%.*]] = addi [[ARG1]], [[STRIDE1]] : index + // CHECK: [[STRIDE2:%.*]] = muli [[ARG4]], [[C3]] : index + // CHECK: [[INDEX2:%.*]] = addi [[ARG2]], [[STRIDE2]] : index + // CHECK: vector.transfer_write [[ARG5]], [[ARG0]]{{\[}}[[INDEX1]], [[INDEX2]]{{\]}} + %0 = subview %arg0[%arg1, %arg2][4, 4][2, 3] : + memref<12x32xf32> to memref<4x4xf32, offset:?, strides: [64, 3]> + vector.transfer_write %arg5, %0[%arg3, %arg4] : vector<4xf32>, memref<4x4xf32, offset:?, strides: [64, 3]> + return +} -- 2.7.4