namespace {
// LoopNestStateCollector walks loop nests and collects load and store
-// operations, and whether or not an IfInst was encountered in the loop nest.
+// operations, and whether or not a region holding op other than ForOp and IfOp
+// was encountered in the loop nest.
struct LoopNestStateCollector {
SmallVector<AffineForOp, 4> forOps;
SmallVector<Operation *, 4> loadOpInsts;
SmallVector<Operation *, 4> storeOpInsts;
- bool hasNonForRegion = false;
+ bool hasNonAffineRegionOp = false;
void collect(Operation *opToWalk) {
opToWalk->walk([&](Operation *op) {
if (isa<AffineForOp>(op))
forOps.push_back(cast<AffineForOp>(op));
- else if (op->getNumRegions() != 0)
- hasNonForRegion = true;
+ else if (op->getNumRegions() != 0 && !isa<AffineIfOp>(op))
+ hasNonAffineRegionOp = true;
else if (isa<AffineReadOpInterface>(op))
loadOpInsts.push_back(op);
else if (isa<AffineWriteOpInterface>(op))
// all loads and store accesses it contains.
LoopNestStateCollector collector;
collector.collect(&op);
- // Return false if a non 'affine.for' region was found (not currently
- // supported).
- if (collector.hasNonForRegion)
+ // Return false if a region holding op other than 'affine.for' and
+ // 'affine.if' was found (not currently supported).
+ if (collector.hasNonAffineRegionOp)
return false;
Node node(nextNodeId++, &op);
for (auto *opInst : collector.loadOpInsts) {
#set0 = affine_set<(d0) : (1 == 0)>
-// CHECK-LABEL: func @should_not_fuse_if_inst_at_top_level() {
-func @should_not_fuse_if_inst_at_top_level() {
+// CHECK-LABEL: func @should_not_fuse_if_op_at_top_level() {
+func @should_not_fuse_if_op_at_top_level() {
%m = memref.alloc() : memref<10xf32>
%cf7 = constant 7.0 : f32
#set0 = affine_set<(d0) : (1 == 0)>
-// CHECK-LABEL: func @should_not_fuse_if_inst_in_loop_nest() {
-func @should_not_fuse_if_inst_in_loop_nest() {
+// CHECK-LABEL: func @should_not_fuse_if_op_in_loop_nest() {
+func @should_not_fuse_if_op_in_loop_nest() {
%m = memref.alloc() : memref<10xf32>
%cf7 = constant 7.0 : f32
%c4 = constant 4 : index
%v0 = affine.load %m[%i1] : memref<10xf32>
}
- // IfOp in ForInst should prevent fusion.
+ // IfOp in ForOp should prevent fusion.
// CHECK: affine.for %{{.*}} = 0 to 10 {
// CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32>
// CHECK-NEXT: }
// -----
+#set = affine_set<(d0) : (d0 - 1 >= 0)>
+
+// CHECK-LABEL: func @should_fuse_if_op_in_loop_nest_not_sandwiched() -> memref<10xf32> {
+func @should_fuse_if_op_in_loop_nest_not_sandwiched() -> memref<10xf32> {
+ %a = memref.alloc() : memref<10xf32>
+ %b = memref.alloc() : memref<10xf32>
+ %cf7 = constant 7.0 : f32
+
+ affine.for %i0 = 0 to 10 {
+ affine.store %cf7, %a[%i0] : memref<10xf32>
+ }
+ affine.for %i1 = 0 to 10 {
+ %v0 = affine.load %a[%i1] : memref<10xf32>
+ affine.store %v0, %b[%i1] : memref<10xf32>
+ }
+ affine.for %i2 = 0 to 10 {
+ affine.if #set(%i2) {
+ %v0 = affine.load %b[%i2] : memref<10xf32>
+ }
+ }
+
+ // IfOp in ForOp should not prevent fusion if it does not in between the
+ // source and dest ForOp ops.
+
+ // CHECK: affine.for
+ // CHECK-NEXT: affine.store
+ // CHECK-NEXT: affine.load
+ // CHECK-NEXT: affine.store
+ // CHECK: affine.for
+ // CHECK-NEXT: affine.if
+ // CHECK-NEXT: affine.load
+ // CHECK-NOT: affine.for
+ // CHECK: return
+
+ return %a : memref<10xf32>
+}
+
+// -----
+
+#set = affine_set<(d0) : (d0 - 1 >= 0)>
+
+// CHECK-LABEL: func @should_not_fuse_if_op_in_loop_nest_between_src_and_dest() -> memref<10xf32> {
+func @should_not_fuse_if_op_in_loop_nest_between_src_and_dest() -> memref<10xf32> {
+ %a = memref.alloc() : memref<10xf32>
+ %b = memref.alloc() : memref<10xf32>
+ %cf7 = constant 7.0 : f32
+
+ affine.for %i0 = 0 to 10 {
+ affine.store %cf7, %a[%i0] : memref<10xf32>
+ }
+ affine.for %i1 = 0 to 10 {
+ affine.if #set(%i1) {
+ affine.store %cf7, %a[%i1] : memref<10xf32>
+ }
+ }
+ affine.for %i3 = 0 to 10 {
+ %v0 = affine.load %a[%i3] : memref<10xf32>
+ affine.store %v0, %b[%i3] : memref<10xf32>
+ }
+ return %b : memref<10xf32>
+
+ // IfOp in ForOp which modifies the memref should prevent fusion if it is in
+ // between the source and dest ForOp.
+
+ // CHECK: affine.for
+ // CHECK-NEXT: affine.store
+ // CHECK: affine.for
+ // CHECK-NEXT: affine.if
+ // CHECK-NEXT: affine.store
+ // CHECK: affine.for
+ // CHECK-NEXT: affine.load
+ // CHECK-NEXT: affine.store
+ // CHECK: return
+}
+
+// -----
+
// CHECK-LABEL: func @permute_and_fuse() {
func @permute_and_fuse() {
%m = memref.alloc() : memref<10x20x30xf32>