[mlir][spirv] Create builtin variable in nearest symbol table
authorLei Zhang <antiagainst@google.com>
Sat, 25 Jan 2020 15:09:46 +0000 (10:09 -0500)
committerLei Zhang <antiagainst@google.com>
Sun, 26 Jan 2020 16:00:49 +0000 (11:00 -0500)
This commit changes the logic of `getBuiltinVariableValue` to get
or create the builtin variable in the nearest symbol table. This
will allow us to use this function in other partial conversion
cases where we haven't created the spv.module yet.

Differential Revision: https://reviews.llvm.org/D73416

mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h
mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp

index 6c8db62..86abd06 100644 (file)
@@ -83,8 +83,10 @@ private:
   llvm::SmallSet<Capability, 8> givenCapabilities; /// Allowed capabilities
 };
 
-/// Returns a value that represents a builtin variable value within the SPIR-V
-/// module.
+/// Returns the value for the given `builtin` variable. This function gets or
+/// inserts the global variable associated for the builtin within the nearest
+/// enclosing op that has a symbol table. Returns null Value if such an
+/// enclosing op cannot be found.
 Value getBuiltinVariableValue(Operation *op, BuiltIn builtin,
                               OpBuilder &builder);
 
index 35eb7e1..aacb0f6 100644 (file)
@@ -221,11 +221,11 @@ void mlir::populateBuiltinFuncToSPIRVPatterns(
 // Builtin Variables
 //===----------------------------------------------------------------------===//
 
-/// Look through all global variables in `moduleOp` and check if there is a
-/// spv.globalVariable that has the same `builtin` attribute.
-static spirv::GlobalVariableOp getBuiltinVariable(spirv::ModuleOp &moduleOp,
+static spirv::GlobalVariableOp getBuiltinVariable(Block &body,
                                                   spirv::BuiltIn builtin) {
-  for (auto varOp : moduleOp.getBlock().getOps<spirv::GlobalVariableOp>()) {
+  // Look through all global variables in the given `body` block and check if
+  // there is a spv.globalVariable that has the same `builtin` attribute.
+  for (auto varOp : body.getOps<spirv::GlobalVariableOp>()) {
     if (auto builtinAttr = varOp.getAttrOfType<StringAttr>(
             spirv::SPIRVDialect::getAttributeName(
                 spirv::Decoration::BuiltIn))) {
@@ -243,16 +243,16 @@ static std::string getBuiltinVarName(spirv::BuiltIn builtin) {
   return std::string("__builtin_var_") + stringifyBuiltIn(builtin).str() + "__";
 }
 
-/// Gets or inserts a global variable for a builtin within a module.
+/// Gets or inserts a global variable for a builtin within `body` block.
 static spirv::GlobalVariableOp
-getOrInsertBuiltinVariable(spirv::ModuleOp &moduleOp, Location loc,
-                           spirv::BuiltIn builtin, OpBuilder &builder) {
-  if (auto varOp = getBuiltinVariable(moduleOp, builtin)) {
+getOrInsertBuiltinVariable(Block &body, Location loc, spirv::BuiltIn builtin,
+                           OpBuilder &builder) {
+  if (auto varOp = getBuiltinVariable(body, builtin))
     return varOp;
-  }
-  auto ip = builder.saveInsertionPoint();
-  builder.setInsertionPointToStart(&moduleOp.getBlock());
-  auto name = getBuiltinVarName(builtin);
+
+  OpBuilder::InsertionGuard guard(builder);
+  builder.setInsertionPointToStart(&body);
+
   spirv::GlobalVariableOp newVarOp;
   switch (builtin) {
   case spirv::BuiltIn::NumWorkgroups:
@@ -263,6 +263,7 @@ getOrInsertBuiltinVariable(spirv::ModuleOp &moduleOp, Location loc,
     auto ptrType = spirv::PointerType::get(
         VectorType::get({3}, builder.getIntegerType(32)),
         spirv::StorageClass::Input);
+    std::string name = getBuiltinVarName(builtin);
     newVarOp =
         builder.create<spirv::GlobalVariableOp>(loc, ptrType, name, builtin);
     break;
@@ -271,22 +272,22 @@ getOrInsertBuiltinVariable(spirv::ModuleOp &moduleOp, Location loc,
     emitError(loc, "unimplemented builtin variable generation for ")
         << stringifyBuiltIn(builtin);
   }
-  builder.restoreInsertionPoint(ip);
   return newVarOp;
 }
 
-/// Gets the global variable associated with a builtin and add
-/// it if it doesn't exist.
 Value mlir::spirv::getBuiltinVariableValue(Operation *op,
                                            spirv::BuiltIn builtin,
                                            OpBuilder &builder) {
-  auto moduleOp = op->getParentOfType<spirv::ModuleOp>();
-  if (!moduleOp) {
-    op->emitError("expected operation to be within a SPIR-V module");
+  Operation *parent = op->getParentOp();
+  while (parent && !parent->hasTrait<OpTrait::SymbolTable>())
+    parent = parent->getParentOp();
+  if (!parent) {
+    op->emitError("expected operation to be within a module-like op");
     return nullptr;
   }
-  spirv::GlobalVariableOp varOp =
-      getOrInsertBuiltinVariable(moduleOp, op->getLoc(), builtin, builder);
+
+  spirv::GlobalVariableOp varOp = getOrInsertBuiltinVariable(
+      *parent->getRegion(0).begin(), op->getLoc(), builtin, builder);
   Value ptr = builder.create<spirv::AddressOfOp>(op->getLoc(), varOp);
   return builder.create<spirv::LoadOp>(op->getLoc(), ptr);
 }