#include "ShapeCanonicalization.inc"
}
-static RankedTensorType getExtentTensorType(MLIRContext *ctx) {
+RankedTensorType shape::getExtentTensorType(MLIRContext *ctx) {
return RankedTensorType::get({ShapedType::kDynamicSize}, IndexType::get(ctx));
}
}
void ShapeOfOp::build(OpBuilder &builder, OperationState &result, Value arg) {
- if (arg.getType().isa<ShapedType>()) {
- auto type = RankedTensorType::get({ShapedType::kDynamicSize},
- builder.getIndexType());
- return ShapeOfOp::build(builder, result, type, arg);
- }
- auto type = ShapeType::get(builder.getContext());
+ Type type = arg.getType().isa<ShapedType>()
+ ? (Type)getExtentTensorType(builder.getContext())
+ : (Type)builder.getType<ShapeType>();
return ShapeOfOp::build(builder, result, type, arg);
}