[mlir] Expand prefixing to OpFormatGen
authorJacques Pienaar <jpienaar@google.com>
Wed, 20 Oct 2021 14:08:36 +0000 (07:08 -0700)
committerJacques Pienaar <jpienaar@google.com>
Wed, 20 Oct 2021 14:08:37 +0000 (07:08 -0700)
Follow up to also use the prefixed emitters in OpFormatGen (moved
getGetterName(s) and getSetterName(s) to Operator as that is most
convenient usage wise even though it just depends on Dialect). Prefix
accessors in Test dialect and follow up on missed changes in
OpDefinitionsGen.

Differential Revision: https://reviews.llvm.org/D112118

mlir/include/mlir/TableGen/Operator.h
mlir/lib/TableGen/Operator.cpp
mlir/test/lib/Dialect/Test/TestDialect.cpp
mlir/test/lib/Dialect/Test/TestOps.td
mlir/test/lib/Dialect/Test/TestPatterns.cpp
mlir/test/lib/Transforms/TestInlining.cpp
mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
mlir/tools/mlir-tblgen/OpFormatGen.cpp

index a9fb93b..e2bd3bb 100644 (file)
@@ -294,6 +294,17 @@ public:
   // Returns the builders of this operation.
   ArrayRef<Builder> getBuilders() const { return builders; }
 
+  // Returns the preferred getter name for the accessor.
+  std::string getGetterName(StringRef name) const {
+    return getGetterNames(name).front();
+  }
+
+  // Returns the getter names for the accessor.
+  SmallVector<std::string, 2> getGetterNames(StringRef name) const;
+
+  // Returns the setter names for the accessor.
+  SmallVector<std::string, 2> getSetterNames(StringRef name) const;
+
 private:
   // Populates the vectors containing operands, attributes, results and traits.
   void populateOpStructure();
index 57c957f..69e6787 100644 (file)
@@ -21,6 +21,7 @@
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Support/Debug.h"
+#include "llvm/Support/ErrorHandling.h"
 #include "llvm/Support/FormatVariadic.h"
 #include "llvm/TableGen/Error.h"
 #include "llvm/TableGen/Record.h"
@@ -642,3 +643,57 @@ auto Operator::getArgToOperandOrAttribute(int index) const
     -> OperandOrAttribute {
   return attrOrOperandMapping[index];
 }
+
+// Helper to return the names for accessor.
+static SmallVector<std::string, 2>
+getGetterOrSetterNames(bool isGetter, const Operator &op, StringRef name) {
+  Dialect::EmitPrefix prefixType = op.getDialect().getEmitAccessorPrefix();
+  std::string prefix;
+  if (prefixType != Dialect::EmitPrefix::Raw)
+    prefix = isGetter ? "get" : "set";
+
+  SmallVector<std::string, 2> names;
+  bool rawToo = prefixType == Dialect::EmitPrefix::Both;
+
+  auto skip = [&](StringRef newName) {
+    bool shouldSkip = newName == "getOperands";
+    if (!shouldSkip)
+      return false;
+
+    // This note could be avoided where the final function generated would
+    // have been identical. But preferably in the op definition avoiding using
+    // the generic name and then getting a more specialize type is better.
+    PrintNote(op.getLoc(),
+              "Skipping generation of prefixed accessor `" + newName +
+                  "` as it overlaps with default one; generating raw form (`" +
+                  name + "`) still");
+    return true;
+  };
+
+  if (!prefix.empty()) {
+    names.push_back(
+        prefix + convertToCamelFromSnakeCase(name, /*capitalizeFirst=*/true));
+    // Skip cases which would overlap with default ones for now.
+    if (skip(names.back())) {
+      rawToo = true;
+      names.clear();
+    } else {
+      LLVM_DEBUG(llvm::errs() << "WITH_GETTER(\"" << op.getQualCppClassName()
+                              << "::" << names.back() << "\");\n"
+                              << "WITH_GETTER(\"" << op.getQualCppClassName()
+                              << "Adaptor::" << names.back() << "\");\n";);
+    }
+  }
+
+  if (prefix.empty() || rawToo)
+    names.push_back(name.str());
+  return names;
+}
+
+SmallVector<std::string, 2> Operator::getGetterNames(StringRef name) const {
+  return getGetterOrSetterNames(/*isGetter=*/true, *this, name);
+}
+
+SmallVector<std::string, 2> Operator::getSetterNames(StringRef name) const {
+  return getGetterOrSetterNames(/*isGetter=*/false, *this, name);
+}
index 6040823..6ee1e79 100644 (file)
@@ -339,7 +339,7 @@ TestDialect::getOperationPrinter(Operation *op) const {
 Optional<MutableOperandRange>
 TestBranchOp::getMutableSuccessorOperands(unsigned index) {
   assert(index == 0 && "invalid successor index");
-  return targetOperandsMutable();
+  return getTargetOperandsMutable();
 }
 
 //===----------------------------------------------------------------------===//
@@ -369,7 +369,7 @@ struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> {
 
   LogicalResult matchAndRewrite(FoldToCallOp op,
                                 PatternRewriter &rewriter) const override {
-    rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), op.calleeAttr(),
+    rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), op.getCalleeAttr(),
                                         ValueRange());
     return success();
   }
@@ -597,8 +597,8 @@ static ParseResult parseIsolatedRegionOp(OpAsmParser &parser,
 static void print(OpAsmPrinter &p, IsolatedRegionOp op) {
   p << "test.isolated_region ";
   p.printOperand(op.getOperand());
-  p.shadowRegionArgs(op.region(), op.getOperand());
-  p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
+  p.shadowRegionArgs(op.getRegion(), op.getOperand());
+  p.printRegion(op.getRegion(), /*printEntryBlockArgs=*/false);
 }
 
 //===----------------------------------------------------------------------===//
@@ -622,7 +622,7 @@ static ParseResult parseGraphRegionOp(OpAsmParser &parser,
 
 static void print(OpAsmPrinter &p, GraphRegionOp op) {
   p << "test.graph_region ";
-  p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
+  p.printRegion(op.getRegion(), /*printEntryBlockArgs=*/false);
 }
 
 RegionKind GraphRegionOp::getRegionKind(unsigned index) {
@@ -642,7 +642,7 @@ static ParseResult parseAffineScopeOp(OpAsmParser &parser,
 
 static void print(OpAsmPrinter &p, AffineScopeOp op) {
   p << "test.affine_scope ";
-  p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
+  p.printRegion(op.getRegion(), /*printEntryBlockArgs=*/false);
 }
 
 //===----------------------------------------------------------------------===//
@@ -678,7 +678,7 @@ static ParseResult parseParseWrappedKeywordOp(OpAsmParser &parser,
 }
 
 static void print(OpAsmPrinter &p, ParseWrappedKeywordOp op) {
-  p << " " << op.keyword();
+  p << " " << op.getKeyword();
 }
 
 //===----------------------------------------------------------------------===//
@@ -717,7 +717,7 @@ static ParseResult parseWrappingRegionOp(OpAsmParser &parser,
 
 static void print(OpAsmPrinter &p, WrappingRegionOp op) {
   p << " wraps ";
-  p.printGenericOp(&op.region().front().front());
+  p.printGenericOp(&op.getRegion().front().front());
 }
 
 //===----------------------------------------------------------------------===//
@@ -762,7 +762,7 @@ void TestOpWithRegionPattern::getCanonicalizationPatterns(
 }
 
 OpFoldResult TestOpWithRegionFold::fold(ArrayRef<Attribute> operands) {
-  return operand();
+  return getOperand();
 }
 
 OpFoldResult TestOpConstant::fold(ArrayRef<Attribute> operands) {
@@ -971,7 +971,7 @@ static void print(OpAsmPrinter &p, StringAttrPrettyNameOp op) {
   // Note that we only need to print the "name" attribute if the asmprinter
   // result name disagrees with it.  This can happen in strange cases, e.g.
   // when there are conflicts.
-  bool namesDisagree = op.names().size() != op.getNumResults();
+  bool namesDisagree = op.getNames().size() != op.getNumResults();
 
   SmallString<32> resultNameStr;
   for (size_t i = 0, e = op.getNumResults(); i != e && !namesDisagree; ++i) {
@@ -979,7 +979,7 @@ static void print(OpAsmPrinter &p, StringAttrPrettyNameOp op) {
     llvm::raw_svector_ostream tmpStream(resultNameStr);
     p.printOperand(op.getResult(i), tmpStream);
 
-    auto expectedName = op.names()[i].dyn_cast<StringAttr>();
+    auto expectedName = op.getNames()[i].dyn_cast<StringAttr>();
     if (!expectedName ||
         tmpStream.str().drop_front() != expectedName.getValue()) {
       namesDisagree = true;
@@ -997,7 +997,7 @@ static void print(OpAsmPrinter &p, StringAttrPrettyNameOp op) {
 void StringAttrPrettyNameOp::getAsmResultNames(
     function_ref<void(Value, StringRef)> setNameFn) {
 
-  auto value = names();
+  auto value = getNames();
   for (size_t i = 0, e = value.size(); i != e; ++i)
     if (auto str = value[i].dyn_cast<StringAttr>())
       if (!str.getValue().empty())
@@ -1014,15 +1014,15 @@ static void print(OpAsmPrinter &p, RegionIfOp op) {
   p << ": " << op.getOperandTypes();
   p.printArrowTypeList(op.getResultTypes());
   p << " then";
-  p.printRegion(op.thenRegion(),
+  p.printRegion(op.getThenRegion(),
                 /*printEntryBlockArgs=*/true,
                 /*printBlockTerminators=*/true);
   p << " else";
-  p.printRegion(op.elseRegion(),
+  p.printRegion(op.getElseRegion(),
                 /*printEntryBlockArgs=*/true,
                 /*printBlockTerminators=*/true);
   p << " join";
-  p.printRegion(op.joinRegion(),
+  p.printRegion(op.getJoinRegion(),
                 /*printEntryBlockArgs=*/true,
                 /*printBlockTerminators=*/true);
 }
@@ -1064,15 +1064,15 @@ void RegionIfOp::getSuccessorRegions(
   // We always branch to the join region.
   if (index.hasValue()) {
     if (index.getValue() < 2)
-      regions.push_back(RegionSuccessor(&joinRegion(), getJoinArgs()));
+      regions.push_back(RegionSuccessor(&getJoinRegion(), getJoinArgs()));
     else
       regions.push_back(RegionSuccessor(getResults()));
     return;
   }
 
   // The then and else regions are the entry regions of this op.
-  regions.push_back(RegionSuccessor(&thenRegion(), getThenArgs()));
-  regions.push_back(RegionSuccessor(&elseRegion(), getElseArgs()));
+  regions.push_back(RegionSuccessor(&getThenRegion(), getThenArgs()));
+  regions.push_back(RegionSuccessor(&getElseRegion(), getElseArgs()));
 }
 
 //===----------------------------------------------------------------------===//
index e7144cc..3a7bd50 100644 (file)
@@ -26,9 +26,7 @@ include "TestInterfaces.td"
 def Test_Dialect : Dialect {
   let name = "test";
   let cppNamespace = "::test";
-  // Temporarily flipping to _Both (given this is test only/not intended for
-  // general use, this won't be following the 2 week process here).
-  let emitAccessorPrefix = kEmitAccessorPrefix_Both;
+  let emitAccessorPrefix = kEmitAccessorPrefix_Prefixed;
   let hasCanonicalizer = 1;
   let hasConstantMaterializer = 1;
   let hasOperationAttrVerify = 1;
@@ -305,9 +303,9 @@ def RankedIntElementsAttrOp : TEST_Op<"ranked_int_elements_attr"> {
 def DerivedTypeAttrOp : TEST_Op<"derived_type_attr", []> {
   let results = (outs AnyTensor:$output);
   DerivedTypeAttr element_dtype =
-    DerivedTypeAttr<"return getElementTypeOrSelf(output().getType());">;
+    DerivedTypeAttr<"return getElementTypeOrSelf(getOutput().getType());">;
   DerivedAttr size = DerivedAttr<"int",
-    "return output().getType().cast<ShapedType>().getSizeInBits();",
+    "return getOutput().getType().cast<ShapedType>().getSizeInBits();",
     "$_builder.getI32IntegerAttr($_self)">;
 }
 
@@ -374,13 +372,10 @@ def VariadicNoTerminatorOp : TEST_Op<"variadic_no_terminator_op",
 
 def ConversionCallOp : TEST_Op<"conversion_call_op",
     [CallOpInterface]> {
-  let arguments = (ins Variadic<AnyType>:$inputs, SymbolRefAttr:$callee);
+  let arguments = (ins Variadic<AnyType>:$arg_operands, SymbolRefAttr:$callee);
   let results = (outs Variadic<AnyType>);
 
   let extraClassDeclaration = [{
-    /// Get the argument operands to the called function.
-    operand_range getArgOperands() { return inputs(); }
-
     /// Return the callee of this operation.
     ::mlir::CallInterfaceCallable getCallableForCallee() {
       return (*this)->getAttrOfType<::mlir::SymbolRefAttr>("callee");
@@ -394,7 +389,7 @@ def FunctionalRegionOp : TEST_Op<"functional_region_op",
   let results = (outs FunctionType);
 
   let extraClassDeclaration = [{
-    ::mlir::Region *getCallableRegion() { return &body(); }
+    ::mlir::Region *getCallableRegion() { return &getBody(); }
     ::llvm::ArrayRef<::mlir::Type> getCallableResults() {
       return getType().cast<::mlir::FunctionType>().getResults();
     }
@@ -673,7 +668,7 @@ def AttrWithTraitOp : TEST_Op<"attr_with_trait", []> {
   let arguments = (ins AnyAttr:$attr);
 
   let verifier = [{
-    if (this->attr().hasTrait<AttributeTrait::TestAttrTrait>())
+    if (this->getAttr().hasTrait<AttributeTrait::TestAttrTrait>())
       return success();
     return this->emitError("'attr' attribute should have trait 'TestAttrTrait'");
   }];
@@ -2340,6 +2335,10 @@ def TestLinalgConvOp :
     std::string getLibraryCallName() {
       return "";
     }
+
+    // To conform with interface requirement on operand naming.
+    mlir::ValueRange inputs() { return getInputs(); }
+    mlir::ValueRange outputs() { return getOutputs(); }
   }];
 }
 
index c995ee7..3ebebe0 100644 (file)
@@ -32,8 +32,8 @@ static void createOpI(PatternRewriter &rewriter, Location loc, Value input) {
 static void handleNoResultOp(PatternRewriter &rewriter,
                              OpSymbolBindingNoResult op) {
   // Turn the no result op to a one-result op.
-  rewriter.create<OpSymbolBindingB>(op.getLoc(), op.operand().getType(),
-                                    op.operand());
+  rewriter.create<OpSymbolBindingB>(op.getLoc(), op.getOperand().getType(),
+                                    op.getOperand());
 }
 
 static bool getFirstI32Result(Operation *op, Value &value) {
@@ -531,7 +531,7 @@ struct TestBoundedRecursiveRewrite
                                 PatternRewriter &rewriter) const final {
     // Decrement the depth of the op in-place.
     rewriter.updateRootInPlace(op, [&] {
-      op->setAttr("depth", rewriter.getI64IntegerAttr(op.depth() - 1));
+      op->setAttr("depth", rewriter.getI64IntegerAttr(op.getDepth() - 1));
     });
     return success();
   }
@@ -705,7 +705,7 @@ struct TestLegalizePatternDriver
 
     // Mark the bound recursion operation as dynamically legal.
     target.addDynamicallyLegalOp<TestRecursiveRewriteOp>(
-        [](TestRecursiveRewriteOp op) { return op.depth() == 0; });
+        [](TestRecursiveRewriteOp op) { return op.getDepth() == 0; });
 
     // Handle a partial conversion.
     if (mode == ConversionMode::Partial) {
@@ -1026,9 +1026,9 @@ struct TestMergeBlock : public OpConversionPattern<TestMergeBlocksOp> {
   LogicalResult
   matchAndRewrite(TestMergeBlocksOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const final {
-    Block &firstBlock = op.body().front();
+    Block &firstBlock = op.getBody().front();
     Operation *branchOp = firstBlock.getTerminator();
-    Block *secondBlock = &*(std::next(op.body().begin()));
+    Block *secondBlock = &*(std::next(op.getBody().begin()));
     auto succOperands = branchOp->getOperands();
     SmallVector<Value, 2> replacements(succOperands);
     rewriter.eraseOp(branchOp);
@@ -1073,7 +1073,7 @@ struct TestMergeSingleBlockOps
         op->getParentOfType<SingleBlockImplicitTerminatorOp>();
     if (!parentOp)
       return failure();
-    Block &innerBlock = op.region().front();
+    Block &innerBlock = op.getRegion().front();
     TerminatorOp innerTerminator =
         cast<TerminatorOp>(innerBlock.getTerminator());
     rewriter.mergeBlockBefore(&innerBlock, op);
@@ -1104,7 +1104,7 @@ struct TestMergeBlocksPatternDriver
     /// Expect the op to have a single block after legalization.
     target.addDynamicallyLegalOp<TestMergeBlocksOp>(
         [&](TestMergeBlocksOp op) -> bool {
-          return llvm::hasSingleElement(op.body());
+          return llvm::hasSingleElement(op.getBody());
         });
 
     /// Only allow `test.br` within test.merge_blocks op.
index c88ee9e..3a761fb 100644 (file)
@@ -51,7 +51,7 @@ struct Inliner : public PassWrapper<Inliner, FunctionPass> {
       // Inline the functional region operation, but only clone the internal
       // region if there is more than one use.
       if (failed(inlineRegion(
-              interface, &callee.body(), caller, caller.getArgOperands(),
+              interface, &callee.getBody(), caller, caller.getArgOperands(),
               caller.getResults(), caller.getLoc(),
               /*shouldCloneInlinedRegion=*/!callee.getResult().hasOneUse())))
         continue;
index 8ee897c..e5d2439 100644 (file)
@@ -358,10 +358,6 @@ private:
 
   // The emitter containing all of the locally emitted verification functions.
   const StaticVerifierFunctionEmitter &staticVerifierEmitter;
-
-  // A map of attribute names (including implicit attributes) registered to the
-  // current operation, to the relative order in which they were registered.
-  llvm::MapVector<StringRef, unsigned> attributeNames;
 };
 } // end anonymous namespace
 
@@ -525,62 +521,6 @@ void OpEmitter::emitDecl(raw_ostream &os) { opClass.writeDeclTo(os); }
 
 void OpEmitter::emitDef(raw_ostream &os) { opClass.writeDefTo(os); }
 
-// Helper to return the names for accessor.
-static SmallVector<std::string, 2>
-getGetterOrSetterNames(bool isGetter, const Operator &op, StringRef name) {
-  Dialect::EmitPrefix prefixType = op.getDialect().getEmitAccessorPrefix();
-  std::string prefix;
-  if (prefixType != Dialect::EmitPrefix::Raw)
-    prefix = isGetter ? "get" : "set";
-
-  SmallVector<std::string, 2> names;
-  bool rawToo = prefixType == Dialect::EmitPrefix::Both;
-
-  auto skip = [&](StringRef newName) {
-    bool shouldSkip = newName == "getOperands";
-    if (!shouldSkip)
-      return false;
-
-    // This note could be avoided where the final function generated would
-    // have been identical. But preferably in the op definition avoiding using
-    // the generic name and then getting a more specialize type is better.
-    PrintNote(op.getLoc(),
-              "Skipping generation of prefixed accessor `" + newName +
-                  "` as it overlaps with default one; generating raw form (`" +
-                  name + "`) still");
-    return true;
-  };
-
-  if (!prefix.empty()) {
-    names.push_back(prefix + convertToCamelFromSnakeCase(name, true));
-    // Skip cases which would overlap with default ones for now.
-    if (skip(names.back())) {
-      rawToo = true;
-      names.clear();
-    } else {
-      LLVM_DEBUG(llvm::errs() << "WITH_GETTER(\"" << op.getQualCppClassName()
-                              << "::" << names.back() << "\");\n"
-                              << "WITH_GETTER(\"" << op.getQualCppClassName()
-                              << "Adaptor::" << names.back() << "\");\n";);
-    }
-  }
-
-  if (prefix.empty() || rawToo)
-    names.push_back(name.str());
-  return names;
-}
-static SmallVector<std::string, 2> getGetterNames(const Operator &op,
-                                                  StringRef name) {
-  return getGetterOrSetterNames(/*isGetter=*/true, op, name);
-}
-static std::string getGetterName(const Operator &op, StringRef name) {
-  return getGetterOrSetterNames(/*isGetter=*/true, op, name).front();
-}
-static SmallVector<std::string, 2> getSetterNames(const Operator &op,
-                                                  StringRef name) {
-  return getGetterOrSetterNames(/*isGetter=*/false, op, name);
-}
-
 static void errorIfPruned(size_t line, OpMethod *m, const Twine &methodName,
                           const Operator &op) {
   if (m)
@@ -593,6 +533,10 @@ static void errorIfPruned(size_t line, OpMethod *m, const Twine &methodName,
 #define ERROR_IF_PRUNED(M, N, O) errorIfPruned(__LINE__, M, N, O)
 
 void OpEmitter::genAttrNameGetters() {
+  // A map of attribute names (including implicit attributes) registered to the
+  // current operation, to the relative order in which they were registered.
+  llvm::MapVector<StringRef, unsigned> attributeNames;
+
   // Enumerate the attribute names of this op, assigning each a relative
   // ordering.
   auto addAttrName = [&](StringRef name) {
@@ -602,10 +546,12 @@ void OpEmitter::genAttrNameGetters() {
   for (const NamedAttribute &namedAttr : op.getAttributes())
     addAttrName(namedAttr.name);
   // Include key attributes from several traits as implicitly registered.
+  std::string operandSizes = "operand_segment_sizes";
   if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments"))
-    addAttrName("operand_segment_sizes");
+    addAttrName(operandSizes);
+  std::string attrSizes = "result_segment_sizes";
   if (op.getTrait("::mlir::OpTrait::AttrSizedResultSegments"))
-    addAttrName("result_segment_sizes");
+    addAttrName(attrSizes);
 
   // Emit the getAttributeNames method.
   {
@@ -656,7 +602,7 @@ void OpEmitter::genAttrNameGetters() {
   // users.
   const char *attrNameMethodBody = "  return getAttributeNameForIndex({0});";
   for (const std::pair<StringRef, unsigned> &attrIt : attributeNames) {
-    for (StringRef name : getGetterNames(op, attrIt.first)) {
+    for (StringRef name : op.getGetterNames(attrIt.first)) {
       std::string methodName = (name + "AttrName").str();
 
       // Generate the non-static variant.
@@ -734,7 +680,7 @@ void OpEmitter::genAttrGetters() {
   };
 
   for (const NamedAttribute &namedAttr : op.getAttributes()) {
-    for (StringRef name : getGetterNames(op, namedAttr.name)) {
+    for (StringRef name : op.getGetterNames(namedAttr.name)) {
       if (namedAttr.attr.isDerivedAttr()) {
         emitDerivedAttr(name, namedAttr.attr);
       } else {
@@ -777,8 +723,9 @@ void OpEmitter::genAttrGetters() {
       if (!nonMaterializable.empty()) {
         std::string attrs;
         llvm::raw_string_ostream os(attrs);
-        interleaveComma(nonMaterializable, os,
-                        [&](const NamedAttribute &attr) { os << attr.name; });
+        interleaveComma(nonMaterializable, os, [&](const NamedAttribute &attr) {
+          os << op.getGetterName(attr.name);
+        });
         PrintWarning(
             op.getLoc(),
             formatv(
@@ -799,8 +746,9 @@ void OpEmitter::genAttrGetters() {
           derivedAttrs, body,
           [&](const NamedAttribute &namedAttr) {
             auto tmpl = namedAttr.attr.getConvertFromStorageCall();
-            body << "    {" << namedAttr.name << "AttrName(),\n"
-                 << tgfmt(tmpl, &fctx.withSelf(namedAttr.name + "()")
+            std::string name = op.getGetterName(namedAttr.name);
+            body << "    {" << name << "AttrName(),\n"
+                 << tgfmt(tmpl, &fctx.withSelf(name + "()")
                                      .withBuilder("odsBuilder")
                                      .addSubst("_ctx", "ctx"))
                  << "}";
@@ -826,8 +774,8 @@ void OpEmitter::genAttrSetters() {
 
   for (const NamedAttribute &namedAttr : op.getAttributes()) {
     if (!namedAttr.attr.isDerivedAttr())
-      for (auto names : llvm::zip(getSetterNames(op, namedAttr.name),
-                                  getGetterNames(op, namedAttr.name)))
+      for (auto names : llvm::zip(op.getSetterNames(namedAttr.name),
+                                  op.getGetterNames(namedAttr.name)))
         emitAttrWithStorageType(std::get<0>(names), std::get<1>(names),
                                 namedAttr.attr);
   }
@@ -843,7 +791,7 @@ void OpEmitter::genOptionalAttrRemovers() {
         "::mlir::Attribute", ("remove" + upperInitial + suffix + "Attr").str());
     if (!method)
       return;
-    method->body() << "  return (*this)->removeAttr(" << getGetterName(op, name)
+    method->body() << "  return (*this)->removeAttr(" << op.getGetterName(name)
                    << "AttrName());";
   };
 
@@ -945,7 +893,7 @@ static void generateNamedOperandGetters(const Operator &op, Class &opClass,
     const auto &operand = op.getOperand(i);
     if (operand.name.empty())
       continue;
-    for (StringRef name : getGetterNames(op, operand.name)) {
+    for (StringRef name : op.getGetterNames(operand.name)) {
       if (operand.isOptional()) {
         m = opClass.addMethodAndPrune("::mlir::Value", name);
         ERROR_IF_PRUNED(m, name, op);
@@ -953,8 +901,8 @@ static void generateNamedOperandGetters(const Operator &op, Class &opClass,
                   << "  return operands.empty() ? ::mlir::Value() : "
                      "*operands.begin();";
       } else if (operand.isVariadicOfVariadic()) {
-        StringRef segmentAttr =
-            operand.constraint.getVariadicOfVariadicSegmentSizeAttr();
+        std::string segmentAttr = op.getGetterName(
+            operand.constraint.getVariadicOfVariadicSegmentSizeAttr());
         if (isAdaptor) {
           m = opClass.addMethodAndPrune(
               "::llvm::SmallVector<::mlir::ValueRange>", name);
@@ -982,13 +930,12 @@ static void generateNamedOperandGetters(const Operator &op, Class &opClass,
 }
 
 void OpEmitter::genNamedOperandGetters() {
-  // Build the code snippet used for initializing the operand_segment_sizes
+  // Build the code snippet used for initializing the operand_segment_size)s
   // array.
   std::string attrSizeInitCode;
   if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) {
-    attrSizeInitCode =
-        formatv(opSegmentSizeAttrInitCode, "operand_segment_sizesAttrName()")
-            .str();
+    std::string attr = op.getGetterName("operand_segment_sizes") + "AttrName()";
+    attrSizeInitCode = formatv(opSegmentSizeAttrInitCode, attr).str();
   }
 
   generateNamedOperandGetters(
@@ -1008,7 +955,7 @@ void OpEmitter::genNamedOperandSetters() {
     const auto &operand = op.getOperand(i);
     if (operand.name.empty())
       continue;
-    for (StringRef name : getGetterNames(op, operand.name)) {
+    for (StringRef name : op.getGetterNames(operand.name)) {
       auto *m = opClass.addMethodAndPrune(
           operand.isVariadicOfVariadic() ? "::mlir::MutableOperandRangeRange"
                                          : "::mlir::MutableOperandRange",
@@ -1022,7 +969,7 @@ void OpEmitter::genNamedOperandSetters() {
       if (attrSizedOperands)
         body << ", ::mlir::MutableOperandRange::OperandSegment(" << i
              << "u, *getOperation()->getAttrDictionary().getNamed("
-                "operand_segment_sizesAttrName()))";
+             << op.getGetterName("operand_segment_sizes") << "AttrName()))";
       body << ");\n";
 
       // If this operand is a nested variadic, we split the range into a
@@ -1032,8 +979,7 @@ void OpEmitter::genNamedOperandSetters() {
         //
         body << "  return "
                 "mutableRange.split(*(*this)->getAttrDictionary().getNamed("
-             << getGetterName(
-                    op,
+             << op.getGetterName(
                     operand.constraint.getVariadicOfVariadicSegmentSizeAttr())
              << "AttrName()));\n";
       } else {
@@ -1076,9 +1022,8 @@ void OpEmitter::genNamedResultGetters() {
   // Build the initializer string for the result segment size attribute.
   std::string attrSizeInitCode;
   if (attrSizedResults) {
-    attrSizeInitCode =
-        formatv(opSegmentSizeAttrInitCode, "result_segment_sizesAttrName()")
-            .str();
+    std::string attr = op.getGetterName("result_segment_sizes") + "AttrName()";
+    attrSizeInitCode = formatv(opSegmentSizeAttrInitCode, attr).str();
   }
 
   generateValueRangeStartAndEnd(
@@ -1096,7 +1041,7 @@ void OpEmitter::genNamedResultGetters() {
     const auto &result = op.getResult(i);
     if (result.name.empty())
       continue;
-    for (StringRef name : getGetterNames(op, result.name)) {
+    for (StringRef name : op.getGetterNames(result.name)) {
       if (result.isOptional()) {
         m = opClass.addMethodAndPrune("::mlir::Value", name);
         ERROR_IF_PRUNED(m, name, op);
@@ -1123,7 +1068,7 @@ void OpEmitter::genNamedRegionGetters() {
     if (region.name.empty())
       continue;
 
-    for (StringRef name : getGetterNames(op, region.name)) {
+    for (StringRef name : op.getGetterNames(region.name)) {
       // Generate the accessors for a variadic region.
       if (region.isVariadic()) {
         auto *m = opClass.addMethodAndPrune(
@@ -1148,7 +1093,7 @@ void OpEmitter::genNamedSuccessorGetters() {
     if (successor.name.empty())
       continue;
 
-    for (StringRef name : getGetterNames(op, successor.name)) {
+    for (StringRef name : op.getGetterNames(successor.name)) {
       // Generate the accessors for a variadic successor list.
       if (successor.isVariadic()) {
         auto *m = opClass.addMethodAndPrune("::mlir::SuccessorRange", name);
@@ -1430,7 +1375,7 @@ void OpEmitter::genUseAttrAsResultTypeBuilder() {
   std::string resultType;
   const auto &namedAttr = op.getAttribute(0);
 
-  body << "  auto attrName = " << getGetterName(op, namedAttr.name)
+  body << "  auto attrName = " << op.getGetterName(namedAttr.name)
        << "AttrName(" << builderOpState
        << ".name);\n"
           "  for (auto attr : attributes) {\n"
@@ -1746,8 +1691,8 @@ void OpEmitter::genCodeForAddingArgAndRegionForBuilder(
            << "    for (::mlir::ValueRange range : " << argName << ")\n"
            << "      rangeSegments.push_back(range.size());\n"
            << "    " << builderOpState << ".addAttribute("
-           << getGetterName(
-                  op, operand.constraint.getVariadicOfVariadicSegmentSizeAttr())
+           << op.getGetterName(
+                  operand.constraint.getVariadicOfVariadicSegmentSizeAttr())
            << "AttrName(" << builderOpState << ".name), " << odsBuilder
            << ".getI32TensorAttr(rangeSegments));"
            << "  }\n";
@@ -1761,9 +1706,9 @@ void OpEmitter::genCodeForAddingArgAndRegionForBuilder(
 
   // If the operation has the operand segment size attribute, add it here.
   if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) {
-    body << "  " << builderOpState
-         << ".addAttribute(operand_segment_sizesAttrName(" << builderOpState
-         << ".name), "
+    std::string sizes = op.getGetterName("operand_segment_sizes");
+    body << "  " << builderOpState << ".addAttribute(" << sizes << "AttrName("
+         << builderOpState << ".name), "
          << "odsBuilder.getI32VectorAttr({";
     interleaveComma(llvm::seq<int>(0, op.getNumOperands()), body, [&](int i) {
       const NamedTypeConstraint &operand = op.getOperand(i);
@@ -1816,10 +1761,10 @@ void OpEmitter::genCodeForAddingArgAndRegionForBuilder(
       std::string value =
           std::string(tgfmt(builderTemplate, &fctx, namedAttr.name));
       body << formatv("  {0}.addAttribute({1}AttrName({0}.name), {2});\n",
-                      builderOpState, getGetterName(op, namedAttr.name), value);
+                      builderOpState, op.getGetterName(namedAttr.name), value);
     } else {
       body << formatv("  {0}.addAttribute({1}AttrName({0}.name), {2});\n",
-                      builderOpState, getGetterName(op, namedAttr.name),
+                      builderOpState, op.getGetterName(namedAttr.name),
                       namedAttr.name);
     }
     if (emitNotNullCheck)
@@ -2255,7 +2200,7 @@ void OpEmitter::genRegionVerifier(OpMethodBody &body) {
                         ? "{0}()"
                         : "::mlir::MutableArrayRef<::mlir::Region>((*this)"
                           "->getRegion({1}))",
-                    region.name, i);
+                    op.getGetterName(region.name), i);
     body << ") {\n";
     auto constraint = tgfmt(region.constraint.getConditionTemplate(),
                             &verifyCtx.withSelf("region"))
@@ -2497,8 +2442,8 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(const Operator &op)
     ERROR_IF_PRUNED(m, "getOperands", op);
     m->body() << "  return odsOperands;";
   }
-  std::string sizeAttrInit =
-      formatv(adapterSegmentSizeAttrInitCode, "operand_segment_sizes");
+  std::string attr = op.getGetterName("operand_segment_sizes");
+  std::string sizeAttrInit = formatv(adapterSegmentSizeAttrInitCode, attr);
   generateNamedOperandGetters(op, adaptor,
                               /*isAdaptor=*/true, sizeAttrInit,
                               /*rangeType=*/"::mlir::ValueRange",
@@ -2542,8 +2487,10 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(const Operator &op)
   for (auto &namedAttr : op.getAttributes()) {
     const auto &name = namedAttr.name;
     const auto &attr = namedAttr.attr;
-    if (!attr.isDerivedAttr())
-      emitAttr(name, attr);
+    if (!attr.isDerivedAttr()) {
+      for (auto emitName : op.getGetterNames(name))
+        emitAttr(emitName, attr);
+    }
   }
 
   unsigned numRegions = op.getNumRegions();
index 708b9b1..3fb0e06 100644 (file)
@@ -873,7 +873,8 @@ static void genLiteralParser(StringRef value, OpMethodBody &body) {
 }
 
 /// Generate the storage code required for parsing the given element.
-static void genElementParserStorage(Element *element, OpMethodBody &body) {
+static void genElementParserStorage(Element *element, const Operator &op,
+                                    OpMethodBody &body) {
   if (auto *optional = dyn_cast<OptionalElement>(element)) {
     auto elements = optional->getThenElements();
 
@@ -885,13 +886,13 @@ static void genElementParserStorage(Element *element, OpMethodBody &body) {
       elidedAnchorElement = anchor;
     for (auto &childElement : elements)
       if (&childElement != elidedAnchorElement)
-        genElementParserStorage(&childElement, body);
+        genElementParserStorage(&childElement, op, body);
     for (auto &childElement : optional->getElseElements())
-      genElementParserStorage(&childElement, body);
+      genElementParserStorage(&childElement, op, body);
 
   } else if (auto *custom = dyn_cast<CustomDirective>(element)) {
     for (auto &paramElement : custom->getArguments())
-      genElementParserStorage(&paramElement, body);
+      genElementParserStorage(&paramElement, op, body);
 
   } else if (isa<OperandsDirective>(element)) {
     body << "  ::mlir::SmallVector<::mlir::OpAsmParser::OperandType, 4> "
@@ -1188,7 +1189,7 @@ void OperationFormat::genParser(Operator &op, OpClass &opClass) {
   // allows for referencing these variables in the presence of optional
   // groupings.
   for (auto &element : elements)
-    genElementParserStorage(&*element, body);
+    genElementParserStorage(&*element, op, body);
 
   // A format context used when parsing attributes with buildable types.
   FmtContext attrTypeCtx;
@@ -1735,36 +1736,38 @@ static void genSpacePrinter(bool value, OpMethodBody &body,
 
 /// Generate the printer for a custom directive parameter.
 static void genCustomDirectiveParameterPrinter(Element *element,
+                                               const Operator &op,
                                                OpMethodBody &body) {
   if (auto *attr = dyn_cast<AttributeVariable>(element)) {
-    body << attr->getVar()->name << "Attr()";
+    body << op.getGetterName(attr->getVar()->name) << "Attr()";
 
   } else if (isa<AttrDictDirective>(element)) {
     body << "getOperation()->getAttrDictionary()";
 
   } else if (auto *operand = dyn_cast<OperandVariable>(element)) {
-    body << operand->getVar()->name << "()";
+    body << op.getGetterName(operand->getVar()->name) << "()";
 
   } else if (auto *region = dyn_cast<RegionVariable>(element)) {
-    body << region->getVar()->name << "()";
+    body << op.getGetterName(region->getVar()->name) << "()";
 
   } else if (auto *successor = dyn_cast<SuccessorVariable>(element)) {
-    body << successor->getVar()->name << "()";
+    body << op.getGetterName(successor->getVar()->name) << "()";
 
   } else if (auto *dir = dyn_cast<RefDirective>(element)) {
-    genCustomDirectiveParameterPrinter(dir->getOperand(), body);
+    genCustomDirectiveParameterPrinter(dir->getOperand(), op, body);
 
   } else if (auto *dir = dyn_cast<TypeDirective>(element)) {
     auto *typeOperand = dir->getOperand();
     auto *operand = dyn_cast<OperandVariable>(typeOperand);
     auto *var = operand ? operand->getVar()
                         : cast<ResultVariable>(typeOperand)->getVar();
+    std::string name = op.getGetterName(var->name);
     if (var->isVariadic())
-      body << var->name << "().getTypes()";
+      body << name << "().getTypes()";
     else if (var->isOptional())
-      body << llvm::formatv("({0}() ? {0}().getType() : Type())", var->name);
+      body << llvm::formatv("({0}() ? {0}().getType() : Type())", name);
     else
-      body << var->name << "().getType()";
+      body << name << "().getType()";
   } else {
     llvm_unreachable("unknown custom directive parameter");
   }
@@ -1772,11 +1775,11 @@ static void genCustomDirectiveParameterPrinter(Element *element,
 
 /// Generate the printer for a custom directive.
 static void genCustomDirectivePrinter(CustomDirective *customDir,
-                                      OpMethodBody &body) {
+                                      const Operator &op, OpMethodBody &body) {
   body << "  print" << customDir->getName() << "(p, *this";
   for (Element &param : customDir->getArguments()) {
     body << ", ";
-    genCustomDirectiveParameterPrinter(&param, body);
+    genCustomDirectiveParameterPrinter(&param, op, body);
   }
   body << ");\n";
 }
@@ -1800,7 +1803,8 @@ static void genVariadicRegionPrinter(const Twine &regionListName,
 }
 
 /// Generate the C++ for an operand to a (*-)type directive.
-static OpMethodBody &genTypeOperandPrinter(Element *arg, OpMethodBody &body) {
+static OpMethodBody &genTypeOperandPrinter(Element *arg, const Operator &op,
+                                           OpMethodBody &body) {
   if (isa<OperandsDirective>(arg))
     return body << "getOperation()->getOperandTypes()";
   if (isa<ResultsDirective>(arg))
@@ -1808,26 +1812,29 @@ static OpMethodBody &genTypeOperandPrinter(Element *arg, OpMethodBody &body) {
   auto *operand = dyn_cast<OperandVariable>(arg);
   auto *var = operand ? operand->getVar() : cast<ResultVariable>(arg)->getVar();
   if (var->isVariadicOfVariadic())
-    return body << llvm::formatv("{0}().join().getTypes()", var->name);
+    return body << llvm::formatv("{0}().join().getTypes()",
+                                 op.getGetterName(var->name));
   if (var->isVariadic())
-    return body << var->name << "().getTypes()";
+    return body << op.getGetterName(var->name) << "().getTypes()";
   if (var->isOptional())
     return body << llvm::formatv(
                "({0}() ? ::llvm::ArrayRef<::mlir::Type>({0}().getType()) : "
                "::llvm::ArrayRef<::mlir::Type>())",
-               var->name);
-  return body << "::llvm::ArrayRef<::mlir::Type>(" << var->name
-              << "().getType())";
+               op.getGetterName(var->name));
+  return body << "::llvm::ArrayRef<::mlir::Type>("
+              << op.getGetterName(var->name) << "().getType())";
 }
 
 /// Generate the printer for an enum attribute.
-static void genEnumAttrPrinter(const NamedAttribute *var, OpMethodBody &body) {
+static void genEnumAttrPrinter(const NamedAttribute *var, const Operator &op,
+                               OpMethodBody &body) {
   Attribute baseAttr = var->attr.getBaseAttr();
   const EnumAttr &enumAttr = cast<EnumAttr>(baseAttr);
   std::vector<EnumAttrCase> cases = enumAttr.getAllCases();
 
   body << llvm::formatv(enumAttrBeginPrinterCode,
-                        (var->attr.isOptional() ? "*" : "") + var->name,
+                        (var->attr.isOptional() ? "*" : "") +
+                            op.getGetterName(var->name),
                         enumAttr.getSymbolToStringFnName());
 
   // Get a string containing all of the cases that can't be represented with a
@@ -1897,25 +1904,28 @@ static void genEnumAttrPrinter(const NamedAttribute *var, OpMethodBody &body) {
 }
 
 /// Generate the check for the anchor of an optional group.
-static void genOptionalGroupPrinterAnchor(Element *anchor, OpMethodBody &body) {
+static void genOptionalGroupPrinterAnchor(Element *anchor, const Operator &op,
+                                          OpMethodBody &body) {
   TypeSwitch<Element *>(anchor)
       .Case<OperandVariable, ResultVariable>([&](auto *element) {
         const NamedTypeConstraint *var = element->getVar();
+        std::string name = op.getGetterName(var->name);
         if (var->isOptional())
-          body << "  if (" << var->name << "()) {\n";
+          body << "  if (" << name << "()) {\n";
         else if (var->isVariadic())
-          body << "  if (!" << var->name << "().empty()) {\n";
+          body << "  if (!" << name << "().empty()) {\n";
       })
       .Case<RegionVariable>([&](RegionVariable *element) {
         const NamedRegion *var = element->getVar();
+        std::string name = op.getGetterName(var->name);
         // TODO: Add a check for optional regions here when ODS supports it.
-        body << "  if (!" << var->name << "().empty()) {\n";
+        body << "  if (!" << name << "().empty()) {\n";
       })
       .Case<TypeDirective>([&](TypeDirective *element) {
-        genOptionalGroupPrinterAnchor(element->getOperand(), body);
+        genOptionalGroupPrinterAnchor(element->getOperand(), op, body);
       })
       .Case<FunctionalTypeDirective>([&](FunctionalTypeDirective *element) {
-        genOptionalGroupPrinterAnchor(element->getInputs(), body);
+        genOptionalGroupPrinterAnchor(element->getInputs(), op, body);
       })
       .Case<AttributeVariable>([&](AttributeVariable *attr) {
         body << "  if ((*this)->getAttr(\"" << attr->getVar()->name
@@ -1943,7 +1953,7 @@ void OperationFormat::genElementPrinter(Element *element, OpMethodBody &body,
   if (OptionalElement *optional = dyn_cast<OptionalElement>(element)) {
     // Emit the check for the presence of the anchor element.
     Element *anchor = optional->getAnchor();
-    genOptionalGroupPrinterAnchor(anchor, body);
+    genOptionalGroupPrinterAnchor(anchor, op, body);
 
     // If the anchor is a unit attribute, we don't need to print it. When
     // parsing, we will add this attribute if this group is present.
@@ -1998,47 +2008,53 @@ void OperationFormat::genElementPrinter(Element *element, OpMethodBody &body,
 
     // If we are formatting as an enum, symbolize the attribute as a string.
     if (canFormatEnumAttr(var))
-      return genEnumAttrPrinter(var, body);
+      return genEnumAttrPrinter(var, op, body);
 
     // If we are formatting as a symbol name, handle it as a symbol name.
     if (shouldFormatSymbolNameAttr(var)) {
-      body << "  p.printSymbolName(" << var->name << "Attr().getValue());\n";
+      body << "  p.printSymbolName(" << op.getGetterName(var->name)
+           << "Attr().getValue());\n";
       return;
     }
 
     // Elide the attribute type if it is buildable.
     if (attr->getTypeBuilder())
-      body << "  p.printAttributeWithoutType(" << var->name << "Attr());\n";
+      body << "  p.printAttributeWithoutType(" << op.getGetterName(var->name)
+           << "Attr());\n";
     else
-      body << "  p.printAttribute(" << var->name << "Attr());\n";
+      body << "  p.printAttribute(" << op.getGetterName(var->name)
+           << "Attr());\n";
   } else if (auto *operand = dyn_cast<OperandVariable>(element)) {
     if (operand->getVar()->isVariadicOfVariadic()) {
-      body << "  ::llvm::interleaveComma(" << operand->getVar()->name
+      body << "  ::llvm::interleaveComma("
+           << op.getGetterName(operand->getVar()->name)
            << "(), p, [&](const auto &operands) { p << \"(\" << operands << "
               "\")\"; });\n";
 
     } else if (operand->getVar()->isOptional()) {
-      body << "  if (::mlir::Value value = " << operand->getVar()->name
-           << "())\n"
+      body << "  if (::mlir::Value value = "
+           << op.getGetterName(operand->getVar()->name) << "())\n"
            << "    p << value;\n";
     } else {
-      body << "  p << " << operand->getVar()->name << "();\n";
+      body << "  p << " << op.getGetterName(operand->getVar()->name) << "();\n";
     }
   } else if (auto *region = dyn_cast<RegionVariable>(element)) {
     const NamedRegion *var = region->getVar();
+    std::string name = op.getGetterName(var->name);
     if (var->isVariadic()) {
-      genVariadicRegionPrinter(var->name + "()", body, hasImplicitTermTrait);
+      genVariadicRegionPrinter(name + "()", body, hasImplicitTermTrait);
     } else {
-      genRegionPrinter(var->name + "()", body, hasImplicitTermTrait);
+      genRegionPrinter(name + "()", body, hasImplicitTermTrait);
     }
   } else if (auto *successor = dyn_cast<SuccessorVariable>(element)) {
     const NamedSuccessor *var = successor->getVar();
+    std::string name = op.getGetterName(var->name);
     if (var->isVariadic())
-      body << "  ::llvm::interleaveComma(" << var->name << "(), p);\n";
+      body << "  ::llvm::interleaveComma(" << name << "(), p);\n";
     else
-      body << "  p << " << var->name << "();\n";
+      body << "  p << " << name << "();\n";
   } else if (auto *dir = dyn_cast<CustomDirective>(element)) {
-    genCustomDirectivePrinter(dir, body);
+    genCustomDirectivePrinter(dir, op, body);
   } else if (isa<OperandsDirective>(element)) {
     body << "  p << getOperation()->getOperands();\n";
   } else if (isa<RegionsDirective>(element)) {
@@ -2052,16 +2068,16 @@ void OperationFormat::genElementPrinter(Element *element, OpMethodBody &body,
         body << llvm::formatv("  ::llvm::interleaveComma({0}().getTypes(), p, "
                               "[&](::mlir::TypeRange types) {{ p << \"(\" << "
                               "types << \")\"; });\n",
-                              operand->getVar()->name);
+                              op.getGetterName(operand->getVar()->name));
         return;
       }
     }
     body << "  p << ";
-    genTypeOperandPrinter(dir->getOperand(), body) << ";\n";
+    genTypeOperandPrinter(dir->getOperand(), op, body) << ";\n";
   } else if (auto *dir = dyn_cast<FunctionalTypeDirective>(element)) {
     body << "  p.printFunctionalType(";
-    genTypeOperandPrinter(dir->getInputs(), body) << ", ";
-    genTypeOperandPrinter(dir->getResults(), body) << ");\n";
+    genTypeOperandPrinter(dir->getInputs(), op, body) << ", ";
+    genTypeOperandPrinter(dir->getResults(), op, body) << ");\n";
   } else {
     llvm_unreachable("unknown format element");
   }