// An optional code block containing extra declarations to place in both
// the interface and trait declaration.
code extraSharedClassDeclaration = "";
+
+ // An optional code block for adding additional "classof" logic. This can
+ // be used to better enable "optional" interfaces, where an entity only
+ // implements the interface if some dynamic characteristic holds.
+ // `$_attr`/`$_op`/`$_type` may be used to refer to an instance of the
+ // entity being checked.
+ code extraClassOf = "";
}
// AttrInterface represents an interface registered to an attribute.
return success();
}];
- let extraClassDeclaration = [{
- /// Convenience version of `getNameAttr` that returns a StringRef.
- StringRef getName() {
- return getNameAttr().getValue();
- }
-
- /// Convenience version of `setName` that take a StringRef.
- void setName(StringRef name) {
- setName(StringAttr::get(this->getContext(), name));
- }
-
- /// Custom classof that handles the case where the symbol is optional.
- static bool classof(Operation *op) {
- auto *opConcept = getInterfaceFor(op);
- if (!opConcept)
- return false;
- return !opConcept->isOptionalSymbol(opConcept, op) ||
- op->getAttr(::mlir::SymbolTable::getSymbolAttrName());
- }
- }];
-
- let extraTraitClassDeclaration = [{
+ let extraSharedClassDeclaration = [{
using Visibility = mlir::SymbolTable::Visibility;
/// Convenience version of `getNameAttr` that returns a StringRef.
setName(StringAttr::get($_op->getContext(), name));
}
}];
+
+ // Add additional classof checks to properly handle "optional" symbols.
+ let extraClassOf = [{
+ return $_op->hasAttr(::mlir::SymbolTable::getSymbolAttrName());
+ }];
}
//===----------------------------------------------------------------------===//
"expected value to provide interface instance");
}
+ /// Constructor for a known concept.
+ Interface(ValueT t, Concept *conceptImpl)
+ : BaseType(t), conceptImpl(conceptImpl) {
+ assert(!t || ConcreteType::getInterfaceFor(t) == conceptImpl);
+ }
+
/// Constructor for DenseMapInfo's empty key and tombstone key.
Interface(ValueT t, std::nullptr_t) : BaseType(t), conceptImpl(nullptr) {}
None,
Custom, // For custom placeholders
Builder, // For the $_builder placeholder
- Op, // For the $_op placeholder
Self, // For the $_self placeholder
};
// Setters for builtin placeholders
FmtContext &withBuilder(Twine subst);
- FmtContext &withOp(Twine subst);
FmtContext &withSelf(Twine subst);
std::optional<StringRef> getSubstFor(PHKind placeholder) const;
// trait classes.
std::optional<StringRef> getExtraSharedClassDeclaration() const;
+ // Return the extra classof method code.
+ std::optional<StringRef> getExtraClassOf() const;
+
// Return the verify method body if it has one.
std::optional<StringRef> getVerify() const;
const ConstraintMap &constraints, StringRef selfName,
const char *const codeTemplate) {
FmtContext ctx;
- ctx.withOp("*op").withSelf(selfName);
+ ctx.addSubst("_op", "*op").withSelf(selfName);
for (auto &it : constraints) {
os << formatv(codeTemplate, it.second,
tgfmt(it.first.getConditionTemplate(), &ctx),
void StaticVerifierFunctionEmitter::emitPatternConstraints() {
FmtContext ctx;
- ctx.withOp("*op").withBuilder("rewriter").withSelf("type");
+ ctx.addSubst("_op", "*op").withBuilder("rewriter").withSelf("type");
for (auto &it : typeConstraints) {
os << formatv(patternAttrOrTypeConstraintCode, it.second,
tgfmt(it.first.getConditionTemplate(), &ctx),
/// because ops use cached identifiers.
static bool canUniqueAttrConstraint(Attribute attr) {
FmtContext ctx;
- auto test =
- tgfmt(attr.getConditionTemplate(), &ctx.withSelf("attr").withOp("*op"))
- .str();
+ auto test = tgfmt(attr.getConditionTemplate(),
+ &ctx.withSelf("attr").addSubst("_op", "*op"))
+ .str();
return !StringRef(test).contains("<no-subst-found>");
}
return *this;
}
-FmtContext &FmtContext::withOp(Twine subst) {
- builtinSubstMap[PHKind::Op] = subst.str();
- return *this;
-}
-
FmtContext &FmtContext::withSelf(Twine subst) {
builtinSubstMap[PHKind::Self] = subst.str();
return *this;
FmtContext::PHKind FmtContext::getPlaceHolderKind(StringRef str) {
return StringSwitch<FmtContext::PHKind>(str)
.Case("_builder", FmtContext::PHKind::Builder)
- .Case("_op", FmtContext::PHKind::Op)
.Case("_self", FmtContext::PHKind::Self)
.Case("", FmtContext::PHKind::None)
.Default(FmtContext::PHKind::Custom);
return value.empty() ? std::optional<StringRef>() : value;
}
+std::optional<StringRef> Interface::getExtraClassOf() const {
+ auto value = def->getValueAsString("extraClassOf");
+ return value.empty() ? std::optional<StringRef>() : value;
+}
+
// Return the body for this method if it has one.
std::optional<StringRef> Interface::getVerify() const {
// Only OpInterface supports the verify method.
include "mlir/IR/OpBase.td"
+def ExtraClassOfInterface : OpInterface<"ExtraClassOfInterface"> {
+ let extraClassOf = "return $_op->someOtherMethod();";
+}
+
+// DECL: class ExtraClassOfInterface
+// DECL: static bool classof(::mlir::Operation * base) {
+// DECL-NEXT: if (!getInterfaceFor(base))
+// DECL-NEXT: return false;
+// DECL-NEXT: return base->someOtherMethod();
+// DECL-NEXT: }
+
def ExtraShardDeclsInterface : OpInterface<"ExtraShardDeclsInterface"> {
let extraSharedClassDeclaration = [{
bool sharedMethodDeclaration() {
formatExtraDefinitions(op)),
staticVerifierEmitter(staticVerifierEmitter),
emitHelper(op, /*emitForOp=*/true) {
- verifyCtx.withOp("(*this->getOperation())");
+ verifyCtx.addSubst("_op", "(*this->getOperation())");
verifyCtx.addSubst("_ctxt", "this->getOperation()->getContext()");
genTraits();
StringRef interfaceBaseType;
/// The name of the typename for the value template.
StringRef valueTemplate;
+ /// The name of the substituion variable for the value.
+ StringRef substVar;
/// The format context to use for methods.
tblgen::FmtContext nonStaticMethodFmt;
tblgen::FmtContext traitMethodFmt;
valueType = "::mlir::Attribute";
interfaceBaseType = "AttributeInterface";
valueTemplate = "ConcreteAttr";
+ substVar = "_attr";
StringRef castCode = "(tablegen_opaque_val.cast<ConcreteAttr>())";
- nonStaticMethodFmt.addSubst("_attr", castCode).withSelf(castCode);
- traitMethodFmt.addSubst("_attr",
+ nonStaticMethodFmt.addSubst(substVar, castCode).withSelf(castCode);
+ traitMethodFmt.addSubst(substVar,
"(*static_cast<const ConcreteAttr *>(this))");
- extraDeclsFmt.addSubst("_attr", "(*this)");
+ extraDeclsFmt.addSubst(substVar, "(*this)");
}
};
/// A specialized generator for operation interfaces.
valueType = "::mlir::Operation *";
interfaceBaseType = "OpInterface";
valueTemplate = "ConcreteOp";
+ substVar = "_op";
StringRef castCode = "(llvm::cast<ConcreteOp>(tablegen_opaque_val))";
nonStaticMethodFmt.addSubst("_this", "impl")
- .withOp(castCode)
+ .addSubst(substVar, castCode)
.withSelf(castCode);
- traitMethodFmt.withOp("(*static_cast<ConcreteOp *>(this))");
- extraDeclsFmt.withOp("(*this)");
+ traitMethodFmt.addSubst(substVar, "(*static_cast<ConcreteOp *>(this))");
+ extraDeclsFmt.addSubst(substVar, "(*this)");
}
};
/// A specialized generator for type interfaces.
valueType = "::mlir::Type";
interfaceBaseType = "TypeInterface";
valueTemplate = "ConcreteType";
+ substVar = "_type";
StringRef castCode = "(tablegen_opaque_val.cast<ConcreteType>())";
- nonStaticMethodFmt.addSubst("_type", castCode).withSelf(castCode);
- traitMethodFmt.addSubst("_type",
+ nonStaticMethodFmt.addSubst(substVar, castCode).withSelf(castCode);
+ traitMethodFmt.addSubst(substVar,
"(*static_cast<const ConcreteType *>(this))");
- extraDeclsFmt.addSubst("_type", "(*this)");
+ extraDeclsFmt.addSubst(substVar, "(*this)");
}
};
} // namespace
assert(isa<OpInterface>(interface) && "only OpInterface supports 'verify'");
tblgen::FmtContext verifyCtx;
- verifyCtx.withOp("op");
+ verifyCtx.addSubst("_op", "op");
os << llvm::formatv(
" static ::mlir::LogicalResult {0}(::mlir::Operation *op) ",
(interface.verifyWithRegions() ? "verifyRegionTrait"
interface.getExtraSharedClassDeclaration())
os << tblgen::tgfmt(*extraDecls, &extraDeclsFmt);
+ // Emit classof code if necessary.
+ if (std::optional<StringRef> extraClassOf = interface.getExtraClassOf()) {
+ auto extraClassOfFmt = tblgen::FmtContext();
+ extraClassOfFmt.addSubst(substVar, "base");
+ os << " static bool classof(" << valueType << " base) {\n"
+ << " if (!getInterfaceFor(base))\n"
+ " return false;\n"
+ << " " << tblgen::tgfmt(extraClassOf->trim(), &extraClassOfFmt)
+ << "\n }\n";
+ }
+
os << "};\n";
os << "namespace detail {\n";
EXPECT_THAT(result, StrEq("bbb"));
}
-TEST(FormatTest, PlaceHolderFmtStrWithOp) {
- FmtContext ctx;
- std::string result = std::string(tgfmt("$_op", &ctx.withOp("ooo")));
- EXPECT_THAT(result, StrEq("ooo"));
-}
-
TEST(FormatTest, PlaceHolderMissingCtx) {
std::string result = std::string(tgfmt("$_op", nullptr));
EXPECT_THAT(result, StrEq("$_op<no-subst-found>"));