From cdfeeb8a4058130d8ce59300867e272642c97dfa Mon Sep 17 00:00:00 2001 From: River Riddle Date: Wed, 12 Oct 2022 18:01:03 -0700 Subject: [PATCH] [mlir:ODS] Generate unwrapped operation attribute setters This allows for setting an attribute using the underlying C++ type, which is generally much nicer to interact with than the attribute type. Differential Revision: https://reviews.llvm.org/D135838 --- mlir/include/mlir/Dialect/Affine/IR/AffineOps.td | 2 - mlir/lib/Dialect/Affine/IR/AffineOps.cpp | 16 ---- mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 2 +- mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 3 +- mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp | 9 +-- mlir/test/mlir-tblgen/op-attribute.td | 8 ++ mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp | 95 +++++++++++++++++++----- 7 files changed, 92 insertions(+), 43 deletions(-) diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td index df0ce36..bb505c5 100644 --- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td +++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td @@ -718,7 +718,6 @@ def AffineParallelOp : Affine_Op<"parallel", /// Sets elements of the loop lower bound. void setLowerBounds(ValueRange operands, AffineMap map); - void setLowerBoundsMap(AffineMap map); /// Returns elements of the loop upper bound. AffineMap getUpperBoundMap(unsigned pos); @@ -727,7 +726,6 @@ def AffineParallelOp : Affine_Op<"parallel", /// Sets elements fo the loop upper bound. void setUpperBounds(ValueRange operands, AffineMap map); - void setUpperBoundsMap(AffineMap map); void setSteps(ArrayRef newSteps); diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp index f200135..1d05601 100644 --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -3579,22 +3579,6 @@ void AffineParallelOp::setUpperBounds(ValueRange ubOperands, AffineMap map) { setUpperBoundsMapAttr(AffineMapAttr::get(map)); } -void AffineParallelOp::setLowerBoundsMap(AffineMap map) { - AffineMap lbMap = getLowerBoundsMap(); - assert(lbMap.getNumDims() == map.getNumDims() && - lbMap.getNumSymbols() == map.getNumSymbols()); - (void)lbMap; - setLowerBoundsMapAttr(AffineMapAttr::get(map)); -} - -void AffineParallelOp::setUpperBoundsMap(AffineMap map) { - AffineMap ubMap = getUpperBoundsMap(); - assert(ubMap.getNumDims() == map.getNumDims() && - ubMap.getNumSymbols() == map.getNumSymbols()); - (void)ubMap; - setUpperBoundsMapAttr(AffineMapAttr::get(map)); -} - void AffineParallelOp::setSteps(ArrayRef newSteps) { setStepsAttr(getBodyBuilder().getI64ArrayAttr(newSteps)); } diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp index 0378f5f..d1d03a5 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -1481,7 +1481,7 @@ OpFoldResult arith::CmpIOp::fold(ArrayRef operands) { Pred origPred = getPredicate(); for (auto pred : invPreds) { if (origPred == pred.first) { - setPredicateAttr(CmpIPredicateAttr::get(getContext(), pred.second)); + setPredicate(pred.second); Value lhs = getLhs(); Value rhs = getRhs(); getLhsMutable().assign(rhs); diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index 365febf..3e9c235 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -2551,8 +2551,7 @@ OpFoldResult LLVM::GEPOp::fold(ArrayRef operands) { dynamicIndices); getDynamicIndicesMutable().assign(dynamicIndices); - setRawConstantIndicesAttr( - DenseI32ArrayAttr::get(getContext(), rawConstantIndices)); + setRawConstantIndices(rawConstantIndices); return Value{*this}; } diff --git a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp index 302afdc..0070000 100644 --- a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp @@ -640,10 +640,9 @@ GlobalOp Importer::processGlobal(llvm::GlobalVariable *gv) { b.create(op.getLoc(), ArrayRef({v})); } if (gv->hasAtLeastLocalUnnamedAddr()) - op.setUnnamedAddrAttr(UnnamedAddrAttr::get( - context, convertUnnamedAddrFromLLVM(gv->getUnnamedAddr()))); + op.setUnnamedAddr(convertUnnamedAddrFromLLVM(gv->getUnnamedAddr())); if (gv->hasSection()) - op.setSectionAttr(b.getStringAttr(gv->getSection())); + op.setSection(gv->getSection()); return globals[gv] = op; } @@ -1046,13 +1045,13 @@ LogicalResult Importer::processFunction(llvm::Function *f) { } if (FlatSymbolRefAttr personality = getPersonalityAsAttr(f)) - fop->setAttr(b.getStringAttr("personality"), personality); + fop.setPersonalityAttr(personality); else if (f->hasPersonalityFn()) emitWarning(UnknownLoc::get(context), "could not deduce personality, skipping it"); if (f->hasGC()) - fop.setGarbageCollectorAttr(b.getStringAttr(f->getGC())); + fop.setGarbageCollector(StringRef(f->getGC())); // Handle Function attributes. processFunctionAttributes(f, fop); diff --git a/mlir/test/mlir-tblgen/op-attribute.td b/mlir/test/mlir-tblgen/op-attribute.td index e6cc49d..7e7a762 100644 --- a/mlir/test/mlir-tblgen/op-attribute.td +++ b/mlir/test/mlir-tblgen/op-attribute.td @@ -127,10 +127,18 @@ def AOp : NS_Op<"a_op", []> { // DEF: void AOp::setAAttrAttr(some-attr-kind attr) { // DEF-NEXT: (*this)->setAttr(getAAttrAttrName(), attr); +// DEF: void AOp::setAAttr(some-return-type attrValue) { +// DEF-NEXT: (*this)->setAttr(getAAttrAttrName(), some-const-builder-call(::mlir::Builder(getContext()), attrValue)); // DEF: void AOp::setBAttrAttr(some-attr-kind attr) { // DEF-NEXT: (*this)->setAttr(getBAttrAttrName(), attr); +// DEF: void AOp::setBAttr(some-return-type attrValue) { +// DEF-NEXT: (*this)->setAttr(getBAttrAttrName(), some-const-builder-call(::mlir::Builder(getContext()), attrValue)); // DEF: void AOp::setCAttrAttr(some-attr-kind attr) { // DEF-NEXT: (*this)->setAttr(getCAttrAttrName(), attr); +// DEF: void AOp::setCAttr(::llvm::Optional attrValue) { +// DEF-NEXT: if (attrValue) +// DEF-NEXT: return (*this)->setAttr(getCAttrAttrName(), some-const-builder-call(::mlir::Builder(getContext()), *attrValue)); +// DEF-NEXT: (*this)->removeAttr(getCAttrAttrName()); // Test remove methods // --- diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp index 6304f74..5b3d0ad 100644 --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -188,6 +188,22 @@ static bool canUseUnwrappedRawValue(const tblgen::Attribute &attr) { !attr.getConstBuilderTemplate().empty(); } +/// Build an attribute from a parameter value using the constant builder. +static std::string constBuildAttrFromParam(const tblgen::Attribute &attr, + FmtContext &fctx, + StringRef paramName) { + std::string builderTemplate = attr.getConstBuilderTemplate().str(); + + // For StringAttr, its constant builder call will wrap the input in + // quotes, which is correct for normal string literals, but incorrect + // here given we use function arguments. So we need to strip the + // wrapping quotes. + if (StringRef(builderTemplate).contains("\"$0\"")) + builderTemplate = replaceAllSubstrs(builderTemplate, "\"$0\"", "$0"); + + return tgfmt(builderTemplate, &fctx, paramName).str(); +} + namespace { /// Metadata on a registered attribute. Given that attributes are stored in /// sorted order on operations, we can use information from ODS to deduce the @@ -1092,13 +1108,69 @@ void OpEmitter::genAttrSetters() { getterName); }; + // Generate a setter that accepts the underlying C++ type as opposed to the + // attribute type. + auto emitAttrWithReturnType = [&](StringRef setterName, StringRef getterName, + Attribute attr) { + Attribute baseAttr = attr.getBaseAttr(); + if (!canUseUnwrappedRawValue(baseAttr)) + return; + FmtContext fctx; + fctx.withBuilder("::mlir::Builder(getContext())"); + bool isUnitAttr = attr.getAttrDefName() == "UnitAttr"; + bool isOptional = attr.isOptional(); + + auto createMethod = [&](const Twine ¶mType) { + return opClass.addMethod("void", setterName, + MethodParameter(paramType.str(), "attrValue")); + }; + + // Build the method using the correct parameter type depending on + // optionality. + Method *method = nullptr; + if (isUnitAttr) + method = createMethod("bool"); + else if (isOptional) + method = + createMethod("::llvm::Optional<" + baseAttr.getReturnType() + ">"); + else + method = createMethod(attr.getReturnType()); + if (!method) + return; + + // If the value isn't optional, just set it directly. + if (!isOptional) { + method->body() << formatv( + " (*this)->setAttr({0}AttrName(), {1});", getterName, + constBuildAttrFromParam(attr, fctx, "attrValue")); + return; + } + + // Otherwise, we only set if the provided value is valid. If it isn't, we + // remove the attribute. + + // TODO: Handle unit attr parameters specially, given that it is treated as + // optional but not in the same way as the others (i.e. it uses bool over + // Optional<>). + StringRef paramStr = isUnitAttr ? "attrValue" : "*attrValue"; + const char *optionalCodeBody = R"( + if (attrValue) + return (*this)->setAttr({0}AttrName(), {1}); + (*this)->removeAttr({0}AttrName());)"; + method->body() << formatv( + optionalCodeBody, getterName, + constBuildAttrFromParam(baseAttr, fctx, paramStr)); + }; + for (const NamedAttribute &namedAttr : op.getAttributes()) { if (namedAttr.attr.isDerivedAttr()) continue; - for (auto names : llvm::zip(op.getSetterNames(namedAttr.name), - op.getGetterNames(namedAttr.name))) - emitAttrWithStorageType(std::get<0>(names), std::get<1>(names), - namedAttr.attr); + for (auto [setterName, getterName] : + llvm::zip(op.getSetterNames(namedAttr.name), + op.getGetterNames(namedAttr.name))) { + emitAttrWithStorageType(setterName, getterName, namedAttr.attr); + emitAttrWithReturnType(setterName, getterName, namedAttr.attr); + } } } @@ -2160,20 +2232,9 @@ void OpEmitter::genCodeForAddingArgAndRegionForBuilder( // instance. FmtContext fctx; fctx.withBuilder("odsBuilder"); - - std::string builderTemplate = std::string(attr.getConstBuilderTemplate()); - - // For StringAttr, its constant builder call will wrap the input in - // quotes, which is correct for normal string literals, but incorrect - // here given we use function arguments. So we need to strip the - // wrapping quotes. - if (StringRef(builderTemplate).contains("\"$0\"")) - builderTemplate = replaceAllSubstrs(builderTemplate, "\"$0\"", "$0"); - - std::string value = - std::string(tgfmt(builderTemplate, &fctx, namedAttr.name)); body << formatv(" {0}.addAttribute({1}AttrName({0}.name), {2});\n", - builderOpState, op.getGetterName(namedAttr.name), value); + builderOpState, op.getGetterName(namedAttr.name), + constBuildAttrFromParam(attr, fctx, namedAttr.name)); } else { body << formatv(" {0}.addAttribute({1}AttrName({0}.name), {2});\n", builderOpState, op.getGetterName(namedAttr.name), -- 2.7.4