[mlir] Add shaped container component type interface
authorJacques Pienaar <jpienaar@google.com>
Thu, 9 Jan 2020 02:48:38 +0000 (18:48 -0800)
committerJacques Pienaar <jpienaar@google.com>
Wed, 15 Jan 2020 21:28:39 +0000 (13:28 -0800)
Summary:
* Add shaped container type interface which allows infering the shape, element
  type and attribute of shaped container type separately. Show usage by way of
  tensor type inference trait which combines the shape & element type in
  infering a tensor type;
  - All components need not be specified;
  - Attribute is added to allow for layout attribute that was previously
    discussed;
* Expand the test driver to make it easier to test new creation instances
  (adding new operands or ops with attributes or regions would trigger build
  functions/type inference methods);
  - The verification part will be moved out of the test and to verify method
    instead of ops implementing the type inference interface in a follow up;
* Add MLIRContext as arg to possible to create type for ops without arguments,
  region or location;
* Also move out the section in OpDefinitions doc to separate ShapeInference doc
  where the shape function requirements can be captured;
  - Part of this would move to the shape dialect and/or shape dialect ops be
    included as subsection of this doc;
* Update ODS's variable usage to match camelBack format for builder,
  state and arg variables;
  - I could have split this out, but I had to make some changes around
    these and the inconsistency bugged me :)

Differential Revision: https://reviews.llvm.org/D72432

15 files changed:
mlir/docs/OpDefinitions.md
mlir/docs/ShapeInference.md [new file with mode: 0644]
mlir/include/mlir/Analysis/InferTypeOpInterface.h
mlir/include/mlir/Analysis/InferTypeOpInterface.td
mlir/include/mlir/IR/OpBase.td
mlir/lib/Analysis/InferTypeOpInterface.cpp
mlir/test/lib/TestDialect/TestDialect.cpp
mlir/test/lib/TestDialect/TestOps.td
mlir/test/lib/TestDialect/TestPatterns.cpp
mlir/test/mlir-tblgen/op-attribute.td
mlir/test/mlir-tblgen/op-decl.td
mlir/test/mlir-tblgen/op-operand.td
mlir/test/mlir-tblgen/op-result.td
mlir/test/mlir-tblgen/return-types.mlir
mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp

index f92fa1d..a035d9b 100644 (file)
@@ -429,14 +429,14 @@ The following builders are generated:
 
 ```c++
 // All result-types/operands/attributes have one aggregate parameter.
-static void build(Builder *tblgen_builder, OperationState &tblgen_state,
+static void build(Builder *odsBuilder, OperationState &odsState,
                   ArrayRef<Type> resultTypes,
                   ValueRange operands,
                   ArrayRef<NamedAttribute> attributes);
 
 // Each result-type/operand/attribute has a separate parameter. The parameters
 // for attributes are of mlir::Attribute types.
-static void build(Builder *tblgen_builder, OperationState &tblgen_state,
+static void build(Builder *odsBuilder, OperationState &odsState,
                   Type i32_result, Type f32_result, ...,
                   Value i32_operand, Value f32_operand, ...,
                   IntegerAttr i32_attr, FloatAttr f32_attr, ...);
@@ -445,20 +445,20 @@ static void build(Builder *tblgen_builder, OperationState &tblgen_state,
 // for attributes are raw values unwrapped with mlir::Attribute instances.
 // (Note that this builder will not always be generated. See the following
 // explanation for more details.)
-static void build(Builder *tblgen_builder, OperationState &tblgen_state,
+static void build(Builder *odsBuilder, OperationState &odsState,
                   Type i32_result, Type f32_result, ...,
                   Value i32_operand, Value f32_operand, ...,
                   APInt i32_attr, StringRef f32_attr, ...);
 
 // Each operand/attribute has a separate parameter but result type is aggregate.
-static void build(Builder *tblgen_builder, OperationState &tblgen_state,
+static void build(Builder *odsBuilder, OperationState &odsState,
                   ArrayRef<Type> resultTypes,
                   Value i32_operand, Value f32_operand, ...,
                   IntegerAttr i32_attr, FloatAttr f32_attr, ...);
 
 // All operands/attributes have aggregate parameters.
 // Generated if InferTypeOpInterface interface is specified.
-static void build(Builder *tblgen_builder, OperationState &tblgen_state,
+static void build(Builder *odsBuilder, OperationState &odsState,
                   ValueRange operands,
                   ArrayRef<NamedAttribute> attributes);
 
@@ -1099,7 +1099,7 @@ requirements that were desirable:
 *   The op's traits (e.g., commutative) are modelled along with the op in the
     registry.
 *   The op's operand/return type constraints are modelled along with the op in
-    the registry (see [Shape inference](#shape-inference) discussion below),
+    the registry (see [Shape inference](ShapeInference.md) discussion below),
     this allows (e.g.) optimized concise syntax in textual dumps.
 *   Behavior of the op is documented along with the op with a summary and a
     description. The description is written in markdown and extracted for
@@ -1156,49 +1156,6 @@ tfl.add $lhs, $rhs {fused_activation_function: $fused_activation_function}: ${ty
 Printing is effectively the inverse of the parsing function generated with the
 mnemonic string serving as a template.
 
-### Shape inference
-
-Type constraints are along (at least) three axis: 1) elemental type, 2) rank
-(including static or dynamic), 3) dimensions. While some ops have no compile
-time fixed shape (e.g., output shape is dictated by data) we could still have
-some knowledge of constraints/bounds in the system for that op (e.g., the output
-of a `tf.where` is at most the size of the input data). And so there are
-additional valuable constraints that could be captured even without full
-knowledge.
-
-Initially the shape inference will be declaratively specified using:
-
-*   Constraint on the operands of an operation directly. For example
-    constraining the input type to be tensor/vector elements or that the
-    elemental type be of a specific type (e.g., output of sign is of elemental
-    type `i1`) or class (e.g., float like).
-*   Constraints across operands and results of an operation. For example,
-    enabling specifying equality constraints on type/constituents of a type
-    (shape and elemental type) between operands and results (e.g., the output
-    type of an add is the same as those of the input operands).
-
-In general there is an input/output transfer function which maps the inputs to
-the outputs (e.g., given input X and Y [or slices thereof] with these sizes, the
-output is Z [or this slice thereof]). Such a function could be used to determine
-the output type (shape) for given input type (shape).
-
-But shape functions are determined by attributes and could be arbitrarily
-complicated with a wide-range of specification possibilities. Equality
-relationships are common (e.g., the elemental type of the output matches the
-primitive type of the inputs, both inputs have exactly the same type [primitive
-type and shape]) and so these should be easy to specify. Algebraic relationships
-would also be common (e.g., a concat of `[n,m]` and `[n,m]` matrix along axis 0
-is `[n+n, m]` matrix), while some ops only have defined shapes under certain
-cases (e.g., matrix multiplication of `[a,b]` and `[c,d]` is only defined if
-`b == c`). As ops are also verified, the shape inference need only specify rules
-for the allowed cases (e.g., shape inference for matmul can ignore the case
-where `b != c`), which would simplify type constraint specification.
-
-Instead of specifying an additional mechanism to specify a shape transfer
-function, the reference implementation of the operation will be used to derive
-the shape function. The reference implementation is general and can support the
-arbitrary computations needed to specify output shapes.
-
 [TableGen]: https://llvm.org/docs/TableGen/index.html
 [TableGenIntro]: https://llvm.org/docs/TableGen/LangIntro.html
 [TableGenRef]: https://llvm.org/docs/TableGen/LangRef.html
diff --git a/mlir/docs/ShapeInference.md b/mlir/docs/ShapeInference.md
new file mode 100644 (file)
index 0000000..8c268f1
--- /dev/null
@@ -0,0 +1,72 @@
+# Shape inference
+
+Shape inference as discussed here is considered a specific instance of type
+inference for [ShapedType][ShapedType]. Type constraints are along (at least)
+three axis: 1) elemental type, 2) rank (including static or dynamic), 3)
+dimensions. While some operations have no compile time fixed shape (e.g., output
+shape is dictated by data) we could still have some knowledge of
+constraints/bounds in the system for that operation (e.g., the output of a
+`tf.where` is at most the size of the input data). That is, there are additional
+valuable constraints that could be captured even without full knowledge of the
+shape.
+
+Type inference is currently modelled executionally for op creation using the
+[`InferTypeOpInterface`][InferTypeOpInterface], while
+`InferShapedTypeOpInterface` is used to implement the shape and element type
+inference. The return type can often be deduced from the deduced return shape
+and elemental type (queryable from `InferShapedTypeOpInterface`) and so type
+inference for tensor types can be implemented with `InferShapedTypeOpInterface`.
+
+## Shape functions
+
+The C++ interfaces are the base mechanism whereby shape inference is queried and
+executed, but not the intended way to specify shape constraints in general.
+
+Initially the shape inference will be declaratively specified using:
+
+*   Constraints on the operands of an operation directly. For example
+    constraining the input type to be tensor/vector elements or that the
+    elemental type be of a specific type (e.g., output of computing the size
+    of a value is of elemental type `i1`) or class (e.g., float like).
+*   Constraints across operands and results of an operation.
+
+    - For example, specifying equality constraints on type/constituents of a
+      type (shape and elemental type) between operands and results (e.g., the
+      output type of an add is the same as those of the input operands).
+
+NOTE: The C++ shape functions are an intermediate step until the shape dialect
+is more full-fledged, at which point the C++ functions should become the
+exceptional case.
+
+## Testing
+
+Shape inference is currently tested alongside type inference by
+`TestReturnTypeDriver` in the test dialect. The driver performs two checks:
+
+1.  Verification that the return types specified matches the infered types. This
+    explicit check will be removed and made part of Op verificaton instead.
+2.  Test the creation of Ops without specifying the return type explicitly in
+    function `testCreateFunctions` by creating new binary Ops (Op classes
+    specified in `TestReturnTypeDriver`) using 1) all operands to
+    `testCreateFunctions` as both operands, and 2) using combinations of input
+    operands of the function.
+
+## WIP/Future considerations
+
+Shape functions are determined by attributes and could be arbitrarily
+complicated with a wide-range of specification possibilities. Equality
+relationships are common (e.g., the elemental type of the output matches the
+primitive type of the inputs, both inputs have exactly the same type [primitive
+type and shape]) and so these should be easy to specify. Algebraic relationships
+would also be common (e.g., a concat of `[n,m]` and `[n,m]` matrix along axis 0
+is `[n+n, m]` matrix), while some ops only have defined shapes under certain
+cases (e.g., matrix multiplication of `[a,b]` and `[c,d]` is only defined if `b
+== c`).
+
+Instead of specifying an additional mechanism to specify a shape transfer
+function, the reference implementation of the operation will be used to derive
+the shape function. The reference implementation is general and can support the
+arbitrary computations needed to specify output shapes.
+
+[InferTypeOpInterface]: https://github.com/llvm/llvm-project/tree/master/mlir/include/mlir/Analysis/InferTypeOpInterface.td
+[ShapedType]: https://github.com/llvm/llvm-project/tree/master/mlir/include/mlir/IR/StandardTypes.h
index baf1616..2607871 100644 (file)
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/Location.h"
 #include "mlir/IR/OpDefinition.h"
-#include "mlir/IR/Operation.h"
-#include "mlir/IR/Types.h"
 #include "mlir/Support/LLVM.h"
 #include "llvm/ADT/SmallVector.h"
 
 namespace mlir {
 
+/// ShapedTypeComponents that represents the components of a ShapedType.
+/// The components consist of
+///  - A ranked or unranked shape with the dimension specification match those
+///    of ShapeType's getShape() (e.g., dynamic dimension represented using
+///    ShapedType::kDynamicSize)
+///  - A element type, may be unset (nullptr)
+///  - A attribute, may be unset (nullptr)
+/// Used by ShapedType type inferences.
+class ShapedTypeComponents {
+  /// Internal storage type for shape.
+  using ShapeStorageT = SmallVector<int64_t, 3>;
+
+public:
+  /// Default construction is an unranked shape.
+  ShapedTypeComponents() : ranked(false), elementType(nullptr), attr(nullptr){};
+
+  template <typename Arg, typename = typename std::enable_if_t<
+                              std::is_constructible<ShapeStorageT, Arg>::value>>
+  ShapedTypeComponents(Arg &&arg, Type elementType = nullptr,
+                       Attribute attr = nullptr)
+      : dims(std::forward<Arg>(arg)), ranked(true), elementType(elementType),
+        attr(attr) {}
+  ShapedTypeComponents(ArrayRef<int64_t> vec, Type elementType = nullptr,
+                       Attribute attr = nullptr)
+      : dims(vec.begin(), vec.end()), ranked(true), elementType(elementType),
+        attr(attr) {}
+
+  /// Return the dimensions of the shape.
+  /// Requires: shape is ranked.
+  ArrayRef<int64_t> getDims() const {
+    assert(ranked && "requires ranked shape");
+    return dims;
+  }
+
+  /// Return whether the shape has a rank.
+  bool hasRank() const { return ranked; };
+
+  /// Return the element type component.
+  Type getElementType() const { return elementType; };
+
+  /// Return the raw attribute component.
+  Attribute getAttribute() const { return attr; };
+
+private:
+  ShapeStorageT dims;
+  bool ranked;
+  Type elementType;
+  Attribute attr;
+};
+
 #include "mlir/Analysis/InferTypeOpInterface.h.inc"
 
+namespace detail {
+// Helper function to infer return tensor returns types given element and shape
+// inference function.
+//
+// TODO: Consider generating typedefs for trait member functions if this usage
+// becomes more common.
+LogicalResult inferReturnTensorTypes(
+    function_ref<LogicalResult(
+        MLIRContext *, Optional<Location> location, ValueRange operands,
+        ArrayRef<NamedAttribute> attributes, RegionRange regions,
+        SmallVectorImpl<ShapedTypeComponents> &retComponents)>
+        componentTypeFn,
+    MLIRContext *context, Optional<Location> location, ValueRange operands,
+    ArrayRef<NamedAttribute> attributes, RegionRange regions,
+    SmallVectorImpl<Type> &inferedReturnTypes);
+} // namespace detail
+
 namespace OpTrait {
+
+/// Tensor type inference trait that constructs a tensor from the infered
+/// shape and elemental types.
+/// Requires: Op implements functions of InferShapedTypeOpInterface.
 template <typename ConcreteType>
-class TypeOpInterfaceDefault
-    : public TraitBase<ConcreteType, TypeOpInterfaceDefault> {
+class InferTensorType : public TraitBase<ConcreteType, InferTensorType> {
 public:
-  /// Returns whether two arrays are equal as strongest check for compatibility
-  /// by default.
-  static bool isCompatibleReturnTypes(ArrayRef<Type> lhs, ArrayRef<Type> rhs) {
-    return lhs == rhs;
-  };
+  static LogicalResult
+  inferReturnTypes(MLIRContext *context, Optional<Location> location,
+                   ValueRange operands, ArrayRef<NamedAttribute> attributes,
+                   RegionRange regions,
+                   SmallVectorImpl<Type> &inferedReturnTypes) {
+    return ::mlir::detail::inferReturnTensorTypes(
+        ConcreteType::inferReturnTypeComponents, context, location, operands,
+        attributes, regions, inferedReturnTypes);
+  }
 };
-} // namespace OpTrait
 
+} // namespace OpTrait
 } // namespace mlir
 
 #endif // MLIR_ANALYSIS_INFERTYPEOPINTERFACE_H_
index bbcea6b..bc06b29 100644 (file)
@@ -22,9 +22,8 @@ include "mlir/IR/OpBase.td"
 // mismatch).
 def InferTypeOpInterface : OpInterface<"InferTypeOpInterface"> {
   let description = [{
-    Interface to access a registered method to infer the return types for an
-    operation that could be used during op construction, verification or
-    type inference.
+    Interface to infer the return types for an operation that could be used
+    during op construction, verification or type inference.
   }];
 
   let methods = [
@@ -38,7 +37,8 @@ def InferTypeOpInterface : OpInterface<"InferTypeOpInterface"> {
       }],
       /*retTy=*/"LogicalResult",
       /*methodName=*/"inferReturnTypes",
-      /*args=*/(ins "Optional<Location>":$location,
+      /*args=*/(ins "MLIRContext*":$context,
+                    "Optional<Location>":$location,
                     "ValueRange":$operands,
                     "ArrayRef<NamedAttribute>":$attributes,
                     "RegionRange":$regions,
@@ -62,4 +62,38 @@ def InferTypeOpInterface : OpInterface<"InferTypeOpInterface"> {
   ];
 }
 
+def InferShapedTypeOpInterface : OpInterface<"InferShapedTypeOpInterface"> {
+  let description = [{
+    Interface to infer the components of a ShapedType returned by an operation
+    that could be used during op construction, verification or shape inference.
+
+    The components consists of element type, shape and raw attribute.
+  }];
+
+  let methods = [
+    StaticInterfaceMethod<
+      /*desc=*/[{Infer the components of return type of shape containter.
+
+      The method takes an optional location which, if set, will be used to
+      report errors on. The operands and attributes correspond to those with
+      which an Operation would be created (e.g., as used in Operation::create)
+      and the regions of the op.
+
+      Unknown (e.g., unranked) shape and nullptrs for element type and attribute
+      may be returned by this function while returning success. E.g., partial
+      population of components is not error condition.
+      }],
+      /*retTy=*/"LogicalResult",
+      /*methodName=*/"inferReturnTypeComponents",
+      /*args=*/(ins "MLIRContext*":$context,
+                    "Optional<Location>":$location,
+                    "ValueRange":$operands,
+                    "ArrayRef<NamedAttribute>":$attributes,
+                    "RegionRange":$regions,
+                    "SmallVectorImpl<ShapedTypeComponents>&":
+                      $inferedReturnShapes)
+    >,
+  ];
+}
+
 #endif // MLIR_INFERTYPEOPINTERFACE
index eeb55c9..4420ffe 100644 (file)
@@ -1539,7 +1539,7 @@ class Op<Dialect dialect, string mnemonic, list<OpTrait> props = []> {
   // following signatures:
   //
   // ```c++
-  // static void build(Builder *, OperationState &tblgen_state,
+  // static void build(Builder *, OperationState &odsState,
   //                   Type <result0-name>, Type <result1-name>, ...,
   //                   Value <arg0-name>, Value <arg1-name>, ...,
   //                   Attribute <attr0-name>, Attribute <attr1-name>, ...);
@@ -1547,7 +1547,7 @@ class Op<Dialect dialect, string mnemonic, list<OpTrait> props = []> {
   // * where the attributes follow the same declaration order as in the op.
   //
   // ```c++
-  // static void build(Builder *, OperationState &tblgen_state,
+  // static void build(Builder *, OperationState &odsState,
   //                   ArrayRef<Type> resultTypes,
   //                   ArrayRef<Value> operands,
   //                   ArrayRef<NamedAttribute> attributes);
index 2e52de2..b1637b8 100644 (file)
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Analysis/InferTypeOpInterface.h"
-#include "mlir/IR/Attributes.h"
-#include "mlir/IR/Operation.h"
-#include "mlir/IR/Types.h"
-#include "llvm/ADT/SmallVector.h"
+
+#include "mlir/IR/StandardTypes.h"
+
+using namespace mlir;
 
 namespace mlir {
 #include "mlir/Analysis/InferTypeOpInterface.cpp.inc"
 } // namespace mlir
+
+LogicalResult mlir::detail::inferReturnTensorTypes(
+    function_ref<LogicalResult(
+        MLIRContext *, Optional<Location> location, ValueRange operands,
+        ArrayRef<NamedAttribute> attributes, RegionRange regions,
+        SmallVectorImpl<ShapedTypeComponents> &retComponents)>
+        componentTypeFn,
+    MLIRContext *context, Optional<Location> location, ValueRange operands,
+    ArrayRef<NamedAttribute> attributes, RegionRange regions,
+    SmallVectorImpl<Type> &inferedReturnTypes) {
+  SmallVector<ShapedTypeComponents, 2> retComponents;
+  if (failed(componentTypeFn(context, location, operands, attributes, regions,
+                             retComponents)))
+    return failure();
+  for (auto shapeAndType : retComponents) {
+    assert(shapeAndType.getAttribute() == nullptr && "attribute not supported");
+    if (shapeAndType.hasRank())
+      inferedReturnTypes.push_back(RankedTensorType::get(
+          shapeAndType.getDims(), shapeAndType.getElementType()));
+    else
+      inferedReturnTypes.push_back(
+          UnrankedTensorType::get(shapeAndType.getElementType()));
+  }
+  return success();
+}
index fef1f65..d64a9f4 100644 (file)
@@ -295,7 +295,7 @@ LogicalResult TestOpWithVariadicResultsAndFolder::fold(
 }
 
 LogicalResult mlir::OpWithInferTypeInterfaceOp::inferReturnTypes(
-    llvm::Optional<Location> location, ValueRange operands,
+    MLIRContext *, Optional<Location> location, ValueRange operands,
     ArrayRef<NamedAttribute> attributes, RegionRange regions,
     SmallVectorImpl<Type> &inferedReturnTypes) {
   if (operands[0].getType() != operands[1].getType()) {
@@ -307,6 +307,30 @@ LogicalResult mlir::OpWithInferTypeInterfaceOp::inferReturnTypes(
   return success();
 }
 
+LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents(
+    MLIRContext *context, Optional<Location> location, ValueRange operands,
+    ArrayRef<NamedAttribute> attributes, RegionRange regions,
+    SmallVectorImpl<ShapedTypeComponents> &inferedComponents) {
+  // Create return type consisting of the first element of each shape of the
+  // input operands or unknown for unranked operand.
+  std::vector<int64_t> shape;
+  shape.reserve(operands.size());
+  for (auto operandType : operands.getTypes()) {
+    if (auto sval = operandType.dyn_cast<ShapedType>()) {
+      if (sval.hasRank())
+        shape.push_back(sval.getShape().front());
+      else
+        shape.push_back(ShapedType::kDynamicSize);
+    } else {
+      return emitOptionalError(location, "only shaped type operands allowed");
+    }
+  }
+  inferedComponents.reserve(1);
+  auto type = IntegerType::get(17, context);
+  inferedComponents.emplace_back(shape, type);
+  return success();
+}
+
 // Static initialization for Test dialect registration.
 static mlir::DialectRegistration<mlir::TestDialect> testDialect;
 
index 0dd32d7..546d21d 100644 (file)
@@ -402,6 +402,21 @@ def OpWithInferTypeInterfaceOp : TEST_Op<"op_with_infer_type_if", [
   let results = (outs AnyTensor);
 }
 
+def InferTensorType : NativeOpTrait<"InferTensorType">;
+def OpWithShapedTypeInferTypeInterfaceOp : TEST_Op<"op_with_shaped_type_infer_type_if",
+  [
+     // Op implements infer type op interface.
+     InferTypeOpInterface,
+     // The op will have methods implementing the ShapedType type infer interface.
+     DeclareOpInterfaceMethods<InferShapedTypeOpInterface>,
+     // The op produces tensors and will use the ShapedType type infer interface
+     // along with knowledge that it is producing Tensors to infer shape.
+     InferTensorType
+   ]> {
+  let arguments = (ins AnyTensor, AnyTensor);
+  let results = (outs AnyTensor);
+}
+
 def IsNotScalar : Constraint<CPred<"$0.getType().getRank() != 0">>;
 
 def UpdateAttr : Pat<(I32ElementsAttrOp $attr),
index bbeb2c6..ecd3611 100644 (file)
@@ -58,50 +58,71 @@ static mlir::PassRegistration<TestPatternDriver>
 //===----------------------------------------------------------------------===//
 
 namespace {
-struct ReturnTypeOpMatch : public RewritePattern {
-  ReturnTypeOpMatch(MLIRContext *ctx)
-      : RewritePattern(OpWithInferTypeInterfaceOp::getOperationName(), 1, ctx) {
-  }
-
-  PatternMatchResult matchAndRewrite(Operation *op,
-                                     PatternRewriter &rewriter) const final {
-    if (auto retTypeFn = dyn_cast<InferTypeOpInterface>(op)) {
-      SmallVector<Value, 4> values(op->getOperands());
+// Generate ops for each instance where the type can be succesfully infered.
+template <typename OpTy>
+static void invokeCreateWithInferedReturnType(Operation *op) {
+  auto *context = op->getContext();
+  auto fop = op->getParentOfType<FuncOp>();
+  auto location = UnknownLoc::get(context);
+  OpBuilder b(op);
+  b.setInsertionPointAfter(op);
+
+  // Use permutations of 2 args as operands.
+  assert(fop.getNumArguments() >= 2);
+  for (int i = 0, e = fop.getNumArguments(); i < e; ++i) {
+    for (int j = 0; j < e; ++j) {
+      std::array<Value, 2> values = {fop.getArgument(i), fop.getArgument(j)};
       SmallVector<Type, 2> inferedReturnTypes;
-      if (failed(retTypeFn.inferReturnTypes(op->getLoc(), values,
-                                            op->getAttrs(), op->getRegions(),
-                                            inferedReturnTypes)))
-        return matchFailure();
-      SmallVector<Type, 1> resultTypes(op->getResultTypes());
-      if (!retTypeFn.isCompatibleReturnTypes(inferedReturnTypes, resultTypes))
-        return op->emitOpError(
-                   "inferred type incompatible with return type of operation"),
-               matchFailure();
-
-      // TODO(jpienaar): Split this out to make the test more focused.
-      // Create new op with unknown location to verify building with
-      // InferTypeOpInterface is triggered.
-      auto fop = op->getParentOfType<FuncOp>();
-      if (values[0] == fop.getArgument(0)) {
-        // Use the 2nd function argument if the first function argument is used
-        // when constructing the new op so that a new return type is inferred.
-        values[0] = fop.getArgument(1);
-        values[1] = fop.getArgument(1);
+      if (succeeded(OpTy::inferReturnTypes(context, llvm::None, values,
+                                           op->getAttrs(), op->getRegions(),
+                                           inferedReturnTypes))) {
+        OperationState state(location, OpTy::getOperationName());
         // TODO(jpienaar): Expand to regions.
-        rewriter.create<OpWithInferTypeInterfaceOp>(
-            UnknownLoc::get(op->getContext()), values, op->getAttrs());
+        OpTy::build(&b, state, values, op->getAttrs());
+        (void)b.createOperation(state);
       }
     }
-    return matchFailure();
   }
-};
+}
 
 struct TestReturnTypeDriver : public FunctionPass<TestReturnTypeDriver> {
   void runOnFunction() override {
-    mlir::OwningRewritePatternList patterns;
-    populateWithGenerated(&getContext(), &patterns);
-    patterns.insert<ReturnTypeOpMatch>(&getContext());
-    applyPatternsGreedily(getFunction(), patterns);
+    if (getFunction().getName() == "testCreateFunctions") {
+      std::vector<Operation *> ops;
+      // Collect ops to avoid triggering on inserted ops.
+      for (auto &op : getFunction().getBody().front())
+        ops.push_back(&op);
+      // Generate test patterns for each, but skip terminator.
+      for (auto *op : llvm::makeArrayRef(ops).drop_back()) {
+        // Test create method of each of the Op classes below. The resultant
+        // output would be in reverse order underneath `op` from which
+        // the attributes and regions are used.
+        invokeCreateWithInferedReturnType<OpWithInferTypeInterfaceOp>(op);
+        invokeCreateWithInferedReturnType<OpWithShapedTypeInferTypeInterfaceOp>(
+            op);
+      };
+      return;
+    }
+
+    // Verification check.
+    // TODO: Move to ops that implement type infer interface.
+    getFunction().walk([this](Operation *op) -> void {
+      auto retTypeFn = dyn_cast<InferTypeOpInterface>(op);
+      if (!retTypeFn)
+        return;
+      auto *context = &getContext();
+      SmallVector<Type, 2> inferedReturnTypes;
+      if (failed(retTypeFn.inferReturnTypes(
+              context, op->getLoc(), op->getOperands(), op->getAttrs(),
+              op->getRegions(), inferedReturnTypes)))
+        return;
+      SmallVector<Type, 1> resultTypes(op->getResultTypes());
+      if (!retTypeFn.isCompatibleReturnTypes(inferedReturnTypes, resultTypes)) {
+        op->emitOpError(
+            "inferred type incompatible with return type of operation");
+        return;
+      }
+    });
   }
 };
 } // end anonymous namespace
index 5e6d56c..11d5794 100644 (file)
@@ -56,18 +56,18 @@ def AOp : NS_Op<"a_op", []> {
 // ---
 
 // DEF:      void AOp::build(
-// DEF:        tblgen_state.addAttribute("aAttr", aAttr);
-// DEF:        tblgen_state.addAttribute("bAttr", bAttr);
+// DEF:        odsState.addAttribute("aAttr", aAttr);
+// DEF:        odsState.addAttribute("bAttr", bAttr);
 // DEF:        if (cAttr) {
-// DEF-NEXT:     tblgen_state.addAttribute("cAttr", cAttr);
+// DEF-NEXT:     odsState.addAttribute("cAttr", cAttr);
 
 // DEF:      void AOp::build(
 // DEF:        some-return-type aAttr, some-return-type bAttr, /*optional*/some-attr-kind cAttr
-// DEF:        tblgen_state.addAttribute("aAttr", some-const-builder-call((*tblgen_builder), aAttr));
+// DEF:        odsState.addAttribute("aAttr", some-const-builder-call((*odsBuilder), aAttr));
 
 // DEF:      void AOp::build(
 // DEF:        ArrayRef<NamedAttribute> attributes
-// DEF:      tblgen_state.addAttributes(attributes);
+// DEF:      odsState.addAttributes(attributes);
 
 // Test verify method
 // ---
@@ -218,7 +218,7 @@ def MixOperandsAndAttrs : NS_Op<"mix_operands_and_attrs", []> {
 // DEF-LABEL: MixOperandsAndAttrs definitions
 // DEF-DAG: Value MixOperandsAndAttrs::operand()
 // DEF-DAG: Value MixOperandsAndAttrs::otherArg()
-// DEF-DAG: void MixOperandsAndAttrs::build(Builder *tblgen_builder, OperationState &tblgen_state, FloatAttr attr, Value operand, FloatAttr otherAttr, Value otherArg)
+// DEF-DAG: void MixOperandsAndAttrs::build(Builder *odsBuilder, OperationState &odsState, FloatAttr attr, Value operand, FloatAttr otherAttr, Value otherArg)
 // DEF-DAG: APFloat MixOperandsAndAttrs::attr()
 // DEF-DAG: APFloat MixOperandsAndAttrs::otherAttr()
 
@@ -233,4 +233,4 @@ def UnitAttrOp : NS_Op<"unit_attr_op", []> {
 // DEF: bool UnitAttrOp::attr() {
 // DEF:   return {{.*}} != nullptr
 
-// DEF: build(Builder *tblgen_builder, OperationState &tblgen_state, /*optional*/UnitAttr attr)
+// DEF: build(Builder *odsBuilder, OperationState &odsState, /*optional*/UnitAttr attr)
index 74da938..61f0c56 100644 (file)
@@ -70,9 +70,9 @@ def NS_AOp : NS_Op<"a_op", [NoSideEffect, NoSideEffect]> {
 // CHECK:   FloatAttr attr2Attr()
 // CHECK:   Optional< APFloat > attr2();
 // CHECK:   static void build(Value val);
-// CHECK:   static void build(Builder *tblgen_builder, OperationState &tblgen_state, Type r, ArrayRef<Type> s, Value a, ValueRange b, IntegerAttr attr1, /*optional*/FloatAttr attr2)
-// CHECK:   static void build(Builder *tblgen_builder, OperationState &tblgen_state, Type r, ArrayRef<Type> s, Value a, ValueRange b, APInt attr1, /*optional*/FloatAttr attr2)
-// CHECK:   static void build(Builder *, OperationState &tblgen_state, ArrayRef<Type> resultTypes, ValueRange operands, ArrayRef<NamedAttribute> attributes)
+// CHECK:   static void build(Builder *odsBuilder, OperationState &odsState, Type r, ArrayRef<Type> s, Value a, ValueRange b, IntegerAttr attr1, /*optional*/FloatAttr attr2)
+// CHECK:   static void build(Builder *odsBuilder, OperationState &odsState, Type r, ArrayRef<Type> s, Value a, ValueRange b, APInt attr1, /*optional*/FloatAttr attr2)
+// CHECK:   static void build(Builder *, OperationState &odsState, ArrayRef<Type> resultTypes, ValueRange operands, ArrayRef<NamedAttribute> attributes)
 // CHECK:   static ParseResult parse(OpAsmParser &parser, OperationState &result);
 // CHECK:   void print(OpAsmPrinter &p);
 // CHECK:   LogicalResult verify();
index e2d5862..2ffde33 100644 (file)
@@ -19,12 +19,12 @@ def OpA : NS_Op<"one_normal_operand_op", []> {
 
 // CHECK:      void OpA::build
 // CHECK:        Value input
-// CHECK:        tblgen_state.addOperands(input);
+// CHECK:        odsState.addOperands(input);
 
 // CHECK:      void OpA::build
 // CHECK:        ValueRange operands
 // CHECK:        assert(operands.size() == 1u && "mismatched number of parameters");
-// CHECK:        tblgen_state.addOperands(operands);
+// CHECK:        odsState.addOperands(operands);
 
 def OpB : NS_Op<"one_variadic_operand_op", []> {
   let arguments = (ins Variadic<I32>:$input);
@@ -33,7 +33,7 @@ def OpB : NS_Op<"one_variadic_operand_op", []> {
 // CHECK-LABEL: OpB::build
 // CHECK:         ValueRange input
 // CHECK-NOT:     assert
-// CHECK:         tblgen_state.addOperands(input);
+// CHECK:         odsState.addOperands(input);
 
 def OpD : NS_Op<"mix_variadic_and_normal_inputs_op", [SameVariadicOperandSize]> {
   let arguments = (ins Variadic<AnyTensor>:$input1, AnyTensor:$input2, Variadic<AnyTensor>:$input3);
@@ -55,6 +55,6 @@ def OpD : NS_Op<"mix_variadic_and_normal_inputs_op", [SameVariadicOperandSize]>
 // CHECK-NEXT: return *getODSOperands(1).begin();
 
 // CHECK-LABEL: OpD::build
-// CHECK-NEXT: tblgen_state.addOperands(input1);
-// CHECK-NEXT: tblgen_state.addOperands(input2);
-// CHECK-NEXT: tblgen_state.addOperands(input3);
+// CHECK-NEXT: odsState.addOperands(input1);
+// CHECK-NEXT: odsState.addOperands(input2);
+// CHECK-NEXT: odsState.addOperands(input3);
index 6de6c1f..af88e0b 100644 (file)
@@ -15,7 +15,7 @@ def OpA : NS_Op<"one_normal_result_op", []> {
 // CHECK-LABEL: void OpA::build
 // CHECK:         ArrayRef<Type> resultTypes, ValueRange operands
 // CHECK:         assert(resultTypes.size() == 1u && "mismatched number of return types");
-// CHECK-NEXT:    tblgen_state.addTypes(resultTypes);
+// CHECK-NEXT:    odsState.addTypes(resultTypes);
 
 def OpB : NS_Op<"same_input_output_type_op", [SameOperandsAndResultType]> {
   let arguments = (ins I32:$x);
@@ -23,20 +23,20 @@ def OpB : NS_Op<"same_input_output_type_op", [SameOperandsAndResultType]> {
 }
 
 // CHECK-LABEL: OpB definitions
-// CHECK: void OpB::build(Builder *tblgen_builder, OperationState &tblgen_state, Type y, Value x)
-// CHECK:   tblgen_state.addTypes(y);
-// CHECK: void OpB::build(Builder *tblgen_builder, OperationState &tblgen_state, Value x)
-// CHECK:   tblgen_state.addTypes({x.getType()});
+// CHECK: void OpB::build(Builder *odsBuilder, OperationState &odsState, Type y, Value x)
+// CHECK:   odsState.addTypes(y);
+// CHECK: void OpB::build(Builder *odsBuilder, OperationState &odsState, Value x)
+// CHECK:   odsState.addTypes({x.getType()});
 
 def OpC : NS_Op<"three_normal_result_op", []> {
   let results = (outs I32:$x, /*unnamed*/I32, I32:$z);
 }
 
 // CHECK-LABEL: OpC definitions
-// CHECK:       void OpC::build(Builder *tblgen_builder, OperationState &tblgen_state, Type x, Type resultType1, Type z)
-// CHECK-NEXT:   tblgen_state.addTypes(x)
-// CHECK-NEXT:   tblgen_state.addTypes(resultType1)
-// CHECK-NEXT:   tblgen_state.addTypes(z)
+// CHECK:       void OpC::build(Builder *odsBuilder, OperationState &odsState, Type x, Type resultType1, Type z)
+// CHECK-NEXT:   odsState.addTypes(x)
+// CHECK-NEXT:   odsState.addTypes(resultType1)
+// CHECK-NEXT:   odsState.addTypes(z)
 
 def IntegerTypeAttr : TypeAttrBase<"IntegerType", "Integer type attribute">;
 def OpD : NS_Op<"type_attr_as_result_type", [FirstAttrDerivedResultType]> {
@@ -45,8 +45,8 @@ def OpD : NS_Op<"type_attr_as_result_type", [FirstAttrDerivedResultType]> {
 }
 
 // CHECK-LABEL: OpD definitions
-// CHECK: void OpD::build(Builder *, OperationState &tblgen_state, ValueRange operands, ArrayRef<NamedAttribute> attributes)
-// CHECK: tblgen_state.addTypes({attr.second.cast<TypeAttr>().getValue()});
+// CHECK: void OpD::build(Builder *, OperationState &odsState, ValueRange operands, ArrayRef<NamedAttribute> attributes)
+// CHECK: odsState.addTypes({attr.second.cast<TypeAttr>().getValue()});
 
 def OpE : NS_Op<"value_attr_as_result_type", [FirstAttrDerivedResultType]> {
   let arguments = (ins I32:$x, F32Attr:$attr);
@@ -54,8 +54,8 @@ def OpE : NS_Op<"value_attr_as_result_type", [FirstAttrDerivedResultType]> {
 }
 
 // CHECK-LABEL: OpE definitions
-// CHECK: void OpE::build(Builder *, OperationState &tblgen_state, ValueRange operands, ArrayRef<NamedAttribute> attributes)
-// CHECK: tblgen_state.addTypes({attr.second.getType()});
+// CHECK: void OpE::build(Builder *, OperationState &odsState, ValueRange operands, ArrayRef<NamedAttribute> attributes)
+// CHECK: odsState.addTypes({attr.second.getType()});
 
 def OpF : NS_Op<"one_variadic_result_op", []> {
   let results = (outs Variadic<I32>:$x);
@@ -64,7 +64,7 @@ def OpF : NS_Op<"one_variadic_result_op", []> {
 // CHECK-LABEL: void OpF::build
 // CHECK-SAME:    ArrayRef<Type> x
 // CHECK-NOT:     assert
-// CHECK:         tblgen_state.addTypes(x);
+// CHECK:         odsState.addTypes(x);
 
 def OpG : NS_Op<"one_normal_and_one_variadic_result_op", []> {
 
@@ -73,14 +73,14 @@ def OpG : NS_Op<"one_normal_and_one_variadic_result_op", []> {
 
 // CHECK-LABEL: OpG definitions
 
-// CHECK:      void OpG::build(Builder *tblgen_builder, OperationState &tblgen_state, Type x, ArrayRef<Type> y)
-// CHECK-NEXT:   tblgen_state.addTypes(x);
-// CHECK-NEXT:   tblgen_state.addTypes(y);
+// CHECK:      void OpG::build(Builder *odsBuilder, OperationState &odsState, Type x, ArrayRef<Type> y)
+// CHECK-NEXT:   odsState.addTypes(x);
+// CHECK-NEXT:   odsState.addTypes(y);
 
 // CHECK:       void OpG::build
 // CHECK:         ArrayRef<Type> resultTypes
 // CHECK:         assert(resultTypes.size() >= 1u && "mismatched number of return types");
-// CHECK-NEXT:    tblgen_state.addTypes(resultTypes);
+// CHECK-NEXT:    odsState.addTypes(resultTypes);
 
 def OpI : NS_Op<"mix_variadic_and_normal_results_op", [SameVariadicResultSize]> {
   let results = (outs Variadic<AnyTensor>:$output1, AnyTensor:$output2, Variadic<AnyTensor>:$output3);
@@ -93,9 +93,9 @@ def OpI : NS_Op<"mix_variadic_and_normal_results_op", [SameVariadicResultSize]>
 // CHECK-NEXT:    return *getODSResults(1).begin();
 
 // CHECK-LABEL: OpI::build
-// CHECK-NEXT:    tblgen_state.addTypes(output1);
-// CHECK-NEXT:    tblgen_state.addTypes(output2);
-// CHECK-NEXT:    tblgen_state.addTypes(output3);
+// CHECK-NEXT:    odsState.addTypes(output1);
+// CHECK-NEXT:    odsState.addTypes(output2);
+// CHECK-NEXT:    odsState.addTypes(output3);
 
 // Test that if the only operand is variadic, we access the first value in the
 // pack to set result type
@@ -105,5 +105,5 @@ def OpK : NS_Op<"only_input_is_variadic_with_same_value_type_op", [SameOperandsA
   let results = (outs AnyTensor:$result);
 }
 
-// CHECK-LABEL: OpK::build(Builder *tblgen_builder, OperationState &tblgen_state, ValueRange input)
-// CHECK: tblgen_state.addTypes({input.front().getType()});
+// CHECK-LABEL: OpK::build(Builder *odsBuilder, OperationState &odsState, ValueRange input)
+// CHECK: odsState.addTypes({input.front().getType()});
index e8c76f1..640f06d 100644 (file)
@@ -1,12 +1,23 @@
 // RUN: mlir-opt %s -test-return-type -split-input-file -verify-diagnostics | FileCheck %s --dump-input-on-failure
 
-// CHECK-LABEL: testReturnTypeOpInterface
-func @testReturnTypeOpInterface(%arg0 : tensor<10xf32>, %arg1 : tensor<20xi32>) {
-  %good = "test.op_with_infer_type_if"(%arg0, %arg0) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
-  // CHECK: test.op_with_infer_type_if
-  // CHECK-SAME: tensor<20xi32>
-  // CHECK: test.op_with_infer_type_if
-  // CHECK-SAME: tensor<10xf32>
+// CHECK-LABEL: testCreateFunctions
+// This function tests invoking the create method with different inference
+// methods. The attributes of the ops inside are used to test creation.
+func @testCreateFunctions(%arg0 : tensor<10xf32>, %arg1 : tensor<20xi32>) {
+// CHECK: "test.no_attributes"
+  %good = "test.no_attributes"(%arg0, %arg0) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
+// CHECK: "test.op_with_shaped_type_infer_type_if"
+// CHECK-SAME: (tensor<10xf32>, tensor<10xf32>) -> tensor<10x10xi17>
+// CHECK: "test.op_with_shaped_type_infer_type_if"
+// CHECK-SAME: (tensor<10xf32>, tensor<20xi32>) -> tensor<10x20xi17>
+// CHECK: "test.op_with_shaped_type_infer_type_if"
+// CHECK-SAME: (tensor<20xi32>, tensor<10xf32>) -> tensor<20x10xi17>
+// CHECK: "test.op_with_shaped_type_infer_type_if"
+// CHECK-SAME: (tensor<20xi32>, tensor<20xi32>) -> tensor<20x20xi17>
+// CHECK: "test.op_with_infer_type_if"
+// CHECK-SAME: (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
+// CHECK: "test.op_with_infer_type_if"
+// CHECK-SAME: (tensor<20xi32>, tensor<20xi32>) -> tensor<20xi32>
   return
 }
 
index 9241573..14402e9 100644 (file)
@@ -58,8 +58,8 @@ ODSDialectHookRegistration::ODSDialectHookRegistration(
 //===----------------------------------------------------------------------===//
 
 static const char *const tblgenNamePrefix = "tblgen_";
-static const char *const generatedArgName = "tblgen_arg";
-static const char *const builderOpState = "tblgen_state";
+static const char *const generatedArgName = "odsArg";
+static const char *const builderOpState = "odsState";
 
 // The logic to calculate the actual value range for a declared operand/result
 // of an op with variadic operands/results. Note that this logic is not for
@@ -627,8 +627,9 @@ void OpEmitter::genSeparateArgParamBuilder() {
       // TODO(jpienaar): Expand to handle regions.
       body << formatv(R"(
         SmallVector<Type, 2> inferedReturnTypes;
-        if (succeeded({0}::inferReturnTypes({1}.location, {1}.operands,
-                      {1}.attributes, /*regions=*/{{}, inferedReturnTypes)))
+        if (succeeded({0}::inferReturnTypes(odsBuilder->getContext(),
+                      {1}.location, {1}.operands, {1}.attributes,
+                      /*regions=*/{{}, inferedReturnTypes)))
           {1}.addTypes(inferedReturnTypes);
         else
           llvm::report_fatal_error("Failed to infer result type(s).");)",
@@ -702,7 +703,7 @@ void OpEmitter::genUseOperandAsResultTypeCollectiveParamBuilder() {
 void OpEmitter::genInferedTypeCollectiveParamBuilder() {
   // TODO(jpienaar): Expand to support regions.
   const char *params =
-      "Builder *builder, OperationState &{0}, "
+      "Builder *odsBuilder, OperationState &{0}, "
       "ValueRange operands, ArrayRef<NamedAttribute> attributes";
   auto &m =
       opClass.newMethod("void", "build", formatv(params, builderOpState).str(),
@@ -710,9 +711,10 @@ void OpEmitter::genInferedTypeCollectiveParamBuilder() {
   auto &body = m.body();
   body << formatv(R"(
     SmallVector<Type, 2> inferedReturnTypes;
-    if (succeeded({0}::inferReturnTypes({1}.location, operands, attributes,
+    if (succeeded({0}::inferReturnTypes(odsBuilder->getContext(),
+                  {1}.location, operands, attributes,
                   /*regions=*/{{}, inferedReturnTypes)))
-      build(builder, tblgen_state, inferedReturnTypes, operands, attributes);
+      build(odsBuilder, odsState, inferedReturnTypes, operands, attributes);
     else
       llvm::report_fatal_error("Failed to infer result type(s).");)",
                   opClass.getClassName(), builderOpState);
@@ -878,7 +880,7 @@ void OpEmitter::buildParamList(std::string &paramList,
   auto numResults = op.getNumResults();
   resultTypeNames.reserve(numResults);
 
-  paramList = "Builder *tblgen_builder, OperationState &";
+  paramList = "Builder *odsBuilder, OperationState &";
   paramList.append(builderOpState);
 
   switch (typeParamKind) {
@@ -1000,7 +1002,7 @@ void OpEmitter::genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body,
         // If this is a raw value, then we need to wrap it in an Attribute
         // instance.
         FmtContext fctx;
-        fctx.withBuilder("(*tblgen_builder)");
+        fctx.withBuilder("(*odsBuilder)");
         std::string value =
             tgfmt(attr.getConstBuilderTemplate(), &fctx, namedAttr.name);
         body << formatv("  {0}.addAttribute(\"{1}\", {2});\n", builderOpState,