From: Tres Popp Date: Tue, 9 Jun 2020 09:33:43 +0000 (+0200) Subject: [mlir] Add a pass to remove all shape.cstr_ and assuming_ ops. X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=3324598844a28850527b8abc8b83579ad7ab94a2;p=platform%2Fupstream%2Fllvm.git [mlir] Add a pass to remove all shape.cstr_ and assuming_ ops. Summary: This is to provide a utility to remove unsupported constraints or for pipelines that happen to receive these but cannot lower them due to not supporting assertions. Differential Revision: https://reviews.llvm.org/D81560 --- diff --git a/mlir/include/mlir/Dialect/Shape/IR/Shape.h b/mlir/include/mlir/Dialect/Shape/IR/Shape.h index 9949758..b78b730 100644 --- a/mlir/include/mlir/Dialect/Shape/IR/Shape.h +++ b/mlir/include/mlir/Dialect/Shape/IR/Shape.h @@ -22,6 +22,8 @@ #include "mlir/Interfaces/SideEffectInterfaces.h" namespace mlir { +class PatternRewriter; + namespace shape { namespace ShapeTypes { diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td index a6f579c..49714c9 100644 --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -375,7 +375,7 @@ def Shape_SizeToIndexOp : Shape_Op<"size_to_index", [NoSideEffect]> { let hasFolder = 1; } -def Shape_YieldOp : Shape_Op<"yield", +def Shape_YieldOp : Shape_Op<"yield", [HasParent<"ReduceOp">, NoSideEffect, ReturnLike, @@ -533,6 +533,14 @@ def Shape_AssumingOp : Shape_Op<"assuming", let printer = [{ return ::print(p, *this); }]; let parser = [{ return ::parse$cppClass(parser, result); }]; + let extraClassDeclaration = [{ + // Inline the region into the region containing the AssumingOp and delete + // the AssumingOp. + // + // This does no checks on the inputs to the AssumingOp. + static void inlineRegionIntoParent(AssumingOp &op, PatternRewriter &rewriter); + }]; + let hasCanonicalizer = 1; } diff --git a/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h b/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h index 7e60653..e8d2167 100644 --- a/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h @@ -18,6 +18,7 @@ namespace mlir { +class FunctionPass; class MLIRContext; class OwningRewritePatternList; class Pass; @@ -30,6 +31,17 @@ std::unique_ptr createShapeToShapeLowering(); /// Collects a set of patterns to rewrite ops within the Shape dialect. void populateShapeRewritePatterns(MLIRContext *context, OwningRewritePatternList &patterns); + +// Collects a set of patterns to replace all constraints with passing witnesses. +// This is intended to then allow all ShapeConstraint related ops and data to +// have no effects and allow them to be freely removed such as through +// canonicalization and dead code elimination. +// +// After this pass, no cstr_ operations exist. +void populateRemoveShapeConstraintsPatterns(OwningRewritePatternList &patterns, + MLIRContext *ctx); +std::unique_ptr createRemoveShapeConstraintsPass(); + } // end namespace mlir #endif // MLIR_DIALECT_SHAPE_TRANSFORMS_PASSES_H_ diff --git a/mlir/include/mlir/Dialect/Shape/Transforms/Passes.td b/mlir/include/mlir/Dialect/Shape/Transforms/Passes.td index 46dc4dc..022bd37 100644 --- a/mlir/include/mlir/Dialect/Shape/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Shape/Transforms/Passes.td @@ -11,6 +11,11 @@ include "mlir/Pass/PassBase.td" +def RemoveShapeConstraints : FunctionPass<"remove-shape-constraints"> { + let summary = "Replace all cstr_ ops with a true witness"; + let constructor = "mlir::createRemoveShapeConstraintsPass()"; +} + def ShapeToShapeLowering : FunctionPass<"shape-to-shape-lowering"> { let summary = "Legalize Shape dialect to be convertible to Standard"; let constructor = "mlir::createShapeToShapeLowering()"; diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp index 4a876e1..664c0cb 100644 --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -168,22 +168,7 @@ struct AssumingWithTrue : public OpRewritePattern { if (!witness || !witness.passingAttr()) return failure(); - auto *blockBeforeAssuming = rewriter.getInsertionBlock(); - auto *assumingBlock = op.getBody(); - auto initPosition = rewriter.getInsertionPoint(); - auto *blockAfterAssuming = - rewriter.splitBlock(blockBeforeAssuming, initPosition); - - // Remove the AssumingOp and AssumingYieldOp. - auto &yieldOp = assumingBlock->back(); - rewriter.inlineRegionBefore(op.doRegion(), blockAfterAssuming); - rewriter.replaceOp(op, yieldOp.getOperands()); - rewriter.eraseOp(&yieldOp); - - // Merge blocks together as there was no branching behavior from the - // AssumingOp. - rewriter.mergeBlocks(assumingBlock, blockBeforeAssuming); - rewriter.mergeBlocks(blockAfterAssuming, blockBeforeAssuming); + AssumingOp::inlineRegionIntoParent(op, rewriter); return success(); } }; @@ -191,10 +176,30 @@ struct AssumingWithTrue : public OpRewritePattern { void AssumingOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns, MLIRContext *context) { - // If taking a passing witness, inline region + // If taking a passing witness, inline region. patterns.insert(context); } +void AssumingOp::inlineRegionIntoParent(AssumingOp &op, + PatternRewriter &rewriter) { + auto *blockBeforeAssuming = rewriter.getInsertionBlock(); + auto *assumingBlock = op.getBody(); + auto initPosition = rewriter.getInsertionPoint(); + auto *blockAfterAssuming = + rewriter.splitBlock(blockBeforeAssuming, initPosition); + + // Remove the AssumingOp and AssumingYieldOp. + auto &yieldOp = assumingBlock->back(); + rewriter.inlineRegionBefore(op.doRegion(), blockAfterAssuming); + rewriter.replaceOp(op, yieldOp.getOperands()); + rewriter.eraseOp(&yieldOp); + + // Merge blocks together as there was no branching behavior from the + // AssumingOp. + rewriter.mergeBlocks(assumingBlock, blockBeforeAssuming); + rewriter.mergeBlocks(blockAfterAssuming, blockBeforeAssuming); +} + //===----------------------------------------------------------------------===// // AssumingAllOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Shape/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Shape/Transforms/CMakeLists.txt index 6f812b6..987f9c5 100644 --- a/mlir/lib/Dialect/Shape/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Shape/Transforms/CMakeLists.txt @@ -1,4 +1,5 @@ add_mlir_dialect_library(MLIRShapeOpsTransforms + RemoveShapeConstraints.cpp ShapeToShapeLowering.cpp ADDITIONAL_HEADER_DIRS diff --git a/mlir/lib/Dialect/Shape/Transforms/RemoveShapeConstraints.cpp b/mlir/lib/Dialect/Shape/Transforms/RemoveShapeConstraints.cpp new file mode 100644 index 0000000..641b4bc --- /dev/null +++ b/mlir/lib/Dialect/Shape/Transforms/RemoveShapeConstraints.cpp @@ -0,0 +1,64 @@ +//===-- RemoveShapeConstraints.cpp - Remove Shape Cstr and Assuming Ops ---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" +#include "mlir/Dialect/Shape/IR/Shape.h" +#include "mlir/Dialect/Shape/Transforms/Passes.h" +#include "mlir/Transforms/DialectConversion.h" + +using namespace mlir; + +namespace { +/// Removal patterns. +class RemoveCstrBroadcastableOp + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(shape::CstrBroadcastableOp op, + PatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp(op.getOperation(), true); + return success(); + } +}; + +class RemoveCstrEqOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(shape::CstrEqOp op, + PatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp(op.getOperation(), true); + return success(); + } +}; + +/// Removal pass. +class RemoveShapeConstraintsPass + : public RemoveShapeConstraintsBase { + + void runOnFunction() override { + MLIRContext &ctx = getContext(); + + OwningRewritePatternList patterns; + populateRemoveShapeConstraintsPatterns(patterns, &ctx); + + applyPatternsAndFoldGreedily(getFunction(), patterns); + } +}; + +} // namespace + +void mlir::populateRemoveShapeConstraintsPatterns( + OwningRewritePatternList &patterns, MLIRContext *ctx) { + patterns.insert(ctx); +} + +std::unique_ptr mlir::createRemoveShapeConstraintsPass() { + return std::make_unique(); +} diff --git a/mlir/test/Dialect/Shape/remove-shape-constraints.mlir b/mlir/test/Dialect/Shape/remove-shape-constraints.mlir new file mode 100644 index 0000000..69887c6 --- /dev/null +++ b/mlir/test/Dialect/Shape/remove-shape-constraints.mlir @@ -0,0 +1,56 @@ +// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -remove-shape-constraints -canonicalize <%s | FileCheck %s --dump-input=fail --check-prefixes=CANON,CHECK-BOTH +// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -remove-shape-constraints <%s | FileCheck %s --dump-input=fail --check-prefixes=REPLACE,CHECK-BOTH + +// ----- +// Check that cstr_broadcastable is removed. +// +// CHECK-BOTH: func @f +func @f(%arg0 : !shape.shape, %arg1 : !shape.shape) -> index { + // REPLACE-NEXT: %[[WITNESS:.+]] = shape.const_witness true + // REPLACE-NOT: shape.cstr_eq + // REPLACE: shape.assuming %[[WITNESS]] + // CANON-NEXT: test.source + // CANON-NEXT: return + %0 = shape.cstr_broadcastable %arg0, %arg1 + %1 = shape.assuming %0 -> index { + %2 = "test.source"() : () -> (index) + shape.assuming_yield %2 : index + } + return %1 : index +} + +// ----- +// Check that cstr_eq is removed. +// +// CHECK-BOTH: func @f +func @f(%arg0 : !shape.shape, %arg1 : !shape.shape) -> index { + // REPLACE-NEXT: %[[WITNESS:.+]] = shape.const_witness true + // REPLACE-NOT: shape.cstr_eq + // REPLACE: shape.assuming %[[WITNESS]] + // CANON-NEXT: test.source + // CANON-NEXT: return + %0 = shape.cstr_eq %arg0, %arg1 + %1 = shape.assuming %0 -> index { + %2 = "test.source"() : () -> (index) + shape.assuming_yield %2 : index + } + return %1 : index +} + +// ----- +// With a non-const value, we cannot fold away the code, but all constraints +// should be removed still. +// +// CHECK-BOTH: func @f +func @f(%arg0 : !shape.shape, %arg1 : !shape.shape) -> index { + // CANON-NEXT: test.source + // CANON-NEXT: return + %0 = shape.cstr_broadcastable %arg0, %arg1 + %1 = shape.cstr_eq %arg0, %arg1 + %2 = shape.assuming_all %0, %1 + %3 = shape.assuming %0 -> index { + %4 = "test.source"() : () -> (index) + shape.assuming_yield %4 : index + } + return %3 : index +}