class Shape_Op<string mnemonic, list<OpTrait> traits = []> :
Op<ShapeDialect, mnemonic, traits>;
-def Shape_AddOp : Shape_Op<"add", [Commutative, NoSideEffect]> {
+def Shape_AddOp : Shape_Op<"add",
+ [Commutative, NoSideEffect,
+ DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
let summary = "Addition of sizes and indices";
let description = [{
Adds two sizes or indices. If either operand is an error it will be
}];
let verifier = [{ return verifySizeOrIndexOp(*this); }];
+
+ let extraClassDeclaration = [{
+ // Returns when two result types are compatible for this op; method used by
+ // InferTypeOpInterface
+ static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
+ }];
}
def Shape_BroadcastOp : Shape_Op<"broadcast", [Commutative, NoSideEffect]> {
OptionalAttr<StrAttr>:$error);
let results = (outs Shape_ShapeOrExtentTensorType:$result);
+ let builders = [OpBuilder<(ins "Value":$shape)>];
+
let assemblyFormat = [{
$shapes attr-dict `:` type($shapes) `->` type($result)
}];
let hasFolder = 1;
}
-def Shape_DivOp : Shape_Op<"div", [NoSideEffect]> {
+def Shape_DivOp : Shape_Op<"div", [NoSideEffect,
+ DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
let summary = "Division of sizes and indices";
let description = [{
Divides two sizes or indices. If either operand is an error it will be
let verifier = [{ return ::verifySizeOrIndexOp(*this); }];
let hasFolder = 1;
+
+ let extraClassDeclaration = [{
+ // Returns when two result types are compatible for this op; method used by
+ // InferTypeOpInterface
+ static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
+ }];
}
-def Shape_ShapeEqOp : Shape_Op<"shape_eq", [NoSideEffect, Commutative,
- InferTypeOpInterface]> {
+def Shape_ShapeEqOp : Shape_Op<"shape_eq",
+ [NoSideEffect, Commutative, InferTypeOpInterface]> {
let summary = "Returns whether the input shapes or extent tensors are equal";
let description = [{
Takes one or more shape or extent tensor operands and determines whether
let assemblyFormat = "$shapes attr-dict `:` type($shapes)";
}
-def Shape_RankOp : Shape_Op<"rank", [NoSideEffect]> {
+def Shape_RankOp : Shape_Op<"rank",
+ [NoSideEffect, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
let summary = "Gets the rank of a shape";
let description = [{
Returns the rank of the shape or extent tensor, i.e. the number of extents.
let hasFolder = 1;
let hasCanonicalizer = 1;
let verifier = [{ return ::verifySizeOrIndexOp(*this); }];
+
+ let extraClassDeclaration = [{
+ // Returns when two result types are compatible for this op; method used by
+ // InferTypeOpInterface
+ static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
+ }];
}
def Shape_ToExtentTensorOp : Shape_Op<"to_extent_tensor", [NoSideEffect]> {
let hasFolder = 1;
}
-def Shape_GetExtentOp : Shape_Op<"get_extent", [NoSideEffect]> {
+def Shape_GetExtentOp : Shape_Op<"get_extent",
+ [NoSideEffect, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
let summary = "Gets the specified extent from a shape or extent tensor";
let description = [{
Gets the extent indexed by `dim` from the `shape` operand. If the shape is
let extraClassDeclaration = [{
/// Get the `dim` value as integer if it is constant.
Optional<int64_t> getConstantDim();
+ /// Returns when two result types are compatible for this op; method used by
+ /// InferTypeOpInterface
+ static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
}];
let hasFolder = 1;
let hasCanonicalizer = 1;
}
-def Shape_JoinOp : Shape_Op<"join", [Commutative]> {
+def Shape_JoinOp : Shape_Op<"join",
+ [Commutative, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
let summary = "Returns the least general shape.shape of its operands";
let description = [{
An operation that computes the least general shape of input operands.
$arg0 `,` $arg1 (`,` `error` `=` $error^)? attr-dict `:`
type($arg0) `,` type($arg1) `->` type($result)
}];
+
+ let extraClassDeclaration = [{
+ // Returns when two result types are compatible for this op; method used by
+ // InferTypeOpInterface
+ static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
+ }];
}
-def Shape_MaxOp : Shape_Op<"max", [Commutative, NoSideEffect]> {
+def Shape_MaxOp : Shape_Op<"max",
+ [Commutative, NoSideEffect,
+ DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
let summary = "Elementwise maximum";
let description = [{
Computes the elementwise maximum of two sizes or shapes with equal ranks.
}];
let hasFolder = 1;
+
+ let extraClassDeclaration = [{
+ // Returns when two result types are compatible for this op; method used by
+ // InferTypeOpInterface
+ static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
+ }];
}
-def Shape_MinOp : Shape_Op<"min", [Commutative, NoSideEffect]> {
+def Shape_MinOp : Shape_Op<"min",
+ [Commutative, NoSideEffect,
+ DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
let summary = "Elementwise minimum";
let description = [{
Computes the elementwise minimum of two sizes or shapes with equal ranks.
}];
let hasFolder = 1;
+
+ let extraClassDeclaration = [{
+ // Returns when two result types are compatible for this op; method used by
+ // InferTypeOpInterface
+ static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
+ }];
}
-def Shape_MulOp : Shape_Op<"mul", [Commutative, NoSideEffect]> {
+def Shape_MulOp : Shape_Op<"mul",
+ [Commutative, NoSideEffect,
+ DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
let summary = "Multiplication of sizes and indices";
let description = [{
Multiplies two sizes or indices. If either operand is an error it will be
let verifier = [{ return ::verifySizeOrIndexOp(*this); }];
let hasFolder = 1;
+
+ let extraClassDeclaration = [{
+ // Returns when two result types are compatible for this op; method used by
+ // InferTypeOpInterface
+ static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
+ }];
}
-def Shape_NumElementsOp : Shape_Op<"num_elements", [NoSideEffect]> {
+def Shape_NumElementsOp : Shape_Op<"num_elements",
+ [NoSideEffect, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
let summary = "Returns the number of elements for a given shape";
let description = [{
Returns the number of elements for a given shape which is the product of its
let arguments = (ins Shape_ShapeOrExtentTensorType:$shape);
let results = (outs Shape_SizeOrIndexType:$result);
- let builders = [OpBuilder<(ins "Value":$shape)>];
-
let assemblyFormat = "$shape attr-dict `:` type($shape) `->` type($result)";
let hasFolder = 1;
let verifier = [{ return ::verifySizeOrIndexOp(*this); }];
+ let extraClassDeclaration = [{
+ // Returns when two result types are compatible for this op; method used by
+ // InferTypeOpInterface
+ static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
+ }];
}
def Shape_ReduceOp : Shape_Op<"reduce",
let parser = [{ return ::parse$cppClass(parser, result); }];
}
-def Shape_ShapeOfOp : Shape_Op<"shape_of", [NoSideEffect]> {
+def Shape_ShapeOfOp : Shape_Op<"shape_of",
+ [NoSideEffect, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
let summary = "Returns shape of a value or shaped type operand";
let description = [{
let assemblyFormat = "$arg attr-dict `:` type($arg) `->` type($result)";
- let builders = [OpBuilder<(ins "Value":$arg)>];
-
let verifier = [{ return ::verifyShapeOrExtentTensorOp(*this); }];
let hasCanonicalizer = 1;
let hasFolder = 1;
+
+ let extraClassDeclaration = [{
+ // Returns when two result types are compatible for this op; method used by
+ // InferTypeOpInterface
+ static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
+ }];
}
def Shape_SizeToIndexOp : Shape_Op<"size_to_index", [NoSideEffect]> {
return success();
}
+template <typename... Ty>
+static bool eachHasOnlyOneOfTypes(TypeRange typeRange) {
+ return typeRange.size() == 1 && typeRange.front().isa<Ty...>();
+}
+
+template <typename... Ty, typename... ranges>
+static bool eachHasOnlyOneOfTypes(TypeRange l, ranges... rs) {
+ return eachHasOnlyOneOfTypes<Ty...>(l) && eachHasOnlyOneOfTypes<Ty...>(rs...);
+}
+
//===----------------------------------------------------------------------===//
// InlinerInterface
//===----------------------------------------------------------------------===//
}
//===----------------------------------------------------------------------===//
+// AddOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult mlir::shape::AddOp::inferReturnTypes(
+ MLIRContext *context, Optional<Location> location, ValueRange operands,
+ DictionaryAttr attributes, RegionRange regions,
+ SmallVectorImpl<Type> &inferredReturnTypes) {
+ if (operands[0].getType().isa<SizeType>() ||
+ operands[1].getType().isa<SizeType>())
+ inferredReturnTypes.assign({SizeType::get(context)});
+ else
+ inferredReturnTypes.assign({IndexType::get(context)});
+ return success();
+}
+
+bool mlir::shape::AddOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
+ // SizeType is compatible with IndexType.
+ return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
+}
+
+//===----------------------------------------------------------------------===//
// AssumingAllOp
//===----------------------------------------------------------------------===//
return IntegerAttr::get(indexTy, quotient);
}
+LogicalResult mlir::shape::DivOp::inferReturnTypes(
+ MLIRContext *context, Optional<Location> location, ValueRange operands,
+ DictionaryAttr attributes, RegionRange regions,
+ SmallVectorImpl<Type> &inferredReturnTypes) {
+ if (operands[0].getType().isa<SizeType>() ||
+ operands[1].getType().isa<SizeType>())
+ inferredReturnTypes.assign({SizeType::get(context)});
+ else
+ inferredReturnTypes.assign({IndexType::get(context)});
+ return success();
+}
+
+bool mlir::shape::DivOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
+ // SizeType is compatible with IndexType.
+ return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
+}
+
//===----------------------------------------------------------------------===//
// ShapeEqOp
//===----------------------------------------------------------------------===//
}
}
+LogicalResult mlir::shape::GetExtentOp::inferReturnTypes(
+ MLIRContext *context, Optional<Location> location, ValueRange operands,
+ DictionaryAttr attributes, RegionRange regions,
+ SmallVectorImpl<Type> &inferredReturnTypes) {
+ inferredReturnTypes.assign({IndexType::get(context)});
+ return success();
+}
+
+bool mlir::shape::GetExtentOp::isCompatibleReturnTypes(TypeRange l,
+ TypeRange r) {
+ // SizeType is compatible with IndexType.
+ return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
+}
+
//===----------------------------------------------------------------------===//
// IsBroadcastableOp
//===----------------------------------------------------------------------===//
}
//===----------------------------------------------------------------------===//
+// JoinOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult mlir::shape::JoinOp::inferReturnTypes(
+ MLIRContext *context, Optional<Location> location, ValueRange operands,
+ DictionaryAttr attributes, RegionRange regions,
+ SmallVectorImpl<Type> &inferredReturnTypes) {
+ inferredReturnTypes.assign({operands[0].getType()});
+ return success();
+}
+
+bool mlir::shape::JoinOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
+ if (l.size() != 1 || r.size() != 1)
+ return false;
+ if (l == r)
+ return true;
+
+ Type lhs = l.front();
+ Type rhs = r.front();
+
+ if (lhs != rhs)
+ return false;
+
+ if (lhs.isa<SizeType>() || lhs.isa<ShapeType>())
+ return true;
+
+ if (succeeded(verifyCompatibleShapes({lhs, rhs})))
+ return true;
+ return false;
+}
+
+//===----------------------------------------------------------------------===//
// RankOp
//===----------------------------------------------------------------------===//
patterns.add<RankShapeOfCanonicalizationPattern>(context);
}
+LogicalResult mlir::shape::RankOp::inferReturnTypes(
+ MLIRContext *context, Optional<Location> location, ValueRange operands,
+ DictionaryAttr attributes, RegionRange regions,
+ SmallVectorImpl<Type> &inferredReturnTypes) {
+ if (operands[0].getType().isa<ShapeType>())
+ inferredReturnTypes.assign({SizeType::get(context)});
+ else
+ inferredReturnTypes.assign({IndexType::get(context)});
+ return success();
+}
+
+bool mlir::shape::RankOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
+ // SizeType is compatible with IndexType.
+ return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
+}
+
//===----------------------------------------------------------------------===//
// NumElementsOp
//===----------------------------------------------------------------------===//
return builder.getIndexAttr(product.getLimitedValue());
}
-void NumElementsOp::build(OpBuilder &builder, OperationState &result,
- Value shape) {
- if (shape.getType().isa<ShapedType>()) {
- auto type = builder.getIndexType();
- return build(builder, result, type, shape);
- }
- auto type = SizeType::get(builder.getContext());
- return build(builder, result, type, shape);
+LogicalResult mlir::shape::NumElementsOp::inferReturnTypes(
+ MLIRContext *context, Optional<Location> location, ValueRange operands,
+ DictionaryAttr attributes, RegionRange regions,
+ SmallVectorImpl<Type> &inferredReturnTypes) {
+ if (operands[0].getType().isa<ShapeType>())
+ inferredReturnTypes.assign({SizeType::get(context)});
+ else
+ inferredReturnTypes.assign({IndexType::get(context)});
+ return success();
+}
+
+bool mlir::shape::NumElementsOp::isCompatibleReturnTypes(TypeRange l,
+ TypeRange r) {
+ // SizeType is compatible with IndexType.
+ return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
}
//===----------------------------------------------------------------------===//
return nullptr;
}
+LogicalResult mlir::shape::MaxOp::inferReturnTypes(
+ MLIRContext *context, Optional<Location> location, ValueRange operands,
+ DictionaryAttr attributes, RegionRange regions,
+ SmallVectorImpl<Type> &inferredReturnTypes) {
+ if (operands[0].getType() == operands[1].getType())
+ inferredReturnTypes.assign({operands[0].getType()});
+ else
+ inferredReturnTypes.assign({SizeType::get(context)});
+ return success();
+}
+
+bool mlir::shape::MaxOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
+ if (l.size() != 1 || r.size() != 1)
+ return false;
+ if (l.front().isa<ShapeType>() && r.front().isa<ShapeType>())
+ return true;
+ if (l.front().isa<SizeType>() && r.front().isa<SizeType>())
+ return true;
+ return false;
+}
+
//===----------------------------------------------------------------------===//
// MinOp
//===----------------------------------------------------------------------===//
return nullptr;
}
+LogicalResult mlir::shape::MinOp::inferReturnTypes(
+ MLIRContext *context, Optional<Location> location, ValueRange operands,
+ DictionaryAttr attributes, RegionRange regions,
+ SmallVectorImpl<Type> &inferredReturnTypes) {
+ if (operands[0].getType() == operands[1].getType())
+ inferredReturnTypes.assign({operands[0].getType()});
+ else
+ inferredReturnTypes.assign({SizeType::get(context)});
+ return success();
+}
+
+bool mlir::shape::MinOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
+ if (l.size() != 1 || r.size() != 1)
+ return false;
+ if (l.front().isa<ShapeType>() && r.front().isa<ShapeType>())
+ return true;
+ if (l.front().isa<SizeType>() && r.front().isa<SizeType>())
+ return true;
+ return false;
+}
+
//===----------------------------------------------------------------------===//
// MulOp
//===----------------------------------------------------------------------===//
return IntegerAttr::get(indexTy, folded);
}
+LogicalResult mlir::shape::MulOp::inferReturnTypes(
+ MLIRContext *context, Optional<Location> location, ValueRange operands,
+ DictionaryAttr attributes, RegionRange regions,
+ SmallVectorImpl<Type> &inferredReturnTypes) {
+ if (operands[0].getType().isa<SizeType>() ||
+ operands[1].getType().isa<SizeType>())
+ inferredReturnTypes.assign({SizeType::get(context)});
+ else
+ inferredReturnTypes.assign({IndexType::get(context)});
+ return success();
+}
+
+bool mlir::shape::MulOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
+ // SizeType is compatible with IndexType.
+ return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
+}
//===----------------------------------------------------------------------===//
// ShapeOfOp
//===----------------------------------------------------------------------===//
return builder.getIndexTensorAttr(type.getShape());
}
-void ShapeOfOp::build(OpBuilder &builder, OperationState &result, Value arg) {
- if (auto shapedTy = arg.getType().dyn_cast<ShapedType>()) {
- int64_t rank =
- shapedTy.hasRank() ? shapedTy.getRank() : ShapedType::kDynamicSize;
- Type indexTy = builder.getIndexType();
- Type extentTensorTy = RankedTensorType::get({rank}, indexTy);
- return ShapeOfOp::build(builder, result, extentTensorTy, arg);
- }
- Type shapeTy = builder.getType<ShapeType>();
- return ShapeOfOp::build(builder, result, shapeTy, arg);
-}
-
namespace {
struct ShapeOfWithTensor : public OpRewritePattern<shape::ShapeOfOp> {
using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern;
patterns.add<ShapeOfCastExtentTensor, ShapeOfWithTensor>(context);
}
+LogicalResult mlir::shape::ShapeOfOp::inferReturnTypes(
+ MLIRContext *context, Optional<Location> location, ValueRange operands,
+ DictionaryAttr attributes, RegionRange regions,
+ SmallVectorImpl<Type> &inferredReturnTypes) {
+ if (operands[0].getType().isa<ValueShapeType>())
+ inferredReturnTypes.assign({ShapeType::get(context)});
+ else {
+ auto shapedTy = operands[0].getType().cast<ShapedType>();
+ int64_t rank =
+ shapedTy.hasRank() ? shapedTy.getRank() : ShapedType::kDynamicSize;
+ Type indexTy = IndexType::get(context);
+ Type extentTensorTy = RankedTensorType::get({rank}, indexTy);
+ inferredReturnTypes.assign({extentTensorTy});
+ }
+ return success();
+}
+
+bool mlir::shape::ShapeOfOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
+ if (l.size() != 1 || r.size() != 1)
+ return false;
+ if (l == r)
+ return true;
+
+ Type lhs = l.front();
+ Type rhs = r.front();
+
+ if (!lhs.isa<ShapeType, ShapedType>() || !rhs.isa<ShapeType, ShapedType>())
+ return false;
+
+ if (lhs.isa<ShapeType>() || rhs.isa<ShapeType>())
+ // Shape type is compatible with all other valid return types.
+ return true;
+
+ if (succeeded(verifyCompatibleShapes({lhs, rhs})))
+ return true;
+ return false;
+}
+
//===----------------------------------------------------------------------===//
// SizeToIndexOp
//===----------------------------------------------------------------------===//