// Builtin Variables
//===----------------------------------------------------------------------===//
-/// Look through all global variables in `moduleOp` and check if there is a
-/// spv.globalVariable that has the same `builtin` attribute.
-static spirv::GlobalVariableOp getBuiltinVariable(spirv::ModuleOp &moduleOp,
+static spirv::GlobalVariableOp getBuiltinVariable(Block &body,
spirv::BuiltIn builtin) {
- for (auto varOp : moduleOp.getBlock().getOps<spirv::GlobalVariableOp>()) {
+ // Look through all global variables in the given `body` block and check if
+ // there is a spv.globalVariable that has the same `builtin` attribute.
+ for (auto varOp : body.getOps<spirv::GlobalVariableOp>()) {
if (auto builtinAttr = varOp.getAttrOfType<StringAttr>(
spirv::SPIRVDialect::getAttributeName(
spirv::Decoration::BuiltIn))) {
return std::string("__builtin_var_") + stringifyBuiltIn(builtin).str() + "__";
}
-/// Gets or inserts a global variable for a builtin within a module.
+/// Gets or inserts a global variable for a builtin within `body` block.
static spirv::GlobalVariableOp
-getOrInsertBuiltinVariable(spirv::ModuleOp &moduleOp, Location loc,
- spirv::BuiltIn builtin, OpBuilder &builder) {
- if (auto varOp = getBuiltinVariable(moduleOp, builtin)) {
+getOrInsertBuiltinVariable(Block &body, Location loc, spirv::BuiltIn builtin,
+ OpBuilder &builder) {
+ if (auto varOp = getBuiltinVariable(body, builtin))
return varOp;
- }
- auto ip = builder.saveInsertionPoint();
- builder.setInsertionPointToStart(&moduleOp.getBlock());
- auto name = getBuiltinVarName(builtin);
+
+ OpBuilder::InsertionGuard guard(builder);
+ builder.setInsertionPointToStart(&body);
+
spirv::GlobalVariableOp newVarOp;
switch (builtin) {
case spirv::BuiltIn::NumWorkgroups:
auto ptrType = spirv::PointerType::get(
VectorType::get({3}, builder.getIntegerType(32)),
spirv::StorageClass::Input);
+ std::string name = getBuiltinVarName(builtin);
newVarOp =
builder.create<spirv::GlobalVariableOp>(loc, ptrType, name, builtin);
break;
emitError(loc, "unimplemented builtin variable generation for ")
<< stringifyBuiltIn(builtin);
}
- builder.restoreInsertionPoint(ip);
return newVarOp;
}
-/// Gets the global variable associated with a builtin and add
-/// it if it doesn't exist.
Value mlir::spirv::getBuiltinVariableValue(Operation *op,
spirv::BuiltIn builtin,
OpBuilder &builder) {
- auto moduleOp = op->getParentOfType<spirv::ModuleOp>();
- if (!moduleOp) {
- op->emitError("expected operation to be within a SPIR-V module");
+ Operation *parent = op->getParentOp();
+ while (parent && !parent->hasTrait<OpTrait::SymbolTable>())
+ parent = parent->getParentOp();
+ if (!parent) {
+ op->emitError("expected operation to be within a module-like op");
return nullptr;
}
- spirv::GlobalVariableOp varOp =
- getOrInsertBuiltinVariable(moduleOp, op->getLoc(), builtin, builder);
+
+ spirv::GlobalVariableOp varOp = getOrInsertBuiltinVariable(
+ *parent->getRegion(0).begin(), op->getLoc(), builtin, builder);
Value ptr = builder.create<spirv::AddressOfOp>(op->getLoc(), varOp);
return builder.create<spirv::LoadOp>(op->getLoc(), ptr);
}