This change allows analyzing ops from different block, in particular when used in programs that have `cf` branches.
Differential Revision: https://reviews.llvm.org/D135644
Block *block,
function_ref<bool(Value, Operation *)> isOperandReady = nullptr);
-/// Compute a topological ordering of the given ops. All ops must belong to the
-/// specified block.
-///
-/// This sort is not stable.
+/// Compute a topological ordering of the given ops. This sort is not stable.
///
/// Note: If the specified ops contain incomplete/interrupted SSA use-def
/// chains, the result may not actually be a topological sorting with respect to
/// the entire program.
bool computeTopologicalSorting(
- Block *block, MutableArrayRef<Operation *> ops,
+ MutableArrayRef<Operation *> ops,
function_ref<bool(Value, Operation *)> isOperandReady = nullptr);
} // end namespace mlir
using namespace mlir;
/// Return `true` if the given operation is ready to be scheduled.
-static bool isOpReady(Block *block, Operation *op,
- DenseSet<Operation *> &unscheduledOps,
+static bool isOpReady(Operation *op, DenseSet<Operation *> &unscheduledOps,
function_ref<bool(Value, Operation *)> isOperandReady) {
// An operation is ready to be scheduled if all its operands are ready. An
// operation is ready if:
- const auto isReady = [&](Value value, Operation *top) {
+ const auto isReady = [&](Value value) {
// - the user-provided callback marks it as ready,
if (isOperandReady && isOperandReady(value, op))
return true;
// - it is a block argument,
if (!parent)
return true;
- Operation *ancestor = block->findAncestorOpInBlock(*parent);
- // - it is an implicit capture,
- if (!ancestor)
- return true;
- // - it is defined in a nested region, or
- if (ancestor == op)
- return true;
- // - its ancestor in the block is scheduled.
- return !unscheduledOps.contains(ancestor);
+ // - or it is not defined by an unscheduled op (and also not nested within
+ // an unscheduled op).
+ do {
+ // Stop traversal when op under examination is reached.
+ if (parent == op)
+ return true;
+ if (unscheduledOps.contains(parent))
+ return false;
+ } while ((parent = parent->getParentOp()));
+ // No unscheduled op found.
+ return true;
};
// An operation is recursively ready to be scheduled of it and its nested
// operations are ready.
WalkResult readyToSchedule = op->walk([&](Operation *nestedOp) {
return llvm::all_of(nestedOp->getOperands(),
- [&](Value operand) { return isReady(operand, op); })
+ [&](Value operand) { return isReady(operand); })
? WalkResult::advance()
: WalkResult::interrupt();
});
// set, and "schedule" it (move it before the `nextScheduledOp`).
for (Operation &op :
llvm::make_early_inc_range(llvm::make_range(nextScheduledOp, end))) {
- if (!isOpReady(block, &op, unscheduledOps, isOperandReady))
+ if (!isOpReady(&op, unscheduledOps, isOperandReady))
continue;
// Schedule the operation by moving it to the start.
}
bool mlir::computeTopologicalSorting(
- Block *block, MutableArrayRef<Operation *> ops,
+ MutableArrayRef<Operation *> ops,
function_ref<bool(Value, Operation *)> isOperandReady) {
if (ops.empty())
return true;
DenseSet<Operation *> unscheduledOps;
// Mark all operations as unscheduled.
- for (Operation *op : ops) {
- assert(op->getBlock() == block && "op must belong to block");
+ for (Operation *op : ops)
unscheduledOps.insert(op);
- }
unsigned nextScheduledOp = 0;
// i.e. the ones for which there aren't any operand produced by an op in the
// set, and "schedule" it (swap it with the op at `nextScheduledOp`).
for (unsigned i = nextScheduledOp; i < ops.size(); ++i) {
- if (!isOpReady(block, ops[i], unscheduledOps, isOperandReady))
+ if (!isOpReady(ops[i], unscheduledOps, isOperandReady))
continue;
// Schedule the operation by moving it to the start.
-// RUN: mlir-opt -topological-sort %s | FileCheck %s
-// RUN: mlir-opt -test-topological-sort-analysis %s | FileCheck %s -check-prefix=CHECK-ANALYSIS
+// RUN: mlir-opt %s -topological-sort | FileCheck %s
+// RUN: mlir-opt %s -test-topological-sort-analysis -verify-diagnostics | FileCheck %s -check-prefix=CHECK-ANALYSIS
// Test producer is after user.
// CHECK-LABEL: test.graph_region
%3 = "test.d"() {selected} : () -> i32
}
+// Test not all scheduled.
+// CHECK-LABEL: test.graph_region
+// CHECK-ANALYSIS-LABEL: test.graph_region
+// expected-error@+1 {{could not schedule all ops}}
+test.graph_region attributes{"root"} {
+ %0 = "test.a"(%1) {selected} : (i32) -> i32
+ %1 = "test.b"(%0) {selected} : (i32) -> i32
+}
+
+// CHECK-LABEL: func @test_multiple_blocks
+// CHECK-ANALYSIS-LABEL: func @test_multiple_blocks
+func.func @test_multiple_blocks() -> (i32) attributes{"root", "ordered"} {
+ // CHECK-ANALYSIS-NEXT: test.foo{{.*}} {pos = 0
+ %0 = "test.foo"() {selected = 2} : () -> (i32)
+ // CHECK-ANALYSIS-NEXT: test.foo
+ %1 = "test.foo"() : () -> (i32)
+ cf.br ^bb0
+^bb0:
+ // CHECK-ANALYSIS: test.foo{{.*}} {pos = 1
+ %2 = "test.foo"() {selected = 3} : () -> (i32)
+ // CHECK-ANALYSIS-NEXT: test.bar{{.*}} {pos = 2
+ %3 = "test.bar"(%0, %1, %2) {selected = 0} : (i32, i32, i32) -> (i32)
+ cf.br ^bb1 (%2 : i32)
+^bb1(%arg0: i32):
+ // CHECK-ANALYSIS: test.qux{{.*}} {pos = 3
+ %4 = "test.qux"(%arg0, %0) {selected = 1} : (i32, i32) -> (i32)
+ return %4 : i32
+}
+
// Test block arguments.
// CHECK-LABEL: test.graph_region
test.graph_region {
Operation *op = getOperation();
OpBuilder builder(op->getContext());
- op->walk([&](Operation *root) {
+ WalkResult result = op->walk([&](Operation *root) {
if (!root->hasAttr("root"))
return WalkResult::advance();
- assert(root->getNumRegions() == 1 && root->getRegion(0).hasOneBlock() &&
- "expected one block");
- Block *block = &root->getRegion(0).front();
SmallVector<Operation *> selectedOps;
- block->walk([&](Operation *op) {
- if (op->hasAttr("selected"))
- selectedOps.push_back(op);
+ root->walk([&](Operation *selected) {
+ if (!selected->hasAttr("selected"))
+ return WalkResult::advance();
+ if (root->hasAttr("ordered")) {
+ // If the root has an "ordered" attribute, we fill the selectedOps
+ // vector in a certain order.
+ int64_t pos =
+ selected->getAttr("selected").cast<IntegerAttr>().getInt();
+ if (pos >= static_cast<int64_t>(selectedOps.size()))
+ selectedOps.append(pos + 1 - selectedOps.size(), nullptr);
+ selectedOps[pos] = selected;
+ } else {
+ selectedOps.push_back(selected);
+ }
+ return WalkResult::advance();
});
- computeTopologicalSorting(block, selectedOps);
+ if (llvm::find(selectedOps, nullptr) != selectedOps.end()) {
+ root->emitError("invalid test case: some indices are missing among the "
+ "selected ops");
+ return WalkResult::skip();
+ }
+
+ if (!computeTopologicalSorting(selectedOps)) {
+ root->emitError("could not schedule all ops");
+ return WalkResult::skip();
+ }
+
for (const auto &it : llvm::enumerate(selectedOps))
it.value()->setAttr("pos", builder.getIndexAttr(it.index()));
return WalkResult::advance();
});
+
+ if (result.wasSkipped())
+ signalPassFailure();
}
};
} // namespace