[mlir] Add a pass to remove all shape.cstr_ and assuming_ ops.
authorTres Popp <tpopp@google.com>
Tue, 9 Jun 2020 09:33:43 +0000 (11:33 +0200)
committerTres Popp <tpopp@google.com>
Thu, 18 Jun 2020 11:31:30 +0000 (13:31 +0200)
Summary:
This is to provide a utility to remove unsupported constraints or for
pipelines that happen to receive these but cannot lower them due to not
supporting assertions.

Differential Revision: https://reviews.llvm.org/D81560

mlir/include/mlir/Dialect/Shape/IR/Shape.h
mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
mlir/include/mlir/Dialect/Shape/Transforms/Passes.h
mlir/include/mlir/Dialect/Shape/Transforms/Passes.td
mlir/lib/Dialect/Shape/IR/Shape.cpp
mlir/lib/Dialect/Shape/Transforms/CMakeLists.txt
mlir/lib/Dialect/Shape/Transforms/RemoveShapeConstraints.cpp [new file with mode: 0644]
mlir/test/Dialect/Shape/remove-shape-constraints.mlir [new file with mode: 0644]

index 9949758..b78b730 100644 (file)
@@ -22,6 +22,8 @@
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 
 namespace mlir {
+class PatternRewriter;
+
 namespace shape {
 
 namespace ShapeTypes {
index a6f579c..49714c9 100644 (file)
@@ -375,7 +375,7 @@ def Shape_SizeToIndexOp : Shape_Op<"size_to_index", [NoSideEffect]> {
   let hasFolder = 1;
 }
 
-def Shape_YieldOp : Shape_Op<"yield", 
+def Shape_YieldOp : Shape_Op<"yield",
     [HasParent<"ReduceOp">,
      NoSideEffect,
      ReturnLike,
@@ -533,6 +533,14 @@ def Shape_AssumingOp : Shape_Op<"assuming",
   let printer = [{ return ::print(p, *this); }];
   let parser = [{ return ::parse$cppClass(parser, result); }];
 
+  let extraClassDeclaration = [{
+    // Inline the region into the region containing the AssumingOp and delete
+    // the AssumingOp.
+    //
+    // This does no checks on the inputs to the AssumingOp.
+    static void inlineRegionIntoParent(AssumingOp &op, PatternRewriter &rewriter);
+  }];
+
   let hasCanonicalizer = 1;
 }
 
index 7e60653..e8d2167 100644 (file)
@@ -18,6 +18,7 @@
 
 namespace mlir {
 
+class FunctionPass;
 class MLIRContext;
 class OwningRewritePatternList;
 class Pass;
@@ -30,6 +31,17 @@ std::unique_ptr<Pass> createShapeToShapeLowering();
 /// Collects a set of patterns to rewrite ops within the Shape dialect.
 void populateShapeRewritePatterns(MLIRContext *context,
                                   OwningRewritePatternList &patterns);
+
+// Collects a set of patterns to replace all constraints with passing witnesses.
+// This is intended to then allow all ShapeConstraint related ops and data to
+// have no effects and allow them to be freely removed such as through
+// canonicalization and dead code elimination.
+//
+// After this pass, no cstr_ operations exist.
+void populateRemoveShapeConstraintsPatterns(OwningRewritePatternList &patterns,
+                                            MLIRContext *ctx);
+std::unique_ptr<FunctionPass> createRemoveShapeConstraintsPass();
+
 } // end namespace mlir
 
 #endif // MLIR_DIALECT_SHAPE_TRANSFORMS_PASSES_H_
index 46dc4dc..022bd37 100644 (file)
 
 include "mlir/Pass/PassBase.td"
 
+def RemoveShapeConstraints : FunctionPass<"remove-shape-constraints"> {
+  let summary = "Replace all cstr_ ops with a true witness";
+  let constructor = "mlir::createRemoveShapeConstraintsPass()";
+}
+
 def ShapeToShapeLowering : FunctionPass<"shape-to-shape-lowering"> {
   let summary = "Legalize Shape dialect to be convertible to Standard";
   let constructor = "mlir::createShapeToShapeLowering()";
index 4a876e1..664c0cb 100644 (file)
@@ -168,22 +168,7 @@ struct AssumingWithTrue : public OpRewritePattern<AssumingOp> {
     if (!witness || !witness.passingAttr())
       return failure();
 
-    auto *blockBeforeAssuming = rewriter.getInsertionBlock();
-    auto *assumingBlock = op.getBody();
-    auto initPosition = rewriter.getInsertionPoint();
-    auto *blockAfterAssuming =
-        rewriter.splitBlock(blockBeforeAssuming, initPosition);
-
-    // Remove the AssumingOp and AssumingYieldOp.
-    auto &yieldOp = assumingBlock->back();
-    rewriter.inlineRegionBefore(op.doRegion(), blockAfterAssuming);
-    rewriter.replaceOp(op, yieldOp.getOperands());
-    rewriter.eraseOp(&yieldOp);
-
-    // Merge blocks together as there was no branching behavior from the
-    // AssumingOp.
-    rewriter.mergeBlocks(assumingBlock, blockBeforeAssuming);
-    rewriter.mergeBlocks(blockAfterAssuming, blockBeforeAssuming);
+    AssumingOp::inlineRegionIntoParent(op, rewriter);
     return success();
   }
 };
@@ -191,10 +176,30 @@ struct AssumingWithTrue : public OpRewritePattern<AssumingOp> {
 
 void AssumingOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns,
                                              MLIRContext *context) {
-  // If taking a passing witness, inline region
+  // If taking a passing witness, inline region.
   patterns.insert<AssumingWithTrue>(context);
 }
 
+void AssumingOp::inlineRegionIntoParent(AssumingOp &op,
+                                        PatternRewriter &rewriter) {
+  auto *blockBeforeAssuming = rewriter.getInsertionBlock();
+  auto *assumingBlock = op.getBody();
+  auto initPosition = rewriter.getInsertionPoint();
+  auto *blockAfterAssuming =
+      rewriter.splitBlock(blockBeforeAssuming, initPosition);
+
+  // Remove the AssumingOp and AssumingYieldOp.
+  auto &yieldOp = assumingBlock->back();
+  rewriter.inlineRegionBefore(op.doRegion(), blockAfterAssuming);
+  rewriter.replaceOp(op, yieldOp.getOperands());
+  rewriter.eraseOp(&yieldOp);
+
+  // Merge blocks together as there was no branching behavior from the
+  // AssumingOp.
+  rewriter.mergeBlocks(assumingBlock, blockBeforeAssuming);
+  rewriter.mergeBlocks(blockAfterAssuming, blockBeforeAssuming);
+}
+
 //===----------------------------------------------------------------------===//
 // AssumingAllOp
 //===----------------------------------------------------------------------===//
index 6f812b6..987f9c5 100644 (file)
@@ -1,4 +1,5 @@
 add_mlir_dialect_library(MLIRShapeOpsTransforms
+  RemoveShapeConstraints.cpp
   ShapeToShapeLowering.cpp
 
   ADDITIONAL_HEADER_DIRS
diff --git a/mlir/lib/Dialect/Shape/Transforms/RemoveShapeConstraints.cpp b/mlir/lib/Dialect/Shape/Transforms/RemoveShapeConstraints.cpp
new file mode 100644 (file)
index 0000000..641b4bc
--- /dev/null
@@ -0,0 +1,64 @@
+//===-- RemoveShapeConstraints.cpp - Remove Shape Cstr and Assuming Ops ---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "PassDetail.h"
+#include "mlir/Dialect/Shape/IR/Shape.h"
+#include "mlir/Dialect/Shape/Transforms/Passes.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+using namespace mlir;
+
+namespace {
+/// Removal patterns.
+class RemoveCstrBroadcastableOp
+    : public OpRewritePattern<shape::CstrBroadcastableOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(shape::CstrBroadcastableOp op,
+                                PatternRewriter &rewriter) const override {
+    rewriter.replaceOpWithNewOp<shape::ConstWitnessOp>(op.getOperation(), true);
+    return success();
+  }
+};
+
+class RemoveCstrEqOp : public OpRewritePattern<shape::CstrEqOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(shape::CstrEqOp op,
+                                PatternRewriter &rewriter) const override {
+    rewriter.replaceOpWithNewOp<shape::ConstWitnessOp>(op.getOperation(), true);
+    return success();
+  }
+};
+
+/// Removal pass.
+class RemoveShapeConstraintsPass
+    : public RemoveShapeConstraintsBase<RemoveShapeConstraintsPass> {
+
+  void runOnFunction() override {
+    MLIRContext &ctx = getContext();
+
+    OwningRewritePatternList patterns;
+    populateRemoveShapeConstraintsPatterns(patterns, &ctx);
+
+    applyPatternsAndFoldGreedily(getFunction(), patterns);
+  }
+};
+
+} // namespace
+
+void mlir::populateRemoveShapeConstraintsPatterns(
+    OwningRewritePatternList &patterns, MLIRContext *ctx) {
+  patterns.insert<RemoveCstrBroadcastableOp, RemoveCstrEqOp>(ctx);
+}
+
+std::unique_ptr<FunctionPass> mlir::createRemoveShapeConstraintsPass() {
+  return std::make_unique<RemoveShapeConstraintsPass>();
+}
diff --git a/mlir/test/Dialect/Shape/remove-shape-constraints.mlir b/mlir/test/Dialect/Shape/remove-shape-constraints.mlir
new file mode 100644 (file)
index 0000000..69887c6
--- /dev/null
@@ -0,0 +1,56 @@
+// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -remove-shape-constraints -canonicalize <%s | FileCheck %s --dump-input=fail --check-prefixes=CANON,CHECK-BOTH
+// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -remove-shape-constraints <%s | FileCheck %s --dump-input=fail --check-prefixes=REPLACE,CHECK-BOTH
+
+// -----
+// Check that cstr_broadcastable is removed.
+//
+// CHECK-BOTH: func @f
+func @f(%arg0 : !shape.shape, %arg1 : !shape.shape) -> index {
+  // REPLACE-NEXT: %[[WITNESS:.+]] = shape.const_witness true
+  // REPLACE-NOT: shape.cstr_eq
+  // REPLACE: shape.assuming %[[WITNESS]]
+  // CANON-NEXT: test.source
+  // CANON-NEXT: return
+  %0 = shape.cstr_broadcastable %arg0, %arg1
+  %1 = shape.assuming %0 -> index {
+    %2 = "test.source"() : () -> (index)
+    shape.assuming_yield %2 : index
+  }
+  return %1 : index
+}
+
+// -----
+// Check that cstr_eq is removed.
+//
+// CHECK-BOTH: func @f
+func @f(%arg0 : !shape.shape, %arg1 : !shape.shape) -> index {
+  // REPLACE-NEXT: %[[WITNESS:.+]] = shape.const_witness true
+  // REPLACE-NOT: shape.cstr_eq
+  // REPLACE: shape.assuming %[[WITNESS]]
+  // CANON-NEXT: test.source
+  // CANON-NEXT: return
+  %0 = shape.cstr_eq %arg0, %arg1
+  %1 = shape.assuming %0 -> index {
+    %2 = "test.source"() : () -> (index)
+    shape.assuming_yield %2 : index
+  }
+  return %1 : index
+}
+
+// -----
+// With a non-const value, we cannot fold away the code, but all constraints
+// should be removed still.
+//
+// CHECK-BOTH: func @f
+func @f(%arg0 : !shape.shape, %arg1 : !shape.shape) -> index {
+  // CANON-NEXT: test.source
+  // CANON-NEXT: return
+  %0 = shape.cstr_broadcastable %arg0, %arg1
+  %1 = shape.cstr_eq %arg0, %arg1
+  %2 = shape.assuming_all %0, %1
+  %3 = shape.assuming %0 -> index {
+    %4 = "test.source"() : () -> (index)
+    shape.assuming_yield %4 : index
+  }
+  return %3 : index
+}