//===----------------------------------------------------------------------===//
OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
- if (!operands[1])
- return nullptr;
+ if (operands.size() == 1)
+ return shapes().front();
// TODO: Support folding with more than 2 input shapes
if (shapes().size() > 2)
return nullptr;
+ if (!operands[1])
+ return nullptr;
+
auto rhsShape = llvm::to_vector<6>(
operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
if (rhsShape.empty())
}
static LogicalResult verify(BroadcastOp op) {
- // Ensure that AssumingAllOp contains at least one operand
- if (op.getNumOperands() < 2)
- return op.emitOpError("required at least 2 input shapes");
-
return verifyShapeOrExtentTensorOp(op);
}
+namespace {
+template <typename OpTy>
+struct RemoveDuplicateOperandsPattern : public OpRewritePattern<OpTy> {
+ using OpRewritePattern<OpTy>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(OpTy op,
+ PatternRewriter &rewriter) const override {
+ // Find unique operands.
+ SmallVector<Value, 2> unique;
+ for (Value v : op.getOperands()) {
+ if (!llvm::is_contained(unique, v))
+ unique.push_back(v);
+ }
+
+ // Reduce op to equivalent with unique operands.
+ if (unique.size() < op.getNumOperands()) {
+ rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), unique,
+ op.getAttrs());
+ return success();
+ }
+
+ return failure();
+ }
+};
+} // namespace
+
+void BroadcastOp::getCanonicalizationPatterns(
+ OwningRewritePatternList &patterns, MLIRContext *context) {
+ patterns.insert<RemoveDuplicateOperandsPattern<BroadcastOp>>(context);
+}
+
//===----------------------------------------------------------------------===//
// ConcatOp
//===----------------------------------------------------------------------===//
// IsBroadcastableOp
//===----------------------------------------------------------------------===//
-static LogicalResult verify(IsBroadcastableOp op) {
- // Ensure that AssumingAllOp contains at least one operand
- if (op.getNumOperands() < 2)
- return op.emitOpError("required at least 2 input shapes");
- return success();
+void IsBroadcastableOp::getCanonicalizationPatterns(
+ OwningRewritePatternList &patterns, MLIRContext *context) {
+ patterns.insert<RemoveDuplicateOperandsPattern<IsBroadcastableOp>>(context);
}
-namespace {
-struct IsBroadcastableCanonicalizationPattern
- : public OpRewritePattern<IsBroadcastableOp> {
- using OpRewritePattern<IsBroadcastableOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(IsBroadcastableOp op,
- PatternRewriter &rewriter) const override {
- // Find unique operands.
- SmallVector<Value, 2> unique;
- for (Value v : op.getOperands()) {
- if (!llvm::is_contained(unique, v))
- unique.push_back(v);
- }
-
- // Can always broadcast fewer than two shapes.
- if (unique.size() < 2) {
- rewriter.replaceOpWithNewOp<mlir::ConstantOp>(op,
- rewriter.getBoolAttr(true));
- return success();
- }
-
- // Reduce op to equivalent with unique operands.
- if (unique.size() < op.getNumOperands()) {
- rewriter.replaceOpWithNewOp<IsBroadcastableOp>(op, rewriter.getI1Type(),
- unique);
- return success();
- }
-
- return failure();
+OpFoldResult IsBroadcastableOp::fold(ArrayRef<Attribute> operands) {
+ // Can always broadcast fewer than two shapes.
+ if (operands.size() < 2) {
+ return BoolAttr::get(getContext(), true);
}
-};
-} // namespace
-void IsBroadcastableOp::getCanonicalizationPatterns(
- OwningRewritePatternList &patterns, MLIRContext *context) {
- patterns.insert<IsBroadcastableCanonicalizationPattern>(context);
+ return nullptr;
}
//===----------------------------------------------------------------------===//
// CHECK-SAME: (%[[A:.*]]: !shape.shape, %[[B:.*]]: !shape.shape)
func @is_broadcastable_on_duplicate_shapes(%a : !shape.shape, %b : !shape.shape)
-> i1 {
- // CHECK: %[[RES:.*]] = shape.is_broadcastable %[[A]], %[[B]]
+ // CHECK: %[[RES:.*]] = shape.is_broadcastable %[[A]], %[[B]] :
// CHECK: return %[[RES]]
%0 = shape.is_broadcastable %a, %b, %a, %a, %a, %b : !shape.shape,
!shape.shape, !shape.shape, !shape.shape, !shape.shape, !shape.shape
return %0 : i1
}
+
+// -----
+
+// CHECK-LABEL: @broadcast_on_same_shape
+// CHECK-SAME: (%[[SHAPE:.*]]: !shape.shape)
+func @broadcast_on_same_shape(%shape : !shape.shape) -> !shape.shape {
+ // CHECK-NOT: broadcast
+ // CHECK: return %[[SHAPE]]
+ %0 = shape.broadcast %shape, %shape, %shape : !shape.shape, !shape.shape,
+ !shape.shape -> !shape.shape
+ return %0 : !shape.shape
+}
+
+// -----
+
+// CHECK-LABEL: @broadcast_on_duplicate_shapes
+// CHECK-SAME: (%[[A:.*]]: !shape.shape, %[[B:.*]]: !shape.shape)
+func @broadcast_on_duplicate_shapes(%a : !shape.shape, %b : !shape.shape)
+ -> !shape.shape {
+ // CHECK: %[[RES:.*]] = shape.broadcast %[[A]], %[[B]] :
+ // CHECK: return %[[RES]]
+ %0 = shape.broadcast %a, %b, %a, %a, %a, %b : !shape.shape, !shape.shape,
+ !shape.shape, !shape.shape, !shape.shape, !shape.shape -> !shape.shape
+ return %0 : !shape.shape
+}