[mlir] Expand operand adapter to take attributes
authorJacques Pienaar <jpienaar@google.com>
Mon, 25 May 2020 03:42:58 +0000 (20:42 -0700)
committerJacques Pienaar <jpienaar@google.com>
Mon, 25 May 2020 04:06:47 +0000 (21:06 -0700)
* Enables using with more variadic sized operands;
* Generate convenience accessors for attributes;
  - The accessor are named the same as their name in ODS and returns attribute
    type (not convenience type) and no derived attributes.

This is first step to changing adapter to support verifying argument
constraints before the op is even created. This does not change the name of
adaptor nor does it require it except for ops with variadic operands to keep this change smaller.

Considered creating separate adapter but decided against that given operands also require attributes in general (and definitely for verification of operands and attributes).

Differential Revision: https://reviews.llvm.org/D80420

mlir/include/mlir/TableGen/OpClass.h
mlir/lib/TableGen/OpClass.cpp
mlir/test/mlir-tblgen/op-decl.td
mlir/test/mlir-tblgen/op-operand.td
mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp

index 8788a50..e8f73c6 100644 (file)
@@ -145,10 +145,6 @@ class OpClass : public Class {
 public:
   explicit OpClass(StringRef name, StringRef extraClassDeclaration = "");
 
-  // Sets whether this OpClass should generate the using directive for its
-  // associate operand adaptor class.
-  void setHasOperandAdaptorClass(bool has);
-
   // Adds an op trait.
   void addTrait(Twine trait);
 
@@ -160,7 +156,6 @@ private:
   StringRef extraClassDeclaration;
   SmallVector<std::string, 4> traitsVec;
   StringSet<> traitsSet;
-  bool hasOperandAdaptor;
 };
 
 } // namespace tblgen
index 26519df..bfdcbdc 100644 (file)
@@ -188,12 +188,7 @@ void tblgen::Class::writeDefTo(raw_ostream &os) const {
 //===----------------------------------------------------------------------===//
 
 tblgen::OpClass::OpClass(StringRef name, StringRef extraClassDeclaration)
-    : Class(name), extraClassDeclaration(extraClassDeclaration),
-      hasOperandAdaptor(true) {}
-
-void tblgen::OpClass::setHasOperandAdaptorClass(bool has) {
-  hasOperandAdaptor = has;
-}
+    : Class(name), extraClassDeclaration(extraClassDeclaration) {}
 
 void tblgen::OpClass::addTrait(Twine trait) {
   auto traitStr = trait.str();
@@ -207,8 +202,7 @@ void tblgen::OpClass::writeDeclTo(raw_ostream &os) const {
     os << ", " << trait;
   os << "> {\npublic:\n";
   os << "  using Op::Op;\n";
-  if (hasOperandAdaptor)
-    os << "  using OperandAdaptor = " << className << "OperandAdaptor;\n";
+  os << "  using OperandAdaptor = " << className << "OperandAdaptor;\n";
 
   bool hasPrivateMethod = false;
   for (const auto &method : methods) {
index 0b9bac2..c68d03c 100644 (file)
@@ -50,12 +50,14 @@ def NS_AOp : NS_Op<"a_op", [IsolatedFromAbove, IsolatedFromAbove]> {
 
 // CHECK: class AOpOperandAdaptor {
 // CHECK: public:
-// CHECK:   AOpOperandAdaptor(ArrayRef<Value> values);
+// CHECK:   AOpOperandAdaptor(ArrayRef<Value> values
 // CHECK:   ArrayRef<Value> getODSOperands(unsigned index);
 // CHECK:   Value a();
 // CHECK:   ArrayRef<Value> b();
+// CHECK:   IntegerAttr attr1();
+// CHECL:   FloatAttr attr2();
 // CHECK: private:
-// CHECK:   ArrayRef<Value> tblgen_operands;
+// CHECK:   ArrayRef<Value> odsOperands;
 // CHECK: };
 
 // CHECK: class AOp : public Op<AOp, OpTrait::AtLeastNRegions<1>::Impl, OpTrait::AtLeastNResults<1>::Impl, OpTrait::ZeroSuccessor, OpTrait::AtLeastNOperands<1>::Impl, OpTrait::IsIsolatedFromAbove
@@ -90,6 +92,29 @@ def NS_AOp : NS_Op<"a_op", [IsolatedFromAbove, IsolatedFromAbove]> {
 // CHECK:   void displayGraph();
 // CHECK: };
 
+// Check AttrSizedOperandSegments
+// ---
+
+def NS_AttrSizedOperandOp : NS_Op<"attr_sized_operands",
+                                 [AttrSizedOperandSegments]> {
+  let arguments = (ins
+    Variadic<I32>:$a,
+    Variadic<I32>:$b,
+    I32:$c,
+    Variadic<I32>:$d,
+    I32ElementsAttr:$operand_segment_sizes
+  );
+}
+
+// CHECK-LABEL: AttrSizedOperandOpOperandAdaptor(
+// CHECK-SAME:    ArrayRef<Value> values
+// CHECK-SAME:    DictionaryAttr attrs
+// CHECK:  ArrayRef<Value> a();
+// CHECK:  ArrayRef<Value> b();
+// CHECK:  Value c();
+// CHECK:  ArrayRef<Value> d();
+// CHECK:  DenseIntElementsAttr operand_segment_sizes();
+
 // Check op trait for different number of operands
 // ---
 
@@ -150,3 +175,4 @@ def _BOp : NS_Op<"_op_with_leading_underscore_and_no_namespace", []>;
 
 // CHECK-LABEL: _BOp declarations
 // CHECK: class _BOp : public Op<_BOp
+
index 2ffde33..5f0bfae 100644 (file)
@@ -15,7 +15,7 @@ def OpA : NS_Op<"one_normal_operand_op", []> {
 // CHECK-LABEL: OpA definitions
 
 // CHECK:      OpAOperandAdaptor::OpAOperandAdaptor
-// CHECK-NEXT: tblgen_operands = values
+// CHECK-NEXT: odsOperands = values
 
 // CHECK:      void OpA::build
 // CHECK:        Value input
index 8709760..2010262 100644 (file)
@@ -70,13 +70,19 @@ const char *sameVariadicSizeValueRangeCalcCode = R"(
 // (variadic or not).
 //
 // {0}: The name of the attribute specifying the segment sizes.
-const char *attrSizedSegmentValueRangeCalcCode = R"(
+const char *adapterSegmentSizeAttrInitCode = R"(
+  assert(odsAttrs && "missing segment size attribute for op");
+  auto sizeAttr = odsAttrs.get("{0}").cast<DenseIntElementsAttr>();
+)";
+const char *opSegmentSizeAttrInitCode = R"(
   auto sizeAttr = getAttrOfType<DenseIntElementsAttr>("{0}");
+)";
+const char *attrSizedSegmentValueRangeCalcCode = R"(
   unsigned start = 0;
   for (unsigned i = 0; i < index; ++i)
     start += (*(sizeAttr.begin() + i)).getZExtValue();
   unsigned size = (*(sizeAttr.begin() + index)).getZExtValue();
-  return {{start, size};
+  return {start, size};
 )";
 
 // The logic to build a range of either operand or result values.
@@ -496,15 +502,14 @@ static void
 generateValueRangeStartAndEnd(Class &opClass, StringRef methodName,
                               int numVariadic, int numNonVariadic,
                               StringRef rangeSizeCall, bool hasAttrSegmentSize,
-                              StringRef segmentSizeAttr, RangeT &&odsValues) {
+                              StringRef sizeAttrInit, RangeT &&odsValues) {
   auto &method = opClass.newMethod("std::pair<unsigned, unsigned>", methodName,
                                    "unsigned index");
 
   if (numVariadic == 0) {
     method.body() << "  return {index, 1};\n";
   } else if (hasAttrSegmentSize) {
-    method.body() << formatv(attrSizedSegmentValueRangeCalcCode,
-                             segmentSizeAttr);
+    method.body() << sizeAttrInit << attrSizedSegmentValueRangeCalcCode;
   } else {
     // Because the op can have arbitrarily interleaved variadic and non-variadic
     // operands, we need to embed a list in the "sink" getter method for
@@ -532,6 +537,7 @@ generateValueRangeStartAndEnd(Class &opClass, StringRef methodName,
 // of ops, in particular for one-operand ops that may not have the
 // `getOperand(unsigned)` method.
 static void generateNamedOperandGetters(const Operator &op, Class &opClass,
+                                        StringRef sizeAttrInit,
                                         StringRef rangeType,
                                         StringRef rangeBeginCall,
                                         StringRef rangeSizeCall,
@@ -563,10 +569,10 @@ static void generateNamedOperandGetters(const Operator &op, Class &opClass,
 
   // First emit a few "sink" getter methods upon which we layer all nicer named
   // getter methods.
-  generateValueRangeStartAndEnd(
-      opClass, "getODSOperandIndexAndLength", numVariadicOperands,
-      numNormalOperands, rangeSizeCall, attrSizedOperands,
-      "operand_segment_sizes", const_cast<Operator &>(op).getOperands());
+  generateValueRangeStartAndEnd(opClass, "getODSOperandIndexAndLength",
+                                numVariadicOperands, numNormalOperands,
+                                rangeSizeCall, attrSizedOperands, sizeAttrInit,
+                                const_cast<Operator &>(op).getOperands());
 
   auto &m = opClass.newMethod(rangeType, "getODSOperands", "unsigned index");
   m.body() << formatv(valueRangeReturnCode, rangeBeginCall,
@@ -574,7 +580,6 @@ static void generateNamedOperandGetters(const Operator &op, Class &opClass,
 
   // Then we emit nicer named getter methods by redirecting to the "sink" getter
   // method.
-
   for (int i = 0; i != numOperands; ++i) {
     const auto &operand = op.getOperand(i);
     if (operand.name.empty())
@@ -595,11 +600,11 @@ static void generateNamedOperandGetters(const Operator &op, Class &opClass,
 }
 
 void OpEmitter::genNamedOperandGetters() {
-  if (op.getTrait("OpTrait::AttrSizedOperandSegments"))
-    opClass.setHasOperandAdaptorClass(false);
-
   generateNamedOperandGetters(
-      op, opClass, /*rangeType=*/"Operation::operand_range",
+      op, opClass,
+      /*sizeAttrInit=*/
+      formatv(opSegmentSizeAttrInitCode, "operand_segment_sizes").str(),
+      /*rangeType=*/"Operation::operand_range",
       /*rangeBeginCall=*/"getOperation()->operand_begin()",
       /*rangeSizeCall=*/"getOperation()->getNumOperands()",
       /*getOperandCallPattern=*/"getOperation()->getOperand({0})");
@@ -656,7 +661,8 @@ void OpEmitter::genNamedResultGetters() {
   generateValueRangeStartAndEnd(
       opClass, "getODSResultIndexAndLength", numVariadicResults,
       numNormalResults, "getOperation()->getNumResults()", attrSizedResults,
-      "result_segment_sizes", op.getResults());
+      formatv(opSegmentSizeAttrInitCode, "result_segment_sizes").str(),
+      op.getResults());
   auto &m = opClass.newMethod("Operation::result_range", "getODSResults",
                               "unsigned index");
   m.body() << formatv(valueRangeReturnCode, "getOperation()->result_begin()",
@@ -1840,15 +1846,56 @@ private:
 
 OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(const Operator &op)
     : adapterClass(op.getCppClassName().str() + "OperandAdaptor") {
-  adapterClass.newField("ArrayRef<Value>", "tblgen_operands");
-  auto &constructor = adapterClass.newConstructor("ArrayRef<Value> values");
-  constructor.body() << "  tblgen_operands = values;\n";
-
-  generateNamedOperandGetters(op, adapterClass,
+  adapterClass.newField("ArrayRef<Value>", "odsOperands");
+  adapterClass.newField("DictionaryAttr", "odsAttrs");
+  const auto *attrSizedOperands =
+      op.getTrait("OpTrait::AttrSizedOperandSegments");
+  auto &constructor = adapterClass.newConstructor(
+      attrSizedOperands
+          ? "ArrayRef<Value> values, DictionaryAttr attrs"
+          : "ArrayRef<Value> values, DictionaryAttr attrs = nullptr");
+  constructor.body() << "  odsOperands = values;\n";
+  constructor.body() << "  odsAttrs = attrs;\n";
+
+  std::string sizeAttrInit =
+      formatv(adapterSegmentSizeAttrInitCode, "operand_segment_sizes");
+  generateNamedOperandGetters(op, adapterClass, sizeAttrInit,
                               /*rangeType=*/"ArrayRef<Value>",
-                              /*rangeBeginCall=*/"tblgen_operands.begin()",
-                              /*rangeSizeCall=*/"tblgen_operands.size()",
-                              /*getOperandCallPattern=*/"tblgen_operands[{0}]");
+                              /*rangeBeginCall=*/"odsOperands.begin()",
+                              /*rangeSizeCall=*/"odsOperands.size()",
+                              /*getOperandCallPattern=*/"odsOperands[{0}]");
+
+  FmtContext fctx;
+  fctx.withBuilder("mlir::Builder(odsAttrs.getContext())");
+
+  auto emitAttr = [&](StringRef name, Attribute attr) {
+    auto &body = adapterClass.newMethod(attr.getStorageType(), name).body();
+    body << "  assert(odsAttrs && \"no attributes when constructing adapter\");"
+         << "\n  " << attr.getStorageType() << " attr = "
+         << "odsAttrs.get(\"" << name << "\").";
+    if (attr.hasDefaultValue() || attr.isOptional())
+      body << "dyn_cast_or_null<";
+    else
+      body << "cast<";
+    body << attr.getStorageType() << ">();\n";
+
+    if (attr.hasDefaultValue()) {
+      // Use the default value if attribute is not set.
+      // TODO: this is inefficient, we are recreating the attribute for every
+      // call. This should be set instead.
+      std::string defaultValue = std::string(
+          tgfmt(attr.getConstBuilderTemplate(), &fctx, attr.getDefaultValue()));
+      body << "  if (!attr)\n    attr = " << defaultValue << ";\n";
+    }
+    body << "  return attr;\n";
+  };
+
+  for (auto &namedAttr : op.getAttributes()) {
+    const auto &name = namedAttr.name;
+    const auto &attr = namedAttr.attr;
+    if (!attr.isDerivedAttr())
+      emitAttr(name, attr);
+  }
 }
 
 void OpOperandAdaptorEmitter::emitDecl(const Operator &op, raw_ostream &os) {
@@ -1873,19 +1920,13 @@ static void emitOpClasses(const std::vector<Record *> &defs, raw_ostream &os,
   }
   for (auto *def : defs) {
     Operator op(*def);
-    const auto *attrSizedOperands =
-        op.getTrait("OpTrait::AttrSizedOperandSegments");
     if (emitDecl) {
       os << formatv(opCommentHeader, op.getQualCppClassName(), "declarations");
-      // We cannot generate the operand adaptor class if operand getters depend
-      // on an attribute.
-      if (!attrSizedOperands)
-        OpOperandAdaptorEmitter::emitDecl(op, os);
+      OpOperandAdaptorEmitter::emitDecl(op, os);
       OpEmitter::emitDecl(op, os);
     } else {
       os << formatv(opCommentHeader, op.getQualCppClassName(), "definitions");
-      if (!attrSizedOperands)
-        OpOperandAdaptorEmitter::emitDef(op, os);
+      OpOperandAdaptorEmitter::emitDef(op, os);
       OpEmitter::emitDef(op, os);
     }
   }