return newMemRef;
}
+/// Walking from node 'srcId' to node 'dstId' (exclusive of 'srcId' and
+/// 'dstId'), if there is any non-affine operation accessing 'memref', return
+/// false. Otherwise, return true.
+static bool hasNonAffineUsersOnThePath(unsigned srcId, unsigned dstId,
+ Value memref,
+ MemRefDependenceGraph *mdg) {
+ auto *srcNode = mdg->getNode(srcId);
+ auto *dstNode = mdg->getNode(dstId);
+ Value::user_range users = memref.getUsers();
+ // For each MemRefDependenceGraph's node that is between 'srcNode' and
+ // 'dstNode' (exclusive of 'srcNodes' and 'dstNode'), check whether any
+ // non-affine operation in the node accesses the 'memref'.
+ for (auto &idAndNode : mdg->nodes) {
+ Operation *op = idAndNode.second.op;
+ // Take care of operations between 'srcNode' and 'dstNode'.
+ if (srcNode->op->isBeforeInBlock(op) && op->isBeforeInBlock(dstNode->op)) {
+ // Walk inside the operation to find any use of the memref.
+ // Interrupt the walk if found.
+ auto walkResult = op->walk([&](Operation *user) {
+ // Skip affine ops.
+ if (isMemRefDereferencingOp(*user))
+ return WalkResult::advance();
+ // Find a non-affine op that uses the memref.
+ if (llvm::is_contained(users, user))
+ return WalkResult::interrupt();
+ return WalkResult::advance();
+ });
+ if (walkResult.wasInterrupted())
+ return true;
+ }
+ }
+ return false;
+}
+
+/// Check whether a memref value in node 'srcId' has a non-affine that
+/// is between node 'srcId' and node 'dstId' (exclusive of 'srcNode' and
+/// 'dstNode').
+static bool hasNonAffineUsersOnThePath(unsigned srcId, unsigned dstId,
+ MemRefDependenceGraph *mdg) {
+ // Collect memref values in node 'srcId'.
+ auto *srcNode = mdg->getNode(srcId);
+ llvm::SmallDenseSet<Value, 2> memRefValues;
+ srcNode->op->walk([&](Operation *op) {
+ // Skip affine ops.
+ if (isa<AffineForOp>(op))
+ return WalkResult::advance();
+ for (Value v : op->getOperands())
+ // Collect memref values only.
+ if (v.getType().isa<MemRefType>())
+ memRefValues.insert(v);
+ return WalkResult::advance();
+ });
+ // Looking for users between node 'srcId' and node 'dstId'.
+ for (Value memref : memRefValues)
+ if (hasNonAffineUsersOnThePath(srcId, dstId, memref, mdg))
+ return true;
+ return false;
+}
+
// Checks if node 'srcId' can be safely fused into node 'dstId'. Node 'srcId'
// may write to multiple memrefs but it is required that only one of them,
// 'srcLiveOutStoreOp', has output edges.
// TODO(andydavis) Check the shape and lower bounds here too.
if (srcNumElements != dstNumElements)
return false;
+
+ // Return false if 'memref' is used by a non-affine operation that is
+ // between node 'srcId' and node 'dstId'.
+ if (hasNonAffineUsersOnThePath(srcId, dstId, mdg))
+ return false;
+
return true;
}
}
if (storeMemrefs.size() != 1)
return false;
+
+ // Skip if a memref value in one node is used by a non-affine memref
+ // access that lies between 'dstNode' and 'sibNode'.
+ if (hasNonAffineUsersOnThePath(dstNode->id, sibNode->id, mdg) ||
+ hasNonAffineUsersOnThePath(sibNode->id, dstNode->id, mdg))
+ return false;
return true;
};
// CHECK-NEXT: affine.store %{{.*}}, %arg{{.*}}[%arg{{.*}}] : memref<?xf32>
// CHECK-NEXT: }
// CHECK-NEXT: return
+
+// -----
+
+// CHECK-LABEL: func @should_not_fuse_since_non_affine_users
+func @should_not_fuse_since_non_affine_users(%in0 : memref<32xf32>,
+ %in1 : memref<32xf32>) {
+ affine.for %d = 0 to 32 {
+ %lhs = affine.load %in0[%d] : memref<32xf32>
+ %rhs = affine.load %in1[%d] : memref<32xf32>
+ %add = addf %lhs, %rhs : f32
+ affine.store %add, %in0[%d] : memref<32xf32>
+ }
+ affine.for %d = 0 to 32 {
+ %lhs = load %in0[%d] : memref<32xf32>
+ %rhs = load %in1[%d] : memref<32xf32>
+ %add = subf %lhs, %rhs : f32
+ store %add, %in0[%d] : memref<32xf32>
+ }
+ affine.for %d = 0 to 32 {
+ %lhs = affine.load %in0[%d] : memref<32xf32>
+ %rhs = affine.load %in1[%d] : memref<32xf32>
+ %add = mulf %lhs, %rhs : f32
+ affine.store %add, %in0[%d] : memref<32xf32>
+ }
+ return
+}
+
+// CHECK: affine.for
+// CHECK: addf
+// CHECK: affine.for
+// CHECK: subf
+// CHECK: affine.for
+// CHECK: mulf
+
+// -----
+
+// CHECK-LABEL: func @should_not_fuse_since_top_level_non_affine_users
+func @should_not_fuse_since_top_level_non_affine_users(%in0 : memref<32xf32>,
+ %in1 : memref<32xf32>) {
+ %sum = alloc() : memref<f32>
+ affine.for %d = 0 to 32 {
+ %lhs = affine.load %in0[%d] : memref<32xf32>
+ %rhs = affine.load %in1[%d] : memref<32xf32>
+ %add = addf %lhs, %rhs : f32
+ store %add, %sum[] : memref<f32>
+ affine.store %add, %in0[%d] : memref<32xf32>
+ }
+ %load_sum = load %sum[] : memref<f32>
+ affine.for %d = 0 to 32 {
+ %lhs = affine.load %in0[%d] : memref<32xf32>
+ %rhs = affine.load %in1[%d] : memref<32xf32>
+ %add = mulf %lhs, %rhs : f32
+ %sub = subf %add, %load_sum: f32
+ affine.store %sub, %in0[%d] : memref<32xf32>
+ }
+ dealloc %sum : memref<f32>
+ return
+}
+
+// CHECK: affine.for
+// CHECK: addf
+// CHECK: affine.for
+// CHECK: mulf
+// CHECK: subf