[mlir:ODS] Generate unwrapped operation attribute setters
authorRiver Riddle <riddleriver@gmail.com>
Thu, 13 Oct 2022 01:01:03 +0000 (18:01 -0700)
committerRiver Riddle <riddleriver@gmail.com>
Fri, 14 Oct 2022 22:57:51 +0000 (15:57 -0700)
This allows for setting an attribute using the underlying C++ type,
which is generally much nicer to interact with than the attribute type.

Differential Revision: https://reviews.llvm.org/D135838

mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
mlir/lib/Dialect/Affine/IR/AffineOps.cpp
mlir/lib/Dialect/Arith/IR/ArithOps.cpp
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp
mlir/test/mlir-tblgen/op-attribute.td
mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp

index df0ce36..bb505c5 100644 (file)
@@ -718,7 +718,6 @@ def AffineParallelOp : Affine_Op<"parallel",
 
     /// Sets elements of the loop lower bound.
     void setLowerBounds(ValueRange operands, AffineMap map);
-    void setLowerBoundsMap(AffineMap map);
 
     /// Returns elements of the loop upper bound.
     AffineMap getUpperBoundMap(unsigned pos);
@@ -727,7 +726,6 @@ def AffineParallelOp : Affine_Op<"parallel",
 
     /// Sets elements fo the loop upper bound.
     void setUpperBounds(ValueRange operands, AffineMap map);
-    void setUpperBoundsMap(AffineMap map);
 
     void setSteps(ArrayRef<int64_t> newSteps);
 
index f200135..1d05601 100644 (file)
@@ -3579,22 +3579,6 @@ void AffineParallelOp::setUpperBounds(ValueRange ubOperands, AffineMap map) {
   setUpperBoundsMapAttr(AffineMapAttr::get(map));
 }
 
-void AffineParallelOp::setLowerBoundsMap(AffineMap map) {
-  AffineMap lbMap = getLowerBoundsMap();
-  assert(lbMap.getNumDims() == map.getNumDims() &&
-         lbMap.getNumSymbols() == map.getNumSymbols());
-  (void)lbMap;
-  setLowerBoundsMapAttr(AffineMapAttr::get(map));
-}
-
-void AffineParallelOp::setUpperBoundsMap(AffineMap map) {
-  AffineMap ubMap = getUpperBoundsMap();
-  assert(ubMap.getNumDims() == map.getNumDims() &&
-         ubMap.getNumSymbols() == map.getNumSymbols());
-  (void)ubMap;
-  setUpperBoundsMapAttr(AffineMapAttr::get(map));
-}
-
 void AffineParallelOp::setSteps(ArrayRef<int64_t> newSteps) {
   setStepsAttr(getBodyBuilder().getI64ArrayAttr(newSteps));
 }
index 0378f5f..d1d03a5 100644 (file)
@@ -1481,7 +1481,7 @@ OpFoldResult arith::CmpIOp::fold(ArrayRef<Attribute> operands) {
     Pred origPred = getPredicate();
     for (auto pred : invPreds) {
       if (origPred == pred.first) {
-        setPredicateAttr(CmpIPredicateAttr::get(getContext(), pred.second));
+        setPredicate(pred.second);
         Value lhs = getLhs();
         Value rhs = getRhs();
         getLhsMutable().assign(rhs);
index 365febf..3e9c235 100644 (file)
@@ -2551,8 +2551,7 @@ OpFoldResult LLVM::GEPOp::fold(ArrayRef<Attribute> operands) {
                        dynamicIndices);
 
     getDynamicIndicesMutable().assign(dynamicIndices);
-    setRawConstantIndicesAttr(
-        DenseI32ArrayAttr::get(getContext(), rawConstantIndices));
+    setRawConstantIndices(rawConstantIndices);
     return Value{*this};
   }
 
index 302afdc..0070000 100644 (file)
@@ -640,10 +640,9 @@ GlobalOp Importer::processGlobal(llvm::GlobalVariable *gv) {
     b.create<ReturnOp>(op.getLoc(), ArrayRef<Value>({v}));
   }
   if (gv->hasAtLeastLocalUnnamedAddr())
-    op.setUnnamedAddrAttr(UnnamedAddrAttr::get(
-        context, convertUnnamedAddrFromLLVM(gv->getUnnamedAddr())));
+    op.setUnnamedAddr(convertUnnamedAddrFromLLVM(gv->getUnnamedAddr()));
   if (gv->hasSection())
-    op.setSectionAttr(b.getStringAttr(gv->getSection()));
+    op.setSection(gv->getSection());
 
   return globals[gv] = op;
 }
@@ -1046,13 +1045,13 @@ LogicalResult Importer::processFunction(llvm::Function *f) {
   }
 
   if (FlatSymbolRefAttr personality = getPersonalityAsAttr(f))
-    fop->setAttr(b.getStringAttr("personality"), personality);
+    fop.setPersonalityAttr(personality);
   else if (f->hasPersonalityFn())
     emitWarning(UnknownLoc::get(context),
                 "could not deduce personality, skipping it");
 
   if (f->hasGC())
-    fop.setGarbageCollectorAttr(b.getStringAttr(f->getGC()));
+    fop.setGarbageCollector(StringRef(f->getGC()));
 
   // Handle Function attributes.
   processFunctionAttributes(f, fop);
index e6cc49d..7e7a762 100644 (file)
@@ -127,10 +127,18 @@ def AOp : NS_Op<"a_op", []> {
 
 // DEF:      void AOp::setAAttrAttr(some-attr-kind attr) {
 // DEF-NEXT:   (*this)->setAttr(getAAttrAttrName(), attr);
+// DEF:      void AOp::setAAttr(some-return-type attrValue) {
+// DEF-NEXT:   (*this)->setAttr(getAAttrAttrName(), some-const-builder-call(::mlir::Builder(getContext()), attrValue));
 // DEF:      void AOp::setBAttrAttr(some-attr-kind attr) {
 // DEF-NEXT:   (*this)->setAttr(getBAttrAttrName(), attr);
+// DEF:      void AOp::setBAttr(some-return-type attrValue) {
+// DEF-NEXT:   (*this)->setAttr(getBAttrAttrName(), some-const-builder-call(::mlir::Builder(getContext()), attrValue));
 // DEF:      void AOp::setCAttrAttr(some-attr-kind attr) {
 // DEF-NEXT:   (*this)->setAttr(getCAttrAttrName(), attr);
+// DEF:      void AOp::setCAttr(::llvm::Optional<some-return-type> attrValue) {
+// DEF-NEXT:   if (attrValue)
+// DEF-NEXT:     return (*this)->setAttr(getCAttrAttrName(), some-const-builder-call(::mlir::Builder(getContext()), *attrValue));
+// DEF-NEXT:   (*this)->removeAttr(getCAttrAttrName());
 
 // Test remove methods
 // ---
index 6304f74..5b3d0ad 100644 (file)
@@ -188,6 +188,22 @@ static bool canUseUnwrappedRawValue(const tblgen::Attribute &attr) {
          !attr.getConstBuilderTemplate().empty();
 }
 
+/// Build an attribute from a parameter value using the constant builder.
+static std::string constBuildAttrFromParam(const tblgen::Attribute &attr,
+                                           FmtContext &fctx,
+                                           StringRef paramName) {
+  std::string builderTemplate = attr.getConstBuilderTemplate().str();
+
+  // For StringAttr, its constant builder call will wrap the input in
+  // quotes, which is correct for normal string literals, but incorrect
+  // here given we use function arguments. So we need to strip the
+  // wrapping quotes.
+  if (StringRef(builderTemplate).contains("\"$0\""))
+    builderTemplate = replaceAllSubstrs(builderTemplate, "\"$0\"", "$0");
+
+  return tgfmt(builderTemplate, &fctx, paramName).str();
+}
+
 namespace {
 /// Metadata on a registered attribute. Given that attributes are stored in
 /// sorted order on operations, we can use information from ODS to deduce the
@@ -1092,13 +1108,69 @@ void OpEmitter::genAttrSetters() {
                                 getterName);
   };
 
+  // Generate a setter that accepts the underlying C++ type as opposed to the
+  // attribute type.
+  auto emitAttrWithReturnType = [&](StringRef setterName, StringRef getterName,
+                                    Attribute attr) {
+    Attribute baseAttr = attr.getBaseAttr();
+    if (!canUseUnwrappedRawValue(baseAttr))
+      return;
+    FmtContext fctx;
+    fctx.withBuilder("::mlir::Builder(getContext())");
+    bool isUnitAttr = attr.getAttrDefName() == "UnitAttr";
+    bool isOptional = attr.isOptional();
+
+    auto createMethod = [&](const Twine &paramType) {
+      return opClass.addMethod("void", setterName,
+                               MethodParameter(paramType.str(), "attrValue"));
+    };
+
+    // Build the method using the correct parameter type depending on
+    // optionality.
+    Method *method = nullptr;
+    if (isUnitAttr)
+      method = createMethod("bool");
+    else if (isOptional)
+      method =
+          createMethod("::llvm::Optional<" + baseAttr.getReturnType() + ">");
+    else
+      method = createMethod(attr.getReturnType());
+    if (!method)
+      return;
+
+    // If the value isn't optional, just set it directly.
+    if (!isOptional) {
+      method->body() << formatv(
+          "  (*this)->setAttr({0}AttrName(), {1});", getterName,
+          constBuildAttrFromParam(attr, fctx, "attrValue"));
+      return;
+    }
+
+    // Otherwise, we only set if the provided value is valid. If it isn't, we
+    // remove the attribute.
+
+    // TODO: Handle unit attr parameters specially, given that it is treated as
+    // optional but not in the same way as the others (i.e. it uses bool over
+    // Optional<>).
+    StringRef paramStr = isUnitAttr ? "attrValue" : "*attrValue";
+    const char *optionalCodeBody = R"(
+    if (attrValue)
+      return (*this)->setAttr({0}AttrName(), {1});
+    (*this)->removeAttr({0}AttrName());)";
+    method->body() << formatv(
+        optionalCodeBody, getterName,
+        constBuildAttrFromParam(baseAttr, fctx, paramStr));
+  };
+
   for (const NamedAttribute &namedAttr : op.getAttributes()) {
     if (namedAttr.attr.isDerivedAttr())
       continue;
-    for (auto names : llvm::zip(op.getSetterNames(namedAttr.name),
-                                op.getGetterNames(namedAttr.name)))
-      emitAttrWithStorageType(std::get<0>(names), std::get<1>(names),
-                              namedAttr.attr);
+    for (auto [setterName, getterName] :
+         llvm::zip(op.getSetterNames(namedAttr.name),
+                   op.getGetterNames(namedAttr.name))) {
+      emitAttrWithStorageType(setterName, getterName, namedAttr.attr);
+      emitAttrWithReturnType(setterName, getterName, namedAttr.attr);
+    }
   }
 }
 
@@ -2160,20 +2232,9 @@ void OpEmitter::genCodeForAddingArgAndRegionForBuilder(
       // instance.
       FmtContext fctx;
       fctx.withBuilder("odsBuilder");
-
-      std::string builderTemplate = std::string(attr.getConstBuilderTemplate());
-
-      // For StringAttr, its constant builder call will wrap the input in
-      // quotes, which is correct for normal string literals, but incorrect
-      // here given we use function arguments. So we need to strip the
-      // wrapping quotes.
-      if (StringRef(builderTemplate).contains("\"$0\""))
-        builderTemplate = replaceAllSubstrs(builderTemplate, "\"$0\"", "$0");
-
-      std::string value =
-          std::string(tgfmt(builderTemplate, &fctx, namedAttr.name));
       body << formatv("  {0}.addAttribute({1}AttrName({0}.name), {2});\n",
-                      builderOpState, op.getGetterName(namedAttr.name), value);
+                      builderOpState, op.getGetterName(namedAttr.name),
+                      constBuildAttrFromParam(attr, fctx, namedAttr.name));
     } else {
       body << formatv("  {0}.addAttribute({1}AttrName({0}.name), {2});\n",
                       builderOpState, op.getGetterName(namedAttr.name),