From ffdd4a46a9a90d7b63b840c4b3c775074815f3ed Mon Sep 17 00:00:00 2001 From: Tres Popp Date: Fri, 18 Sep 2020 11:14:32 +0200 Subject: [PATCH] [mlir] Shape.AssumingOp implements RegionBranchOpInterface. This adds support for the interface and provides unambigious information on the control flow as it is unconditional on any runtime values. The code is tested through confirming that buffer-placement behaves as expected. Differential Revision: https://reviews.llvm.org/D87894 --- mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td | 1 + mlir/lib/Dialect/Shape/IR/Shape.cpp | 15 ++++++++++ mlir/test/Transforms/buffer-placement.mlir | 39 ++++++++++++++++++++++++++ 3 files changed, 55 insertions(+) diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td index b709d3c..2d7c484 100644 --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -619,6 +619,7 @@ def Shape_AssumingAllOp : Shape_Op<"assuming_all", [Commutative, NoSideEffect]> def Shape_AssumingOp : Shape_Op<"assuming", [SingleBlockImplicitTerminator<"AssumingYieldOp">, + DeclareOpInterfaceMethods, RecursiveSideEffects]> { let summary = "Execute the region"; let description = [{ diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp index 7062129..7da3b39 100644 --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -233,6 +233,21 @@ void AssumingOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns, patterns.insert(context); } +// See RegionBranchOpInterface in Interfaces/ControlFlowInterfaces.td +void AssumingOp::getSuccessorRegions( + Optional index, ArrayRef operands, + SmallVectorImpl ®ions) { + // AssumingOp has unconditional control flow into the region and back to the + // parent, so return the correct RegionSuccessor purely based on the index + // being None or 0. + if (index.hasValue()) { + regions.push_back(RegionSuccessor(getResults())); + return; + } + + regions.push_back(RegionSuccessor(&doRegion())); +} + void AssumingOp::inlineRegionIntoParent(AssumingOp &op, PatternRewriter &rewriter) { auto *blockBeforeAssuming = rewriter.getInsertionBlock(); diff --git a/mlir/test/Transforms/buffer-placement.mlir b/mlir/test/Transforms/buffer-placement.mlir index dc9ff44..e03f8c9 100644 --- a/mlir/test/Transforms/buffer-placement.mlir +++ b/mlir/test/Transforms/buffer-placement.mlir @@ -1417,3 +1417,42 @@ func @do_loop_alloc( } // expected-error@+1 {{Structured control-flow loops are supported only}} + +// ----- + +func @assumingOp(%arg0: !shape.witness, %arg2: memref<2xf32>, %arg3: memref<2xf32>) { + // Confirm the alloc will be dealloc'ed in the block. + %1 = shape.assuming %arg0 -> memref<2xf32> { + %0 = alloc() : memref<2xf32> + shape.assuming_yield %arg2 : memref<2xf32> + } + // Confirm the alloc will be returned and dealloc'ed after its use. + %3 = shape.assuming %arg0 -> memref<2xf32> { + %2 = alloc() : memref<2xf32> + shape.assuming_yield %2 : memref<2xf32> + } + "linalg.copy"(%3, %arg3) : (memref<2xf32>, memref<2xf32>) -> () + return +} + +// CHECK-LABEL: func @assumingOp( +// CHECK-SAME: %[[ARG0:.*]]: !shape.witness, +// CHECK-SAME: %[[ARG1:.*]]: memref<2xf32>, +// CHECK-SAME: %[[ARG2:.*]]: memref<2xf32>) { +// CHECK: %[[UNUSED_RESULT:.*]] = shape.assuming %[[ARG0]] -> (memref<2xf32>) { +// CHECK: %[[ALLOC0:.*]] = alloc() : memref<2xf32> +// CHECK: dealloc %[[ALLOC0]] : memref<2xf32> +// CHECK: shape.assuming_yield %[[ARG1]] : memref<2xf32> +// CHECK: } +// CHECK: %[[ASSUMING_RESULT:.*]] = shape.assuming %[[ARG0]] -> (memref<2xf32>) { +// CHECK: %[[TMP_ALLOC:.*]] = alloc() : memref<2xf32> +// CHECK: %[[RETURNING_ALLOC:.*]] = alloc() : memref<2xf32> +// CHECK: linalg.copy(%[[TMP_ALLOC]], %[[RETURNING_ALLOC]]) : memref<2xf32>, memref<2xf32> +// CHECK: dealloc %[[TMP_ALLOC]] : memref<2xf32> +// CHECK: shape.assuming_yield %[[RETURNING_ALLOC]] : memref<2xf32> +// CHECK: } +// CHECK: linalg.copy(%[[ASSUMING_RESULT:.*]], %[[ARG2]]) : memref<2xf32>, memref<2xf32> +// CHECK: dealloc %[[ASSUMING_RESULT]] : memref<2xf32> +// CHECK: return +// CHECK: } + -- 2.7.4