[ods] Update Operator to record Arg->[Attr|Operand]Index mapping
authorJacques Pienaar <jpienaar@google.com>
Mon, 29 Jun 2020 23:40:52 +0000 (16:40 -0700)
committerJacques Pienaar <jpienaar@google.com>
Mon, 29 Jun 2020 23:40:52 +0000 (16:40 -0700)
Also fixed bug in type inferface generator to address bug where operands and
attributes are interleaved.

Differential Revision: https://reviews.llvm.org/D82819

mlir/include/mlir/TableGen/Operator.h
mlir/lib/TableGen/Operator.cpp
mlir/test/mlir-tblgen/op-result.td
mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp

index cce754d..8f567a7 100644 (file)
@@ -261,6 +261,22 @@ public:
   // Requires: all result types are known.
   ArrayRef<ArgOrType> getSameTypeAsResult(int index) const;
 
+  // Pair consisting kind of argument and index into operands or attributes.
+  struct OperandOrAttribute {
+    enum class Kind { Operand, Attribute };
+    OperandOrAttribute(Kind kind, int index) {
+      packed = (index << 1) & (kind == Kind::Attribute);
+    }
+    int operandOrAttributeIndex() const { return (packed >> 1); }
+    Kind kind() { return (packed & 0x1) ? Kind::Attribute : Kind::Operand; }
+
+  private:
+    int packed;
+  };
+
+  // Returns the OperandOrAttribute corresponding to the index.
+  OperandOrAttribute getArgToOperandOrAttribute(int index) const;
+
 private:
   // Populates the vectors containing operands, attributes, results and traits.
   void populateOpStructure();
@@ -303,6 +319,9 @@ private:
   // The argument with the same type as the result.
   SmallVector<SmallVector<ArgOrType, 2>, 4> resultTypeMapping;
 
+  // Map from argument to attribute or operand number.
+  SmallVector<OperandOrAttribute, 4> attrOrOperandMapping;
+
   // The number of native attributes stored in the leading positions of
   // `attributes`.
   int numNativeAttributes;
index 8350cd1..7e8b4d8 100644 (file)
@@ -436,9 +436,13 @@ void tblgen::Operator::populateOpStructure() {
       argDef = argDef->getValueAsDef("constraint");
 
     if (argDef->isSubClassOf(typeConstraintClass)) {
+      attrOrOperandMapping.push_back(
+          {OperandOrAttribute::Kind::Operand, operandIndex});
       arguments.emplace_back(&operands[operandIndex++]);
     } else {
       assert(argDef->isSubClassOf(attrClass));
+      attrOrOperandMapping.push_back(
+          {OperandOrAttribute::Kind::Attribute, attrIndex});
       arguments.emplace_back(&attributes[attrIndex++]);
     }
   }
@@ -581,3 +585,8 @@ auto tblgen::Operator::VariableDecoratorIterator::unwrap(llvm::Init *init)
     -> VariableDecorator {
   return VariableDecorator(cast<llvm::DefInit>(init)->getDef());
 }
+
+auto tblgen::Operator::getArgToOperandOrAttribute(int index) const
+    -> OperandOrAttribute {
+  return attrOrOperandMapping[index];
+}
index c9959c5..4b091e4 100644 (file)
@@ -1,6 +1,7 @@
 // RUN: mlir-tblgen -gen-op-defs -I %S/../../include %s | FileCheck %s
 
 include "mlir/IR/OpBase.td"
+include "mlir/Interfaces/InferTypeOpInterface.td"
 
 def Test_Dialect : Dialect {
   let name = "test";
@@ -111,3 +112,15 @@ def OpK : NS_Op<"only_input_is_variadic_with_same_value_type_op", [SameOperandsA
 
 // CHECK-LABEL: OpK::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::ValueRange input)
 // CHECK: odsState.addTypes({input.front().getType()});
+
+// Test with inferred shapes and interleaved with operands/attributes.
+//
+def OpL : NS_Op<"op_with_all_types_constraint",
+    [AllTypesMatch<["a", "b"]>]> {
+  let arguments = (ins I32Attr:$attr1, AnyType:$a);
+  let results = (outs Res<AnyType, "output b", []>:$b);
+}
+
+// CHECK-LABEL: LogicalResult OpL::inferReturnTypes
+// CHECK-NOT: }
+// CHECK: inferredReturnTypes[0] = operands[0].getType();
index 5f07704..fd0b4f4 100644 (file)
@@ -1601,7 +1601,12 @@ void OpEmitter::genTypeInterfaceMethods() {
     if (type.isArg()) {
       auto argIndex = type.getArg();
       assert(!op.getArg(argIndex).is<NamedAttribute *>());
-      return os << "operands[" << argIndex << "].getType()";
+      auto arg = op.getArgToOperandOrAttribute(argIndex);
+      if (arg.kind() == Operator::OperandOrAttribute::Kind::Operand)
+        return os << "operands[" << arg.operandOrAttributeIndex()
+                  << "].getType()";
+      return os << "attributes[" << arg.operandOrAttributeIndex()
+                << "].getType()";
     } else {
       return os << tgfmt(*type.getType().getBuilderCall(), &fctx);
     }