[MLIR] Add canoncalization for `shape.is_broadcastable`
authorFrederik Gossen <frgossen@google.com>
Thu, 11 Mar 2021 09:09:26 +0000 (10:09 +0100)
committerFrederik Gossen <frgossen@google.com>
Thu, 11 Mar 2021 09:10:34 +0000 (10:10 +0100)
Canonicalize `is_broadcastable` to constant true if fewer than 2 unique shape
operands. Eliminate redundant operands, otherwise.

Differential Revision: https://reviews.llvm.org/D98361

mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
mlir/lib/Dialect/Shape/IR/Shape.cpp
mlir/test/Dialect/Shape/canonicalize.mlir

index a176e6d..ae14d81 100644 (file)
@@ -277,9 +277,10 @@ def Shape_IsBroadcastableOp : Shape_Op<"is_broadcastable",
     };
   }];
 
+  let hasCanonicalizer = 1;
+
   let assemblyFormat = "$shapes attr-dict `:` type($shapes)";
   let verifier = [{ return ::verify(*this); }];
-
 }
 
 def Shape_RankOp : Shape_Op<"rank", [NoSideEffect]> {
index 719f4bd..741d065 100644 (file)
@@ -779,6 +779,44 @@ static LogicalResult verify(IsBroadcastableOp op) {
   return success();
 }
 
+namespace {
+struct IsBroadcastableCanonicalizationPattern
+    : public OpRewritePattern<IsBroadcastableOp> {
+  using OpRewritePattern<IsBroadcastableOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(IsBroadcastableOp op,
+                                PatternRewriter &rewriter) const override {
+    // Find unique operands.
+    SmallVector<Value, 2> unique;
+    for (Value v : op.getOperands()) {
+      if (!llvm::is_contained(unique, v))
+        unique.push_back(v);
+    }
+
+    // Can always broadcast fewer than two shapes.
+    if (unique.size() < 2) {
+      rewriter.replaceOpWithNewOp<mlir::ConstantOp>(op,
+                                                    rewriter.getBoolAttr(true));
+      return success();
+    }
+
+    // Reduce op to equivalent with unique operands.
+    if (unique.size() < op.getNumOperands()) {
+      rewriter.replaceOpWithNewOp<IsBroadcastableOp>(op, rewriter.getI1Type(),
+                                                     unique);
+      return success();
+    }
+
+    return failure();
+  }
+};
+} // namespace
+
+void IsBroadcastableOp::getCanonicalizationPatterns(
+    OwningRewritePatternList &patterns, MLIRContext *context) {
+  patterns.insert<IsBroadcastableCanonicalizationPattern>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // RankOp
 //===----------------------------------------------------------------------===//
index 5ee495d..5589221 100644 (file)
@@ -1069,3 +1069,28 @@ func @fold_tensor.cast_of_const_shape_returned_dynamic(%arg: i1) -> tensor<?xind
   %1 = tensor.cast %0 : tensor<1xindex> to tensor<?xindex>
   return %1 : tensor<?xindex>
 }
+
+// -----
+
+// CHECK-LABEL: @is_broadcastable_on_same_shape
+func @is_broadcastable_on_same_shape(%shape : !shape.shape) -> i1 {
+  // CHECK-NOT: is_broadcastable
+  // CHECK: %[[RES:.*]] = constant true
+  // CHECK: return %[[RES]]
+  %0 = shape.is_broadcastable %shape, %shape, %shape
+      : !shape.shape, !shape.shape, !shape.shape
+  return %0 : i1
+}
+
+// -----
+
+// CHECK-LABEL: @is_broadcastable_on_duplicate_shapes
+// CHECK-SAME: (%[[A:.*]]: !shape.shape, %[[B:.*]]: !shape.shape)
+func @is_broadcastable_on_duplicate_shapes(%a : !shape.shape, %b : !shape.shape)
+    -> i1 {
+  // CHECK: %[[RES:.*]] = shape.is_broadcastable %[[A]], %[[B]]
+  // CHECK: return %[[RES]]
+  %0 = shape.is_broadcastable %a, %b, %a, %a, %a, %b : !shape.shape,
+      !shape.shape, !shape.shape, !shape.shape, !shape.shape, !shape.shape
+  return %0 : i1
+}