[mlir][spirv] Allow custom mangling of SPIRV built-in global variables
authorVictor Perez <victor.perez@codeplay.com>
Wed, 28 Jun 2023 08:18:20 +0000 (09:18 +0100)
committerVictor Perez <victor.perez@codeplay.com>
Fri, 30 Jun 2023 12:20:42 +0000 (13:20 +0100)
The SPIR-V spec does not specify the mangling for these variables, so
the conversion to SPIR-V should be flexible enough to allow adding a
custom prefix and suffix to the core name.

Differential Revision: https://reviews.llvm.org/D153951

Signed-off-by: Victor Perez <victor.perez@codeplay.com>
mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp

index e3b5e24..ba3e8ae 100644 (file)
@@ -140,8 +140,13 @@ class AccessChainOp;
 /// Returns the value for the given `builtin` variable. This function gets or
 /// inserts the global variable associated for the builtin within the nearest
 /// symbol table enclosing `op`. Returns null Value on error.
+///
+/// The global name being generated will be mangled using `preffix` and
+/// `suffix`.
 Value getBuiltinVariableValue(Operation *op, BuiltIn builtin, Type integerType,
-                              OpBuilder &builder);
+                              OpBuilder &builder,
+                              StringRef prefix = "__builtin__",
+                              StringRef suffix = "__");
 
 /// Gets the value at the given `offset` of the push constant storage with a
 /// total of `elementCount` `integerType` integers. A global variable will be
index 793b025..9fe2f8b 100644 (file)
@@ -702,14 +702,16 @@ static spirv::GlobalVariableOp getBuiltinVariable(Block &body,
 }
 
 /// Gets name of global variable for a builtin.
-static std::string getBuiltinVarName(spirv::BuiltIn builtin) {
-  return std::string("__builtin_var_") + stringifyBuiltIn(builtin).str() + "__";
+static std::string getBuiltinVarName(spirv::BuiltIn builtin, StringRef prefix,
+                                     StringRef suffix) {
+  return Twine(prefix).concat(stringifyBuiltIn(builtin)).concat(suffix).str();
 }
 
 /// Gets or inserts a global variable for a builtin within `body` block.
 static spirv::GlobalVariableOp
 getOrInsertBuiltinVariable(Block &body, Location loc, spirv::BuiltIn builtin,
-                           Type integerType, OpBuilder &builder) {
+                           Type integerType, OpBuilder &builder,
+                           StringRef prefix, StringRef suffix) {
   if (auto varOp = getBuiltinVariable(body, builtin))
     return varOp;
 
@@ -725,7 +727,7 @@ getOrInsertBuiltinVariable(Block &body, Location loc, spirv::BuiltIn builtin,
   case spirv::BuiltIn::GlobalInvocationId: {
     auto ptrType = spirv::PointerType::get(VectorType::get({3}, integerType),
                                            spirv::StorageClass::Input);
-    std::string name = getBuiltinVarName(builtin);
+    std::string name = getBuiltinVarName(builtin, prefix, suffix);
     newVarOp =
         builder.create<spirv::GlobalVariableOp>(loc, ptrType, name, builtin);
     break;
@@ -735,7 +737,7 @@ getOrInsertBuiltinVariable(Block &body, Location loc, spirv::BuiltIn builtin,
   case spirv::BuiltIn::SubgroupSize: {
     auto ptrType =
         spirv::PointerType::get(integerType, spirv::StorageClass::Input);
-    std::string name = getBuiltinVarName(builtin);
+    std::string name = getBuiltinVarName(builtin, prefix, suffix);
     newVarOp =
         builder.create<spirv::GlobalVariableOp>(loc, ptrType, name, builtin);
     break;
@@ -749,8 +751,8 @@ getOrInsertBuiltinVariable(Block &body, Location loc, spirv::BuiltIn builtin,
 
 Value mlir::spirv::getBuiltinVariableValue(Operation *op,
                                            spirv::BuiltIn builtin,
-                                           Type integerType,
-                                           OpBuilder &builder) {
+                                           Type integerType, OpBuilder &builder,
+                                           StringRef prefix, StringRef suffix) {
   Operation *parent = SymbolTable::getNearestSymbolTable(op->getParentOp());
   if (!parent) {
     op->emitError("expected operation to be within a module-like op");
@@ -759,7 +761,7 @@ Value mlir::spirv::getBuiltinVariableValue(Operation *op,
 
   spirv::GlobalVariableOp varOp =
       getOrInsertBuiltinVariable(*parent->getRegion(0).begin(), op->getLoc(),
-                                 builtin, integerType, builder);
+                                 builtin, integerType, builder, prefix, suffix);
   Value ptr = builder.create<spirv::AddressOfOp>(op->getLoc(), varOp);
   return builder.create<spirv::LoadOp>(op->getLoc(), ptr);
 }