From: Tres Popp Date: Wed, 20 May 2020 13:52:55 +0000 (+0200) Subject: [mlir] Canonicalization and folding of shape.cstr_broadcastable X-Git-Tag: llvmorg-12-init~4039 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=6aab70945915ef1d565f1146734416029549a5a9;p=platform%2Fupstream%2Fllvm.git [mlir] Canonicalization and folding of shape.cstr_broadcastable This allows replacing of this op with a true witness in the case of both inputs being const_shapes and being found to be broadcastable. Differential Revision: https://reviews.llvm.org/D80304 --- diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td index 075050a..a05273f 100644 --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -531,7 +531,7 @@ def Shape_AssumingYieldOp : Shape_Op<"assuming_yield", let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?"; } -def Shape_CstrBroadcastableOp : Shape_Op<"cstr_broadcastable", []> { +def Shape_CstrBroadcastableOp : Shape_Op<"cstr_broadcastable", [Commutative]> { let summary = "Determines if 2 shapes can be successfully broadcasted"; let description = [{ Given 2 input shapes, return a witness specifying if they are broadcastable. @@ -550,6 +550,9 @@ def Shape_CstrBroadcastableOp : Shape_Op<"cstr_broadcastable", []> { let results = (outs Shape_WitnessType:$result); let assemblyFormat = "$lhs `,` $rhs attr-dict"; + + let hasCanonicalizer = 1; + let hasFolder = 1; } def Shape_CstrEqOp : Shape_Op<"cstr_eq", []> { diff --git a/mlir/lib/Dialect/Shape/CMakeLists.txt b/mlir/lib/Dialect/Shape/CMakeLists.txt index 2af3de8..0a03849 100644 --- a/mlir/lib/Dialect/Shape/CMakeLists.txt +++ b/mlir/lib/Dialect/Shape/CMakeLists.txt @@ -1,3 +1,7 @@ +set(LLVM_TARGET_DEFINITIONS IR/ShapeCanonicalization.td) +mlir_tablegen(IR/ShapeCanonicalization.inc -gen-rewriters) +add_public_tablegen_target(MLIRShapeCanonicalizationIncGen) + add_mlir_dialect_library(MLIRShape IR/Shape.cpp diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp index 2b05c4c..3a8831c 100644 --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -18,6 +18,10 @@ using namespace mlir; using namespace mlir::shape; +namespace { +#include "IR/ShapeCanonicalization.inc" +} + ShapeDialect::ShapeDialect(MLIRContext *context) : Dialect(getDialectNamespace(), context) { addOperations< @@ -261,6 +265,32 @@ static ParseResult parseConstShapeOp(OpAsmParser &parser, OpFoldResult ConstShapeOp::fold(ArrayRef) { return shapeAttr(); } //===----------------------------------------------------------------------===// +// CstrBroadcastableOp +//===----------------------------------------------------------------------===// + +void CstrBroadcastableOp::getCanonicalizationPatterns( + OwningRewritePatternList &patterns, MLIRContext *context) { + // If inputs are equal, return passing witness + patterns.insert(context); +} + +OpFoldResult CstrBroadcastableOp::fold(ArrayRef operands) { + if (!operands[0] || !operands[1]) + return nullptr; + auto lhsShape = llvm::to_vector<6>( + operands[0].cast().getValues()); + auto rhsShape = llvm::to_vector<6>( + operands[1].cast().getValues()); + SmallVector resultShape; + if (OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape)) + return BoolAttr::get(true, getContext()); + + // Because a failing witness result here represents an eventual assertion + // failure, we do not replace it with a constant witness. + return nullptr; +} + +//===----------------------------------------------------------------------===// // ConstSizeOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td b/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td new file mode 100644 index 0000000..9a73a88 --- /dev/null +++ b/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td @@ -0,0 +1,8 @@ +include "mlir/Dialect/Shape/IR/ShapeOps.td" + +def EqualBinaryOperands : Constraint>; + +// Canonicalization patterns. +def CstrBroadcastableEqOps : Pat<(Shape_CstrBroadcastableOp:$op $lhs, $rhs), + (Shape_ConstWitnessOp ConstBoolAttrTrue), + [(EqualBinaryOperands $lhs, $rhs)]>; diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir index 0f92c18..93ce36a 100644 --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -267,3 +267,59 @@ func @f(%arg0 : !shape.shape, %arg1 : !shape.shape) -> !shape.shape { %1 = shape.any %arg0, %arg1 return %1 : !shape.shape } + +// ----- +// Broadcastable with broadcastable constant shapes can be removed. +// CHECK-LABEL: func @f +func @f() { + // CHECK-NEXT: shape.const_witness true + // CHECK-NEXT: consume.witness + // CHECK-NEXT: return + %cs0 = shape.const_shape [3, 1] + %cs1 = shape.const_shape [1, 5] + %0 = shape.cstr_broadcastable %cs0, %cs1 + "consume.witness"(%0) : (!shape.witness) -> () + return +} + +// ----- +// Broadcastable with non-broadcastable constant shapes is always false +// CHECK-LABEL: func @f +func @f() { + // CHECK-NEXT: shape.const_shape + // CHECK-NEXT: shape.const_shape + // CHECK-NEXT: shape.cstr_broadcastable + // CHECK-NEXT: consume.witness + // CHECK-NEXT: return + %cs0 = shape.const_shape [1, 3] + %cs1 = shape.const_shape [1, 5] + %0 = shape.cstr_broadcastable %cs0, %cs1 + "consume.witness"(%0) : (!shape.witness) -> () + return +} + +// ----- +// Broadcastable without guaranteed broadcastable shapes cannot be removed. +// CHECK-LABEL: func @f +func @f(%arg0 : !shape.shape) { + // CHECK-NEXT: shape.const_shape + // CHECK-NEXT: shape.cstr_broadcastable + // CHECK-NEXT: consume.witness + // CHECK-NEXT: return + %cs0 = shape.const_shape [1,3] + %0 = shape.cstr_broadcastable %arg0, %cs0 + "consume.witness"(%0) : (!shape.witness) -> () + return +} + +// ----- +// Broadcastable with non-constant but known equal shapes can be removed. +// CHECK-LABEL: func @f +func @f(%arg0 : !shape.shape) { + // CHECK-NEXT: shape.const_witness true + // CHECK-NEXT: consume.witness + // CHECK-NEXT: return + %0 = shape.cstr_broadcastable %arg0, %arg0 + "consume.witness"(%0) : (!shape.witness) -> () + return +}