[mlir] Shape.AssumingOp implements RegionBranchOpInterface.
authorTres Popp <tpopp@google.com>
Fri, 18 Sep 2020 09:14:32 +0000 (11:14 +0200)
committerTres Popp <tpopp@google.com>
Mon, 21 Sep 2020 09:33:11 +0000 (11:33 +0200)
This adds support for the interface and provides unambigious information
on the control flow as it is unconditional on any runtime values.
The code is tested through confirming that buffer-placement behaves as
expected.

Differential Revision: https://reviews.llvm.org/D87894

mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
mlir/lib/Dialect/Shape/IR/Shape.cpp
mlir/test/Transforms/buffer-placement.mlir

index b709d3c342e88187f16d36be9a3009c801610095..2d7c4841f68167f592cce2c9b9ead78605e1579d 100644 (file)
@@ -619,6 +619,7 @@ def Shape_AssumingAllOp : Shape_Op<"assuming_all", [Commutative, NoSideEffect]>
 
 def Shape_AssumingOp : Shape_Op<"assuming",
                            [SingleBlockImplicitTerminator<"AssumingYieldOp">,
+                            DeclareOpInterfaceMethods<RegionBranchOpInterface>,
                             RecursiveSideEffects]> {
   let summary = "Execute the region";
   let description = [{
index 70621295e39cf221f8a22e4777c0662e5778684c..7da3b3989b9b8ed119a67b70908a6aa70b09605b 100644 (file)
@@ -233,6 +233,21 @@ void AssumingOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns,
   patterns.insert<AssumingWithTrue>(context);
 }
 
+// See RegionBranchOpInterface in Interfaces/ControlFlowInterfaces.td
+void AssumingOp::getSuccessorRegions(
+    Optional<unsigned> index, ArrayRef<Attribute> operands,
+    SmallVectorImpl<RegionSuccessor> &regions) {
+  // AssumingOp has unconditional control flow into the region and back to the
+  // parent, so return the correct RegionSuccessor purely based on the index
+  // being None or 0.
+  if (index.hasValue()) {
+    regions.push_back(RegionSuccessor(getResults()));
+    return;
+  }
+
+  regions.push_back(RegionSuccessor(&doRegion()));
+}
+
 void AssumingOp::inlineRegionIntoParent(AssumingOp &op,
                                         PatternRewriter &rewriter) {
   auto *blockBeforeAssuming = rewriter.getInsertionBlock();
index dc9ff44bf4838ebaffc178b5ed0fd86a873f7834..e03f8c9318109a6d6526f30bed0ff29d44fca137 100644 (file)
@@ -1417,3 +1417,42 @@ func @do_loop_alloc(
 }
 
 // expected-error@+1 {{Structured control-flow loops are supported only}}
+
+// -----
+
+func @assumingOp(%arg0: !shape.witness, %arg2: memref<2xf32>, %arg3: memref<2xf32>) {
+  // Confirm the alloc will be dealloc'ed in the block.
+  %1 = shape.assuming %arg0 -> memref<2xf32> {
+     %0 = alloc() : memref<2xf32>
+    shape.assuming_yield %arg2 : memref<2xf32>
+  }
+  // Confirm the alloc will be returned and dealloc'ed after its use.
+  %3 = shape.assuming %arg0 -> memref<2xf32> {
+    %2 = alloc() : memref<2xf32>
+    shape.assuming_yield %2 : memref<2xf32>
+  }
+  "linalg.copy"(%3, %arg3) : (memref<2xf32>, memref<2xf32>) -> ()
+  return
+}
+
+// CHECK-LABEL: func @assumingOp(
+// CHECK-SAME:     %[[ARG0:.*]]: !shape.witness,
+// CHECK-SAME:     %[[ARG1:.*]]: memref<2xf32>,
+// CHECK-SAME:     %[[ARG2:.*]]: memref<2xf32>) {
+// CHECK:        %[[UNUSED_RESULT:.*]] = shape.assuming %[[ARG0]] -> (memref<2xf32>) {
+// CHECK:         %[[ALLOC0:.*]] = alloc() : memref<2xf32>
+// CHECK:         dealloc %[[ALLOC0]] : memref<2xf32>
+// CHECK:         shape.assuming_yield %[[ARG1]] : memref<2xf32>
+// CHECK:        }
+// CHECK:        %[[ASSUMING_RESULT:.*]] = shape.assuming %[[ARG0]] -> (memref<2xf32>) {
+// CHECK:         %[[TMP_ALLOC:.*]] = alloc() : memref<2xf32>
+// CHECK:         %[[RETURNING_ALLOC:.*]] = alloc() : memref<2xf32>
+// CHECK:         linalg.copy(%[[TMP_ALLOC]], %[[RETURNING_ALLOC]]) : memref<2xf32>, memref<2xf32>
+// CHECK:         dealloc %[[TMP_ALLOC]] : memref<2xf32>
+// CHECK:         shape.assuming_yield %[[RETURNING_ALLOC]] : memref<2xf32>
+// CHECK:        }
+// CHECK:        linalg.copy(%[[ASSUMING_RESULT:.*]], %[[ARG2]]) : memref<2xf32>, memref<2xf32>
+// CHECK:        dealloc %[[ASSUMING_RESULT]] : memref<2xf32>
+// CHECK:        return
+// CHECK:       }
+