From fc5c1a767658314ff30e1af6bf52a956d96b1b04 Mon Sep 17 00:00:00 2001 From: Manish Gupta Date: Wed, 12 Apr 2023 01:00:58 +0000 Subject: [PATCH] [mlir][Memref] Fold nvgpu device cp.async on src memref to dst memref Differential Revision: https://reviews.llvm.org/D148161 --- .../MemRef/Transforms/FoldMemRefAliasOps.cpp | 72 +++++++++++++++++++++- .../test/Dialect/MemRef/fold-memref-alias-ops.mlir | 58 +++++++++++++++++ 2 files changed, 129 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp index a43184e..d99cf91 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp @@ -19,6 +19,7 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/Transforms/Passes.h" #include "mlir/Dialect/MemRef/Transforms/Transforms.h" +#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/AffineMap.h" @@ -26,6 +27,10 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallBitVector.h" #include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "fold-memref-alias-ops" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") namespace mlir { namespace memref { @@ -283,6 +288,17 @@ public: return success(); } }; + +/// Folds nvgpu.device_async_copy subviews into the copy itself. This pattern +/// is folds subview on src and dst memref of the copy. +class NvgpuAsyncCopyOpSubViewOpFolder final + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(nvgpu::DeviceAsyncCopyOp copyOp, + PatternRewriter &rewriter) const override; +}; } // namespace static SmallVector @@ -580,6 +596,59 @@ LogicalResult StoreOpOfCollapseShapeOpFolder::matchAndRewrite( return success(); } +LogicalResult NvgpuAsyncCopyOpSubViewOpFolder::matchAndRewrite( + nvgpu::DeviceAsyncCopyOp copyOp, PatternRewriter &rewriter) const { + + LLVM_DEBUG(DBGS() << "copyOp : " << copyOp << "\n"); + + auto srcSubViewOp = + copyOp.getSrc().template getDefiningOp(); + auto dstSubViewOp = + copyOp.getDst().template getDefiningOp(); + + if (!(srcSubViewOp || dstSubViewOp)) + return rewriter.notifyMatchFailure(copyOp, "does not use subview ops for " + "source or destination"); + + // If the source is a subview, we need to resolve the indices. + SmallVector srcindices(copyOp.getSrcIndices().begin(), + copyOp.getSrcIndices().end()); + SmallVector foldedSrcIndices(srcindices); + + if (srcSubViewOp) { + LLVM_DEBUG(DBGS() << "srcSubViewOp : " << srcSubViewOp << "\n"); + resolveIndicesIntoOpWithOffsetsAndStrides( + rewriter, copyOp.getLoc(), srcSubViewOp.getMixedOffsets(), + srcSubViewOp.getMixedStrides(), srcSubViewOp.getDroppedDims(), + srcindices, foldedSrcIndices); + } + + // If the destination is a subview, we need to resolve the indices. + SmallVector dstindices(copyOp.getDstIndices().begin(), + copyOp.getDstIndices().end()); + SmallVector foldedDstIndices(dstindices); + + if (dstSubViewOp) { + LLVM_DEBUG(DBGS() << "dstSubViewOp : " << dstSubViewOp << "\n"); + resolveIndicesIntoOpWithOffsetsAndStrides( + rewriter, copyOp.getLoc(), dstSubViewOp.getMixedOffsets(), + dstSubViewOp.getMixedStrides(), dstSubViewOp.getDroppedDims(), + dstindices, foldedDstIndices); + } + + // Replace the copy op with a new copy op that uses the source and destination + // of the subview. + rewriter.replaceOpWithNewOp( + copyOp, nvgpu::DeviceAsyncTokenType::get(copyOp.getContext()), + (dstSubViewOp ? dstSubViewOp.getSource() : copyOp.getDst()), + foldedDstIndices, + (srcSubViewOp ? srcSubViewOp.getSource() : copyOp.getSrc()), + foldedSrcIndices, copyOp.getDstElements(), copyOp.getSrcElements(), + copyOp.getBypassL1Attr()); + + return success(); +} + void memref::populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns) { patterns.add, LoadOpOfSubViewOpFolder, @@ -597,7 +666,8 @@ void memref::populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns) { LoadOpOfCollapseShapeOpFolder, StoreOpOfCollapseShapeOpFolder, StoreOpOfCollapseShapeOpFolder, - SubViewOfSubViewFolder>(patterns.getContext()); + SubViewOfSubViewFolder, NvgpuAsyncCopyOpSubViewOpFolder>( + patterns.getContext()); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir index a29f86e..93e8a20 100644 --- a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir +++ b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir @@ -541,3 +541,61 @@ func.func @fold_gpu_subgroup_mma_load_matrix_2d(%arg0 : memref<128x128xf32>, %ar gpu.subgroup_mma_store_matrix %matrix, %subview[%arg3, %arg4] {leadDimension = 32 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<64x32xf32, strided<[64, 1], offset: ?>> return } + +// ----- + + +func.func @fold_nvgpu_device_async_copy_zero_sub_idx(%gmem_memref_3d : memref<2x128x768xf16>, %idx_1 : index, %idx_2 : index, %idx_3 : index) { + + %c0 = arith.constant 0 : index + %smem_memref_4d = memref.alloc() : memref<5x1x64x64xf16, #gpu.address_space> + %gmem_memref_subview_2d = memref.subview %gmem_memref_3d[%idx_1, %idx_2, %idx_3] [1, 1, 8] [1, 1, 1] : memref<2x128x768xf16> to memref<1x8xf16, strided<[98304, 1], offset: ?>> + %async_token = nvgpu.device_async_copy %gmem_memref_subview_2d[%c0, %c0], %smem_memref_4d[%c0, %c0, %c0, %c0], 8 {bypassL1} : memref<1x8xf16, strided<[98304, 1], offset: ?>> to memref<5x1x64x64xf16, #gpu.address_space> + return +} + +// CHECK-LABEL: func.func @fold_nvgpu_device_async_copy_zero_sub_idx +// CHECK-SAME: (%[[GMEM_MEMREF_3d:.+]]: memref<2x128x768xf16>, %[[IDX_1:.+]]: index, %[[IDX_2:.+]]: index, %[[IDX_3:.+]]: index) +// CHECK-DAG: %[[c0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[SMEM_MEMREF_4d:.+]] = memref.alloc() : memref<5x1x64x64xf16, #gpu.address_space> +// CHECK: nvgpu.device_async_copy %[[GMEM_MEMREF_3d]][%[[IDX_1]], %[[IDX_2]], %[[IDX_3]]], %[[SMEM_MEMREF_4d]][%[[c0]], %[[c0]], %[[c0]], %[[c0]]], 8 {bypassL1} : memref<2x128x768xf16> to memref<5x1x64x64xf16, #gpu.address_space> + +// ----- + + +func.func @fold_src_nvgpu_device_async_copy(%gmem_memref_3d : memref<2x128x768xf16>, %src_idx_0 : index, %src_idx_1 : index, %src_idx_2 : index, %src_sub_idx_0 : index, %src_sub_idx_1 : index) { + %c0 = arith.constant 0 : index + %smem_memref_4d = memref.alloc() : memref<5x1x64x64xf16, #gpu.address_space> + %gmem_memref_subview_2d = memref.subview %gmem_memref_3d[%src_idx_0, %src_idx_1, %src_idx_2] [1, 1, 8] [1, 1, 1] : memref<2x128x768xf16> to memref<1x8xf16, strided<[98304, 1], offset: ?>> + %async_token = nvgpu.device_async_copy %gmem_memref_subview_2d[%src_sub_idx_0, %src_sub_idx_1], %smem_memref_4d[%c0, %c0, %c0, %c0], 8 {bypassL1} : memref<1x8xf16, strided<[98304, 1], offset: ?>> to memref<5x1x64x64xf16, #gpu.address_space> + return +} + +// CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1)> +// CHECK: func.func @fold_src_nvgpu_device_async_copy +// CHECK-SAME: (%[[GMEM_MEMREF_3d:.+]]: memref<2x128x768xf16>, %[[SRC_IDX_0:.+]]: index, %[[SRC_IDX_1:.+]]: index, %[[SRC_IDX_2:.+]]: index, %[[SRC_SUB_IDX_0:.+]]: index, %[[SRC_SUB_IDX_1:.+]]: index) +// CHECK-DAG: %[[c0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[RESOLVED_SRC_IDX_0:.+]] = affine.apply #[[MAP]]()[%[[SRC_IDX_0]], %[[SRC_SUB_IDX_0]]] +// CHECK-DAG: %[[RESOLVED_SRC_IDX_1:.+]] = affine.apply #[[MAP]]()[%[[SRC_IDX_2]], %[[SRC_SUB_IDX_1]]] +// CHECK-DAG: nvgpu.device_async_copy %[[GMEM_MEMREF_3d]][%[[RESOLVED_SRC_IDX_0]], %[[SRC_IDX_1]], %[[RESOLVED_SRC_IDX_1]]], %[[SMEM_MEMREF_4d]][%[[c0]], %[[c0]], %[[c0]], %[[c0]]], 8 {bypassL1} : memref<2x128x768xf16> to memref<5x1x64x64xf16, #gpu.address_space> + +// ----- + + +func.func @fold_src_fold_dest_nvgpu_device_async_copy(%gmem_memref_3d : memref<2x128x768xf16>, %src_idx_0 : index, %src_idx_1 : index, %src_idx_2 : index, %src_sub_idx_0 : index, %src_sub_idx_1 : index, %dest_idx_0 : index, %dest_idx_1 : index, %dest_idx_2 : index, %dest_idx_3 : index, %dest_sub_idx_0 : index, %dest_sub_idx_1 : index) { + %c0 = arith.constant 0 : index + %smem_memref_4d = memref.alloc() : memref<5x1x64x64xf16, #gpu.address_space> + %gmem_memref_subview_2d = memref.subview %gmem_memref_3d[%src_idx_0, %src_idx_1, %src_idx_2] [1, 1, 8] [1, 1, 1] : memref<2x128x768xf16> to memref<1x8xf16, strided<[98304, 1], offset: ?>> + %smem_memref_2d = memref.subview %smem_memref_4d[%dest_idx_0, %dest_idx_1, %dest_idx_2, %dest_idx_3] [1, 1, 1, 8] [1, 1, 1, 1] : memref<5x1x64x64xf16, #gpu.address_space> to memref<1x8xf16, strided<[4096, 1], offset: ?>, #gpu.address_space> + %async_token = nvgpu.device_async_copy %gmem_memref_subview_2d[%src_sub_idx_0, %src_sub_idx_1], %smem_memref_2d[%dest_sub_idx_0, %dest_sub_idx_1], 8 {bypassL1} : memref<1x8xf16, strided<[98304, 1], offset: ?>> to memref<1x8xf16, strided<[4096, 1], offset: ?>, #gpu.address_space> + return +} + +// CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1)> +// CHECK: func.func @fold_src_fold_dest_nvgpu_device_async_copy +// CHECK-SAME: (%[[GMEM_MEMREF_3d:.+]]: memref<2x128x768xf16>, %[[SRC_IDX_0:.+]]: index, %[[SRC_IDX_1:.+]]: index, %[[SRC_IDX_2:.+]]: index, %[[SRC_SUB_IDX_0:.+]]: index, %[[SRC_SUB_IDX_1:.+]]: index, %[[DEST_IDX_0:.+]]: index, %[[DEST_IDX_1:.+]]: index, %[[DEST_IDX_2:.+]]: index, %[[DEST_IDX_3:.+]]: index, %[[DEST_SUB_IDX_0:.+]]: index, %[[DEST_SUB_IDX_1:.+]]: index) +// CHECK-DAG: %[[RESOLVED_SRC_IDX_0:.+]] = affine.apply #[[MAP]]()[%[[SRC_IDX_0]], %[[SRC_SUB_IDX_0]]] +// CHECK-DAG: %[[RESOLVED_SRC_IDX_1:.+]] = affine.apply #[[MAP]]()[%[[SRC_IDX_2]], %[[SRC_SUB_IDX_1]]] +// CHECK-DAG: %[[RESOLVED_DST_IDX_1:.+]] = affine.apply #[[MAP]]()[%[[DEST_IDX_1]], %[[DEST_SUB_IDX_0]]] +// CHECK-DAG: %[[RESOLVED_DST_IDX_3:.+]] = affine.apply #[[MAP]]()[%[[DEST_IDX_3]], %[[DEST_SUB_IDX_1]]] +// CHECK-DAG: nvgpu.device_async_copy %[[GMEM_MEMREF_3d]][%[[RESOLVED_SRC_IDX_0]], %[[SRC_IDX_1]], %[[RESOLVED_SRC_IDX_1]]], %[[SMEM_MEMREF_4d]][%[[DEST_IDX_0]], %[[RESOLVED_DST_IDX_1]], %[[DEST_IDX_2]], %[[RESOLVED_DST_IDX_3]]], 8 {bypassL1} : memref<2x128x768xf16> to memref<5x1x64x64xf16, #gpu.address_space> -- 2.7.4