let results = (outs Shape_SizeType:$result);
}
-def Shape_BroadcastOp : Shape_Op<"broadcast", []> {
+def Shape_BroadcastOp : Shape_Op<"broadcast",
+ [DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
let summary = "Returns the broadcasted output shape of two inputs";
let description = [{
Computes the broadcasted output shape following:
let regions = (region SizedRegion<1>:$body);
}
-def Shape_ShapeOfOp : Shape_Op<"shape_of", []> {
+def Shape_ShapeOfOp : Shape_Op<"shape_of",
+ [DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
let summary = "Returns shape of a value or shaped type operand";
let arguments = (ins AnyTypeOf<[AnyShaped, Shape_ValueShapeType]>:$arg);
// BroadcastOp
//===----------------------------------------------------------------------===//
+LogicalResult BroadcastOp::inferReturnTypes(
+ MLIRContext *context, Optional<Location> location, ValueRange operands,
+ ArrayRef<NamedAttribute> attributes, RegionRange regions,
+ SmallVectorImpl<Type> &inferredReturnTypes) {
+ inferredReturnTypes.push_back(ShapeType::get(context));
+ return success();
+}
+
OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
if (!operands[0] || !operands[1])
return nullptr;
// ShapeOfOp
//===----------------------------------------------------------------------===//
+LogicalResult ShapeOfOp::inferReturnTypes(
+ MLIRContext *context, Optional<Location> location, ValueRange operands,
+ ArrayRef<NamedAttribute> attributes, RegionRange regions,
+ SmallVectorImpl<Type> &inferredReturnTypes) {
+ inferredReturnTypes.push_back(ShapeType::get(context));
+ return success();
+}
+
OpFoldResult ShapeOfOp::fold(ArrayRef<Attribute>) {
auto type = getOperand().getType().dyn_cast<ShapedType>();
if (!type || !type.hasStaticShape())