let hasFolder = 1;
let hasCanonicalizer = 1;
- let verifier = [{ return ::verify(*this); }];
+ let verifier = [{ return ::verifySizeOrIndexOp(*this); }];
}
def Shape_ToExtentTensorOp : Shape_Op<"to_extent_tensor", [NoSideEffect]> {
}];
let hasFolder = 1;
- let verifier = [{ return ::verify(*this); }];
+ let verifier = [{ return ::verifySizeOrIndexOp(*this); }];
}
def Shape_IndexToSizeOp : Shape_Op<"index_to_size", [NoSideEffect]> {
$lhs `,` $rhs `:` type($lhs) `,` type($rhs) `->` type($result) attr-dict
}];
- let verifier = [{ return ::verify(*this); }];
+ let verifier = [{ return ::verifySizeOrIndexOp(*this); }];
let hasFolder = 1;
}
let assemblyFormat = "$arg `:` type($arg) `->` type($result) attr-dict";
- let verifier = [{ return ::verify(*this); }];
+ let verifier = [{ return ::verifyShapeOrExtentTensorOp(*this); }];
let hasFolder = 1;
}
return RankedTensorType::get({ShapedType::kDynamicSize}, IndexType::get(ctx));
}
-static bool isErrorPropagationPossible(ArrayRef<Type> operandTypes) {
+static bool isErrorPropagationPossible(TypeRange operandTypes) {
for (Type ty : operandTypes)
if (ty.isa<SizeType>() || ty.isa<ShapeType>() || ty.isa<ValueShapeType>())
return true;
return false;
}
+static LogicalResult verifySizeOrIndexOp(Operation *op) {
+ assert(op != nullptr && op->getNumResults() == 1);
+ Type resultTy = op->getResultTypes().front();
+ if (isErrorPropagationPossible(op->getOperandTypes())) {
+ if (!resultTy.isa<SizeType>())
+ return op->emitOpError()
+ << "if at least one of the operands can hold error values then "
+ "the result must be of type `size` to propagate them";
+ }
+ return success();
+}
+
+static LogicalResult verifyShapeOrExtentTensorOp(Operation *op) {
+ assert(op != nullptr && op->getNumResults() == 1);
+ Type resultTy = op->getResultTypes().front();
+ if (isErrorPropagationPossible(op->getOperandTypes())) {
+ if (!resultTy.isa<ShapeType>())
+ return op->emitOpError()
+ << "if at least one of the operands can hold error values then "
+ "the result must be of type `shape` to propagate them";
+ }
+ return success();
+}
+
ShapeDialect::ShapeDialect(MLIRContext *context)
: Dialect(getDialectNamespace(), context) {
addOperations<
// GetExtentOp
//===----------------------------------------------------------------------===//
-static LogicalResult verify(GetExtentOp op) {
- Type shapeTy = op.shape().getType();
- Type dimTy = op.dim().getType();
- Type extentTy = op.extent().getType();
- if (isErrorPropagationPossible({shapeTy, dimTy})) {
- if (!extentTy.isa<SizeType>())
- op.emitError()
- << "if at least one of the operands can hold error values then the "
- "result must be of type `size` to propagate them";
- } else {
- if (extentTy.isa<SizeType>())
- op.emitError() << "if none of the operands can hold error values then "
- "the result must be of type `index`";
- }
- return success();
-}
-
Optional<int64_t> GetExtentOp::getConstantDim() {
if (auto constSizeOp = dim().getDefiningOp<ConstSizeOp>())
return constSizeOp.value().getLimitedValue();
// RankOp
//===----------------------------------------------------------------------===//
-static LogicalResult verify(shape::RankOp op) {
- if (op.shape().getType().isa<ShapeType>() &&
- !op.rank().getType().isa<SizeType>())
- return op.emitOpError()
- << "if operand is of type `shape` then the result must be of type "
- "`size` to propagate potential errors";
- return success();
-}
-
OpFoldResult shape::RankOp::fold(ArrayRef<Attribute> operands) {
auto shape = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
if (!shape)
// MulOp
//===----------------------------------------------------------------------===//
-static LogicalResult verify(MulOp op) {
- Type resultTy = op.result().getType();
- if (isErrorPropagationPossible({op.lhs().getType(), op.rhs().getType()})) {
- if (!resultTy.isa<SizeType>())
- return op.emitOpError()
- << "if at least one of the operands can hold error values then "
- "the result must be of type `size` to propagate them";
- } else {
- if (resultTy.isa<SizeType>())
- return op.emitError() << "if none of the operands can hold error values "
- "then the result must be of type `index`";
- }
- return success();
-}
-
OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) {
auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>();
if (!lhs)
return builder.getIndexTensorAttr(type.getShape());
}
-static LogicalResult verify(ShapeOfOp op) {
- Type resultTy = op.result().getType();
- if (isErrorPropagationPossible(op.arg().getType())) {
- if (!resultTy.isa<ShapeType>())
- return op.emitOpError()
- << "if operand is of type `value_shape` then the result must be "
- "of type `shape` to propagate potential error shapes";
- } else {
- if (resultTy != getExtentTensorType(op.getContext()))
- return op.emitOpError() << "if operand is a shaped type then the result "
- "must be an extent tensor";
- }
- return success();
-}
-
//===----------------------------------------------------------------------===//
// SizeToIndexOp
//===----------------------------------------------------------------------===//
func @shape_of(%value_arg : !shape.value_shape,
%shaped_arg : tensor<?x3x4xf32>) {
- // expected-error@+1 {{if operand is of type `value_shape` then the result must be of type `shape` to propagate potential error shapes}}
+ // expected-error@+1 {{if at least one of the operands can hold error values then the result must be of type `shape` to propagate them}}
%0 = shape.shape_of %value_arg : !shape.value_shape -> tensor<?xindex>
return
}
// -----
-func @shape_of(%value_arg : !shape.value_shape,
- %shaped_arg : tensor<?x3x4xf32>) {
- // expected-error@+1 {{if operand is a shaped type then the result must be an extent tensor}}
- %1 = shape.shape_of %shaped_arg : tensor<?x3x4xf32> -> !shape.shape
- return
-}
-
-// -----
-
func @rank(%arg : !shape.shape) {
- // expected-error@+1 {{if operand is of type `shape` then the result must be of type `size` to propagate potential errors}}
+ // expected-error@+1 {{if at least one of the operands can hold error values then the result must be of type `size` to propagate them}}
%0 = shape.rank %arg : !shape.shape -> index
return
}
// -----
-func @get_extent_error_free(%arg : tensor<?xindex>) -> !shape.size {
- %c0 = constant 0 : index
- // expected-error@+1 {{if none of the operands can hold error values then the result must be of type `index`}}
- %result = shape.get_extent %arg, %c0 : tensor<?xindex>, index -> !shape.size
- return %result : !shape.size
-}
-
-// -----
-
func @get_extent_error_possible(%arg : tensor<?xindex>) -> index {
%c0 = shape.const_size 0
// expected-error@+1 {{if at least one of the operands can hold error values then the result must be of type `size` to propagate them}}
// -----
-func @mul_error_free(%arg : index) -> !shape.size {
- // expected-error@+1 {{if none of the operands can hold error values then the result must be of type `index`}}
- %result = shape.mul %arg, %arg : index, index -> !shape.size
- return %result : !shape.size
-}
-
-// -----
-
func @mul_error_possible(%lhs : !shape.size, %rhs : index) -> index {
// expected-error@+1 {{if at least one of the operands can hold error values then the result must be of type `size` to propagate them}}
%result = shape.mul %lhs, %rhs : !shape.size, index -> index