From 622702d08948cb105cae2fc6fd0d4c3d988e65d4 Mon Sep 17 00:00:00 2001 From: Uday Bondhugula Date: Tue, 7 Mar 2023 08:11:58 +0530 Subject: [PATCH] [MLIR] Fix affine analysis check for ops with no common block Fix affine analysis check for ops with no common block in their affine scope. Clean up some out of date comments and naming. Fixes: https://github.com/llvm/llvm-project/issues/59444 Differential Revision: https://reviews.llvm.org/D145460 --- .../lib/Dialect/Affine/Analysis/AffineAnalysis.cpp | 41 +++++++++++++--------- mlir/test/Dialect/Affine/scalrep.mlir | 20 +++++++++++ 2 files changed, 44 insertions(+), 17 deletions(-) diff --git a/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp b/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp index 6e73014..e702d2f 100644 --- a/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp +++ b/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp @@ -322,9 +322,10 @@ getNumCommonLoops(const FlatAffineValueConstraints &srcDomain, } /// Returns the closest surrounding block common to `opA` and `opB`. `opA` and -/// `opB` should be in the same affine scope and thus such a block is guaranteed -/// to exist. -static Block *getCommonBlock(Operation *opA, Operation *opB) { +/// `opB` should be in the same affine scope. Returns nullptr if such a block +/// does not exist (when the two ops are in different blocks of an op starting +/// an `AffineScope`). +static Block *getCommonBlockInAffineScope(Operation *opA, Operation *opB) { // Get the chain of ancestor blocks for the given `MemRefAccess` instance. The // chain extends up to and includnig an op that starts an affine scope. auto getChainOfAncestorBlocks = @@ -342,7 +343,7 @@ static Block *getCommonBlock(Operation *opA, Operation *opB) { ancestorBlocks.push_back(currBlock); }; - // Find the closest common block including those in AffineIf. + // Find the closest common block. SmallVector srcAncestorBlocks, dstAncestorBlocks; getChainOfAncestorBlocks(opA, srcAncestorBlocks); getChainOfAncestorBlocks(opB, dstAncestorBlocks); @@ -352,28 +353,31 @@ static Block *getCommonBlock(Operation *opA, Operation *opB) { i >= 0 && j >= 0 && srcAncestorBlocks[i] == dstAncestorBlocks[j]; i--, j--) commonBlock = srcAncestorBlocks[i]; - // This is guaranteed since both ops are from the same affine scope. - assert(commonBlock && "ops expected to have a common surrounding block"); + return commonBlock; } /// Returns true if the ancestor operation of 'srcAccess' appears before the /// ancestor operation of 'dstAccess' in their common ancestral block. The /// operations for `srcAccess` and `dstAccess` are expected to be in the same -/// affine scope. +/// affine scope and have a common surrounding block within it. static bool srcAppearsBeforeDstInAncestralBlock(const MemRefAccess &srcAccess, const MemRefAccess &dstAccess) { // Get Block common to 'srcAccess.opInst' and 'dstAccess.opInst'. - auto *commonBlock = getCommonBlock(srcAccess.opInst, dstAccess.opInst); + Block *commonBlock = + getCommonBlockInAffineScope(srcAccess.opInst, dstAccess.opInst); + assert(commonBlock && + "ops expected to have a common surrounding block in affine scope"); + // Check the dominance relationship between the respective ancestors of the // src and dst in the Block of the innermost among the common loops. - auto *srcInst = commonBlock->findAncestorOpInBlock(*srcAccess.opInst); - assert(srcInst && "src access op must lie in common block"); - auto *dstInst = commonBlock->findAncestorOpInBlock(*dstAccess.opInst); - assert(dstInst && "dest access op must lie in common block"); + Operation *srcOp = commonBlock->findAncestorOpInBlock(*srcAccess.opInst); + assert(srcOp && "src access op must lie in common block"); + Operation *dstOp = commonBlock->findAncestorOpInBlock(*dstAccess.opInst); + assert(dstOp && "dest access op must lie in common block"); - // Determine whether dstInst comes after srcInst. - return srcInst->isBeforeInBlock(dstInst); + // Determine whether dstOp comes after srcOp. + return srcOp->isBeforeInBlock(dstOp); } // Adds ordering constraints to 'dependenceDomain' based on number of loops @@ -607,8 +611,8 @@ DependenceResult mlir::checkMemrefAccessDependence( SmallVector *dependenceComponents, bool allowRAR) { LLVM_DEBUG(llvm::dbgs() << "Checking for dependence at depth: " << Twine(loopDepth) << " between:\n";); - LLVM_DEBUG(srcAccess.opInst->dump();); - LLVM_DEBUG(dstAccess.opInst->dump();); + LLVM_DEBUG(srcAccess.opInst->dump()); + LLVM_DEBUG(dstAccess.opInst->dump()); // Return 'NoDependence' if these accesses do not access the same memref. if (srcAccess.memref != dstAccess.memref) @@ -620,9 +624,12 @@ DependenceResult mlir::checkMemrefAccessDependence( !isa(dstAccess.opInst)) return DependenceResult::NoDependence; - // We can't analyze further if the ops lie in different affine scopes. + // We can't analyze further if the ops lie in different affine scopes or have + // no common block in an affine scope. if (getAffineScope(srcAccess.opInst) != getAffineScope(dstAccess.opInst)) return DependenceResult::Failure; + if (!getCommonBlockInAffineScope(srcAccess.opInst, dstAccess.opInst)) + return DependenceResult::Failure; // Create access relation from each MemRefAccess. FlatAffineRelation srcRel, dstRel; diff --git a/mlir/test/Dialect/Affine/scalrep.mlir b/mlir/test/Dialect/Affine/scalrep.mlir index c5862be..64b8534 100644 --- a/mlir/test/Dialect/Affine/scalrep.mlir +++ b/mlir/test/Dialect/Affine/scalrep.mlir @@ -868,3 +868,23 @@ func.func @dead_affine_region_op() { // CHECK-NEXT: affine.load return } + +// We perform no scalar replacement here since we don't depend on dominance +// info, which would be needed in such cases when ops fall in different blocks +// of a CFG region. + +// CHECK-LABEL: func @cross_block +func.func @cross_block() { + %c10 = arith.constant 10 : index + %alloc_83 = memref.alloc() : memref<1x13xf32> + %alloc_99 = memref.alloc() : memref<13xi1> + %true_110 = arith.constant true + affine.store %true_110, %alloc_99[%c10] : memref<13xi1> + %true = arith.constant true + affine.store %true, %alloc_99[%c10] : memref<13xi1> + cf.br ^bb1(%alloc_83 : memref<1x13xf32>) +^bb1(%35: memref<1x13xf32>): + // CHECK: affine.load + %69 = affine.load %alloc_99[%c10] : memref<13xi1> + return +} -- 2.7.4