Refactor / improve replaceAllMemRefUsesWith
authorUday Bondhugula <udayb@iisc.ac.in>
Wed, 28 Aug 2019 00:56:25 +0000 (17:56 -0700)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Wed, 28 Aug 2019 00:56:56 +0000 (17:56 -0700)
Refactor replaceAllMemRefUsesWith to split it into two methods: the new
method does the replacement on a single op, and is used by the existing
one.

- make the methods return LogicalResult instead of bool

- Earlier, when replacement failed (due to non-deferencing uses of the
  memref), the set of ops that had already been processed would have
  been replaced leaving the IR in an inconsistent state. Now, a
  pass is made over all ops to first check for non-deferencing
  uses, and then replacement is performed. No test cases were affected
  because all clients of this method were first checking for
  non-deferencing uses before calling this method (for other reasons).
  This isn't true for a use case in another upcoming PR (scalar
  replacement); clients can now bail out with consistent IR on failure
  of replaceAllMemRefUsesWith. Add test case.

- multiple deferencing uses of the same memref in a single op is
  possible (we have no such use cases/scenarios), and this has always
  remained unsupported. Add an assertion for this.

- minor fix to another test pipeline-data-transfer case.

Signed-off-by: Uday Bondhugula <uday@polymagelabs.com>
Closes tensorflow/mlir#87

PiperOrigin-RevId: 265808183

mlir/include/mlir/Transforms/Utils.h
mlir/lib/Transforms/LoopFusion.cpp
mlir/lib/Transforms/PipelineDataTransfer.cpp
mlir/lib/Transforms/Utils/Utils.cpp
mlir/test/Transforms/pipeline-data-transfer.mlir

index c59d76a..23286af 100644 (file)
@@ -37,26 +37,26 @@ class AffineForOp;
 class Location;
 class OpBuilder;
 
-/// Replaces all "deferencing" uses of oldMemRef with newMemRef while optionally
-/// remapping the old memref's indices using the supplied affine map,
-/// 'indexRemap'. The new memref could be of a different shape or rank.
-/// 'extraIndices' provides additional access indices to be added to the start.
+/// Replaces all "dereferencing" uses of `oldMemRef` with `newMemRef` while
+/// optionally remapping the old memref's indices using the supplied affine map,
+/// `indexRemap`. The new memref could be of a different shape or rank.
+/// `extraIndices` provides additional access indices to be added to the start.
 ///
-/// 'indexRemap' remaps indices of the old memref access to a new set of indices
+/// `indexRemap` remaps indices of the old memref access to a new set of indices
 /// that are used to index the memref. Additional input operands to indexRemap
 /// can be optionally provided, and they are added at the start of its input
-/// list. 'indexRemap' is expected to have only dimensional inputs, and the
+/// list. `indexRemap` is expected to have only dimensional inputs, and the
 /// number of its inputs equal to extraOperands.size() plus rank of the memref.
 /// 'extraOperands' is an optional argument that corresponds to additional
 /// operands (inputs) for indexRemap at the beginning of its input list.
 ///
-/// 'domInstFilter', if non-null, restricts the replacement to only those
+/// `domInstFilter`, if non-null, restricts the replacement to only those
 /// operations that are dominated by the former; similarly, `postDomInstFilter`
 /// restricts replacement to only those operations that are postdominated by it.
 ///
 /// Returns true on success and false if the replacement is not possible,
-/// whenever a memref is used as an operand in a non-deferencing context, except
-/// for dealloc's on the memref which are left untouched. See comments at
+/// whenever a memref is used as an operand in a non-dereferencing context,
+/// except for dealloc's on the memref which are left untouched. See comments at
 /// function definition for an example.
 //
 //  Ex: to replace load %A[%i, %j] with load %Abuf[%t mod 2, %ii - %i, %j]:
@@ -66,12 +66,20 @@ class OpBuilder;
 //  extra operands, note that 'indexRemap' would just be applied to existing
 //  indices (%i, %j).
 //  TODO(bondhugula): allow extraIndices to be added at any position.
-bool replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef,
-                              ArrayRef<Value *> extraIndices = {},
-                              AffineMap indexRemap = AffineMap(),
-                              ArrayRef<Value *> extraOperands = {},
-                              Operation *domInstFilter = nullptr,
-                              Operation *postDomInstFilter = nullptr);
+LogicalResult replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef,
+                                       ArrayRef<Value *> extraIndices = {},
+                                       AffineMap indexRemap = AffineMap(),
+                                       ArrayRef<Value *> extraOperands = {},
+                                       Operation *domInstFilter = nullptr,
+                                       Operation *postDomInstFilter = nullptr);
+
+/// Performs the same replacement as the other version above but only for the
+/// dereferencing uses of `oldMemRef` in `op`.
+LogicalResult replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef,
+                                       Operation *op,
+                                       ArrayRef<Value *> extraIndices = {},
+                                       AffineMap indexRemap = AffineMap(),
+                                       ArrayRef<Value *> extraOperands = {});
 
 /// Creates and inserts into 'builder' a new AffineApplyOp, with the number of
 /// its results equal to the number of operands, as a composition
index 46713dc..a17481f 100644 (file)
@@ -952,12 +952,13 @@ static Value *createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst,
                         ? AffineMap()
                         : b.getAffineMap(outerIVs.size() + rank, 0, remapExprs);
   // Replace all users of 'oldMemRef' with 'newMemRef'.
-  bool ret =
+  LogicalResult res =
       replaceAllMemRefUsesWith(oldMemRef, newMemRef, {}, indexRemap,
                                /*extraOperands=*/outerIVs,
                                /*domInstFilter=*/&*forOp.getBody()->begin());
-  assert(ret && "replaceAllMemrefUsesWith should always succeed here");
-  (void)ret;
+  assert(succeeded(res) &&
+         "replaceAllMemrefUsesWith should always succeed here");
+  (void)res;
   return newMemRef;
 }
 
index 0cd979a..a814af9 100644 (file)
@@ -115,13 +115,14 @@ static bool doubleBuffer(Value *oldMemRef, AffineForOp forOp) {
   auto ivModTwoOp = bInner.create<AffineApplyOp>(forOp.getLoc(), modTwoMap,
                                                  forOp.getInductionVar());
 
-  // replaceAllMemRefUsesWith will always succeed unless the forOp body has
-  // non-deferencing uses of the memref (dealloc's are fine though).
-  if (!replaceAllMemRefUsesWith(oldMemRef, newMemRef,
-                                /*extraIndices=*/{ivModTwoOp},
-                                /*indexRemap=*/AffineMap(),
-                                /*extraOperands=*/{},
-                                /*domInstFilter=*/&*forOp.getBody()->begin())) {
+  // replaceAllMemRefUsesWith will succeed unless the forOp body has
+  // non-dereferencing uses of the memref (dealloc's are fine though).
+  if (failed(replaceAllMemRefUsesWith(
+          oldMemRef, newMemRef,
+          /*extraIndices=*/{ivModTwoOp},
+          /*indexRemap=*/AffineMap(),
+          /*extraOperands=*/{},
+          /*domInstFilter=*/&*forOp.getBody()->begin()))) {
     LLVM_DEBUG(
         forOp.emitError("memref replacement for double buffering failed"));
     ivModTwoOp.erase();
@@ -276,9 +277,9 @@ void PipelineDataTransfer::runOnAffineForOp(AffineForOp forOp) {
     if (!doubleBuffer(oldMemRef, forOp)) {
       // Normally, double buffering should not fail because we already checked
       // that there are no uses outside.
-      LLVM_DEBUG(llvm::dbgs() << "double buffering failed for: \n";);
-      LLVM_DEBUG(dmaStartInst->dump());
-      // IR still in a valid state.
+      LLVM_DEBUG(llvm::dbgs()
+                     << "double buffering failed for" << dmaStartInst << "\n";);
+      // IR still valid and semantically correct.
       return;
     }
     // If the old memref has no more uses, remove its 'dead' alloc if it was
index 8d7b7a8..b0c9b94 100644 (file)
@@ -57,16 +57,181 @@ static NamedAttribute getAffineMapAttrForMemRef(Operation *op, Value *memref) {
   return cast<AffineDmaWaitOp>(op).getAffineMapAttrForMemRef(memref);
 }
 
-bool mlir::replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef,
-                                    ArrayRef<Value *> extraIndices,
-                                    AffineMap indexRemap,
-                                    ArrayRef<Value *> extraOperands,
-                                    Operation *domInstFilter,
-                                    Operation *postDomInstFilter) {
+// Perform the replacement in `op`.
+LogicalResult mlir::replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef,
+                                             Operation *op,
+                                             ArrayRef<Value *> extraIndices,
+                                             AffineMap indexRemap,
+                                             ArrayRef<Value *> extraOperands) {
   unsigned newMemRefRank = newMemRef->getType().cast<MemRefType>().getRank();
   (void)newMemRefRank; // unused in opt mode
   unsigned oldMemRefRank = oldMemRef->getType().cast<MemRefType>().getRank();
-  (void)newMemRefRank;
+  (void)oldMemRefRank;
+  if (indexRemap) {
+    assert(indexRemap.getNumSymbols() == 0 && "pure dimensional map expected");
+    assert(indexRemap.getNumInputs() == extraOperands.size() + oldMemRefRank);
+    assert(indexRemap.getNumResults() + extraIndices.size() == newMemRefRank);
+  } else {
+    assert(oldMemRefRank + extraIndices.size() == newMemRefRank);
+  }
+
+  // Assert same elemental type.
+  assert(oldMemRef->getType().cast<MemRefType>().getElementType() ==
+         newMemRef->getType().cast<MemRefType>().getElementType());
+
+  if (!isMemRefDereferencingOp(*op))
+    // Failure: memref used in a non-dereferencing context (potentially
+    // escapes); no replacement in these cases.
+    return failure();
+
+  SmallVector<unsigned, 2> usePositions;
+  for (const auto &opEntry : llvm::enumerate(op->getOperands())) {
+    if (opEntry.value() == oldMemRef)
+      usePositions.push_back(opEntry.index());
+  }
+
+  // If memref doesn't appear, nothing to do.
+  if (usePositions.empty())
+    return success();
+
+  if (usePositions.size() > 1) {
+    // TODO(mlir-team): extend it for this case when needed (rare).
+    assert(false && "multiple dereferencing uses in a single op not supported");
+    return failure();
+  }
+
+  unsigned memRefOperandPos = usePositions.front();
+
+  OpBuilder builder(op);
+  NamedAttribute oldMapAttrPair = getAffineMapAttrForMemRef(op, oldMemRef);
+  AffineMap oldMap = oldMapAttrPair.second.cast<AffineMapAttr>().getValue();
+  unsigned oldMapNumInputs = oldMap.getNumInputs();
+  SmallVector<Value *, 4> oldMapOperands(
+      op->operand_begin() + memRefOperandPos + 1,
+      op->operand_begin() + memRefOperandPos + 1 + oldMapNumInputs);
+
+  // Apply 'oldMemRefOperands = oldMap(oldMapOperands)'.
+  SmallVector<Value *, 4> oldMemRefOperands;
+  SmallVector<Value *, 4> affineApplyOps;
+  oldMemRefOperands.reserve(oldMemRefRank);
+  if (oldMap != builder.getMultiDimIdentityMap(oldMap.getNumDims())) {
+    for (auto resultExpr : oldMap.getResults()) {
+      auto singleResMap = builder.getAffineMap(
+          oldMap.getNumDims(), oldMap.getNumSymbols(), resultExpr);
+      auto afOp = builder.create<AffineApplyOp>(op->getLoc(), singleResMap,
+                                                oldMapOperands);
+      oldMemRefOperands.push_back(afOp);
+      affineApplyOps.push_back(afOp);
+    }
+  } else {
+    oldMemRefOperands.append(oldMapOperands.begin(), oldMapOperands.end());
+  }
+
+  // Construct new indices as a remap of the old ones if a remapping has been
+  // provided. The indices of a memref come right after it, i.e.,
+  // at position memRefOperandPos + 1.
+  SmallVector<Value *, 4> remapOperands;
+  remapOperands.reserve(extraOperands.size() + oldMemRefRank);
+  remapOperands.append(extraOperands.begin(), extraOperands.end());
+  remapOperands.append(oldMemRefOperands.begin(), oldMemRefOperands.end());
+
+  SmallVector<Value *, 4> remapOutputs;
+  remapOutputs.reserve(oldMemRefRank);
+
+  if (indexRemap &&
+      indexRemap != builder.getMultiDimIdentityMap(indexRemap.getNumDims())) {
+    // Remapped indices.
+    for (auto resultExpr : indexRemap.getResults()) {
+      auto singleResMap = builder.getAffineMap(
+          indexRemap.getNumDims(), indexRemap.getNumSymbols(), resultExpr);
+      auto afOp = builder.create<AffineApplyOp>(op->getLoc(), singleResMap,
+                                                remapOperands);
+      remapOutputs.push_back(afOp);
+      affineApplyOps.push_back(afOp);
+    }
+  } else {
+    // No remapping specified.
+    remapOutputs.append(remapOperands.begin(), remapOperands.end());
+  }
+
+  SmallVector<Value *, 4> newMapOperands;
+  newMapOperands.reserve(newMemRefRank);
+
+  // Prepend 'extraIndices' in 'newMapOperands'.
+  for (auto *extraIndex : extraIndices) {
+    assert(extraIndex->getDefiningOp()->getNumResults() == 1 &&
+           "single result op's expected to generate these indices");
+    assert((isValidDim(extraIndex) || isValidSymbol(extraIndex)) &&
+           "invalid memory op index");
+    newMapOperands.push_back(extraIndex);
+  }
+
+  // Append 'remapOutputs' to 'newMapOperands'.
+  newMapOperands.append(remapOutputs.begin(), remapOutputs.end());
+
+  // Create new fully composed AffineMap for new op to be created.
+  assert(newMapOperands.size() == newMemRefRank);
+  auto newMap = builder.getMultiDimIdentityMap(newMemRefRank);
+  // TODO(b/136262594) Avoid creating/deleting temporary AffineApplyOps here.
+  fullyComposeAffineMapAndOperands(&newMap, &newMapOperands);
+  newMap = simplifyAffineMap(newMap);
+  canonicalizeMapAndOperands(&newMap, &newMapOperands);
+  // Remove any affine.apply's that became dead as a result of composition.
+  for (auto *value : affineApplyOps)
+    if (value->use_empty())
+      value->getDefiningOp()->erase();
+
+  // Construct the new operation using this memref.
+  OperationState state(op->getLoc(), op->getName());
+  state.setOperandListToResizable(op->hasResizableOperandsList());
+  state.operands.reserve(op->getNumOperands() + extraIndices.size());
+  // Insert the non-memref operands.
+  state.operands.append(op->operand_begin(),
+                        op->operand_begin() + memRefOperandPos);
+  // Insert the new memref value.
+  state.operands.push_back(newMemRef);
+
+  // Insert the new memref map operands.
+  state.operands.append(newMapOperands.begin(), newMapOperands.end());
+
+  // Insert the remaining operands unmodified.
+  state.operands.append(op->operand_begin() + memRefOperandPos + 1 +
+                            oldMapNumInputs,
+                        op->operand_end());
+
+  // Result types don't change. Both memref's are of the same elemental type.
+  state.types.reserve(op->getNumResults());
+  for (auto *result : op->getResults())
+    state.types.push_back(result->getType());
+
+  // Add attribute for 'newMap', other Attributes do not change.
+  auto newMapAttr = builder.getAffineMapAttr(newMap);
+  for (auto namedAttr : op->getAttrs()) {
+    if (namedAttr.first == oldMapAttrPair.first) {
+      state.attributes.push_back({namedAttr.first, newMapAttr});
+    } else {
+      state.attributes.push_back(namedAttr);
+    }
+  }
+
+  // Create the new operation.
+  auto *repOp = builder.createOperation(state);
+  op->replaceAllUsesWith(repOp);
+  op->erase();
+
+  return success();
+}
+
+LogicalResult mlir::replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef,
+                                             ArrayRef<Value *> extraIndices,
+                                             AffineMap indexRemap,
+                                             ArrayRef<Value *> extraOperands,
+                                             Operation *domInstFilter,
+                                             Operation *postDomInstFilter) {
+  unsigned newMemRefRank = newMemRef->getType().cast<MemRefType>().getRank();
+  (void)newMemRefRank; // unused in opt mode
+  unsigned oldMemRefRank = oldMemRef->getType().cast<MemRefType>().getRank();
+  (void)oldMemRefRank;
   if (indexRemap) {
     assert(indexRemap.getNumSymbols() == 0 && "pure dimensional map expected");
     assert(indexRemap.getNumInputs() == extraOperands.size() + oldMemRefRank);
@@ -89,170 +254,44 @@ bool mlir::replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef,
     postDomInfo = std::make_unique<PostDominanceInfo>(
         postDomInstFilter->getParentOfType<FuncOp>());
 
-  // The ops where memref replacement succeeds are replaced with new ones.
-  SmallVector<Operation *, 8> opsToErase;
-
-  // Walk all uses of old memref. Operation using the memref gets replaced.
-  for (auto *opInst : llvm::make_early_inc_range(oldMemRef->getUsers())) {
+  // Walk all uses of old memref; collect ops to perform replacement. We use a
+  // DenseSet since an operation could potentially have multiple uses of a
+  // memref (although rare), and the replacement later is going to erase ops.
+  DenseSet<Operation *> opsToReplace;
+  for (auto *op : oldMemRef->getUsers()) {
     // Skip this use if it's not dominated by domInstFilter.
-    if (domInstFilter && !domInfo->dominates(domInstFilter, opInst))
+    if (domInstFilter && !domInfo->dominates(domInstFilter, op))
       continue;
 
     // Skip this use if it's not post-dominated by postDomInstFilter.
-    if (postDomInstFilter &&
-        !postDomInfo->postDominates(postDomInstFilter, opInst))
+    if (postDomInstFilter && !postDomInfo->postDominates(postDomInstFilter, op))
       continue;
 
-    // Skip dealloc's - no replacement is necessary, and a replacement doesn't
-    // hurt dealloc's.
-    if (isa<DeallocOp>(opInst))
+    // Skip dealloc's - no replacement is necessary, and a memref replacement
+    // at other uses doesn't hurt these dealloc's.
+    if (isa<DeallocOp>(op))
       continue;
 
-    // Check if the memref was used in a non-deferencing context. It is fine for
-    // the memref to be used in a non-deferencing way outside of the region
-    // where this replacement is happening.
-    if (!isMemRefDereferencingOp(*opInst))
-      // Failure: memref used in a non-deferencing op (potentially escapes); no
-      // replacement in these cases.
-      return false;
-
-    auto getMemRefOperandPos = [&]() -> unsigned {
-      unsigned i, e;
-      for (i = 0, e = opInst->getNumOperands(); i < e; i++) {
-        if (opInst->getOperand(i) == oldMemRef)
-          break;
-      }
-      assert(i < opInst->getNumOperands() && "operand guaranteed to be found");
-      return i;
-    };
-
-    OpBuilder builder(opInst);
-    unsigned memRefOperandPos = getMemRefOperandPos();
-    NamedAttribute oldMapAttrPair =
-        getAffineMapAttrForMemRef(opInst, oldMemRef);
-    AffineMap oldMap = oldMapAttrPair.second.cast<AffineMapAttr>().getValue();
-    unsigned oldMapNumInputs = oldMap.getNumInputs();
-    SmallVector<Value *, 4> oldMapOperands(
-        opInst->operand_begin() + memRefOperandPos + 1,
-        opInst->operand_begin() + memRefOperandPos + 1 + oldMapNumInputs);
-    SmallVector<Value *, 4> affineApplyOps;
-
-    // Apply 'oldMemRefOperands = oldMap(oldMapOperands)'.
-    SmallVector<Value *, 4> oldMemRefOperands;
-    oldMemRefOperands.reserve(oldMemRefRank);
-    if (oldMap != builder.getMultiDimIdentityMap(oldMap.getNumDims())) {
-      for (auto resultExpr : oldMap.getResults()) {
-        auto singleResMap = builder.getAffineMap(
-            oldMap.getNumDims(), oldMap.getNumSymbols(), resultExpr);
-        auto afOp = builder.create<AffineApplyOp>(opInst->getLoc(),
-                                                  singleResMap, oldMapOperands);
-        oldMemRefOperands.push_back(afOp);
-        affineApplyOps.push_back(afOp);
-      }
-    } else {
-      oldMemRefOperands.append(oldMapOperands.begin(), oldMapOperands.end());
-    }
-
-    // Construct new indices as a remap of the old ones if a remapping has been
-    // provided. The indices of a memref come right after it, i.e.,
-    // at position memRefOperandPos + 1.
-    SmallVector<Value *, 4> remapOperands;
-    remapOperands.reserve(extraOperands.size() + oldMemRefRank);
-    remapOperands.append(extraOperands.begin(), extraOperands.end());
-    remapOperands.append(oldMemRefOperands.begin(), oldMemRefOperands.end());
-
-    SmallVector<Value *, 4> remapOutputs;
-    remapOutputs.reserve(oldMemRefRank);
-
-    if (indexRemap &&
-        indexRemap != builder.getMultiDimIdentityMap(indexRemap.getNumDims())) {
-      // Remapped indices.
-      for (auto resultExpr : indexRemap.getResults()) {
-        auto singleResMap = builder.getAffineMap(
-            indexRemap.getNumDims(), indexRemap.getNumSymbols(), resultExpr);
-        auto afOp = builder.create<AffineApplyOp>(opInst->getLoc(),
-                                                  singleResMap, remapOperands);
-        remapOutputs.push_back(afOp);
-        affineApplyOps.push_back(afOp);
-      }
-    } else {
-      // No remapping specified.
-      remapOutputs.append(remapOperands.begin(), remapOperands.end());
-    }
-
-    SmallVector<Value *, 4> newMapOperands;
-    newMapOperands.reserve(newMemRefRank);
-
-    // Prepend 'extraIndices' in 'newMapOperands'.
-    for (auto *extraIndex : extraIndices) {
-      assert(extraIndex->getDefiningOp()->getNumResults() == 1 &&
-             "single result op's expected to generate these indices");
-      assert((isValidDim(extraIndex) || isValidSymbol(extraIndex)) &&
-             "invalid memory op index");
-      newMapOperands.push_back(extraIndex);
-    }
-
-    // Append 'remapOutputs' to 'newMapOperands'.
-    newMapOperands.append(remapOutputs.begin(), remapOutputs.end());
-
-    // Create new fully composed AffineMap for new op to be created.
-    assert(newMapOperands.size() == newMemRefRank);
-    auto newMap = builder.getMultiDimIdentityMap(newMemRefRank);
-    // TODO(b/136262594) Avoid creating/deleting temporary AffineApplyOps here.
-    fullyComposeAffineMapAndOperands(&newMap, &newMapOperands);
-    newMap = simplifyAffineMap(newMap);
-    canonicalizeMapAndOperands(&newMap, &newMapOperands);
-    // Remove any affine.apply's that became dead as a result of composition.
-    for (auto *value : affineApplyOps)
-      if (value->use_empty())
-        value->getDefiningOp()->erase();
-
-    // Construct the new operation using this memref.
-    OperationState state(opInst->getLoc(), opInst->getName());
-    state.setOperandListToResizable(opInst->hasResizableOperandsList());
-    state.operands.reserve(opInst->getNumOperands() + extraIndices.size());
-    // Insert the non-memref operands.
-    state.operands.append(opInst->operand_begin(),
-                          opInst->operand_begin() + memRefOperandPos);
-    // Insert the new memref value.
-    state.operands.push_back(newMemRef);
-
-    // Insert the new memref map operands.
-    state.operands.append(newMapOperands.begin(), newMapOperands.end());
-
-    // Insert the remaining operands unmodified.
-    state.operands.append(opInst->operand_begin() + memRefOperandPos + 1 +
-                              oldMapNumInputs,
-                          opInst->operand_end());
-
-    // Result types don't change. Both memref's are of the same elemental type.
-    state.types.reserve(opInst->getNumResults());
-    for (auto *result : opInst->getResults())
-      state.types.push_back(result->getType());
-
-    // Add attribute for 'newMap', other Attributes do not change.
-    auto newMapAttr = builder.getAffineMapAttr(newMap);
-    for (auto namedAttr : opInst->getAttrs()) {
-      if (namedAttr.first == oldMapAttrPair.first) {
-        state.attributes.push_back({namedAttr.first, newMapAttr});
-      } else {
-        state.attributes.push_back(namedAttr);
-      }
-    }
-
-    // Create the new operation.
-    auto *repOp = builder.createOperation(state);
-    opInst->replaceAllUsesWith(repOp);
-
-    // Collect and erase at the end since one of these op's could be
-    // domInstFilter or postDomInstFilter as well!
-    opsToErase.push_back(opInst);
+    // Check if the memref was used in a non-dereferencing context. It is fine
+    // for the memref to be used in a non-dereferencing way outside of the
+    // region where this replacement is happening.
+    if (!isMemRefDereferencingOp(*op))
+      // Failure: memref used in a non-dereferencing op (potentially escapes);
+      // no replacement in these cases.
+      return failure();
+
+    // We'll first collect and then replace --- since replacement erases the op
+    // that has the use, and that op could be postDomFilter or domFilter itself!
+    opsToReplace.insert(op);
   }
 
-  for (auto *opInst : opsToErase)
-    opInst->erase();
+  for (auto *op : opsToReplace) {
+    if (failed(replaceAllMemRefUsesWith(oldMemRef, newMemRef, op, extraIndices,
+                                        indexRemap, extraOperands)))
+      assert(false && "memref replacement guaranteed to succeed here");
+  }
 
-  return true;
+  return success();
 }
 
 /// Given an operation, inserts one or more single result affine
index 2b4d143..ce266d5 100644 (file)
@@ -14,7 +14,7 @@ func @loop_nest_dma() {
   %tag = alloc() : memref<1 x f32>
 
   %zero = constant 0 : index
-  %num_elts = constant 128 : index
+  %num_elts = constant 32 : index
 
   affine.for %i = 0 to 8 {
     affine.dma_start %A[%i], %Ah[%i], %tag[%zero], %num_elts : memref<256 x f32>, memref<32 x f32, 1>, memref<1 x f32>
@@ -22,7 +22,7 @@ func @loop_nest_dma() {
     %v = affine.load %Ah[%i] : memref<32 x f32, (d0) -> (d0), 1>
     %r = "compute"(%v) : (f32) -> (f32)
     affine.store %r, %Ah[%i] : memref<32 x f32, (d0) -> (d0), 1>
-    affine.for %j = 0 to 128 {
+    affine.for %j = 0 to 32 {
       "do_more_compute"(%i, %j) : (index, index) -> ()
     }
   }
@@ -41,7 +41,7 @@ func @loop_nest_dma() {
 // CHECK-NEXT:    %{{.*}} = affine.load %{{.*}}[%{{.*}} mod 2, %{{.*}}] : memref<2x32xf32, 1>
 // CHECK-NEXT:    %{{.*}} = "compute"(%{{.*}}) : (f32) -> f32
 // CHECK-NEXT:    affine.store %{{.*}}, %{{.*}}[%{{.*}} mod 2, %{{.*}}] : memref<2x32xf32, 1>
-// CHECK-NEXT:    affine.for %{{.*}} = 0 to 128 {
+// CHECK-NEXT:    affine.for %{{.*}} = 0 to 32 {
 // CHECK-NEXT:      "do_more_compute"(%{{.*}}, %{{.*}}) : (index, index) -> ()
 // CHECK-NEXT:    }
 // CHECK-NEXT:  }
@@ -52,7 +52,7 @@ func @loop_nest_dma() {
 // CHECK-NEXT:  %{{.*}} = affine.load %{{.*}}[%{{.*}} mod 2, %{{.*}}] : memref<2x32xf32, 1>
 // CHECK-NEXT:  %{{.*}} = "compute"(%{{.*}}) : (f32) -> f32
 // CHECK-NEXT:  affine.store %{{.*}}, %{{.*}}[%{{.*}} mod 2, %{{.*}}] : memref<2x32xf32, 1>
-// CHECK-NEXT:  affine.for %{{.*}} = 0 to 128 {
+// CHECK-NEXT:  affine.for %{{.*}} = 0 to 32 {
 // CHECK-NEXT:    "do_more_compute"(%{{.*}}, %{{.*}}) : (index, index) -> ()
 // CHECK-NEXT:  }
 // CHECK-NEXT:  dealloc %{{.*}} : memref<2x1xf32>
@@ -297,3 +297,32 @@ func @dynamic_shape_dma_buffer(%arg0: memref<512 x 32 x f32>) {
 // CHECK:       affine.dma_wait %{{.*}}[%{{.*}} mod 2, symbol(%{{.*}})], %{{.*}} : memref<2x1xi32>
 // CHECK:       return
 }
+
+// Memref replacement will fail here due to a non-dereferencing use. However,
+// no incorrect transformation is performed since replaceAllMemRefUsesWith
+// checks for escaping uses before performing any replacement.
+// CHECK-LABEL: func @escaping_use
+func @escaping_use() {
+  %A = alloc() : memref<256 x f32, (d0) -> (d0), 0>
+  %Ah = alloc() : memref<32 x f32, (d0) -> (d0), 1>
+  %tag = alloc() : memref<1 x f32>
+  %zero = constant 0 : index
+  %num_elts = constant 32 : index
+
+  // alloc for the buffer is created but no replacement should happen.
+  affine.for %i = 0 to 8 {
+    affine.dma_start %A[%i], %Ah[%i], %tag[%zero], %num_elts : memref<256 x f32>, memref<32 x f32, 1>, memref<1 x f32>
+    affine.dma_wait %tag[%zero], %num_elts : memref<1 x f32>
+    "compute"(%Ah) : (memref<32 x f32, 1>) -> ()
+    %v = affine.load %Ah[%i] : memref<32 x f32, (d0) -> (d0), 1>
+    "foo"(%v) : (f32) -> ()
+  }
+  return
+}
+// No replacement
+// CHECK: affine.for %{{.*}} = 0 to 8 {
+// CHECK-NEXT:   affine.dma_start %{{.*}}[%{{.*}}], %{{.*}}[%{{.*}}], %{{.*}}[%{{.*}}], %{{.*}}
+// CHECK-NEXT:   affine.dma_wait %{{.*}}[%{{.*}}], %{{.*}} : memref<1xf32>
+// CHECK-NEXT:   "compute"(%{{.*}}) : (memref<32xf32, 1>) -> ()
+// CHECK-NEXT:   [[VAL:%[0-9]+]] = affine.load %{{.*}}[%{{.*}}] : memref<32xf32, 1>
+// CHECK-NEXT:   "foo"([[VAL]]) : (f32) -> ()