From b975e3b5aa8c6c8b608302997a3bf0fda06bf8d8 Mon Sep 17 00:00:00 2001 From: Frederik Gossen Date: Thu, 11 Mar 2021 10:09:26 +0100 Subject: [PATCH] [MLIR] Add canoncalization for `shape.is_broadcastable` Canonicalize `is_broadcastable` to constant true if fewer than 2 unique shape operands. Eliminate redundant operands, otherwise. Differential Revision: https://reviews.llvm.org/D98361 --- mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td | 3 +- mlir/lib/Dialect/Shape/IR/Shape.cpp | 38 ++++++++++++++++++++++++++ mlir/test/Dialect/Shape/canonicalize.mlir | 25 +++++++++++++++++ 3 files changed, 65 insertions(+), 1 deletion(-) diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td index a176e6d..ae14d81 100644 --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -277,9 +277,10 @@ def Shape_IsBroadcastableOp : Shape_Op<"is_broadcastable", }; }]; + let hasCanonicalizer = 1; + let assemblyFormat = "$shapes attr-dict `:` type($shapes)"; let verifier = [{ return ::verify(*this); }]; - } def Shape_RankOp : Shape_Op<"rank", [NoSideEffect]> { diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp index 719f4bd..741d065 100644 --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -779,6 +779,44 @@ static LogicalResult verify(IsBroadcastableOp op) { return success(); } +namespace { +struct IsBroadcastableCanonicalizationPattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(IsBroadcastableOp op, + PatternRewriter &rewriter) const override { + // Find unique operands. + SmallVector unique; + for (Value v : op.getOperands()) { + if (!llvm::is_contained(unique, v)) + unique.push_back(v); + } + + // Can always broadcast fewer than two shapes. + if (unique.size() < 2) { + rewriter.replaceOpWithNewOp(op, + rewriter.getBoolAttr(true)); + return success(); + } + + // Reduce op to equivalent with unique operands. + if (unique.size() < op.getNumOperands()) { + rewriter.replaceOpWithNewOp(op, rewriter.getI1Type(), + unique); + return success(); + } + + return failure(); + } +}; +} // namespace + +void IsBroadcastableOp::getCanonicalizationPatterns( + OwningRewritePatternList &patterns, MLIRContext *context) { + patterns.insert(context); +} + //===----------------------------------------------------------------------===// // RankOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir index 5ee495d..5589221 100644 --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -1069,3 +1069,28 @@ func @fold_tensor.cast_of_const_shape_returned_dynamic(%arg: i1) -> tensor to tensor return %1 : tensor } + +// ----- + +// CHECK-LABEL: @is_broadcastable_on_same_shape +func @is_broadcastable_on_same_shape(%shape : !shape.shape) -> i1 { + // CHECK-NOT: is_broadcastable + // CHECK: %[[RES:.*]] = constant true + // CHECK: return %[[RES]] + %0 = shape.is_broadcastable %shape, %shape, %shape + : !shape.shape, !shape.shape, !shape.shape + return %0 : i1 +} + +// ----- + +// CHECK-LABEL: @is_broadcastable_on_duplicate_shapes +// CHECK-SAME: (%[[A:.*]]: !shape.shape, %[[B:.*]]: !shape.shape) +func @is_broadcastable_on_duplicate_shapes(%a : !shape.shape, %b : !shape.shape) + -> i1 { + // CHECK: %[[RES:.*]] = shape.is_broadcastable %[[A]], %[[B]] + // CHECK: return %[[RES]] + %0 = shape.is_broadcastable %a, %b, %a, %a, %a, %b : !shape.shape, + !shape.shape, !shape.shape, !shape.shape, !shape.shape, !shape.shape + return %0 : i1 +} -- 2.7.4