OperationName::PrintAssemblyFn printFn;
OperationName::FoldHookFn foldHookFn;
OperationName::GetCanonicalizationPatternsFn getCanonicalizationPatternsFn;
+ OperationName::PopulateDefaultAttrsFn getPopulateDefaultAttrsFn;
friend ExtensibleDialect;
};
static void getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {}
+ /// This hook populates any unset default attrs.
+ static void populateDefaultAttrs(const RegisteredOperationName &,
+ NamedAttrList &) {}
+
protected:
/// If the concrete type didn't implement a custom verifier hook, just fall
/// back to this one which accepts everything.
OpState::printOpName(op, p, defaultDialect);
return cast<ConcreteType>(op).print(p);
}
+ /// Implementation of `PopulateDefaultAttrsFn` OperationName hook.
+ static OperationName::PopulateDefaultAttrsFn getPopulateDefaultAttrsFn() {
+ return ConcreteType::populateDefaultAttrs;
+ }
/// Implementation of `VerifyInvariantsFn` OperationName hook.
static LogicalResult verifyInvariants(Operation *op) {
static_assert(hasNoDataMembers(),
setAttrs(attrs.getDictionary(getContext()));
}
+ /// Sets default attributes on unset attributes.
+ void populateDefaultAttrs() {
+ if (auto registered = getRegisteredInfo()) {
+ NamedAttrList attrs(getAttrDictionary());
+ registered->populateDefaultAttrs(attrs);
+ setAttrs(attrs.getDictionary(getContext()));
+ }
+ }
+
//===--------------------------------------------------------------------===//
// Blocks
//===--------------------------------------------------------------------===//
class DictionaryAttr;
class ElementsAttr;
class MutableOperandRangeRange;
+class NamedAttrList;
class Operation;
struct OperationState;
class OpAsmParser;
using HasTraitFn = llvm::unique_function<bool(TypeID) const>;
using ParseAssemblyFn =
llvm::unique_function<ParseResult(OpAsmParser &, OperationState &) const>;
+ // Note: RegisteredOperationName is passed as reference here as the derived
+ // class is defined below.
+ using PopulateDefaultAttrsFn = llvm::unique_function<void(
+ const RegisteredOperationName &, NamedAttrList &) const>;
using PrintAssemblyFn =
llvm::unique_function<void(Operation *, OpAsmPrinter &, StringRef) const>;
using VerifyInvariantsFn =
GetCanonicalizationPatternsFn getCanonicalizationPatternsFn;
HasTraitFn hasTraitFn;
ParseAssemblyFn parseAssemblyFn;
+ PopulateDefaultAttrsFn populateDefaultAttrsFn;
PrintAssemblyFn printAssemblyFn;
VerifyInvariantsFn verifyInvariantsFn;
VerifyRegionInvariantsFn verifyRegionInvariantsFn;
T::getParseAssemblyFn(), T::getPrintAssemblyFn(),
T::getVerifyInvariantsFn(), T::getVerifyRegionInvariantsFn(),
T::getFoldHookFn(), T::getGetCanonicalizationPatternsFn(),
- T::getInterfaceMap(), T::getHasTraitFn(), T::getAttributeNames());
+ T::getInterfaceMap(), T::getHasTraitFn(), T::getAttributeNames(),
+ T::getPopulateDefaultAttrsFn());
}
/// The use of this method is in general discouraged in favor of
/// 'insert<CustomOp>(dialect)'.
FoldHookFn &&foldHook,
GetCanonicalizationPatternsFn &&getCanonicalizationPatterns,
detail::InterfaceMap &&interfaceMap, HasTraitFn &&hasTrait,
- ArrayRef<StringRef> attrNames);
+ ArrayRef<StringRef> attrNames,
+ PopulateDefaultAttrsFn &&populateDefaultAttrs);
/// Return the dialect this operation is registered to.
Dialect &getDialect() const { return *impl->dialect; }
return impl->attributeNames;
}
+ /// This hook implements the method to populate defaults attributes that are
+ /// unset.
+ void populateDefaultAttrs(NamedAttrList &attrs) const;
+
/// Represent the operation name as an opaque pointer. (Used to support
/// PointerLikeTypeTraits).
static RegisteredOperationName getFromOpaquePointer(const void *pointer) {
std::move(op->printFn), std::move(op->verifyFn),
std::move(op->verifyRegionFn), std::move(op->foldHookFn),
std::move(op->getCanonicalizationPatternsFn),
- detail::InterfaceMap::get<>(), std::move(hasTraitFn), {});
+ detail::InterfaceMap::get<>(), std::move(hasTraitFn), {},
+ std::move(op->getPopulateDefaultAttrsFn));
}
bool ExtensibleDialect::classof(const Dialect *dialect) {
return impl->parseAssemblyFn(parser, result);
}
+void RegisteredOperationName::populateDefaultAttrs(NamedAttrList &attrs) const {
+ impl->populateDefaultAttrsFn(*this, attrs);
+}
+
void RegisteredOperationName::insert(
StringRef name, Dialect &dialect, TypeID typeID,
ParseAssemblyFn &&parseAssembly, PrintAssemblyFn &&printAssembly,
VerifyRegionInvariantsFn &&verifyRegionInvariants, FoldHookFn &&foldHook,
GetCanonicalizationPatternsFn &&getCanonicalizationPatterns,
detail::InterfaceMap &&interfaceMap, HasTraitFn &&hasTrait,
- ArrayRef<StringRef> attrNames) {
+ ArrayRef<StringRef> attrNames,
+ PopulateDefaultAttrsFn &&populateDefaultAttrs) {
MLIRContext *ctx = dialect.getContext();
auto &ctxImpl = ctx->getImpl();
assert(ctxImpl.multiThreadedExecutionContext == 0 &&
impl.verifyInvariantsFn = std::move(verifyInvariants);
impl.verifyRegionInvariantsFn = std::move(verifyRegionInvariants);
impl.attributeNames = cachedAttrNames;
+ impl.populateDefaultAttrsFn = std::move(populateDefaultAttrs);
}
//===----------------------------------------------------------------------===//
// Generates getters for named successors.
void genNamedSuccessorGetters();
+ // Generates the method to populate default attributes.
+ void genPopulateDefaultAttributes();
+
// Generates builder methods for the operation.
void genBuilder();
genAttrSetters();
genOptionalAttrRemovers();
genBuilder();
+ genPopulateDefaultAttributes();
genParser();
genPrinter();
genVerifier();
<< llvm::join(resultTypes, ", ") << "});\n\n";
}
+void OpEmitter::genPopulateDefaultAttributes() {
+ // All done if no attributes have default values.
+ if (llvm::all_of(op.getAttributes(), [](const NamedAttribute &named) {
+ return !named.attr.hasDefaultValue();
+ }))
+ return;
+
+ SmallVector<MethodParameter> paramList;
+ paramList.emplace_back("const ::mlir::RegisteredOperationName &", "opName");
+ paramList.emplace_back("::mlir::NamedAttrList &", "attributes");
+ auto *m = opClass.addStaticMethod("void", "populateDefaultAttrs", paramList);
+ ERROR_IF_PRUNED(m, "populateDefaultAttrs", op);
+ auto &body = m->body();
+ body.indent();
+
+ // Set default attributes that are unset.
+ body << "auto attrNames = opName.getAttributeNames();\n";
+ body << "::mlir::Builder " << odsBuilder
+ << "(attrNames.front().getContext());\n";
+ StringMap<int> attrIndex;
+ for (const auto &it : llvm::enumerate(emitHelper.getAttrMetadata())) {
+ attrIndex[it.value().first] = it.index();
+ }
+ for (const NamedAttribute &namedAttr : op.getAttributes()) {
+ auto &attr = namedAttr.attr;
+ if (!attr.hasDefaultValue())
+ continue;
+ auto index = attrIndex[namedAttr.name];
+ body << "if (!attributes.get(attrNames[" << index << "])) {\n";
+ FmtContext fctx;
+ fctx.withBuilder(odsBuilder);
+ std::string defaultValue = std::string(
+ tgfmt(attr.getConstBuilderTemplate(), &fctx, attr.getDefaultValue()));
+ body.indent() << formatv(" attributes.append(attrNames[{0}], {1});\n",
+ index, defaultValue);
+ body.unindent() << "}\n";
+ }
+}
+
void OpEmitter::genInferredTypeCollectiveParamBuilder() {
SmallVector<MethodParameter> paramList;
paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder");
auto numResults = op.getNumResults();
resultTypeNames.reserve(numResults);
- paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder");
+ paramList.emplace_back("::mlir::OpBuilder &", odsBuilder);
paramList.emplace_back("::mlir::OperationState &", builderOpState);
switch (typeParamKind) {
tgfmt(attr.getConstBuilderTemplate(), &fctx, attr.getDefaultValue()));
body << " if (!attr)\n attr = " << defaultValue << ";\n";
}
- body << " return attr;\n";
+ body << "return attr;\n";
};
{
//===----------------------------------------------------------------------===//
#include "mlir/IR/OperationSupport.h"
+#include "../../test/lib/Dialect/Test/TestDialect.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "llvm/ADT/BitVector.h"
attrs.assign({});
ASSERT_TRUE(attrs.empty());
}
+
+TEST(OperandStorageTest, PopulateDefaultAttrs) {
+ MLIRContext context;
+ context.getOrLoadDialect<test::TestDialect>();
+ Builder builder(&context);
+
+ OpBuilder b(&context);
+ auto req1 = b.getI32IntegerAttr(10);
+ auto req2 = b.getI32IntegerAttr(60);
+ Operation *op = b.create<test::OpAttrMatch1>(b.getUnknownLoc(), req1, nullptr,
+ nullptr, req2);
+ EXPECT_EQ(op->getAttr("default_valued_attr"), nullptr);
+ op->populateDefaultAttrs();
+ auto opt = op->getAttr("default_valued_attr");
+ EXPECT_NE(opt, nullptr) << *op;
+
+ op->destroy();
+}
} // namespace