[MLIR] Add canonicalization for `shape.broadcast`
authorFrederik Gossen <frgossen@google.com>
Mon, 15 Mar 2021 09:10:07 +0000 (10:10 +0100)
committerFrederik Gossen <frgossen@google.com>
Mon, 15 Mar 2021 09:11:28 +0000 (10:11 +0100)
Remove redundant operands and fold if only one left.

Differential Revision: https://reviews.llvm.org/D98402

mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
mlir/lib/Dialect/Shape/IR/Shape.cpp
mlir/test/Dialect/Shape/canonicalize.mlir
mlir/test/Dialect/Shape/invalid.mlir

index ae14d81..a17be38 100644 (file)
@@ -89,6 +89,7 @@ def Shape_BroadcastOp : Shape_Op<"broadcast", [Commutative, NoSideEffect]> {
   ];
 
   let hasFolder = 1;
+  let hasCanonicalizer = 1;
   let verifier = [{ return ::verify(*this); }];
 }
 
@@ -277,10 +278,10 @@ def Shape_IsBroadcastableOp : Shape_Op<"is_broadcastable",
     };
   }];
 
+  let hasFolder = 1;
   let hasCanonicalizer = 1;
 
   let assemblyFormat = "$shapes attr-dict `:` type($shapes)";
-  let verifier = [{ return ::verify(*this); }];
 }
 
 def Shape_RankOp : Shape_Op<"rank", [NoSideEffect]> {
index 741d065..472197c 100644 (file)
@@ -354,13 +354,16 @@ static LogicalResult verify(AssumingAllOp op) {
 //===----------------------------------------------------------------------===//
 
 OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
-  if (!operands[1])
-    return nullptr;
+  if (operands.size() == 1)
+    return shapes().front();
 
   // TODO: Support folding with more than 2 input shapes
   if (shapes().size() > 2)
     return nullptr;
 
+  if (!operands[1])
+    return nullptr;
+
   auto rhsShape = llvm::to_vector<6>(
       operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
   if (rhsShape.empty())
@@ -384,13 +387,40 @@ OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
 }
 
 static LogicalResult verify(BroadcastOp op) {
-  // Ensure that AssumingAllOp contains at least one operand
-  if (op.getNumOperands() < 2)
-    return op.emitOpError("required at least 2 input shapes");
-
   return verifyShapeOrExtentTensorOp(op);
 }
 
+namespace {
+template <typename OpTy>
+struct RemoveDuplicateOperandsPattern : public OpRewritePattern<OpTy> {
+  using OpRewritePattern<OpTy>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(OpTy op,
+                                PatternRewriter &rewriter) const override {
+    // Find unique operands.
+    SmallVector<Value, 2> unique;
+    for (Value v : op.getOperands()) {
+      if (!llvm::is_contained(unique, v))
+        unique.push_back(v);
+    }
+
+    // Reduce op to equivalent with unique operands.
+    if (unique.size() < op.getNumOperands()) {
+      rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), unique,
+                                        op.getAttrs());
+      return success();
+    }
+
+    return failure();
+  }
+};
+} // namespace
+
+void BroadcastOp::getCanonicalizationPatterns(
+    OwningRewritePatternList &patterns, MLIRContext *context) {
+  patterns.insert<RemoveDuplicateOperandsPattern<BroadcastOp>>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // ConcatOp
 //===----------------------------------------------------------------------===//
@@ -772,49 +802,18 @@ void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape,
 // IsBroadcastableOp
 //===----------------------------------------------------------------------===//
 
-static LogicalResult verify(IsBroadcastableOp op) {
-  // Ensure that AssumingAllOp contains at least one operand
-  if (op.getNumOperands() < 2)
-    return op.emitOpError("required at least 2 input shapes");
-  return success();
+void IsBroadcastableOp::getCanonicalizationPatterns(
+    OwningRewritePatternList &patterns, MLIRContext *context) {
+  patterns.insert<RemoveDuplicateOperandsPattern<IsBroadcastableOp>>(context);
 }
 
-namespace {
-struct IsBroadcastableCanonicalizationPattern
-    : public OpRewritePattern<IsBroadcastableOp> {
-  using OpRewritePattern<IsBroadcastableOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(IsBroadcastableOp op,
-                                PatternRewriter &rewriter) const override {
-    // Find unique operands.
-    SmallVector<Value, 2> unique;
-    for (Value v : op.getOperands()) {
-      if (!llvm::is_contained(unique, v))
-        unique.push_back(v);
-    }
-
-    // Can always broadcast fewer than two shapes.
-    if (unique.size() < 2) {
-      rewriter.replaceOpWithNewOp<mlir::ConstantOp>(op,
-                                                    rewriter.getBoolAttr(true));
-      return success();
-    }
-
-    // Reduce op to equivalent with unique operands.
-    if (unique.size() < op.getNumOperands()) {
-      rewriter.replaceOpWithNewOp<IsBroadcastableOp>(op, rewriter.getI1Type(),
-                                                     unique);
-      return success();
-    }
-
-    return failure();
+OpFoldResult IsBroadcastableOp::fold(ArrayRef<Attribute> operands) {
+  // Can always broadcast fewer than two shapes.
+  if (operands.size() < 2) {
+    return BoolAttr::get(getContext(), true);
   }
-};
-} // namespace
 
-void IsBroadcastableOp::getCanonicalizationPatterns(
-    OwningRewritePatternList &patterns, MLIRContext *context) {
-  patterns.insert<IsBroadcastableCanonicalizationPattern>(context);
+  return nullptr;
 }
 
 //===----------------------------------------------------------------------===//
index 5589221..53f27e4 100644 (file)
@@ -1088,9 +1088,34 @@ func @is_broadcastable_on_same_shape(%shape : !shape.shape) -> i1 {
 // CHECK-SAME: (%[[A:.*]]: !shape.shape, %[[B:.*]]: !shape.shape)
 func @is_broadcastable_on_duplicate_shapes(%a : !shape.shape, %b : !shape.shape)
     -> i1 {
-  // CHECK: %[[RES:.*]] = shape.is_broadcastable %[[A]], %[[B]]
+  // CHECK: %[[RES:.*]] = shape.is_broadcastable %[[A]], %[[B]] :
   // CHECK: return %[[RES]]
   %0 = shape.is_broadcastable %a, %b, %a, %a, %a, %b : !shape.shape,
       !shape.shape, !shape.shape, !shape.shape, !shape.shape, !shape.shape
   return %0 : i1
 }
+
+// -----
+
+// CHECK-LABEL: @broadcast_on_same_shape
+// CHECK-SAME: (%[[SHAPE:.*]]: !shape.shape)
+func @broadcast_on_same_shape(%shape : !shape.shape) -> !shape.shape {
+  // CHECK-NOT: broadcast
+  // CHECK: return %[[SHAPE]]
+  %0 = shape.broadcast %shape, %shape, %shape : !shape.shape, !shape.shape,
+      !shape.shape -> !shape.shape
+  return %0 : !shape.shape
+}
+
+// -----
+
+// CHECK-LABEL: @broadcast_on_duplicate_shapes
+// CHECK-SAME: (%[[A:.*]]: !shape.shape, %[[B:.*]]: !shape.shape)
+func @broadcast_on_duplicate_shapes(%a : !shape.shape, %b : !shape.shape)
+    -> !shape.shape {
+  // CHECK: %[[RES:.*]] = shape.broadcast %[[A]], %[[B]] :
+  // CHECK: return %[[RES]]
+  %0 = shape.broadcast %a, %b, %a, %a, %a, %b : !shape.shape, !shape.shape,
+      !shape.shape, !shape.shape, !shape.shape, !shape.shape -> !shape.shape
+  return %0 : !shape.shape
+}
index d685e67..d42f0fa 100644 (file)
@@ -249,18 +249,8 @@ module attributes {shape.lib = @fn} { }
 
 // -----
 
-func @fn(%arg: !shape.shape) -> i1 {
-  // expected-error@+1 {{required at least 2 input shapes}}
-  %0 = shape.is_broadcastable %arg : !shape.shape
-  return %0 : i1
-}
-
-// -----
-
 func @fn(%arg: !shape.shape) -> !shape.witness {
   // expected-error@+1 {{required at least 2 input shapes}}
   %0 = shape.cstr_broadcastable %arg : !shape.shape
   return %0 : !shape.witness
 }
-
-