];
}
+// Convenience class grouping together type and shaped type op interfaces for
+// ops that have tensor return types.
+class InferTensorType<list<string> overridenMethods = []> {
+ list<OpTrait> traits = [
+ // Op implements infer type op interface.
+ InferTypeOpInterface,
+ // The op will have methods implementing the ShapedType type inference
+ // interface.
+ DeclareOpInterfaceMethods<InferShapedTypeOpInterface, overridenMethods>,
+ // The op produces tensors and will use the ShapedType type infer interface
+ // along with knowledge that it is producing Tensors to infer the type.
+ NativeOpTrait<"InferTensorType">
+ ];
+}
+defvar InferTensorTypeWithReify = InferTensorType<["reifyReturnTypeShapes"]>;
+
#endif // MLIR_INFERTYPEOPINTERFACE
let results = (outs AnyTensor);
}
-def InferTensorType : NativeOpTrait<"InferTensorType">;
def OpWithShapedTypeInferTypeInterfaceOp : TEST_Op<"op_with_shaped_type_infer_type_if",
- [
- // Op implements infer type op interface.
- InferTypeOpInterface,
- // The op will have methods implementing the ShapedType type infer interface.
- DeclareOpInterfaceMethods<InferShapedTypeOpInterface>,
- // The op produces tensors and will use the ShapedType type infer interface
- // along with knowledge that it is producing Tensors to infer shape.
- InferTensorType
- ]> {
+ InferTensorTypeWithReify.traits> {
let arguments = (ins AnyTensor, AnyTensor);
let results = (outs AnyTensor);
-
- let extraClassDeclaration = [{
- LogicalResult reifyReturnTypeShapes(OpBuilder &builder,
- SmallVectorImpl<Value> &shapes);
- }];
}
def IsNotScalar : Constraint<CPred<"$0.getType().getRank() != 0">>;