[mlir] Move SymbolOpInterfaces "classof" check to a proper "extraClassOf" interface...
authorRiver Riddle <riddleriver@gmail.com>
Fri, 16 Dec 2022 09:16:15 +0000 (01:16 -0800)
committerRiver Riddle <riddleriver@gmail.com>
Thu, 19 Jan 2023 03:16:30 +0000 (19:16 -0800)
SymbolOpInterface overrides the base classof to provide support
for optionally implementing the interface. This is currently placed
in the extraClassDeclarations, but that is kind of awkard given that
it requires underlying knowledge of how the base classof is implemented.
This commit adds a proper "extraClassOf" field to allow interfaces to
implement this, which abstracts away the default classof logic.

Differential Revision: https://reviews.llvm.org/D140197

12 files changed:
mlir/include/mlir/IR/OpBase.td
mlir/include/mlir/IR/SymbolInterfaces.td
mlir/include/mlir/Support/InterfaceSupport.h
mlir/include/mlir/TableGen/Format.h
mlir/include/mlir/TableGen/Interfaces.h
mlir/lib/TableGen/CodeGenHelpers.cpp
mlir/lib/TableGen/Format.cpp
mlir/lib/TableGen/Interfaces.cpp
mlir/test/mlir-tblgen/op-interface.td
mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
mlir/unittests/TableGen/FormatTest.cpp

index 00b70ab..ea9976c 100644 (file)
@@ -2048,6 +2048,13 @@ class Interface<string name> {
   // An optional code block containing extra declarations to place in both
   // the interface and trait declaration.
   code extraSharedClassDeclaration = "";
+
+  // An optional code block for adding additional "classof" logic. This can
+  // be used to better enable "optional" interfaces, where an entity only
+  // implements the interface if some dynamic characteristic holds.
+  // `$_attr`/`$_op`/`$_type` may be used to refer to an instance of the
+  // entity being checked.
+  code extraClassOf = "";
 }
 
 // AttrInterface represents an interface registered to an attribute.
index 5073774..a3a6833 100644 (file)
@@ -174,28 +174,7 @@ def Symbol : OpInterface<"SymbolOpInterface"> {
     return success();
   }];
 
-  let extraClassDeclaration = [{
-    /// Convenience version of `getNameAttr` that returns a StringRef.
-    StringRef getName() {
-      return getNameAttr().getValue();
-    }
-
-    /// Convenience version of `setName` that take a StringRef.
-    void setName(StringRef name) {
-      setName(StringAttr::get(this->getContext(), name));
-    }
-
-    /// Custom classof that handles the case where the symbol is optional.
-    static bool classof(Operation *op) {
-      auto *opConcept = getInterfaceFor(op);
-      if (!opConcept)
-        return false;
-      return !opConcept->isOptionalSymbol(opConcept, op) ||
-             op->getAttr(::mlir::SymbolTable::getSymbolAttrName());
-    }
-  }];
-
-  let extraTraitClassDeclaration = [{
+  let extraSharedClassDeclaration = [{
     using Visibility = mlir::SymbolTable::Visibility;
 
     /// Convenience version of `getNameAttr` that returns a StringRef.
@@ -208,6 +187,11 @@ def Symbol : OpInterface<"SymbolOpInterface"> {
       setName(StringAttr::get($_op->getContext(), name));
     }
   }];
+
+  // Add additional classof checks to properly handle "optional" symbols.
+  let extraClassOf = [{
+    return $_op->hasAttr(::mlir::SymbolTable::getSymbolAttrName());
+  }];
 }
 
 //===----------------------------------------------------------------------===//
index d8f63e0..6ba7b33 100644 (file)
@@ -110,6 +110,12 @@ public:
            "expected value to provide interface instance");
   }
 
+  /// Constructor for a known concept.
+  Interface(ValueT t, Concept *conceptImpl)
+      : BaseType(t), conceptImpl(conceptImpl) {
+    assert(!t || ConcreteType::getInterfaceFor(t) == conceptImpl);
+  }
+
   /// Constructor for DenseMapInfo's empty key and tombstone key.
   Interface(ValueT t, std::nullptr_t) : BaseType(t), conceptImpl(nullptr) {}
 
index 60d5887..79d3d26 100644 (file)
@@ -44,7 +44,6 @@ public:
     None,
     Custom,  // For custom placeholders
     Builder, // For the $_builder placeholder
-    Op,      // For the $_op placeholder
     Self,    // For the $_self placeholder
   };
 
@@ -58,7 +57,6 @@ public:
 
   // Setters for builtin placeholders
   FmtContext &withBuilder(Twine subst);
-  FmtContext &withOp(Twine subst);
   FmtContext &withSelf(Twine subst);
 
   std::optional<StringRef> getSubstFor(PHKind placeholder) const;
index aeef360..7168f13 100644 (file)
@@ -95,6 +95,9 @@ public:
   // trait classes.
   std::optional<StringRef> getExtraSharedClassDeclaration() const;
 
+  // Return the extra classof method code.
+  std::optional<StringRef> getExtraClassOf() const;
+
   // Return the verify method body if it has one.
   std::optional<StringRef> getVerify() const;
 
index 5caefb4..193e8c1 100644 (file)
@@ -190,7 +190,7 @@ void StaticVerifierFunctionEmitter::emitConstraints(
     const ConstraintMap &constraints, StringRef selfName,
     const char *const codeTemplate) {
   FmtContext ctx;
-  ctx.withOp("*op").withSelf(selfName);
+  ctx.addSubst("_op", "*op").withSelf(selfName);
   for (auto &it : constraints) {
     os << formatv(codeTemplate, it.second,
                   tgfmt(it.first.getConditionTemplate(), &ctx),
@@ -216,7 +216,7 @@ void StaticVerifierFunctionEmitter::emitRegionConstraints() {
 
 void StaticVerifierFunctionEmitter::emitPatternConstraints() {
   FmtContext ctx;
-  ctx.withOp("*op").withBuilder("rewriter").withSelf("type");
+  ctx.addSubst("_op", "*op").withBuilder("rewriter").withSelf("type");
   for (auto &it : typeConstraints) {
     os << formatv(patternAttrOrTypeConstraintCode, it.second,
                   tgfmt(it.first.getConditionTemplate(), &ctx),
@@ -240,9 +240,9 @@ void StaticVerifierFunctionEmitter::emitPatternConstraints() {
 /// because ops use cached identifiers.
 static bool canUniqueAttrConstraint(Attribute attr) {
   FmtContext ctx;
-  auto test =
-      tgfmt(attr.getConditionTemplate(), &ctx.withSelf("attr").withOp("*op"))
-          .str();
+  auto test = tgfmt(attr.getConditionTemplate(),
+                    &ctx.withSelf("attr").addSubst("_op", "*op"))
+                  .str();
   return !StringRef(test).contains("<no-subst-found>");
 }
 
index 2595215..03f888b 100644 (file)
@@ -38,11 +38,6 @@ FmtContext &FmtContext::withBuilder(Twine subst) {
   return *this;
 }
 
-FmtContext &FmtContext::withOp(Twine subst) {
-  builtinSubstMap[PHKind::Op] = subst.str();
-  return *this;
-}
-
 FmtContext &FmtContext::withSelf(Twine subst) {
   builtinSubstMap[PHKind::Self] = subst.str();
   return *this;
@@ -69,7 +64,6 @@ std::optional<StringRef> FmtContext::getSubstFor(StringRef placeholder) const {
 FmtContext::PHKind FmtContext::getPlaceHolderKind(StringRef str) {
   return StringSwitch<FmtContext::PHKind>(str)
       .Case("_builder", FmtContext::PHKind::Builder)
-      .Case("_op", FmtContext::PHKind::Op)
       .Case("_self", FmtContext::PHKind::Self)
       .Case("", FmtContext::PHKind::None)
       .Default(FmtContext::PHKind::Custom);
index 4ddfa2f..bd56f6b 100644 (file)
@@ -116,6 +116,11 @@ std::optional<StringRef> Interface::getExtraSharedClassDeclaration() const {
   return value.empty() ? std::optional<StringRef>() : value;
 }
 
+std::optional<StringRef> Interface::getExtraClassOf() const {
+  auto value = def->getValueAsString("extraClassOf");
+  return value.empty() ? std::optional<StringRef>() : value;
+}
+
 // Return the body for this method if it has one.
 std::optional<StringRef> Interface::getVerify() const {
   // Only OpInterface supports the verify method.
index ab04198..8129eb1 100644 (file)
@@ -4,6 +4,17 @@
 
 include "mlir/IR/OpBase.td"
 
+def ExtraClassOfInterface : OpInterface<"ExtraClassOfInterface"> {
+  let extraClassOf = "return $_op->someOtherMethod();";
+}
+
+// DECL: class ExtraClassOfInterface
+// DECL:   static bool classof(::mlir::Operation * base) {
+// DECL-NEXT:     if (!getInterfaceFor(base))
+// DECL-NEXT:       return false;
+// DECL-NEXT:     return base->someOtherMethod();
+// DECL-NEXT:   }
+
 def ExtraShardDeclsInterface : OpInterface<"ExtraShardDeclsInterface"> {
   let extraSharedClassDeclaration = [{
     bool sharedMethodDeclaration() {
index 83937f4..7ed29f9 100644 (file)
@@ -819,7 +819,7 @@ OpEmitter::OpEmitter(const Operator &op,
               formatExtraDefinitions(op)),
       staticVerifierEmitter(staticVerifierEmitter),
       emitHelper(op, /*emitForOp=*/true) {
-  verifyCtx.withOp("(*this->getOperation())");
+  verifyCtx.addSubst("_op", "(*this->getOperation())");
   verifyCtx.addSubst("_ctxt", "this->getOperation()->getContext()");
 
   genTraits();
index 9e84d19..363bec7 100644 (file)
@@ -108,6 +108,8 @@ protected:
   StringRef interfaceBaseType;
   /// The name of the typename for the value template.
   StringRef valueTemplate;
+  /// The name of the substituion variable for the value.
+  StringRef substVar;
   /// The format context to use for methods.
   tblgen::FmtContext nonStaticMethodFmt;
   tblgen::FmtContext traitMethodFmt;
@@ -121,11 +123,12 @@ struct AttrInterfaceGenerator : public InterfaceGenerator {
     valueType = "::mlir::Attribute";
     interfaceBaseType = "AttributeInterface";
     valueTemplate = "ConcreteAttr";
+    substVar = "_attr";
     StringRef castCode = "(tablegen_opaque_val.cast<ConcreteAttr>())";
-    nonStaticMethodFmt.addSubst("_attr", castCode).withSelf(castCode);
-    traitMethodFmt.addSubst("_attr",
+    nonStaticMethodFmt.addSubst(substVar, castCode).withSelf(castCode);
+    traitMethodFmt.addSubst(substVar,
                             "(*static_cast<const ConcreteAttr *>(this))");
-    extraDeclsFmt.addSubst("_attr", "(*this)");
+    extraDeclsFmt.addSubst(substVar, "(*this)");
   }
 };
 /// A specialized generator for operation interfaces.
@@ -135,12 +138,13 @@ struct OpInterfaceGenerator : public InterfaceGenerator {
     valueType = "::mlir::Operation *";
     interfaceBaseType = "OpInterface";
     valueTemplate = "ConcreteOp";
+    substVar = "_op";
     StringRef castCode = "(llvm::cast<ConcreteOp>(tablegen_opaque_val))";
     nonStaticMethodFmt.addSubst("_this", "impl")
-        .withOp(castCode)
+        .addSubst(substVar, castCode)
         .withSelf(castCode);
-    traitMethodFmt.withOp("(*static_cast<ConcreteOp *>(this))");
-    extraDeclsFmt.withOp("(*this)");
+    traitMethodFmt.addSubst(substVar, "(*static_cast<ConcreteOp *>(this))");
+    extraDeclsFmt.addSubst(substVar, "(*this)");
   }
 };
 /// A specialized generator for type interfaces.
@@ -150,11 +154,12 @@ struct TypeInterfaceGenerator : public InterfaceGenerator {
     valueType = "::mlir::Type";
     interfaceBaseType = "TypeInterface";
     valueTemplate = "ConcreteType";
+    substVar = "_type";
     StringRef castCode = "(tablegen_opaque_val.cast<ConcreteType>())";
-    nonStaticMethodFmt.addSubst("_type", castCode).withSelf(castCode);
-    traitMethodFmt.addSubst("_type",
+    nonStaticMethodFmt.addSubst(substVar, castCode).withSelf(castCode);
+    traitMethodFmt.addSubst(substVar,
                             "(*static_cast<const ConcreteType *>(this))");
-    extraDeclsFmt.addSubst("_type", "(*this)");
+    extraDeclsFmt.addSubst(substVar, "(*this)");
   }
 };
 } // namespace
@@ -434,7 +439,7 @@ void InterfaceGenerator::emitTraitDecl(const Interface &interface,
     assert(isa<OpInterface>(interface) && "only OpInterface supports 'verify'");
 
     tblgen::FmtContext verifyCtx;
-    verifyCtx.withOp("op");
+    verifyCtx.addSubst("_op", "op");
     os << llvm::formatv(
               "    static ::mlir::LogicalResult {0}(::mlir::Operation *op) ",
               (interface.verifyWithRegions() ? "verifyRegionTrait"
@@ -506,6 +511,17 @@ void InterfaceGenerator::emitInterfaceDecl(const Interface &interface) {
           interface.getExtraSharedClassDeclaration())
     os << tblgen::tgfmt(*extraDecls, &extraDeclsFmt);
 
+  // Emit classof code if necessary.
+  if (std::optional<StringRef> extraClassOf = interface.getExtraClassOf()) {
+    auto extraClassOfFmt = tblgen::FmtContext();
+    extraClassOfFmt.addSubst(substVar, "base");
+    os << "  static bool classof(" << valueType << " base) {\n"
+       << "    if (!getInterfaceFor(base))\n"
+          "      return false;\n"
+       << "    " << tblgen::tgfmt(extraClassOf->trim(), &extraClassOfFmt)
+       << "\n  }\n";
+  }
+
   os << "};\n";
 
   os << "namespace detail {\n";
index 0cae408..ef00cb4 100644 (file)
@@ -105,12 +105,6 @@ TEST(FormatTest, PlaceHolderFmtStrWithBuilder) {
   EXPECT_THAT(result, StrEq("bbb"));
 }
 
-TEST(FormatTest, PlaceHolderFmtStrWithOp) {
-  FmtContext ctx;
-  std::string result = std::string(tgfmt("$_op", &ctx.withOp("ooo")));
-  EXPECT_THAT(result, StrEq("ooo"));
-}
-
 TEST(FormatTest, PlaceHolderMissingCtx) {
   std::string result = std::string(tgfmt("$_op", nullptr));
   EXPECT_THAT(result, StrEq("$_op<no-subst-found>"));