Fix backward slice computation to iterate through known control flow
authorNicolas Vasilache <ntv@google.com>
Thu, 25 Jul 2019 08:33:02 +0000 (01:33 -0700)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Thu, 25 Jul 2019 08:33:35 +0000 (01:33 -0700)
This CL fixes an oversight with dealing with loops in slicing analysis.
The forward slice computation properly propagates through loops but not the backward slice.

Add relevant unit tests.

PiperOrigin-RevId: 259903396

mlir/include/mlir/Analysis/SliceAnalysis.h
mlir/lib/Analysis/SliceAnalysis.cpp
mlir/test/Transforms/slicing-utils.mlir

index fec79cd..ad6b653 100644 (file)
@@ -56,7 +56,7 @@ using TransitiveFilter = std::function<bool(Operation *)>;
 /// Example starting from node 0
 /// ============================
 ///
-///              0
+///               0
 ///    ___________|___________
 ///    1       2      3      4
 ///    |_______|      |______|
index 1823f3d..05dcfce 100644 (file)
@@ -89,9 +89,8 @@ void mlir::getForwardSlice(Operation *op, SetVector<Operation *> *forwardSlice,
 static void getBackwardSliceImpl(Operation *op,
                                  SetVector<Operation *> *backwardSlice,
                                  TransitiveFilter filter) {
-  if (!op) {
+  if (!op)
     return;
-  }
 
   assert((op->getNumRegions() == 0 || isa<AffineForOp>(op) ||
           isa<loop::ForOp>(op)) &&
@@ -105,6 +104,20 @@ static void getBackwardSliceImpl(Operation *op,
   }
 
   for (auto *operand : op->getOperands()) {
+    if (isa<BlockArgument>(operand)) {
+      if (auto affIv = getForInductionVarOwner(operand)) {
+        auto *affOp = affIv.getOperation();
+        if (backwardSlice->count(affOp) == 0)
+          getBackwardSliceImpl(affOp, backwardSlice, filter);
+      } else if (auto loopIv = loop::getForInductionVarOwner(operand)) {
+        auto *loopOp = loopIv.getOperation();
+        if (backwardSlice->count(loopOp) == 0)
+          getBackwardSliceImpl(loopOp, backwardSlice, filter);
+      } else {
+        llvm_unreachable("Unsupported control flow");
+      }
+      continue;
+    }
     auto *op = operand->getDefiningOp();
     if (backwardSlice->count(op) == 0) {
       getBackwardSliceImpl(op, backwardSlice, filter);
index ae59ecd..4849c18 100644 (file)
@@ -217,8 +217,61 @@ func @slicing_test() {
   return
 }
 
-// This test dumps 2 sets of outputs: first the test outputs themselves followed
+// FWD-LABEL: slicing_test_2
+// BWD-LABEL: slicing_test_2
+// FWDBWD-LABEL: slicing_test_2
+func @slicing_test_2() {
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %c2 = constant 2 : index
+  %c16 = constant 16 : index
+  loop.for %i0 = %c0 to %c16 step %c1 {
+    affine.for %i1 = (i)[] -> (i)(%i0) to 10 {
+      // BWD: matched: %[[b:.*]] {{.*}} backward static slice:
+      // BWD: loop.for {{.*}}
+
+      // affine.for appears in the body of loop.for
+      // BWD: affine.for {{.*}}
+
+      // affine.for appears as a proper op in the backward slice
+      // BWD: affine.for {{.*}}
+      %b = "slicing-test-op"(%i1): (index) -> index
+
+      // BWD: matched: %[[c:.*]] {{.*}} backward static slice:
+      // BWD: loop.for {{.*}}
+
+      // affine.for appears in the body of loop.for
+      // BWD-NEXT: affine.for {{.*}}
+
+      // affine.for only appears in the body of loop.for
+      // BWD-NOT: affine.for {{.*}}
+      %c = "slicing-test-op"(%i0): (index) -> index
+    }
+  }
+  return
+}
+
+// FWD-LABEL: slicing_test_3
+// BWD-LABEL: slicing_test_3
+// FWDBWD-LABEL: slicing_test_3
+func @slicing_test_3() {
+  %f = constant 1.0 : f32
+  %c = "slicing-test-op"(%f): (f32) -> index
+  // FWD: matched: {{.*}} (f32) -> index forward static slice:
+  // FWD: loop.for {{.*}}
+  // FWD: matched: {{.*}} (index, index) -> index forward static slice:
+  loop.for %i2 = %c to %c step %c {
+    %d = "slicing-test-op"(%c, %i2): (index, index) -> index
+  }
+  return
+}// This test dumps 2 sets of outputs: first the test outputs themselves followed
 // by the module. These labels isolate the test outputs from the module dump.
 // FWD-LABEL: slicing_test
 // BWD-LABEL: slicing_test
 // FWDBWD-LABEL: slicing_test
+// FWD-LABEL: slicing_test_2
+// BWD-LABEL: slicing_test_2
+// FWDBWD-LABEL: slicing_test_2
+// FWD-LABEL: slicing_test_3
+// BWD-LABEL: slicing_test_3
+// FWDBWD-LABEL: slicing_test_3