[mlir][transforms] TopologicalSort: Support ops from different blocks
authorMatthias Springer <springerm@google.com>
Thu, 13 Oct 2022 01:26:25 +0000 (10:26 +0900)
committerMatthias Springer <springerm@google.com>
Thu, 13 Oct 2022 01:36:06 +0000 (10:36 +0900)
This change allows analyzing ops from different block, in particular when used in programs that have `cf` branches.

Differential Revision: https://reviews.llvm.org/D135644

mlir/include/mlir/Transforms/TopologicalSortUtils.h
mlir/lib/Transforms/Utils/TopologicalSortUtils.cpp
mlir/test/Transforms/test-toposort.mlir
mlir/test/lib/Transforms/TestTopologicalSort.cpp

index 1a50d4d..74e44b1 100644 (file)
@@ -95,16 +95,13 @@ bool sortTopologically(
     Block *block,
     function_ref<bool(Value, Operation *)> isOperandReady = nullptr);
 
-/// Compute a topological ordering of the given ops. All ops must belong to the
-/// specified block.
-///
-/// This sort is not stable.
+/// Compute a topological ordering of the given ops. This sort is not stable.
 ///
 /// Note: If the specified ops contain incomplete/interrupted SSA use-def
 /// chains, the result may not actually be a topological sorting with respect to
 /// the entire program.
 bool computeTopologicalSorting(
-    Block *block, MutableArrayRef<Operation *> ops,
+    MutableArrayRef<Operation *> ops,
     function_ref<bool(Value, Operation *)> isOperandReady = nullptr);
 
 } // end namespace mlir
index 8767877..f3a9d21 100644 (file)
 using namespace mlir;
 
 /// Return `true` if the given operation is ready to be scheduled.
-static bool isOpReady(Block *block, Operation *op,
-                      DenseSet<Operation *> &unscheduledOps,
+static bool isOpReady(Operation *op, DenseSet<Operation *> &unscheduledOps,
                       function_ref<bool(Value, Operation *)> isOperandReady) {
   // An operation is ready to be scheduled if all its operands are ready. An
   // operation is ready if:
-  const auto isReady = [&](Value value, Operation *top) {
+  const auto isReady = [&](Value value) {
     // - the user-provided callback marks it as ready,
     if (isOperandReady && isOperandReady(value, op))
       return true;
@@ -25,22 +24,24 @@ static bool isOpReady(Block *block, Operation *op,
     // - it is a block argument,
     if (!parent)
       return true;
-    Operation *ancestor = block->findAncestorOpInBlock(*parent);
-    // - it is an implicit capture,
-    if (!ancestor)
-      return true;
-    // - it is defined in a nested region, or
-    if (ancestor == op)
-      return true;
-    // - its ancestor in the block is scheduled.
-    return !unscheduledOps.contains(ancestor);
+    // - or it is not defined by an unscheduled op (and also not nested within
+    //   an unscheduled op).
+    do {
+      // Stop traversal when op under examination is reached.
+      if (parent == op)
+        return true;
+      if (unscheduledOps.contains(parent))
+        return false;
+    } while ((parent = parent->getParentOp()));
+    // No unscheduled op found.
+    return true;
   };
 
   // An operation is recursively ready to be scheduled of it and its nested
   // operations are ready.
   WalkResult readyToSchedule = op->walk([&](Operation *nestedOp) {
     return llvm::all_of(nestedOp->getOperands(),
-                        [&](Value operand) { return isReady(operand, op); })
+                        [&](Value operand) { return isReady(operand); })
                ? WalkResult::advance()
                : WalkResult::interrupt();
   });
@@ -71,7 +72,7 @@ bool mlir::sortTopologically(
     // set, and "schedule" it (move it before the `nextScheduledOp`).
     for (Operation &op :
          llvm::make_early_inc_range(llvm::make_range(nextScheduledOp, end))) {
-      if (!isOpReady(block, &op, unscheduledOps, isOperandReady))
+      if (!isOpReady(&op, unscheduledOps, isOperandReady))
         continue;
 
       // Schedule the operation by moving it to the start.
@@ -104,7 +105,7 @@ bool mlir::sortTopologically(
 }
 
 bool mlir::computeTopologicalSorting(
-    Block *block, MutableArrayRef<Operation *> ops,
+    MutableArrayRef<Operation *> ops,
     function_ref<bool(Value, Operation *)> isOperandReady) {
   if (ops.empty())
     return true;
@@ -113,10 +114,8 @@ bool mlir::computeTopologicalSorting(
   DenseSet<Operation *> unscheduledOps;
 
   // Mark all operations as unscheduled.
-  for (Operation *op : ops) {
-    assert(op->getBlock() == block && "op must belong to block");
+  for (Operation *op : ops)
     unscheduledOps.insert(op);
-  }
 
   unsigned nextScheduledOp = 0;
 
@@ -128,7 +127,7 @@ bool mlir::computeTopologicalSorting(
     // i.e. the ones for which there aren't any operand produced by an op in the
     // set, and "schedule" it (swap it with the op at `nextScheduledOp`).
     for (unsigned i = nextScheduledOp; i < ops.size(); ++i) {
-      if (!isOpReady(block, ops[i], unscheduledOps, isOperandReady))
+      if (!isOpReady(ops[i], unscheduledOps, isOperandReady))
         continue;
 
       // Schedule the operation by moving it to the start.
index 2ebf35c..c47b885 100644 (file)
@@ -1,5 +1,5 @@
-// RUN: mlir-opt -topological-sort %s | FileCheck %s
-// RUN: mlir-opt -test-topological-sort-analysis %s | FileCheck %s -check-prefix=CHECK-ANALYSIS
+// RUN: mlir-opt %s -topological-sort | FileCheck %s
+// RUN: mlir-opt %s -test-topological-sort-analysis -verify-diagnostics | FileCheck %s -check-prefix=CHECK-ANALYSIS
 
 // Test producer is after user.
 // CHECK-LABEL: test.graph_region
@@ -36,6 +36,35 @@ test.graph_region attributes{"root"} {
   %3 = "test.d"() {selected} : () -> i32
 }
 
+// Test not all scheduled.
+// CHECK-LABEL: test.graph_region
+// CHECK-ANALYSIS-LABEL: test.graph_region
+// expected-error@+1 {{could not schedule all ops}}
+test.graph_region attributes{"root"} {
+  %0 = "test.a"(%1) {selected} : (i32) -> i32
+  %1 = "test.b"(%0) {selected} : (i32) -> i32
+}
+
+// CHECK-LABEL: func @test_multiple_blocks
+// CHECK-ANALYSIS-LABEL: func @test_multiple_blocks
+func.func @test_multiple_blocks() -> (i32) attributes{"root", "ordered"} {
+  // CHECK-ANALYSIS-NEXT: test.foo{{.*}} {pos = 0
+  %0 = "test.foo"() {selected = 2} : () -> (i32)
+  // CHECK-ANALYSIS-NEXT: test.foo
+  %1 = "test.foo"() : () -> (i32)
+  cf.br ^bb0
+^bb0:
+  // CHECK-ANALYSIS: test.foo{{.*}} {pos = 1
+  %2 = "test.foo"() {selected = 3} : () -> (i32)
+  // CHECK-ANALYSIS-NEXT: test.bar{{.*}} {pos = 2
+  %3 = "test.bar"(%0, %1, %2) {selected = 0} : (i32, i32, i32) -> (i32)
+  cf.br ^bb1 (%2 : i32)
+^bb1(%arg0: i32):
+  // CHECK-ANALYSIS: test.qux{{.*}} {pos = 3
+  %4 = "test.qux"(%arg0, %0) {selected = 1} : (i32, i32) -> (i32)
+  return %4 : i32
+}
+
 // Test block arguments.
 // CHECK-LABEL: test.graph_region
 test.graph_region {
index 9ed64ea..4ad5b5c 100644 (file)
@@ -30,25 +30,47 @@ struct TestTopologicalSortAnalysisPass
     Operation *op = getOperation();
     OpBuilder builder(op->getContext());
 
-    op->walk([&](Operation *root) {
+    WalkResult result = op->walk([&](Operation *root) {
       if (!root->hasAttr("root"))
         return WalkResult::advance();
 
-      assert(root->getNumRegions() == 1 && root->getRegion(0).hasOneBlock() &&
-             "expected one block");
-      Block *block = &root->getRegion(0).front();
       SmallVector<Operation *> selectedOps;
-      block->walk([&](Operation *op) {
-        if (op->hasAttr("selected"))
-          selectedOps.push_back(op);
+      root->walk([&](Operation *selected) {
+        if (!selected->hasAttr("selected"))
+          return WalkResult::advance();
+        if (root->hasAttr("ordered")) {
+          // If the root has an "ordered" attribute, we fill the selectedOps
+          // vector in a certain order.
+          int64_t pos =
+              selected->getAttr("selected").cast<IntegerAttr>().getInt();
+          if (pos >= static_cast<int64_t>(selectedOps.size()))
+            selectedOps.append(pos + 1 - selectedOps.size(), nullptr);
+          selectedOps[pos] = selected;
+        } else {
+          selectedOps.push_back(selected);
+        }
+        return WalkResult::advance();
       });
 
-      computeTopologicalSorting(block, selectedOps);
+      if (llvm::find(selectedOps, nullptr) != selectedOps.end()) {
+        root->emitError("invalid test case: some indices are missing among the "
+                        "selected ops");
+        return WalkResult::skip();
+      }
+
+      if (!computeTopologicalSorting(selectedOps)) {
+        root->emitError("could not schedule all ops");
+        return WalkResult::skip();
+      }
+
       for (const auto &it : llvm::enumerate(selectedOps))
         it.value()->setAttr("pos", builder.getIndexAttr(it.index()));
 
       return WalkResult::advance();
     });
+
+    if (result.wasSkipped())
+      signalPassFailure();
   }
 };
 } // namespace