[mlir] Add an 'cppNamespace' field to availability
authorLei Zhang <antiagainst@google.com>
Tue, 5 Oct 2021 13:32:35 +0000 (09:32 -0400)
committerLei Zhang <antiagainst@google.com>
Tue, 5 Oct 2021 13:38:09 +0000 (09:38 -0400)
This allows us to generate interfaces in a namespace,
following other TableGen'erated code.

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D108311

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAvailability.td
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.h
mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp

index 4e2e943..bfb03da 100644 (file)
@@ -19,6 +19,8 @@ include "mlir/IR/OpBase.td"
 class Availability {
   // The following are fields for controlling the generated C++ OpInterface.
 
+  // The namespace for the generated C++ OpInterface subclass.
+  string cppNamespace = ?;
   // The name for the generated C++ OpInterface subclass.
   string interfaceName = ?;
   // The documentation for the generated C++ OpInterface subclass.
index 539de6a..990af80 100644 (file)
@@ -125,6 +125,7 @@ def SPV_VersionAttr : SPV_I32EnumAttr<"Version", "valid SPIR-V version", [
 
 class MinVersion<I32EnumAttrCase min> : MinVersionBase<
     "QueryMinVersionInterface", SPV_VersionAttr, min> {
+  let cppNamespace = "::mlir::spirv";
   let interfaceDescription = [{
     Querying interface for minimal required SPIR-V version.
 
@@ -136,6 +137,7 @@ class MinVersion<I32EnumAttrCase min> : MinVersionBase<
 
 class MaxVersion<I32EnumAttrCase max> : MaxVersionBase<
     "QueryMaxVersionInterface", SPV_VersionAttr, max> {
+  let cppNamespace = "::mlir::spirv";
   let interfaceDescription = [{
     Querying interface for maximal supported SPIR-V version.
 
@@ -146,6 +148,7 @@ class MaxVersion<I32EnumAttrCase max> : MaxVersionBase<
 }
 
 class Extension<list<StrEnumAttrCase> extensions> : Availability {
+  let cppNamespace = "::mlir::spirv";
   let interfaceName = "QueryExtensionInterface";
   let interfaceDescription = [{
     Querying interface for required SPIR-V extensions.
@@ -189,6 +192,7 @@ class Extension<list<StrEnumAttrCase> extensions> : Availability {
 }
 
 class Capability<list<I32EnumAttrCase> capabilities> : Availability {
+  let cppNamespace = "::mlir::spirv";
   let interfaceName = "QueryCapabilityInterface";
   let interfaceDescription = [{
     Querying interface for required SPIR-V capabilities.
index c4f5fd1..e3d4e56 100644 (file)
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "llvm/Support/PointerLikeTypeTraits.h"
 
+// TableGen'erated operation interfaces for querying versions, extensions, and
+// capabilities.
+#include "mlir/Dialect/SPIRV/IR/SPIRVAvailability.h.inc"
+
 namespace mlir {
 class OpBuilder;
 
 namespace spirv {
 class VerCapExtAttr;
-
-// TableGen'erated operation interfaces for querying versions, extensions, and
-// capabilities.
-#include "mlir/Dialect/SPIRV/IR/SPIRVAvailability.h.inc"
 } // namespace spirv
 } // namespace mlir
 
index f711c18..9a26622 100644 (file)
@@ -3915,14 +3915,9 @@ static LogicalResult verify(spirv::PtrAccessChainOp accessChainOp) {
   return verifyAccessChain(accessChainOp, accessChainOp.indices());
 }
 
-namespace mlir {
-namespace spirv {
-
 // TableGen'erated operation interfaces for querying versions, extensions, and
 // capabilities.
 #include "mlir/Dialect/SPIRV/IR/SPIRVAvailability.cpp.inc"
-} // namespace spirv
-} // namespace mlir
 
 // TablenGen'erated operation definitions.
 #define GET_OP_CLASSES
@@ -3932,6 +3927,5 @@ namespace mlir {
 namespace spirv {
 // TableGen'erated operation availability interface implementations.
 #include "mlir/Dialect/SPIRV/IR/SPIRVOpAvailabilityImpl.inc"
-
 } // namespace spirv
 } // namespace mlir
index 10c00fe..4053bc9 100644 (file)
@@ -12,6 +12,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/TableGen/Attribute.h"
+#include "mlir/TableGen/CodeGenHelpers.h"
 #include "mlir/TableGen/Format.h"
 #include "mlir/TableGen/GenInfo.h"
 #include "mlir/TableGen/Operator.h"
@@ -45,6 +46,7 @@ using mlir::tblgen::EnumAttr;
 using mlir::tblgen::EnumAttrCase;
 using mlir::tblgen::NamedAttribute;
 using mlir::tblgen::NamedTypeConstraint;
+using mlir::tblgen::NamespaceEmitter;
 using mlir::tblgen::Operator;
 
 //===----------------------------------------------------------------------===//
@@ -62,6 +64,9 @@ public:
   // instance.
   StringRef getClass() const;
 
+  // Returns the generated C++ interface's class namespace.
+  StringRef getInterfaceClassNamespace() const;
+
   // Returns the generated C++ interface's class name.
   StringRef getInterfaceClassName() const;
 
@@ -91,6 +96,9 @@ public:
   // Returns the concrete availability instance carried in this case.
   StringRef getMergeInstance() const;
 
+  // Returns the underlying LLVM TableGen Record.
+  const llvm::Record *getDef() const { return def; }
+
 private:
   // The TableGen definition of this availability.
   const llvm::Record *def;
@@ -112,6 +120,10 @@ StringRef Availability::getClass() const {
   return parentClass.front()->getName();
 }
 
+StringRef Availability::getInterfaceClassNamespace() const {
+  return def->getValueAsString("cppNamespace");
+}
+
 StringRef Availability::getInterfaceClassName() const {
   return def->getValueAsString("interfaceName");
 }
@@ -168,9 +180,16 @@ std::vector<Availability> getAvailabilities(const Record &def) {
 
 static void emitInterfaceDef(const Availability &availability,
                              raw_ostream &os) {
+
+  os << availability.getQueryFnRetType() << " ";
+
+  StringRef cppNamespace = availability.getInterfaceClassNamespace();
+  cppNamespace.consume_front("::");
+  if (!cppNamespace.empty())
+    os << cppNamespace << "::";
+
   StringRef methodName = availability.getQueryFnName();
-  os << availability.getQueryFnRetType() << " "
-     << availability.getInterfaceClassName() << "::" << methodName << "() {\n"
+  os << availability.getInterfaceClassName() << "::" << methodName << "() {\n"
      << "  return getImpl()->" << methodName << "(getImpl(), getOperation());\n"
      << "}\n";
 }
@@ -237,13 +256,16 @@ static void emitInterfaceDecl(const Availability &availability,
   std::string interfaceTraitsName =
       std::string(formatv("{0}Traits", interfaceName));
 
+  StringRef cppNamespace = availability.getInterfaceClassNamespace();
+  NamespaceEmitter nsEmitter(os, cppNamespace);
+
   // Emit the traits struct containing the concept and model declarations.
   os << "namespace detail {\n"
      << "struct " << interfaceTraitsName << " {\n";
   emitConceptDecl(availability, os);
   os << '\n';
   emitModelDecl(availability, os);
-  os << "};\n} // end namespace detail\n\n";
+  os << "};\n} // namespace detail\n\n";
 
   // Emit the main interface class declaration.
   os << "/*\n" << availability.getInterfaceDescription().trim() << "\n*/\n";