Fix backward slice corner case
authorNicolas Vasilache <ntv@google.com>
Fri, 26 Jul 2019 10:48:51 +0000 (03:48 -0700)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 26 Jul 2019 10:49:17 +0000 (03:49 -0700)
In the backward slice computation, BlockArgument coming from function arguments represent a natural boundary for the traversal and should not trigger llvm_unreachable.
This CL also improves the error message and adds a relevant test.

PiperOrigin-RevId: 260118630

mlir/lib/Analysis/SliceAnalysis.cpp
mlir/test/Transforms/slicing-utils.mlir

index 05dcfce..68ab2d3 100644 (file)
@@ -23,6 +23,7 @@
 #include "mlir/AffineOps/AffineOps.h"
 #include "mlir/Analysis/VectorAnalysis.h"
 #include "mlir/Dialect/LoopOps/LoopOps.h"
+#include "mlir/IR/Function.h"
 #include "mlir/IR/Operation.h"
 #include "mlir/Support/Functional.h"
 #include "mlir/Support/STLExtras.h"
@@ -103,8 +104,9 @@ static void getBackwardSliceImpl(Operation *op,
     return;
   }
 
-  for (auto *operand : op->getOperands()) {
-    if (isa<BlockArgument>(operand)) {
+  for (auto en : llvm::enumerate(op->getOperands())) {
+    auto *operand = en.value();
+    if (auto *blockArg = dyn_cast<BlockArgument>(operand)) {
       if (auto affIv = getForInductionVarOwner(operand)) {
         auto *affOp = affIv.getOperation();
         if (backwardSlice->count(affOp) == 0)
@@ -113,7 +115,9 @@ static void getBackwardSliceImpl(Operation *op,
         auto *loopOp = loopIv.getOperation();
         if (backwardSlice->count(loopOp) == 0)
           getBackwardSliceImpl(loopOp, backwardSlice, filter);
-      } else {
+      } else if (blockArg->getOwner() !=
+                 &op->getParentOfType<FuncOp>().getBody().front()) {
+        op->emitError("Unsupported CF for operand ") << en.index();
         llvm_unreachable("Unsupported control flow");
       }
       continue;
index 4849c18..49410db 100644 (file)
@@ -264,7 +264,18 @@ func @slicing_test_3() {
     %d = "slicing-test-op"(%c, %i2): (index, index) -> index
   }
   return
-}// This test dumps 2 sets of outputs: first the test outputs themselves followed
+}
+
+// FWD-LABEL: slicing_test_function_argument
+// BWD-LABEL: slicing_test_function_argument
+// FWDBWD-LABEL: slicing_test_function_argument
+func @slicing_test_function_argument(%arg0: index) -> index {
+  // BWD: matched: {{.*}} (index, index) -> index backward static slice:
+  %0 = "slicing-test-op"(%arg0, %arg0): (index, index) -> index
+  return %0 : index
+}
+
+// This test dumps 2 sets of outputs: first the test outputs themselves followed
 // by the module. These labels isolate the test outputs from the module dump.
 // FWD-LABEL: slicing_test
 // BWD-LABEL: slicing_test
@@ -275,3 +286,6 @@ func @slicing_test_3() {
 // FWD-LABEL: slicing_test_3
 // BWD-LABEL: slicing_test_3
 // FWDBWD-LABEL: slicing_test_3
+// FWD-LABEL: slicing_test_function_argument
+// BWD-LABEL: slicing_test_function_argument
+// FWDBWD-LABEL: slicing_test_function_argument