From 85d7fb3324a6442e865c87ea766992ab096f8859 Mon Sep 17 00:00:00 2001 From: Hanhan Wang Date: Tue, 12 Nov 2019 18:58:36 -0800 Subject: [PATCH] Make VariableOp instructions be in the first block in the function. Since VariableOp is serialized during processBlock, we add two more fields, `functionHeader` and `functionBody`, to collect instructions for a function. After all the blocks have been processed, we append them to the `functions`. Also, fix a bug in processGlobalVariableOp. The global variables should be encoded into `typesGlobalValues`. PiperOrigin-RevId: 280105366 --- .../lib/Dialect/SPIRV/Serialization/Serializer.cpp | 117 +++++++++++++++++---- .../test/Dialect/SPIRV/Serialization/constant.mlir | 13 +++ mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp | 2 +- 3 files changed, 111 insertions(+), 21 deletions(-) diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp index 0213420..0ff79d9 100644 --- a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp @@ -177,6 +177,8 @@ private: /// Processes a SPIR-V function op. LogicalResult processFuncOp(FuncOp op); + LogicalResult processVariableOp(spirv::VariableOp op); + /// Process a SPIR-V GlobalVariableOp LogicalResult processGlobalVariableOp(spirv::GlobalVariableOp varOp); @@ -374,6 +376,19 @@ private: SmallVector typesGlobalValues; SmallVector functions; + /// `functionHeader` contains all the instructions that must be in the first + /// block in the function, and `functionBody` contains the rest. After + /// processing FuncOp, the encoded instructions of a function are appended to + /// `functions`. An example of instructions in `functionHeader` in order: + /// OpFunction ... + /// OpFunctionParameter ... + /// OpFunctionParameter ... + /// OpLabel ... + /// OpVariable ... + /// OpVariable ... + SmallVector functionHeader; + SmallVector functionBody; + /// Map from type used in SPIR-V module to their s. DenseMap typeIDMap; @@ -671,6 +686,7 @@ Serializer::processMemberDecoration(uint32_t structID, uint32_t memberIndex, LogicalResult Serializer::processFuncOp(FuncOp op) { LLVM_DEBUG(llvm::dbgs() << "-- start function '" << op.getName() << "' --\n"); + assert(functionHeader.empty() && functionBody.empty()); uint32_t fnTypeID = 0; // Generate type of the function. @@ -694,7 +710,7 @@ LogicalResult Serializer::processFuncOp(FuncOp op) { // TODO : Support other function control options. operands.push_back(static_cast(spirv::FunctionControl::None)); operands.push_back(fnTypeID); - encodeInstructionInto(functions, spirv::Opcode::OpFunction, operands); + encodeInstructionInto(functionHeader, spirv::Opcode::OpFunction, operands); // Add function name. if (failed(processName(funcID, op.getName()))) { @@ -709,7 +725,7 @@ LogicalResult Serializer::processFuncOp(FuncOp op) { } auto argValueID = getNextID(); valueIDMap[arg] = argValueID; - encodeInstructionInto(functions, spirv::Opcode::OpFunctionParameter, + encodeInstructionInto(functionHeader, spirv::Opcode::OpFunctionParameter, {argTypeID, argValueID}); } @@ -718,9 +734,18 @@ LogicalResult Serializer::processFuncOp(FuncOp op) { return op.emitError("external function is unhandled"); } + // Some instructions (e.g., OpVariable) in a function must be in the first + // block in the function. These instructions will be put in functionHeader. + // Thus, we put the label in functionHeader first, and omit it from the first + // block. + encodeInstructionInto(functionHeader, spirv::Opcode::OpLabel, + {getOrCreateBlockID(&op.front())}); + processBlock(&op.front(), /*omitLabel=*/true); if (failed(visitInPrettyBlockOrder( - &op.front(), [&](Block *block) { return processBlock(block); }))) + &op.front(), [&](Block *block) { return processBlock(block); }, + /*skipHeader=*/true))) { return failure(); + } // There might be OpPhi instructions who have value references needing to fix. for (auto deferredValue : deferredPhiValues) { @@ -730,14 +755,63 @@ LogicalResult Serializer::processFuncOp(FuncOp op) { << " to id = " << id << '\n'); assert(id && "OpPhi references undefined value!"); for (size_t offset : deferredValue.second) - functions[offset] = id; + functionBody[offset] = id; } deferredPhiValues.clear(); LLVM_DEBUG(llvm::dbgs() << "-- completed function '" << op.getName() << "' --\n"); // Insert OpFunctionEnd. - return encodeInstructionInto(functions, spirv::Opcode::OpFunctionEnd, {}); + if (failed(encodeInstructionInto(functionBody, spirv::Opcode::OpFunctionEnd, + {}))) { + return failure(); + } + + functions.append(functionHeader.begin(), functionHeader.end()); + functions.append(functionBody.begin(), functionBody.end()); + functionHeader.clear(); + functionBody.clear(); + + return success(); +} + +LogicalResult Serializer::processVariableOp(spirv::VariableOp op) { + SmallVector operands; + SmallVector elidedAttrs; + uint32_t resultID = 0; + uint32_t resultTypeID = 0; + if (failed(processType(op.getLoc(), op.getType(), resultTypeID))) { + return failure(); + } + operands.push_back(resultTypeID); + resultID = getNextID(); + valueIDMap[op.getResult()] = resultID; + operands.push_back(resultID); + auto attr = op.getAttr(spirv::attributeName()); + if (attr) { + operands.push_back(static_cast( + attr.cast().getValue().getZExtValue())); + } + elidedAttrs.push_back(spirv::attributeName()); + for (auto arg : op.getODSOperands(0)) { + auto argID = getValueID(arg); + if (!argID) { + return emitError(op.getLoc(), "operand 0 has a use before def"); + } + operands.push_back(argID); + } + encodeInstructionInto(functionHeader, spirv::getOpcode(), + operands); + for (auto attr : op.getAttrs()) { + if (llvm::any_of(elidedAttrs, + [&](StringRef elided) { return attr.first.is(elided); })) { + continue; + } + if (failed(processDecoration(op.getLoc(), resultID, attr))) { + return failure(); + } + } + return success(); } LogicalResult @@ -789,7 +863,7 @@ Serializer::processGlobalVariableOp(spirv::GlobalVariableOp varOp) { elidedAttrs.push_back("initializer"); } - if (failed(encodeInstructionInto(functions, spirv::Opcode::OpVariable, + if (failed(encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpVariable, operands))) { elidedAttrs.push_back("initializer"); return failure(); @@ -1360,7 +1434,7 @@ Serializer::processBlock(Block *block, bool omitLabel, << "[block] " << block << " (id = " << blockID << ")\n"); // Emit OpLabel for this block. - encodeInstructionInto(functions, spirv::Opcode::OpLabel, {blockID}); + encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {blockID}); } // Emit OpPhi instructions for block arguments, if any. @@ -1431,7 +1505,7 @@ LogicalResult Serializer::emitPhiForBlockArguments(Block *block) { // The op generating this value hasn't been visited yet so we don't have // an assigned yet. Record this to fix up later. LLVM_DEBUG(llvm::dbgs() << "(need to fix)\n"); - deferredPhiValues[value].push_back(functions.size() + 1 + + deferredPhiValues[value].push_back(functionBody.size() + 1 + phiArgs.size()); } else { LLVM_DEBUG(llvm::dbgs() << "(id = " << valueId << ")\n"); @@ -1441,7 +1515,7 @@ LogicalResult Serializer::emitPhiForBlockArguments(Block *block) { phiArgs.push_back(predBlockId); } - encodeInstructionInto(functions, spirv::Opcode::OpPhi, phiArgs); + encodeInstructionInto(functionBody, spirv::Opcode::OpPhi, phiArgs); valueIDMap[arg] = phiID; } @@ -1465,7 +1539,7 @@ LogicalResult Serializer::processSelectionOp(spirv::SelectionOp selectionOp) { auto emitSelectionMerge = [&]() { // TODO(antiagainst): properly support loop control here encodeInstructionInto( - functions, spirv::Opcode::OpSelectionMerge, + functionBody, spirv::Opcode::OpSelectionMerge, {mergeID, static_cast(spirv::LoopControl::None)}); }; // For structured selection, we cannot have blocks in the selection construct @@ -1489,7 +1563,7 @@ LogicalResult Serializer::processSelectionOp(spirv::SelectionOp selectionOp) { // contains a spv._merge op, itself. But we need to have an OpLabel // instruction to start a new SPIR-V block for ops following this SelectionOp. // The block should use the for the merge block. - return encodeInstructionInto(functions, spirv::Opcode::OpLabel, {mergeID}); + return encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {mergeID}); } LogicalResult Serializer::processLoopOp(spirv::LoopOp loopOp) { @@ -1513,7 +1587,7 @@ LogicalResult Serializer::processLoopOp(spirv::LoopOp loopOp) { // preceding and following ops. So we need to emit unconditional branches to // jump to this LoopOp's SPIR-V blocks and jumping back to the normal flow // afterwards. - encodeInstructionInto(functions, spirv::Opcode::OpBranch, {headerID}); + encodeInstructionInto(functionBody, spirv::Opcode::OpBranch, {headerID}); // We omit the LoopOp's entry block and start serialization from the loop // header block. The entry block should not contain any additional ops other @@ -1532,7 +1606,7 @@ LogicalResult Serializer::processLoopOp(spirv::LoopOp loopOp) { auto emitLoopMerge = [&]() { // TODO(antiagainst): properly support loop control here encodeInstructionInto( - functions, spirv::Opcode::OpLoopMerge, + functionBody, spirv::Opcode::OpLoopMerge, {mergeID, continueID, static_cast(spirv::LoopControl::None)}); }; if (failed(processBlock(headerBlock, /*omitLabel=*/false, emitLoopMerge))) @@ -1554,7 +1628,7 @@ LogicalResult Serializer::processLoopOp(spirv::LoopOp loopOp) { // a spv._merge op, itself. But we need to have an OpLabel instruction to // start a new SPIR-V block for ops following this LoopOp. The block should // use the for the merge block. - return encodeInstructionInto(functions, spirv::Opcode::OpLabel, {mergeID}); + return encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {mergeID}); } LogicalResult Serializer::processBranchConditionalOp( @@ -1569,12 +1643,12 @@ LogicalResult Serializer::processBranchConditionalOp( arguments.push_back(val.cast().getInt()); } - return encodeInstructionInto(functions, spirv::Opcode::OpBranchConditional, + return encodeInstructionInto(functionBody, spirv::Opcode::OpBranchConditional, arguments); } LogicalResult Serializer::processBranchOp(spirv::BranchOp branchOp) { - return encodeInstructionInto(functions, spirv::Opcode::OpBranch, + return encodeInstructionInto(functionBody, spirv::Opcode::OpBranch, {getOrCreateBlockID(branchOp.getTarget())}); } @@ -1610,7 +1684,7 @@ LogicalResult Serializer::encodeExtensionInstruction( extInstOperands.push_back(setID); extInstOperands.push_back(extensionOpcode); extInstOperands.append(std::next(operands.begin(), 2), operands.end()); - return encodeInstructionInto(functions, spirv::Opcode::OpExtInst, + return encodeInstructionInto(functionBody, spirv::Opcode::OpExtInst, extInstOperands); } @@ -1658,6 +1732,9 @@ LogicalResult Serializer::processOperation(Operation *op) { if (auto fnOp = dyn_cast(op)) { return processFuncOp(fnOp); } + if (auto varOp = dyn_cast(op)) { + return processVariableOp(varOp); + } if (auto varOp = dyn_cast(op)) { return processGlobalVariableOp(varOp); } @@ -1736,7 +1813,7 @@ Serializer::processOp(spirv::ControlBarrierOp op) { operands.push_back(operand); } - return encodeInstructionInto(functions, spirv::Opcode::OpControlBarrier, + return encodeInstructionInto(functionBody, spirv::Opcode::OpControlBarrier, operands); } @@ -1783,7 +1860,7 @@ Serializer::processOp(spirv::MemoryBarrierOp op) { operands.push_back(operand); } - return encodeInstructionInto(functions, spirv::Opcode::OpMemoryBarrier, + return encodeInstructionInto(functionBody, spirv::Opcode::OpMemoryBarrier, operands); } @@ -1814,7 +1891,7 @@ Serializer::processOp(spirv::FunctionCallOp op) { valueIDMap[op.getResult(0)] = funcCallID; } - return encodeInstructionInto(functions, spirv::Opcode::OpFunctionCall, + return encodeInstructionInto(functionBody, spirv::Opcode::OpFunctionCall, operands); } diff --git a/mlir/test/Dialect/SPIRV/Serialization/constant.mlir b/mlir/test/Dialect/SPIRV/Serialization/constant.mlir index acfa40a..9531209 100644 --- a/mlir/test/Dialect/SPIRV/Serialization/constant.mlir +++ b/mlir/test/Dialect/SPIRV/Serialization/constant.mlir @@ -165,4 +165,17 @@ spv.module "Logical" "GLSL450" { %1 = spv.IAdd %0, %0 : i32 spv.ReturnValue %1 : i32 } + + // CHECK-LABEL: @const_variable + func @const_variable(%arg0 : i32, %arg1 : i32) -> () { + // CHECK: %[[CONST:.*]] = spv.constant 5 : i32 + // CHECK: spv.Variable init(%[[CONST]]) : !spv.ptr + // CHECK: spv.IAdd %arg0, %arg1 + %0 = spv.IAdd %arg0, %arg1 : i32 + %1 = spv.constant 5 : i32 + %2 = spv.Variable init(%1) : !spv.ptr + %3 = spv.Load "Function" %2 : i32 + %4 = spv.IAdd %0, %3 : i32 + spv.Return + } } diff --git a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp index 34bfd40..f39295a 100644 --- a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp +++ b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp @@ -231,7 +231,7 @@ static void emitSerializationFunction(const Record *attrClass, record->getValueAsInt("extendedInstOpcode"), operands); } else { os << formatv(" encodeInstructionInto(" - "functions, spirv::getOpcode<{0}>(), {1});\n", + "functionBody, spirv::getOpcode<{0}>(), {1});\n", op.getQualCppClassName(), operands); } -- 2.7.4