[mlir] Enable folding memref alias for`vector.load`
authorGuray Ozen <guray.ozen@gmail.com>
Thu, 25 May 2023 14:10:38 +0000 (16:10 +0200)
committerGuray Ozen <guray.ozen@gmail.com>
Thu, 25 May 2023 15:07:20 +0000 (17:07 +0200)
This work enables  folding memref alias pass for`vector.load`

Reviewed By: qcolombet

Differential Revision: https://reviews.llvm.org/D151447

mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir

index 5916d64..1fee97a 100644 (file)
@@ -173,6 +173,8 @@ static Value getMemRefOperand(nvgpu::LdMatrixOp op) {
   return op.getSrcMemref();
 }
 
+static Value getMemRefOperand(vector::LoadOp op) { return op.getBase(); }
+
 static Value getMemRefOperand(vector::TransferWriteOp op) {
   return op.getSource();
 }
@@ -397,6 +399,10 @@ LogicalResult LoadOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
         rewriter.replaceOpWithNewOp<memref::LoadOp>(
             loadOp, subViewOp.getSource(), sourceIndices, op.getNontemporal());
       })
+      .Case([&](vector::LoadOp op) {
+        rewriter.replaceOpWithNewOp<vector::LoadOp>(
+            op, op.getType(), subViewOp.getSource(), sourceIndices);
+      })
       .Case([&](vector::TransferReadOp op) {
         rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
             op, op.getVectorType(), subViewOp.getSource(), sourceIndices,
@@ -668,6 +674,7 @@ void memref::populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns) {
   patterns.add<LoadOpOfSubViewOpFolder<affine::AffineLoadOp>,
                LoadOpOfSubViewOpFolder<memref::LoadOp>,
                LoadOpOfSubViewOpFolder<nvgpu::LdMatrixOp>,
+               LoadOpOfSubViewOpFolder<vector::LoadOp>,
                LoadOpOfSubViewOpFolder<vector::TransferReadOp>,
                LoadOpOfSubViewOpFolder<gpu::SubgroupMmaLoadMatrixOp>,
                StoreOpOfSubViewOpFolder<affine::AffineStoreOp>,
index 0e9df29..acfddc3 100644 (file)
@@ -621,3 +621,18 @@ func.func @test_ldmatrix(%arg0: memref<4x32x32xf16, 3>, %arg1: index, %arg2: ind
 // CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: index
 // CHECK-SAME:   %[[ARG3:[a-zA-Z0-9_]+]]: index
 //      CHECK:   nvgpu.ldmatrix %[[ARG0]][%[[ARG1]], %[[ARG2]], %[[ARG3]]] {numTiles = 4 : i32, transpose = false} : memref<4x32x32xf16, 3> -> vector<4x2xf16>
+
+// -----
+
+func.func @fold_vector_load(
+  %arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index) -> vector<12x32xf32> {  
+  %0 = memref.subview %arg0[%arg1, %arg2][1, 1][1, 1] : memref<12x32xf32> to memref<f32, strided<[], offset: ?>>
+  %1 = vector.load %0[] : memref<f32, strided<[], offset: ?>>, vector<12x32xf32>
+  return %1 : vector<12x32xf32>
+}
+
+//      CHECK: func @fold_vector_load
+// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: memref<12x32xf32>
+// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: index
+//      CHECK:   vector.load %[[ARG0]][%[[ARG1]], %[[ARG2]]] :  memref<12x32xf32>, vector<12x32xf32>