From 469c02d0581a4bd7539c7dd62063c29072b55852 Mon Sep 17 00:00:00 2001 From: River Riddle Date: Mon, 4 May 2020 19:54:36 -0700 Subject: [PATCH] [mlir] Add support for merging identical blocks during canonicalization This revision adds support for merging identical blocks, or those with the same operations that branch to the same successors. Operands that mismatch between the different blocks are replaced with new block arguments added to the merged block. Differential Revision: https://reviews.llvm.org/D79134 --- llvm/include/llvm/ADT/STLExtras.h | 3 + mlir/include/mlir/IR/BlockSupport.h | 23 ++ mlir/include/mlir/IR/Operation.h | 8 + mlir/include/mlir/IR/OperationSupport.h | 33 ++- mlir/include/mlir/IR/Value.h | 6 + mlir/lib/IR/OperationSupport.cpp | 18 +- mlir/lib/IR/Value.cpp | 14 + mlir/lib/Transforms/Utils/RegionUtils.cpp | 326 ++++++++++++++++++++- mlir/test/Dialect/SPIRV/canonicalize.mlir | 7 +- mlir/test/Transforms/canonicalize-block-merge.mlir | 204 +++++++++++++ mlir/test/Transforms/canonicalize-dce.mlir | 4 - mlir/test/Transforms/canonicalize.mlir | 8 +- 12 files changed, 632 insertions(+), 22 deletions(-) create mode 100644 mlir/test/Transforms/canonicalize-block-merge.mlir diff --git a/llvm/include/llvm/ADT/STLExtras.h b/llvm/include/llvm/ADT/STLExtras.h index 30bcdf5..71ad4fc 100644 --- a/llvm/include/llvm/ADT/STLExtras.h +++ b/llvm/include/llvm/ADT/STLExtras.h @@ -1174,6 +1174,9 @@ public: return RangeT(iterator_range(*this)); } + /// Returns the base of this range. + const BaseT &getBase() const { return base; } + private: /// Offset the given base by the given amount. static BaseT offset_base(const BaseT &base, size_t n) { diff --git a/mlir/include/mlir/IR/BlockSupport.h b/mlir/include/mlir/IR/BlockSupport.h index 10b8c48..f3dd614 100644 --- a/mlir/include/mlir/IR/BlockSupport.h +++ b/mlir/include/mlir/IR/BlockSupport.h @@ -119,6 +119,29 @@ public: namespace llvm { +/// Provide support for hashing successor ranges. +template <> +struct DenseMapInfo { + static mlir::SuccessorRange getEmptyKey() { + auto *pointer = llvm::DenseMapInfo::getEmptyKey(); + return mlir::SuccessorRange(pointer, 0); + } + static mlir::SuccessorRange getTombstoneKey() { + auto *pointer = llvm::DenseMapInfo::getTombstoneKey(); + return mlir::SuccessorRange(pointer, 0); + } + static unsigned getHashValue(mlir::SuccessorRange value) { + return llvm::hash_combine_range(value.begin(), value.end()); + } + static bool isEqual(mlir::SuccessorRange lhs, mlir::SuccessorRange rhs) { + if (rhs.getBase() == getEmptyKey().getBase()) + return lhs.getBase() == getEmptyKey().getBase(); + if (rhs.getBase() == getTombstoneKey().getBase()) + return lhs.getBase() == getTombstoneKey().getBase(); + return lhs == rhs; + } +}; + //===----------------------------------------------------------------------===// // ilist_traits for Operation //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h index 5c94081..fcde73e 100644 --- a/mlir/include/mlir/IR/Operation.h +++ b/mlir/include/mlir/IR/Operation.h @@ -554,6 +554,14 @@ public: [](OpResult result) { return result.use_empty(); }); } + /// Returns true if the results of this operation are used outside of the + /// given block. + bool isUsedOutsideOfBlock(Block *block) { + return llvm::any_of(getOpResults(), [block](OpResult result) { + return result.isUsedOutsideOfBlock(block); + }); + } + //===--------------------------------------------------------------------===// // Users //===--------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h index edfe89a..8c0a3f1 100644 --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -20,6 +20,7 @@ #include "mlir/IR/Types.h" #include "mlir/IR/Value.h" #include "mlir/Support/LogicalResult.h" +#include "llvm/ADT/BitmaskEnum.h" #include "llvm/ADT/PointerUnion.h" #include "llvm/Support/PointerLikeTypeTraits.h" #include "llvm/Support/TrailingObjects.h" @@ -617,6 +618,17 @@ public: ValueTypeIterator>::iterator_range; template ValueTypeRange(Container &&c) : ValueTypeRange(c.begin(), c.end()) {} + + /// Compare this range with another. + template + bool operator==(const OtherT &other) const { + return llvm::size(*this) == llvm::size(other) && + std::equal(this->begin(), this->end(), other.begin()); + } + template + bool operator!=(const OtherT &other) const { + return !(*this == other); + } }; template @@ -829,12 +841,29 @@ private: /// This class provides utilities for computing if two operations are /// equivalent. struct OperationEquivalence { + enum Flags { + None = 0, + + /// This flag signals that operands should not be considered when checking + /// for equivalence. This allows for users to implement there own + /// equivalence schemes for operand values. The number of operands are still + /// checked, just not the operands themselves. + IgnoreOperands = 1, + + LLVM_MARK_AS_BITMASK_ENUM(/* LargestValue = */ IgnoreOperands) + }; + /// Compute a hash for the given operation. - static llvm::hash_code computeHash(Operation *op); + static llvm::hash_code computeHash(Operation *op, Flags flags = Flags::None); /// Compare two operations and return if they are equivalent. - static bool isEquivalentTo(Operation *lhs, Operation *rhs); + static bool isEquivalentTo(Operation *lhs, Operation *rhs, + Flags flags = Flags::None); }; + +/// Enable Bitmask enums for OperationEquivalence::Flags. +LLVM_ENABLE_BITMASK_ENUMS_IN_NAMESPACE(); + } // end namespace mlir namespace llvm { diff --git a/mlir/include/mlir/IR/Value.h b/mlir/include/mlir/IR/Value.h index 95def76..7851730 100644 --- a/mlir/include/mlir/IR/Value.h +++ b/mlir/include/mlir/IR/Value.h @@ -123,6 +123,9 @@ public: /// Return the Region in which this Value is defined. Region *getParentRegion(); + /// Return the Block in which this Value is defined. + Block *getParentBlock(); + //===--------------------------------------------------------------------===// // UseLists //===--------------------------------------------------------------------===// @@ -150,6 +153,9 @@ public: void replaceUsesWithIf(Value newValue, function_ref shouldReplace); + /// Returns true if the value is used outside of the given block. + bool isUsedOutsideOfBlock(Block *block); + //===--------------------------------------------------------------------===// // Uses diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp index a087623..91842cf 100644 --- a/mlir/lib/IR/OperationSupport.cpp +++ b/mlir/lib/IR/OperationSupport.cpp @@ -412,7 +412,7 @@ Value ValueRange::dereference_iterator(const OwnerT &owner, ptrdiff_t index) { // Operation Equivalency //===----------------------------------------------------------------------===// -llvm::hash_code OperationEquivalence::computeHash(Operation *op) { +llvm::hash_code OperationEquivalence::computeHash(Operation *op, Flags flags) { // Hash operations based upon their: // - Operation Name // - Attributes @@ -438,12 +438,17 @@ llvm::hash_code OperationEquivalence::computeHash(Operation *op) { } // - Operands - // TODO: Allow commutative operations to have different ordering. - return llvm::hash_combine( - hash, llvm::hash_combine_range(op->operand_begin(), op->operand_end())); + bool ignoreOperands = flags & Flags::IgnoreOperands; + if (!ignoreOperands) { + // TODO: Allow commutative operations to have different ordering. + hash = llvm::hash_combine( + hash, llvm::hash_combine_range(op->operand_begin(), op->operand_end())); + } + return hash; } -bool OperationEquivalence::isEquivalentTo(Operation *lhs, Operation *rhs) { +bool OperationEquivalence::isEquivalentTo(Operation *lhs, Operation *rhs, + Flags flags) { if (lhs == rhs) return true; @@ -478,6 +483,9 @@ bool OperationEquivalence::isEquivalentTo(Operation *lhs, Operation *rhs) { break; } // Compare operands. + bool ignoreOperands = flags & Flags::IgnoreOperands; + if (ignoreOperands) + return true; // TODO: Allow commutative operations to have different ordering. return std::equal(lhs->operand_begin(), lhs->operand_end(), rhs->operand_begin()); diff --git a/mlir/lib/IR/Value.cpp b/mlir/lib/IR/Value.cpp index fdc5ad6..6467a7f 100644 --- a/mlir/lib/IR/Value.cpp +++ b/mlir/lib/IR/Value.cpp @@ -87,6 +87,13 @@ Region *Value::getParentRegion() { return cast().getOwner()->getParent(); } +/// Return the Block in which this Value is defined. +Block *Value::getParentBlock() { + if (Operation *op = getDefiningOp()) + return op->getBlock(); + return cast().getOwner(); +} + //===----------------------------------------------------------------------===// // Value::UseLists //===----------------------------------------------------------------------===// @@ -134,6 +141,13 @@ void Value::replaceUsesWithIf(Value newValue, use.set(newValue); } +/// Returns true if the value is used outside of the given block. +bool Value::isUsedOutsideOfBlock(Block *block) { + return llvm::any_of(getUsers(), [block](Operation *user) { + return user->getBlock() != block; + }); +} + //===--------------------------------------------------------------------===// // Uses diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp index 7a00032..76b3931 100644 --- a/mlir/lib/Transforms/Utils/RegionUtils.cpp +++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp @@ -368,6 +368,324 @@ static LogicalResult runRegionDCE(MutableArrayRef regions) { } //===----------------------------------------------------------------------===// +// Block Merging +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// BlockEquivalenceData + +namespace { +/// This class contains the information for comparing the equivalencies of two +/// blocks. Blocks are considered equivalent if they contain the same operations +/// in the same order. The only allowed divergence is for operands that come +/// from sources outside of the parent block, i.e. the uses of values produced +/// within the block must be equivalent. +/// e.g., +/// Equivalent: +/// ^bb1(%arg0: i32) +/// return %arg0, %foo : i32, i32 +/// ^bb2(%arg1: i32) +/// return %arg1, %bar : i32, i32 +/// Not Equivalent: +/// ^bb1(%arg0: i32) +/// return %foo, %arg0 : i32, i32 +/// ^bb2(%arg1: i32) +/// return %arg1, %bar : i32, i32 +struct BlockEquivalenceData { + BlockEquivalenceData(Block *block); + + /// Return the order index for the given value that is within the block of + /// this data. + unsigned getOrderOf(Value value) const; + + /// The block this data refers to. + Block *block; + /// A hash value for this block. + llvm::hash_code hash; + /// A map of result producing operations to their relative orders within this + /// block. The order of an operation is the number of defined values that are + /// produced within the block before this operation. + DenseMap opOrderIndex; +}; +} // end anonymous namespace + +BlockEquivalenceData::BlockEquivalenceData(Block *block) + : block(block), hash(0) { + unsigned orderIt = block->getNumArguments(); + for (Operation &op : *block) { + if (unsigned numResults = op.getNumResults()) { + opOrderIndex.try_emplace(&op, orderIt); + orderIt += numResults; + } + auto opHash = OperationEquivalence::computeHash( + &op, OperationEquivalence::Flags::IgnoreOperands); + hash = llvm::hash_combine(hash, opHash); + } +} + +unsigned BlockEquivalenceData::getOrderOf(Value value) const { + assert(value.getParentBlock() == block && "expected value of this block"); + + // Arguments use the argument number as the order index. + if (BlockArgument arg = value.dyn_cast()) + return arg.getArgNumber(); + + // Otherwise, the result order is offset from the parent op's order. + OpResult result = value.cast(); + auto opOrderIt = opOrderIndex.find(result.getDefiningOp()); + assert(opOrderIt != opOrderIndex.end() && "expected op to have an order"); + return opOrderIt->second + result.getResultNumber(); +} + +//===----------------------------------------------------------------------===// +// BlockMergeCluster + +namespace { +/// This class represents a cluster of blocks to be merged together. +class BlockMergeCluster { +public: + BlockMergeCluster(BlockEquivalenceData &&leaderData) + : leaderData(std::move(leaderData)) {} + + /// Attempt to add the given block to this cluster. Returns success if the + /// block was merged, failure otherwise. + LogicalResult addToCluster(BlockEquivalenceData &blockData); + + /// Try to merge all of the blocks within this cluster into the leader block. + LogicalResult merge(); + +private: + /// The equivalence data for the leader of the cluster. + BlockEquivalenceData leaderData; + + /// The set of blocks that can be merged into the leader. + llvm::SmallSetVector blocksToMerge; + + /// A set of operand+index pairs that correspond to operands that need to be + /// replaced by arguments when the cluster gets merged. + std::set> operandsToMerge; + + /// A map of operations with external uses to a replacement within the leader + /// block. + DenseMap opsToReplace; +}; +} // end anonymous namespace + +LogicalResult BlockMergeCluster::addToCluster(BlockEquivalenceData &blockData) { + if (leaderData.hash != blockData.hash) + return failure(); + Block *leaderBlock = leaderData.block, *mergeBlock = blockData.block; + if (leaderBlock->getArgumentTypes() != mergeBlock->getArgumentTypes()) + return failure(); + + // A set of operands that mismatch between the leader and the new block. + SmallVector, 8> mismatchedOperands; + SmallVector, 2> newOpsToReplace; + auto lhsIt = leaderBlock->begin(), lhsE = leaderBlock->end(); + auto rhsIt = blockData.block->begin(), rhsE = blockData.block->end(); + for (int opI = 0; lhsIt != lhsE && rhsIt != rhsE; ++lhsIt, ++rhsIt, ++opI) { + // Check that the operations are equivalent. + if (!OperationEquivalence::isEquivalentTo( + &*lhsIt, &*rhsIt, OperationEquivalence::Flags::IgnoreOperands)) + return failure(); + + // Compare the operands of the two operations. If the operand is within + // the block, it must refer to the same operation. + auto lhsOperands = lhsIt->getOperands(), rhsOperands = rhsIt->getOperands(); + for (int operand : llvm::seq(0, lhsIt->getNumOperands())) { + Value lhsOperand = lhsOperands[operand]; + Value rhsOperand = rhsOperands[operand]; + if (lhsOperand == rhsOperand) + continue; + + // Check that these uses are both external, or both internal. + bool lhsIsInBlock = lhsOperand.getParentBlock() == leaderBlock; + bool rhsIsInBlock = rhsOperand.getParentBlock() == mergeBlock; + if (lhsIsInBlock != rhsIsInBlock) + return failure(); + // Let the operands differ if they are defined in a different block. These + // will become new arguments if the blocks get merged. + if (!lhsIsInBlock) { + mismatchedOperands.emplace_back(opI, operand); + continue; + } + + // Otherwise, these operands must have the same logical order within the + // parent block. + if (leaderData.getOrderOf(lhsOperand) != blockData.getOrderOf(rhsOperand)) + return failure(); + } + + // If the rhs has external uses, it will need to be replaced. + if (rhsIt->isUsedOutsideOfBlock(mergeBlock)) + newOpsToReplace.emplace_back(&*rhsIt, &*lhsIt); + } + // Make sure that the block sizes are equivalent. + if (lhsIt != lhsE || rhsIt != rhsE) + return failure(); + + // If we get here, the blocks are equivalent and can be merged. + operandsToMerge.insert(mismatchedOperands.begin(), mismatchedOperands.end()); + opsToReplace.insert(newOpsToReplace.begin(), newOpsToReplace.end()); + blocksToMerge.insert(blockData.block); + return success(); +} + +/// Returns true if the predecessor terminators of the given block can not have +/// their operands updated. +static bool ableToUpdatePredOperands(Block *block) { + for (auto it = block->pred_begin(), e = block->pred_end(); it != e; ++it) { + auto branch = dyn_cast((*it)->getTerminator()); + if (!branch || !branch.getMutableSuccessorOperands(it.getSuccessorIndex())) + return false; + } + return true; +} + +LogicalResult BlockMergeCluster::merge() { + // Don't consider clusters that don't have blocks to merge. + if (blocksToMerge.empty()) + return failure(); + + Block *leaderBlock = leaderData.block; + if (!operandsToMerge.empty()) { + // If the cluster has operands to merge, verify that the predecessor + // terminators of each of the blocks can have their successor operands + // updated. + // TODO: We could try and sub-partition this cluster if only some blocks + // cause the mismatch. + if (!ableToUpdatePredOperands(leaderBlock) || + !llvm::all_of(blocksToMerge, ableToUpdatePredOperands)) + return failure(); + + // Replace any necessary operations. + for (std::pair &it : opsToReplace) + it.first->replaceAllUsesWith(it.second); + + // Collect the iterators for each of the blocks to merge. We will walk all + // of the iterators at once to avoid operand index invalidation. + SmallVector blockIterators; + blockIterators.reserve(blocksToMerge.size() + 1); + blockIterators.push_back(leaderBlock->begin()); + for (Block *mergeBlock : blocksToMerge) + blockIterators.push_back(mergeBlock->begin()); + + // Update each of the predecessor terminators with the new arguments. + SmallVector, 2> newArguments( + 1 + blocksToMerge.size(), + SmallVector(operandsToMerge.size())); + unsigned curOpIndex = 0; + for (auto it : llvm::enumerate(operandsToMerge)) { + unsigned nextOpOffset = it.value().first - curOpIndex; + curOpIndex = it.value().first; + + // Process the operand for each of the block iterators. + for (unsigned i = 0, e = blockIterators.size(); i != e; ++i) { + Block::iterator &blockIter = blockIterators[i]; + std::advance(blockIter, nextOpOffset); + auto &operand = blockIter->getOpOperand(it.value().second); + newArguments[i][it.index()] = operand.get(); + + // Update the operand and insert an argument if this is the leader. + if (i == 0) + operand.set(leaderBlock->addArgument(operand.get().getType())); + } + } + // Update the predecessors for each of the blocks. + auto updatePredecessors = [&](Block *block, unsigned clusterIndex) { + for (auto predIt = block->pred_begin(), predE = block->pred_end(); + predIt != predE; ++predIt) { + auto branch = cast((*predIt)->getTerminator()); + unsigned succIndex = predIt.getSuccessorIndex(); + branch.getMutableSuccessorOperands(succIndex)->append( + newArguments[clusterIndex]); + } + }; + updatePredecessors(leaderBlock, /*clusterIndex=*/0); + for (unsigned i = 0, e = blocksToMerge.size(); i != e; ++i) + updatePredecessors(blocksToMerge[i], /*clusterIndex=*/i + 1); + } + + // Replace all uses of the merged blocks with the leader and erase them. + for (Block *block : blocksToMerge) { + block->replaceAllUsesWith(leaderBlock); + block->erase(); + } + return success(); +} + +/// Identify identical blocks within the given region and merge them, inserting +/// new block arguments as necessary. Returns success if any blocks were merged, +/// failure otherwise. +static LogicalResult mergeIdenticalBlocks(Region ®ion) { + if (region.empty() || llvm::hasSingleElement(region)) + return failure(); + + // Identify sets of blocks, other than the entry block, that branch to the + // same successors. We will use these groups to create clusters of equivalent + // blocks. + DenseMap> matchingSuccessors; + for (Block &block : llvm::drop_begin(region, 1)) + matchingSuccessors[block.getSuccessors()].push_back(&block); + + bool mergedAnyBlocks = false; + for (ArrayRef blocks : llvm::make_second_range(matchingSuccessors)) { + if (blocks.size() == 1) + continue; + + SmallVector clusters; + for (Block *block : blocks) { + BlockEquivalenceData data(block); + + // Don't allow merging if this block has any regions. + // TODO: Add support for regions if necessary. + bool hasNonEmptyRegion = llvm::any_of(*block, [](Operation &op) { + return llvm::any_of(op.getRegions(), + [](Region ®ion) { return !region.empty(); }); + }); + if (hasNonEmptyRegion) + continue; + + // Try to add this block to an existing cluster. + bool addedToCluster = false; + for (auto &cluster : clusters) + if ((addedToCluster = succeeded(cluster.addToCluster(data)))) + break; + if (!addedToCluster) + clusters.emplace_back(std::move(data)); + } + for (auto &cluster : clusters) + mergedAnyBlocks |= succeeded(cluster.merge()); + } + + return success(mergedAnyBlocks); +} + +/// Identify identical blocks within the given regions and merge them, inserting +/// new block arguments as necessary. +static LogicalResult mergeIdenticalBlocks(MutableArrayRef regions) { + llvm::SmallSetVector worklist; + for (auto ®ion : regions) + worklist.insert(®ion); + bool anyChanged = false; + while (!worklist.empty()) { + Region *region = worklist.pop_back_val(); + if (succeeded(mergeIdenticalBlocks(*region))) { + worklist.insert(region); + anyChanged = true; + } + + // Add any nested regions to the worklist. + for (Block &block : *region) + for (auto &op : block) + for (auto &nestedRegion : op.getRegions()) + worklist.insert(&nestedRegion); + } + + return success(anyChanged); +} + +//===----------------------------------------------------------------------===// // Region Simplification //===----------------------------------------------------------------------===// @@ -376,7 +694,9 @@ static LogicalResult runRegionDCE(MutableArrayRef regions) { /// elimination, as well as some other DCE. This function returns success if any /// of the regions were simplified, failure otherwise. LogicalResult mlir::simplifyRegions(MutableArrayRef regions) { - LogicalResult eliminatedBlocks = eraseUnreachableBlocks(regions); - LogicalResult eliminatedOpsOrArgs = runRegionDCE(regions); - return success(succeeded(eliminatedBlocks) || succeeded(eliminatedOpsOrArgs)); + bool eliminatedBlocks = succeeded(eraseUnreachableBlocks(regions)); + bool eliminatedOpsOrArgs = succeeded(runRegionDCE(regions)); + bool mergedIdenticalBlocks = succeeded(mergeIdenticalBlocks(regions)); + return success(eliminatedBlocks || eliminatedOpsOrArgs || + mergedIdenticalBlocks); } diff --git a/mlir/test/Dialect/SPIRV/canonicalize.mlir b/mlir/test/Dialect/SPIRV/canonicalize.mlir index f8c3bde..20ed6e9 100644 --- a/mlir/test/Dialect/SPIRV/canonicalize.mlir +++ b/mlir/test/Dialect/SPIRV/canonicalize.mlir @@ -559,15 +559,18 @@ func @cannot_canonicalize_selection_op_0(%cond: i1) -> () { // CHECK: spv.selection { spv.selection { + // CHECK: spv.BranchConditional + // CHECK-SAME: ^bb1(%[[DST_VAR_0]], %[[SRC_VALUE_0]] + // CHECK-SAME: ^bb1(%[[DST_VAR_1]], %[[SRC_VALUE_1]] spv.BranchConditional %cond, ^then, ^else ^then: - // CHECK: spv.Store "Function" %[[DST_VAR_0]], %[[SRC_VALUE_0]] ["Aligned", 8] : vector<3xi32> + // CHECK: ^bb1(%[[ARG0:.*]]: !spv.ptr, Function>, %[[ARG1:.*]]: vector<3xi32>): + // CHECK: spv.Store "Function" %[[ARG0]], %[[ARG1]] ["Aligned", 8] : vector<3xi32> spv.Store "Function" %3, %1 ["Aligned", 8]: vector<3xi32> spv.Branch ^merge ^else: - // CHECK: spv.Store "Function" %[[DST_VAR_1]], %[[SRC_VALUE_1]] ["Aligned", 8] : vector<3xi32> spv.Store "Function" %4, %2 ["Aligned", 8] : vector<3xi32> spv.Branch ^merge diff --git a/mlir/test/Transforms/canonicalize-block-merge.mlir b/mlir/test/Transforms/canonicalize-block-merge.mlir new file mode 100644 index 0000000..86cac9d --- /dev/null +++ b/mlir/test/Transforms/canonicalize-block-merge.mlir @@ -0,0 +1,204 @@ +// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='func(canonicalize)' -split-input-file | FileCheck %s + +// Check the simple case of single operation blocks with a return. + +// CHECK-LABEL: func @return_blocks( +func @return_blocks() { + // CHECK: "foo.cond_br"()[^bb1, ^bb1] + // CHECK: ^bb1: + // CHECK-NEXT: return + // CHECK-NOT: ^bb2 + + "foo.cond_br"() [^bb1, ^bb2] : () -> () + +^bb1: + return +^bb2: + return +} + +// Check the case of identical blocks with matching arguments. + +// CHECK-LABEL: func @matching_arguments( +func @matching_arguments() -> i32 { + // CHECK: "foo.cond_br"()[^bb1, ^bb1] + // CHECK: ^bb1(%{{.*}}: i32): + // CHECK-NEXT: return + // CHECK-NOT: ^bb2 + + "foo.cond_br"() [^bb1, ^bb2] : () -> () + +^bb1(%arg0 : i32): + return %arg0 : i32 +^bb2(%arg1 : i32): + return %arg1 : i32 +} + +// Check that no merging occurs if there is an operand mismatch and we can't +// update th predecessor. + +// CHECK-LABEL: func @mismatch_unknown_terminator +func @mismatch_unknown_terminator(%arg0 : i32, %arg1 : i32) -> i32 { + // CHECK: "foo.cond_br"()[^bb1, ^bb2] + + "foo.cond_br"() [^bb1, ^bb2] : () -> () + +^bb1: + return %arg0 : i32 +^bb2: + return %arg1 : i32 +} + +// Check that merging does occurs if there is an operand mismatch and we can +// update th predecessor. + +// CHECK-LABEL: func @mismatch_operands +// CHECK-SAME: %[[COND:.*]]: i1, %[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32 +func @mismatch_operands(%cond : i1, %arg0 : i32, %arg1 : i32) -> i32 { + // CHECK: %[[RES:.*]] = select %[[COND]], %[[ARG0]], %[[ARG1]] + // CHECK: return %[[RES]] + + cond_br %cond, ^bb1, ^bb2 + +^bb1: + return %arg0 : i32 +^bb2: + return %arg1 : i32 +} + +// Check the same as above, but with pre-existing arguments. + +// CHECK-LABEL: func @mismatch_operands_matching_arguments( +// CHECK-SAME: %[[COND:.*]]: i1, %[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32 +func @mismatch_operands_matching_arguments(%cond : i1, %arg0 : i32, %arg1 : i32) -> (i32, i32) { + // CHECK: %[[RES0:.*]] = select %[[COND]], %[[ARG1]], %[[ARG0]] + // CHECK: %[[RES1:.*]] = select %[[COND]], %[[ARG0]], %[[ARG1]] + // CHECK: return %[[RES1]], %[[RES0]] + + cond_br %cond, ^bb1(%arg1 : i32), ^bb2(%arg0 : i32) + +^bb1(%arg2 : i32): + return %arg0, %arg2 : i32, i32 +^bb2(%arg3 : i32): + return %arg1, %arg3 : i32, i32 +} + +// Check that merging does not occur if the uses of the arguments differ. + +// CHECK-LABEL: func @mismatch_argument_uses( +func @mismatch_argument_uses(%cond : i1, %arg0 : i32, %arg1 : i32) -> (i32, i32) { + // CHECK: cond_br %{{.*}}, ^bb1(%{{.*}}), ^bb2 + + cond_br %cond, ^bb1(%arg1 : i32), ^bb2(%arg0 : i32) + +^bb1(%arg2 : i32): + return %arg0, %arg2 : i32, i32 +^bb2(%arg3 : i32): + return %arg3, %arg1 : i32, i32 +} + +// Check that merging does not occur if the types of the arguments differ. + +// CHECK-LABEL: func @mismatch_argument_types( +func @mismatch_argument_types(%cond : i1, %arg0 : i32, %arg1 : i16) { + // CHECK: cond_br %{{.*}}, ^bb1(%{{.*}}), ^bb2 + + cond_br %cond, ^bb1(%arg0 : i32), ^bb2(%arg1 : i16) + +^bb1(%arg2 : i32): + "foo.return"(%arg2) : (i32) -> () +^bb2(%arg3 : i16): + "foo.return"(%arg3) : (i16) -> () +} + +// Check that merging does not occur if the number of the arguments differ. + +// CHECK-LABEL: func @mismatch_argument_count( +func @mismatch_argument_count(%cond : i1, %arg0 : i32) { + // CHECK: cond_br %{{.*}}, ^bb1(%{{.*}}), ^bb2 + + cond_br %cond, ^bb1(%arg0 : i32), ^bb2 + +^bb1(%arg2 : i32): + "foo.return"(%arg2) : (i32) -> () +^bb2: + "foo.return"() : () -> () +} + +// Check that merging does not occur if the operations differ. + +// CHECK-LABEL: func @mismatch_operations( +func @mismatch_operations(%cond : i1) { + // CHECK: cond_br %{{.*}}, ^bb1, ^bb2 + + cond_br %cond, ^bb1, ^bb2 + +^bb1: + "foo.return"() : () -> () +^bb2: + return +} + +// Check that merging does not occur if the number of operations differ. + +// CHECK-LABEL: func @mismatch_operation_count( +func @mismatch_operation_count(%cond : i1) { + // CHECK: cond_br %{{.*}}, ^bb1, ^bb2 + + cond_br %cond, ^bb1, ^bb2 + +^bb1: + "foo.op"() : () -> () + return +^bb2: + return +} + +// Check that merging does not occur if the blocks contain regions. + +// CHECK-LABEL: func @contains_regions( +func @contains_regions(%cond : i1) { + // CHECK: cond_br %{{.*}}, ^bb1, ^bb2 + + cond_br %cond, ^bb1, ^bb2 + +^bb1: + loop.if %cond { + "foo.op"() : () -> () + } + return +^bb2: + loop.if %cond { + "foo.op"() : () -> () + } + return +} + +// Check that properly handles back edges and the case where a value from one +// block is used in another. + +// CHECK-LABEL: func @mismatch_loop( +// CHECK-SAME: %[[ARG:.*]]: i1 +func @mismatch_loop(%cond : i1) { + // CHECK: cond_br %{{.*}}, ^bb1(%[[ARG]] : i1), ^bb2 + + cond_br %cond, ^bb2, ^bb3 + +^bb1: + // CHECK: ^bb1(%[[ARG2:.*]]: i1): + // CHECK-NEXT: %[[LOOP_CARRY:.*]] = "foo.op" + // CHECK-NEXT: cond_br %[[ARG2]], ^bb1(%[[LOOP_CARRY]] : i1), ^bb2 + + %ignored = "foo.op"() : () -> (i1) + cond_br %cond2, ^bb1, ^bb3 + +^bb2: + %cond2 = "foo.op"() : () -> (i1) + cond_br %cond, ^bb1, ^bb3 + +^bb3: + // CHECK: ^bb2: + // CHECK-NEXT: return + + return +} diff --git a/mlir/test/Transforms/canonicalize-dce.mlir b/mlir/test/Transforms/canonicalize-dce.mlir index b93af00..6028821 100644 --- a/mlir/test/Transforms/canonicalize-dce.mlir +++ b/mlir/test/Transforms/canonicalize-dce.mlir @@ -62,10 +62,6 @@ func @f(%arg0: f32) { // Test case: Delete block arguments for cond_br. // CHECK: func @f(%arg0: f32, %arg1: i1) -// CHECK-NEXT: cond_br %arg1, ^bb1, ^bb2 -// CHECK-NEXT: ^bb1: -// CHECK-NEXT: return -// CHECK-NEXT: ^bb2: // CHECK-NEXT: return func @f(%arg0: f32, %pred: i1) { diff --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir index 69e8b39..1cff314 100644 --- a/mlir/test/Transforms/canonicalize.mlir +++ b/mlir/test/Transforms/canonicalize.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='func(canonicalize)' -split-input-file | FileCheck %s +// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='func(canonicalize)' -split-input-file | FileCheck %s -dump-input-on-failure // CHECK-LABEL: func @test_subi_zero func @test_subi_zero(%arg0: i32) -> i32 { @@ -361,19 +361,15 @@ func @dead_dealloc_fold() { // CHECK-LABEL: func @dead_dealloc_fold_multi_use func @dead_dealloc_fold_multi_use(%cond : i1) { - // CHECK-NEXT: cond_br + // CHECK-NEXT: return %a = alloc() : memref<4xf32> cond_br %cond, ^bb1, ^bb2 - // CHECK-LABEL: bb1: ^bb1: - // CHECK-NEXT: return dealloc %a: memref<4xf32> return - // CHECK-LABEL: bb2: ^bb2: - // CHECK-NEXT: return dealloc %a: memref<4xf32> return } -- 2.7.4