// Compute MemRefRegion 'dstWriteRegion' for 'dstStoreOpInst' on 'memref'.
SmallVector<Operation *, 2> dstStoreOps;
dstNode->getStoreOpsForMemref(memref, &dstStoreOps);
- assert(dstStoreOps.size() == 1);
+ // TODO(andydavis) Compute 'unionboundingbox' of all write regions (one for
+ // each store op in 'dstStoreOps').
auto *dstStoreOpInst = dstStoreOps[0];
MemRefRegion dstWriteRegion(dstStoreOpInst->getLoc());
if (failed(dstWriteRegion.compute(dstStoreOpInst, /*loopDepth=*/0))) {
if (srcNode->stores.size() != 1)
continue;
- // Skip 'srcNode' if it has in edges on 'memref'.
- // TODO(andydavis) Track dependence type with edges, and just check
- // for WAW dependence edge here. Note that this check is overly
- // conservative and will be removed in the future.
- if (mdg->getIncomingMemRefAccesses(srcNode->id, memref) != 0)
- continue;
-
// Skip if 'srcNode' writes to any live in or escaping memrefs,
// and cannot be fused.
bool writesToLiveInOrOut =
// -----
-// CHECK-LABEL: func @should_not_fuse_across_waw_dep() {
-func @should_not_fuse_across_waw_dep() {
+// CHECK: [[MAP0:#map[0-9]+]] = (d0, d1) -> (-d0 + d1)
+
+// CHECK-LABEL: func @should_fuse_producer_consumer() {
+func @should_fuse_producer_consumer() {
%m = alloc() : memref<10xf32>
%cf7 = constant 7.0 : f32
affine.for %i2 = 0 to 10 {
%v1 = load %m[%i2] : memref<10xf32>
}
- // Fusing loop %i0 to %i2 would violate the WAW dependence between %i0 and %i1
+ // Fusing loop %i0 to %i2 would violate the WAW dependence between %i0 and
+ // %i1, but OK to fuse %i1 into %i2.
+ // TODO(andydavis) When the fusion pass is run to a fixed-point, it should
+ // fuse all three of these loop nests.
+ // CHECK: %0 = alloc() : memref<1xf32>
+ // CHECK: %1 = alloc() : memref<10xf32>
// CHECK: affine.for %i0 = 0 to 10 {
- // CHECK-NEXT: store %cst, %0[%i0] : memref<10xf32>
- // CHECK-NEXT: }
- // CHECK: affine.for %i1 = 0 to 10 {
- // CHECK-NEXT: store %cst, %0[%i1] : memref<10xf32>
+ // CHECK-NEXT: store %cst, %1[%i0] : memref<10xf32>
// CHECK-NEXT: }
- // CHECK: affine.for %i2 = 0 to 10 {
- // CHECK-NEXT: %1 = load %0[%i2] : memref<10xf32>
+ // CHECK-NEXT: affine.for %i1 = 0 to 10 {
+ // CHECK-NEXT: %2 = affine.apply [[MAP0]](%i1, %i1)
+ // CHECK-NEXT: store %cst, %0[%2] : memref<1xf32>
+ // CHECK-NEXT: %3 = affine.apply [[MAP0]](%i1, %i1)
+ // CHECK-NEXT: %4 = load %0[%3] : memref<1xf32>
// CHECK-NEXT: }
// CHECK-NEXT: return
return
// CHECK-NEXT: return
return
}
+
+// -----
+
+func @affine_add_mm_fused(%arg0: memref<1024x1024xf32>, %arg1: memref<1024x1024xf32>, %arg2: memref<1024x1024xf32>, %arg3: memref<1024x1024xf32>) {
+ affine.for %i2 = 0 to 1024 {
+ affine.for %i3 = 0 to 1024 {
+ %0 = load %arg3[%i2, %i3] : memref<1024x1024xf32>
+ %1 = load %arg2[%i2, %i3] : memref<1024x1024xf32>
+ %2 = addf %1, %0 : f32
+ store %2, %arg2[%i2, %i3] : memref<1024x1024xf32>
+ }
+ }
+ affine.for %i4 = 0 to 1024 {
+ affine.for %i5 = 0 to 1024 {
+ affine.for %i6 = 0 to 1024 {
+ %3 = load %arg1[%i6, %i5] : memref<1024x1024xf32>
+ %4 = load %arg0[%i4, %i6] : memref<1024x1024xf32>
+ %5 = mulf %4, %3 : f32
+ %6 = load %arg2[%i4, %i5] : memref<1024x1024xf32>
+ %7 = addf %6, %5 : f32
+ store %7, %arg2[%i4, %i5] : memref<1024x1024xf32>
+ }
+ }
+ }
+ // Should fuse elementwise add loop at loop depth 2, above loop-carried
+ // dependence between load/store on '%arg2', carried on reduction loop %i6.
+ // CHECK: affine.for %i0 = 0 to 1024 {
+ // CHECK-NEXT: affine.for %i1 = 0 to 1024 {
+ // CHECK-NEXT: %0 = load %arg3[%i0, %i1] : memref<1024x1024xf32>
+ // CHECK-NEXT: %1 = load %arg2[%i0, %i1] : memref<1024x1024xf32>
+ // CHECK-NEXT: %2 = addf %1, %0 : f32
+ // CHECK-NEXT: store %2, %arg2[%i0, %i1] : memref<1024x1024xf32>
+ // CHECK-NEXT: affine.for %i2 = 0 to 1024 {
+ // CHECK-NEXT: %3 = load %arg1[%i2, %i1] : memref<1024x1024xf32>
+ // CHECK-NEXT: %4 = load %arg0[%i0, %i2] : memref<1024x1024xf32>
+ // CHECK-NEXT: %5 = mulf %4, %3 : f32
+ // CHECK-NEXT: %6 = load %arg2[%i0, %i1] : memref<1024x1024xf32>
+ // CHECK-NEXT: %7 = addf %6, %5 : f32
+ // CHECK-NEXT: store %7, %arg2[%i0, %i1] : memref<1024x1024xf32>
+ // CHECK-NEXT: }
+ // CHECK-NEXT: }
+ // CHECK-NEXT: }
+ return
+}