From 89d900b2a1c11582a0a4396921282aa8f365d901 Mon Sep 17 00:00:00 2001 From: Tres Popp Date: Mon, 15 Feb 2021 10:58:25 +0100 Subject: [PATCH] [mlir] Add error message on shape.broadcast verification failure --- mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td | 5 +---- mlir/lib/Dialect/Shape/IR/Shape.cpp | 8 ++++++++ 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td index f917f2e..20b0706 100644 --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -89,10 +89,7 @@ def Shape_BroadcastOp : Shape_Op<"broadcast", [Commutative]> { ]; let hasFolder = 1; - let verifier = [{ - return success(succeeded(::verifyShapeOrExtentTensorOp(*this)) && - getNumOperands() >= 2); - }]; + let verifier = [{ return ::verify(*this); }]; } def Shape_ConstShapeOp : Shape_Op<"const_shape", [ConstantLike, NoSideEffect]> { diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp index 9657f95..8c75bdc 100644 --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -383,6 +383,14 @@ OpFoldResult BroadcastOp::fold(ArrayRef operands) { return builder.getIndexTensorAttr(resultShape); } +static LogicalResult verify(BroadcastOp op) { + // Ensure that AssumingAllOp contains at least one operand + if (op.getNumOperands() < 2) + return op.emitOpError("required at least 2 input shapes"); + + return verifyShapeOrExtentTensorOp(op); +} + //===----------------------------------------------------------------------===// // ConcatOp //===----------------------------------------------------------------------===// -- 2.7.4