From: MLIR Team Date: Thu, 28 Mar 2019 21:54:49 +0000 (-0700) Subject: Remove overly conservative check in LoopFusion pass (enables fusion in tutorial example). X-Git-Tag: llvmorg-11-init~1466^2~2095 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=9d9675fc8fa96e78efa17dcc2d6fcc3e773f7a5f;p=platform%2Fupstream%2Fllvm.git Remove overly conservative check in LoopFusion pass (enables fusion in tutorial example). PiperOrigin-RevId: 240859227 --- diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 80308ea..c35b75f 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -1275,7 +1275,8 @@ static bool canFuseSrcWhichWritesToLiveOut(unsigned srcId, unsigned dstId, // Compute MemRefRegion 'dstWriteRegion' for 'dstStoreOpInst' on 'memref'. SmallVector dstStoreOps; dstNode->getStoreOpsForMemref(memref, &dstStoreOps); - assert(dstStoreOps.size() == 1); + // TODO(andydavis) Compute 'unionboundingbox' of all write regions (one for + // each store op in 'dstStoreOps'). auto *dstStoreOpInst = dstStoreOps[0]; MemRefRegion dstWriteRegion(dstStoreOpInst->getLoc()); if (failed(dstWriteRegion.compute(dstStoreOpInst, /*loopDepth=*/0))) { @@ -1886,13 +1887,6 @@ public: if (srcNode->stores.size() != 1) continue; - // Skip 'srcNode' if it has in edges on 'memref'. - // TODO(andydavis) Track dependence type with edges, and just check - // for WAW dependence edge here. Note that this check is overly - // conservative and will be removed in the future. - if (mdg->getIncomingMemRefAccesses(srcNode->id, memref) != 0) - continue; - // Skip if 'srcNode' writes to any live in or escaping memrefs, // and cannot be fused. bool writesToLiveInOrOut = diff --git a/mlir/test/Transforms/loop-fusion.mlir b/mlir/test/Transforms/loop-fusion.mlir index 4d21d00..dd3af06 100644 --- a/mlir/test/Transforms/loop-fusion.mlir +++ b/mlir/test/Transforms/loop-fusion.mlir @@ -342,8 +342,10 @@ func @should_not_fuse_would_create_cycle() { // ----- -// CHECK-LABEL: func @should_not_fuse_across_waw_dep() { -func @should_not_fuse_across_waw_dep() { +// CHECK: [[MAP0:#map[0-9]+]] = (d0, d1) -> (-d0 + d1) + +// CHECK-LABEL: func @should_fuse_producer_consumer() { +func @should_fuse_producer_consumer() { %m = alloc() : memref<10xf32> %cf7 = constant 7.0 : f32 @@ -356,15 +358,20 @@ func @should_not_fuse_across_waw_dep() { affine.for %i2 = 0 to 10 { %v1 = load %m[%i2] : memref<10xf32> } - // Fusing loop %i0 to %i2 would violate the WAW dependence between %i0 and %i1 + // Fusing loop %i0 to %i2 would violate the WAW dependence between %i0 and + // %i1, but OK to fuse %i1 into %i2. + // TODO(andydavis) When the fusion pass is run to a fixed-point, it should + // fuse all three of these loop nests. + // CHECK: %0 = alloc() : memref<1xf32> + // CHECK: %1 = alloc() : memref<10xf32> // CHECK: affine.for %i0 = 0 to 10 { - // CHECK-NEXT: store %cst, %0[%i0] : memref<10xf32> - // CHECK-NEXT: } - // CHECK: affine.for %i1 = 0 to 10 { - // CHECK-NEXT: store %cst, %0[%i1] : memref<10xf32> + // CHECK-NEXT: store %cst, %1[%i0] : memref<10xf32> // CHECK-NEXT: } - // CHECK: affine.for %i2 = 0 to 10 { - // CHECK-NEXT: %1 = load %0[%i2] : memref<10xf32> + // CHECK-NEXT: affine.for %i1 = 0 to 10 { + // CHECK-NEXT: %2 = affine.apply [[MAP0]](%i1, %i1) + // CHECK-NEXT: store %cst, %0[%2] : memref<1xf32> + // CHECK-NEXT: %3 = affine.apply [[MAP0]](%i1, %i1) + // CHECK-NEXT: %4 = load %0[%3] : memref<1xf32> // CHECK-NEXT: } // CHECK-NEXT: return return @@ -2289,3 +2296,47 @@ func @should_fuse_with_slice_union() { // CHECK-NEXT: return return } + +// ----- + +func @affine_add_mm_fused(%arg0: memref<1024x1024xf32>, %arg1: memref<1024x1024xf32>, %arg2: memref<1024x1024xf32>, %arg3: memref<1024x1024xf32>) { + affine.for %i2 = 0 to 1024 { + affine.for %i3 = 0 to 1024 { + %0 = load %arg3[%i2, %i3] : memref<1024x1024xf32> + %1 = load %arg2[%i2, %i3] : memref<1024x1024xf32> + %2 = addf %1, %0 : f32 + store %2, %arg2[%i2, %i3] : memref<1024x1024xf32> + } + } + affine.for %i4 = 0 to 1024 { + affine.for %i5 = 0 to 1024 { + affine.for %i6 = 0 to 1024 { + %3 = load %arg1[%i6, %i5] : memref<1024x1024xf32> + %4 = load %arg0[%i4, %i6] : memref<1024x1024xf32> + %5 = mulf %4, %3 : f32 + %6 = load %arg2[%i4, %i5] : memref<1024x1024xf32> + %7 = addf %6, %5 : f32 + store %7, %arg2[%i4, %i5] : memref<1024x1024xf32> + } + } + } + // Should fuse elementwise add loop at loop depth 2, above loop-carried + // dependence between load/store on '%arg2', carried on reduction loop %i6. + // CHECK: affine.for %i0 = 0 to 1024 { + // CHECK-NEXT: affine.for %i1 = 0 to 1024 { + // CHECK-NEXT: %0 = load %arg3[%i0, %i1] : memref<1024x1024xf32> + // CHECK-NEXT: %1 = load %arg2[%i0, %i1] : memref<1024x1024xf32> + // CHECK-NEXT: %2 = addf %1, %0 : f32 + // CHECK-NEXT: store %2, %arg2[%i0, %i1] : memref<1024x1024xf32> + // CHECK-NEXT: affine.for %i2 = 0 to 1024 { + // CHECK-NEXT: %3 = load %arg1[%i2, %i1] : memref<1024x1024xf32> + // CHECK-NEXT: %4 = load %arg0[%i0, %i2] : memref<1024x1024xf32> + // CHECK-NEXT: %5 = mulf %4, %3 : f32 + // CHECK-NEXT: %6 = load %arg2[%i0, %i1] : memref<1024x1024xf32> + // CHECK-NEXT: %7 = addf %6, %5 : f32 + // CHECK-NEXT: store %7, %arg2[%i0, %i1] : memref<1024x1024xf32> + // CHECK-NEXT: } + // CHECK-NEXT: } + // CHECK-NEXT: } + return +}