[mlir] Add convenience grouping for tensor type inference
authorJacques Pienaar <jpienaar@google.com>
Mon, 1 Mar 2021 13:21:07 +0000 (05:21 -0800)
committerJacques Pienaar <jpienaar@google.com>
Mon, 1 Mar 2021 13:21:08 +0000 (05:21 -0800)
For ops that produces tensor types and implement the shaped type component interface, the type inference interface can be used. Create a grouping of these together to make it easier to specify (it cannot be added into a list of traits, but must rather be appended/concated to one as it isn't a trait but a list of traits).

Differential Revision: https://reviews.llvm.org/D97636

mlir/include/mlir/Interfaces/InferTypeOpInterface.td
mlir/test/lib/Dialect/Test/TestOps.td

index ca044f0..f15d9b3 100644 (file)
@@ -116,4 +116,20 @@ def InferShapedTypeOpInterface : OpInterface<"InferShapedTypeOpInterface"> {
   ];
 }
 
+// Convenience class grouping together type and shaped type op interfaces for
+// ops that have tensor return types.
+class InferTensorType<list<string> overridenMethods = []> {
+  list<OpTrait> traits = [
+    // Op implements infer type op interface.
+    InferTypeOpInterface,
+    // The op will have methods implementing the ShapedType type inference
+    // interface.
+    DeclareOpInterfaceMethods<InferShapedTypeOpInterface, overridenMethods>,
+    // The op produces tensors and will use the ShapedType type infer interface
+    // along with knowledge that it is producing Tensors to infer the type.
+    NativeOpTrait<"InferTensorType">
+  ];
+}
+defvar InferTensorTypeWithReify = InferTensorType<["reifyReturnTypeShapes"]>;
+
 #endif // MLIR_INFERTYPEOPINTERFACE
index 44df4ab..4893ac3 100644 (file)
@@ -503,24 +503,10 @@ def OpWithInferTypeInterfaceOp : TEST_Op<"op_with_infer_type_if", [
   let results = (outs AnyTensor);
 }
 
-def InferTensorType : NativeOpTrait<"InferTensorType">;
 def OpWithShapedTypeInferTypeInterfaceOp : TEST_Op<"op_with_shaped_type_infer_type_if",
-  [
-     // Op implements infer type op interface.
-     InferTypeOpInterface,
-     // The op will have methods implementing the ShapedType type infer interface.
-     DeclareOpInterfaceMethods<InferShapedTypeOpInterface>,
-     // The op produces tensors and will use the ShapedType type infer interface
-     // along with knowledge that it is producing Tensors to infer shape.
-     InferTensorType
-   ]> {
+      InferTensorTypeWithReify.traits> {
   let arguments = (ins AnyTensor, AnyTensor);
   let results = (outs AnyTensor);
-
-  let extraClassDeclaration = [{
-    LogicalResult reifyReturnTypeShapes(OpBuilder &builder,
-                                        SmallVectorImpl<Value> &shapes);
-  }];
 }
 
 def IsNotScalar : Constraint<CPred<"$0.getType().getRank() != 0">>;