From 2d86ad79f0021eb5612f1b37d5a3de5160e919fb Mon Sep 17 00:00:00 2001 From: Mahesh Ravishankar Date: Mon, 16 Sep 2019 17:11:50 -0700 Subject: [PATCH] Autogenerate (de)serialization for Extended Instruction Sets A generic mechanism for (de)serialization of extended instruction sets is added with this CL. To facilitate this, a new class "SPV_ExtendedInstSetOp" is added which is a base class for all operations corresponding to extended instruction sets. The methods to (de)serialization such ops as well as its dispatch is generated automatically. The behavior controlled by autogenSerialization and hasOpcode is also slightly modified to enable this. They are now decoupled. 1) Setting hasOpcode=1 means the operation has a corresponding opcode in SPIR-V binary format, and its dispatch for (de)serialization is automatically generated. 2) Setting autogenSerialization=1 generates the function for (de)serialization automatically. So now it is possible to have hasOpcode=0 and autogenSerialization=1 (for example SPV_ExtendedInstSetOp). Since the dispatch functions is also auto-generated, the input file needs to contain all operations. To this effect, SPIRVGLSLOps.td is included into SPIRVOps.td. This makes the previously added SPIRVGLSLOps.h and SPIRVGLSLOps.cpp unnecessary, and are deleted. The SPIRVUtilsGen.cpp is also changed to make better use of formatv,making the code more readable. PiperOrigin-RevId: 269456263 --- mlir/include/mlir/Dialect/SPIRV/CMakeLists.txt | 5 - mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td | 67 ++- .../mlir/Dialect/SPIRV/SPIRVControlFlowOps.td | 4 + mlir/include/mlir/Dialect/SPIRV/SPIRVGLSLOps.h | 37 -- mlir/include/mlir/Dialect/SPIRV/SPIRVGLSLOps.td | 12 +- mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td | 6 + .../mlir/Dialect/SPIRV/SPIRVStructureOps.td | 15 + mlir/lib/Dialect/SPIRV/CMakeLists.txt | 1 - mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp | 7 - mlir/lib/Dialect/SPIRV/SPIRVGLSLOps.cpp | 58 -- mlir/lib/Dialect/SPIRV/SPIRVOps.cpp | 23 +- .../Dialect/SPIRV/Serialization/Deserializer.cpp | 67 ++- .../lib/Dialect/SPIRV/Serialization/Serializer.cpp | 39 ++ mlir/test/Dialect/SPIRV/Serialization/glslops.mlir | 9 + mlir/test/Dialect/SPIRV/glslops.mlir | 18 +- mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp | 633 ++++++++++++++------- 16 files changed, 647 insertions(+), 354 deletions(-) delete mode 100644 mlir/include/mlir/Dialect/SPIRV/SPIRVGLSLOps.h delete mode 100644 mlir/lib/Dialect/SPIRV/SPIRVGLSLOps.cpp create mode 100644 mlir/test/Dialect/SPIRV/Serialization/glslops.mlir diff --git a/mlir/include/mlir/Dialect/SPIRV/CMakeLists.txt b/mlir/include/mlir/Dialect/SPIRV/CMakeLists.txt index 0c847f0..f1d6803 100644 --- a/mlir/include/mlir/Dialect/SPIRV/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/SPIRV/CMakeLists.txt @@ -3,11 +3,6 @@ mlir_tablegen(SPIRVOps.h.inc -gen-op-decls) mlir_tablegen(SPIRVOps.cpp.inc -gen-op-defs) add_public_tablegen_target(MLIRSPIRVOpsIncGen) -set(LLVM_TARGET_DEFINITIONS SPIRVGLSLOps.td) -mlir_tablegen(SPIRVGLSLOps.h.inc -gen-op-decls) -mlir_tablegen(SPIRVGLSLOps.cpp.inc -gen-op-defs) -add_public_tablegen_target(MLIRSPIRVGLSLOpsIncGen) - set(LLVM_TARGET_DEFINITIONS SPIRVBase.td) mlir_tablegen(SPIRVIntEnums.h.inc -gen-enum-decls) mlir_tablegen(SPIRVIntEnums.cpp.inc -gen-enum-defs) diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td index b2b844d..d1abf9f 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td @@ -75,6 +75,8 @@ class SPV_OpCode { def SPV_OC_OpNop : I32EnumAttrCase<"OpNop", 0>; def SPV_OC_OpName : I32EnumAttrCase<"OpName", 5>; def SPV_OC_OpExtension : I32EnumAttrCase<"OpExtension", 10>; +def SPV_OC_OpExtInstImport : I32EnumAttrCase<"OpExtInstImport", 11>; +def SPV_OC_OpExtInst : I32EnumAttrCase<"OpExtInst", 12>; def SPV_OC_OpMemoryModel : I32EnumAttrCase<"OpMemoryModel", 14>; def SPV_OC_OpEntryPoint : I32EnumAttrCase<"OpEntryPoint", 15>; def SPV_OC_OpExecutionMode : I32EnumAttrCase<"OpExecutionMode", 16>; @@ -154,21 +156,22 @@ def SPV_OC_OpReturnValue : I32EnumAttrCase<"OpReturnValue", 254>; def SPV_OpcodeAttr : I32EnumAttr<"Opcode", "valid SPIR-V instructions", [ - SPV_OC_OpNop, SPV_OC_OpName, SPV_OC_OpExtension, SPV_OC_OpMemoryModel, - SPV_OC_OpEntryPoint, SPV_OC_OpExecutionMode, SPV_OC_OpCapability, - SPV_OC_OpTypeVoid, SPV_OC_OpTypeBool, SPV_OC_OpTypeInt, SPV_OC_OpTypeFloat, - SPV_OC_OpTypeVector, SPV_OC_OpTypeArray, SPV_OC_OpTypeStruct, - SPV_OC_OpTypePointer, SPV_OC_OpTypeFunction, SPV_OC_OpConstantTrue, - SPV_OC_OpConstantFalse, SPV_OC_OpConstant, SPV_OC_OpConstantComposite, - SPV_OC_OpConstantNull, SPV_OC_OpSpecConstantTrue, SPV_OC_OpSpecConstantFalse, - SPV_OC_OpSpecConstant, SPV_OC_OpSpecConstantComposite, SPV_OC_OpFunction, - SPV_OC_OpFunctionParameter, SPV_OC_OpFunctionEnd, SPV_OC_OpFunctionCall, - SPV_OC_OpVariable, SPV_OC_OpLoad, SPV_OC_OpStore, SPV_OC_OpAccessChain, - SPV_OC_OpDecorate, SPV_OC_OpMemberDecorate, SPV_OC_OpCompositeExtract, - SPV_OC_OpIAdd, SPV_OC_OpFAdd, SPV_OC_OpISub, SPV_OC_OpFSub, SPV_OC_OpIMul, - SPV_OC_OpFMul, SPV_OC_OpUDiv, SPV_OC_OpSDiv, SPV_OC_OpFDiv, SPV_OC_OpUMod, - SPV_OC_OpSRem, SPV_OC_OpSMod, SPV_OC_OpFRem, SPV_OC_OpFMod, SPV_OC_OpSelect, - SPV_OC_OpIEqual, SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan, SPV_OC_OpSGreaterThan, + SPV_OC_OpNop, SPV_OC_OpName, SPV_OC_OpExtension, SPV_OC_OpExtInstImport, + SPV_OC_OpExtInst, SPV_OC_OpMemoryModel, SPV_OC_OpEntryPoint, + SPV_OC_OpExecutionMode, SPV_OC_OpCapability, SPV_OC_OpTypeVoid, + SPV_OC_OpTypeBool, SPV_OC_OpTypeInt, SPV_OC_OpTypeFloat, SPV_OC_OpTypeVector, + SPV_OC_OpTypeArray, SPV_OC_OpTypeStruct, SPV_OC_OpTypePointer, + SPV_OC_OpTypeFunction, SPV_OC_OpConstantTrue, SPV_OC_OpConstantFalse, + SPV_OC_OpConstant, SPV_OC_OpConstantComposite, SPV_OC_OpConstantNull, + SPV_OC_OpSpecConstantTrue, SPV_OC_OpSpecConstantFalse, SPV_OC_OpSpecConstant, + SPV_OC_OpSpecConstantComposite, SPV_OC_OpFunction, SPV_OC_OpFunctionParameter, + SPV_OC_OpFunctionEnd, SPV_OC_OpFunctionCall, SPV_OC_OpVariable, SPV_OC_OpLoad, + SPV_OC_OpStore, SPV_OC_OpAccessChain, SPV_OC_OpDecorate, + SPV_OC_OpMemberDecorate, SPV_OC_OpCompositeExtract, SPV_OC_OpIAdd, + SPV_OC_OpFAdd, SPV_OC_OpISub, SPV_OC_OpFSub, SPV_OC_OpIMul, SPV_OC_OpFMul, + SPV_OC_OpUDiv, SPV_OC_OpSDiv, SPV_OC_OpFDiv, SPV_OC_OpUMod, SPV_OC_OpSRem, + SPV_OC_OpSMod, SPV_OC_OpFRem, SPV_OC_OpFMod, SPV_OC_OpSelect, SPV_OC_OpIEqual, + SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan, SPV_OC_OpSGreaterThan, SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual, SPV_OC_OpULessThan, SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual, SPV_OC_OpSLessThanEqual, SPV_OC_OpFOrdEqual, SPV_OC_OpFUnordEqual, SPV_OC_OpFOrdNotEqual, @@ -1143,8 +1146,7 @@ class SPV_Op traits = []> : // Specifies whether this op has a direct corresponding SPIR-V binary // instruction opcode. The (de)serializer use this field to determine whether // to auto-generate an entry in the (de)serialization dispatch table for this - // op. If set, this field also futher enables `autogenSerialization` (see - // below for details). + // op. bit hasOpcode = 1; // Name of the corresponding SPIR-V op. Only valid to use when hasOpcode is 1. @@ -1162,15 +1164,14 @@ class SPV_Op traits = []> : // these methods is required. // // Note: - // // 1) If hasOpcode is set but autogenSerialization is not set, the // (de)serializer dispatch method still calls the above method for // (de)serializing this op. - // - // 2) If hasOpcode is not set, then this field is not interpreted; this op's - // (de)serialization method will not be auto-generated regardless. Neither - // does the handling in the (de)serialization dispatch table. Both - // (de)serializing this op and its dispatch should be handled manually. + // 2) If hasOpcode is not set, but autogenSerialization is set, the + // above methods for (de)serialization are generated, but there is no + // entry added in the dispatch tables to invoke these methods. The + // dispatch needs to be handled manually. SPV_ExtInstOps are an + // example of this. bit autogenSerialization = 1; } @@ -1190,4 +1191,24 @@ class SPV_BinaryOp traits = []> : + SPV_Op { + + // Extended instruction sets have no direct opcode (they share the + // same `OpExtInst` instruction). So the hasOpcode field is set to + // false. So no entry corresponding to these ops are added in the + // dispatch functions for (de)serialization. The methods for + // (de)serialization are still automatically generated (since + // autogenSerialization remains 1). A separate method is generated + // for dispatching extended instruction set ops. + let hasOpcode = 0; + + // Opcode within extended instruction set. + int extendedInstOpcode = opcode; + + // Name used to import the extended instruction set. + string extendedInstSetName = setName; +} + #endif // SPIRV_BASE diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td index b9f8604..1d3796c 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td @@ -250,6 +250,8 @@ def SPV_LoopOp : SPV_Op<"loop"> { }]; let hasOpcode = 0; + + let autogenSerialization = 0; } // ----- @@ -273,6 +275,8 @@ def SPV_MergeOp : SPV_Op<"_merge", [HasParent<"LoopOp">, Terminator]> { let printer = [{ printNoIOOp(getOperation(), p); }]; let hasOpcode = 0; + + let autogenSerialization = 0; } // ----- diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVGLSLOps.h b/mlir/include/mlir/Dialect/SPIRV/SPIRVGLSLOps.h deleted file mode 100644 index b20b81c..0000000 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVGLSLOps.h +++ /dev/null @@ -1,37 +0,0 @@ -//===- SPIRVGLSLOps.h - MLIR SPIR-V extended ops for GLSL --------*- C++-*-===// -// -// Copyright 2019 The MLIR Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= -// -// This file declares the extended operations for GLSL in the SPIR-V dialect. -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_DIALECT_SPIRV_SPIRVGLSLOPS_H_ -#define MLIR_DIALECT_SPIRV_SPIRVGLSLOPS_H_ - -#include "mlir/Dialect/SPIRV/SPIRVTypes.h" -#include "mlir/IR/OpDefinition.h" - -namespace mlir { -namespace spirv { - -#define GET_OP_CLASSES -#include "mlir/Dialect/SPIRV/SPIRVGLSLOps.h.inc" - -} // namespace spirv -} // namespace mlir - -#endif // MLIR_DIALECT_SPIRV_SPIRVGLSLOPS_H_ diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVGLSLOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVGLSLOps.td index e32a81d..61aac88 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVGLSLOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVGLSLOps.td @@ -34,17 +34,7 @@ include "mlir/Dialect/SPIRV/SPIRVBase.td" // Base class for all GLSL ops. class SPV_GLSLOp traits = []> : - SPV_Op<"glsl." # mnemonic, traits> { - - // Do not use the default auto-generation serializer/deserializer. - let hasOpcode = 0; - - // Opcode within the extended instruction set. - int glslOpcode = opcode; - - // Name used to refer to the extended instruction set. - string extensionSetName = "GLSL.std.450"; -} + SPV_ExtInstOp; // Base class for GLSL unary ops. class SPV_GLSLUnaryOp { diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td index 8b2eb16..37e79d5 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td @@ -65,6 +65,8 @@ def SPV_AddressOfOp : SPV_Op<"_address_of", [InFunctionScope, NoSideEffect]> { ); let hasOpcode = 0; + + let autogenSerialization = 0; } def SPV_ConstantOp : SPV_Op<"constant", [NoSideEffect]> { @@ -120,6 +122,8 @@ def SPV_ConstantOp : SPV_Op<"constant", [NoSideEffect]> { }]; let hasOpcode = 0; + + let autogenSerialization = 0; } def SPV_EntryPointOp : SPV_Op<"EntryPoint", [InModuleScope]> { @@ -174,6 +178,7 @@ def SPV_EntryPointOp : SPV_Op<"EntryPoint", [InModuleScope]> { ); let results = (outs); + let autogenSerialization = 0; } @@ -233,6 +238,8 @@ def SPV_GlobalVariableOp : SPV_Op<"globalVariable", [InModuleScope]> { let hasOpcode = 0; + let autogenSerialization = 0; + let extraClassDeclaration = [{ ::mlir::spirv::StorageClass storageClass() { return this->type().cast<::mlir::spirv::PointerType>().getStorageClass(); @@ -313,6 +320,8 @@ def SPV_ModuleOp : SPV_Op<"module", let hasOpcode = 0; + let autogenSerialization = 0; + let extraClassDeclaration = [{ Block& getBlock() { return this->getOperation()->getRegion(0).front(); @@ -340,6 +349,8 @@ def SPV_ModuleEndOp : SPV_Op<"_module_end", [InModuleScope, Terminator]> { let verifier = [{ return success(); }]; let hasOpcode = 0; + + let autogenSerialization = 0; } def SPV_ReferenceOfOp : SPV_Op<"_reference_of", [NoSideEffect]> { @@ -376,6 +387,8 @@ def SPV_ReferenceOfOp : SPV_Op<"_reference_of", [NoSideEffect]> { ); let hasOpcode = 0; + + let autogenSerialization = 0; } def SPV_SpecConstantOp : SPV_Op<"specConstant", [InModuleScope]> { @@ -417,6 +430,8 @@ def SPV_SpecConstantOp : SPV_Op<"specConstant", [InModuleScope]> { let results = (outs); let hasOpcode = 0; + + let autogenSerialization = 0; } #endif // SPIRV_STRUCTURE_OPS diff --git a/mlir/lib/Dialect/SPIRV/CMakeLists.txt b/mlir/lib/Dialect/SPIRV/CMakeLists.txt index f4e89d1..05f09d2 100644 --- a/mlir/lib/Dialect/SPIRV/CMakeLists.txt +++ b/mlir/lib/Dialect/SPIRV/CMakeLists.txt @@ -1,7 +1,6 @@ add_llvm_library(MLIRSPIRV DialectRegistration.cpp SPIRVDialect.cpp - SPIRVGLSLOps.cpp SPIRVOps.cpp SPIRVTypes.cpp diff --git a/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp index c1c214f..4660aa8 100644 --- a/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp @@ -11,7 +11,6 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/SPIRV/SPIRVDialect.h" -#include "mlir/Dialect/SPIRV/SPIRVGLSLOps.h" #include "mlir/Dialect/SPIRV/SPIRVOps.h" #include "mlir/Dialect/SPIRV/SPIRVTypes.h" #include "mlir/IR/Builders.h" @@ -48,12 +47,6 @@ SPIRVDialect::SPIRVDialect(MLIRContext *context) #include "mlir/Dialect/SPIRV/SPIRVOps.cpp.inc" >(); - // Add SPIR-V extension ops of GLSL. - addOperations< -#define GET_OP_LIST -#include "mlir/Dialect/SPIRV/SPIRVGLSLOps.cpp.inc" - >(); - // Allow unknown operations because SPIR-V is extensible. allowUnknownOperations(); } diff --git a/mlir/lib/Dialect/SPIRV/SPIRVGLSLOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVGLSLOps.cpp deleted file mode 100644 index b007aaf..0000000 --- a/mlir/lib/Dialect/SPIRV/SPIRVGLSLOps.cpp +++ /dev/null @@ -1,58 +0,0 @@ -//===- SPIRVGLSLOps.cpp - MLIR SPIR-V GLSL extended operations ------------===// -// -// Copyright 2019 The MLIR Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= -// -// This file defines the operations in the SPIR-V extended instructions set for -// GLSL -// -//===----------------------------------------------------------------------===// - -#include "mlir/Dialect/SPIRV/SPIRVGLSLOps.h" -#include "mlir/Dialect/SPIRV/SPIRVDialect.h" -#include "mlir/Dialect/SPIRV/SPIRVTypes.h" -#include "mlir/IR/OpImplementation.h" - -using namespace mlir; - -//===----------------------------------------------------------------------===// -// spv.glsl.UnaryOp -//===----------------------------------------------------------------------===// - -static ParseResult parseGLSLUnaryOp(OpAsmParser *parser, - OperationState *state) { - OpAsmParser::OperandType operandInfo; - Type type; - if (parser->parseOperand(operandInfo) || parser->parseColonType(type) || - parser->resolveOperands(operandInfo, type, state->operands)) { - return failure(); - } - state->addTypes(type); - return success(); -} - -static void printGLSLUnaryOp(Operation *unaryOp, OpAsmPrinter *printer) { - *printer << unaryOp->getName() << ' ' << *unaryOp->getOperand(0) << " : " - << unaryOp->getOperand(0)->getType(); -} - -namespace mlir { -namespace spirv { - -#define GET_OP_CLASSES -#include "mlir/Dialect/SPIRV/SPIRVGLSLOps.cpp.inc" - -} // namespace spirv -} // namespace mlir diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp index 9766d6c..c8a8078 100644 --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -914,7 +914,7 @@ static void print(spirv::ExecutionModeOp execModeOp, OpAsmPrinter *printer) { } //===----------------------------------------------------------------------===// -// spv.FuncionCall +// spv.FunctionCall //===----------------------------------------------------------------------===// static ParseResult parseFunctionCallOp(OpAsmParser *parser, @@ -1016,6 +1016,27 @@ static LogicalResult verify(spirv::FunctionCallOp functionCallOp) { } //===----------------------------------------------------------------------===// +// spv.GLSL.UnaryOp +//===----------------------------------------------------------------------===// + +static ParseResult parseGLSLUnaryOp(OpAsmParser *parser, + OperationState *state) { + OpAsmParser::OperandType operandInfo; + Type type; + if (parser->parseOperand(operandInfo) || parser->parseColonType(type) || + parser->resolveOperands(operandInfo, type, state->operands)) { + return failure(); + } + state->addTypes(type); + return success(); +} + +static void printGLSLUnaryOp(Operation *unaryOp, OpAsmPrinter *printer) { + *printer << unaryOp->getName() << ' ' << *unaryOp->getOperand(0) << " : " + << unaryOp->getOperand(0)->getType(); +} + +//===----------------------------------------------------------------------===// // spv.globalVariable //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp index e5f4e06..23cd60e 100644 --- a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp @@ -97,7 +97,11 @@ private: /// Processes the SPIR-V OpExtension with `operands` and updates bookkeeping /// in the deserializer. - LogicalResult processExtension(ArrayRef operands); + LogicalResult processExtension(ArrayRef words); + + /// Processes the SPIR-V OpExtInstImport with `operands` and updates + /// bookkeeping in the deserializer. + LogicalResult processExtInstImport(ArrayRef words); /// Attaches all collected extensions to `module` as an attribute. void attachExtensions(); @@ -300,6 +304,20 @@ private: LogicalResult dispatchToAutogenDeserialization(spirv::Opcode opcode, ArrayRef words); + /// Processes a SPIR-V OpExtInst with given `operands`. This slices the + /// entries of `operands` that specify the extended instruction set and + /// the instruction opcode. The op deserializer is then invoked using the + /// other entries. + LogicalResult processExtInst(ArrayRef operands); + + /// Dispatches the deserialization of extended instruction set operation based + /// on the extended instruction set name, and instruction opcode. This is + /// autogenerated from ODS. + LogicalResult + dispatchToExtensionSetAutogenDeserialization(StringRef extensionSetName, + uint32_t instructionID, + ArrayRef words); + /// Method to deserialize an operation in the SPIR-V dialect that is a mirror /// of an instruction in the SPIR-V spec. This is auto generated if hasOpcode /// == 1 and autogenSerialization == 1 in ODS. @@ -381,6 +399,9 @@ private: // Result to member decorations. DenseMap> memberDecorationMap; + // Result to extended instruction set name. + DenseMap extendedInstSets; + // List of instructions that are processed in a defered fashion (after an // initial processing of the entire binary). Some operations like // OpEntryPoint, and OpExecutionMode use forward references to function @@ -487,16 +508,16 @@ void Deserializer::attachCapabilities() { module->setAttr("capabilities", opBuilder.getStrArrayAttr(caps)); } -LogicalResult Deserializer::processExtension(ArrayRef operands) { - if (operands.empty()) { +LogicalResult Deserializer::processExtension(ArrayRef words) { + if (words.empty()) { return emitError( unknownLoc, "OpExtension must have a literal string for the extension name"); } unsigned wordIndex = 0; - StringRef extName = decodeStringLiteral(operands, wordIndex); - if (wordIndex != operands.size()) { + StringRef extName = decodeStringLiteral(words, wordIndex); + if (wordIndex != words.size()) { return emitError(unknownLoc, "unexpected trailing words in OpExtension instruction"); } @@ -505,6 +526,22 @@ LogicalResult Deserializer::processExtension(ArrayRef operands) { return success(); } +LogicalResult Deserializer::processExtInstImport(ArrayRef words) { + if (words.size() < 2) { + return emitError(unknownLoc, + "OpExtInstImport must have a result and a literal " + "string for the extensed instruction set name"); + } + + unsigned wordIndex = 1; + extendedInstSets[words[0]] = decodeStringLiteral(words, wordIndex); + if (wordIndex != words.size()) { + return emitError(unknownLoc, + "unexpected trailing words in OpExtInstImport"); + } + return success(); +} + void Deserializer::attachExtensions() { if (extensions.empty()) return; @@ -1652,6 +1689,10 @@ LogicalResult Deserializer::processInstruction(spirv::Opcode opcode, return processCapability(operands); case spirv::Opcode::OpExtension: return processExtension(operands); + case spirv::Opcode::OpExtInst: + return processExtInst(operands); + case spirv::Opcode::OpExtInstImport: + return processExtInstImport(operands); case spirv::Opcode::OpMemoryModel: return processMemoryModel(operands); case spirv::Opcode::OpEntryPoint: @@ -1714,6 +1755,22 @@ LogicalResult Deserializer::processInstruction(spirv::Opcode opcode, return dispatchToAutogenDeserialization(opcode, operands); } +LogicalResult Deserializer::processExtInst(ArrayRef operands) { + if (operands.size() < 4) { + return emitError(unknownLoc, + "OpExtInst must have at least 4 operands, result type " + ", result , set and instruction opcode"); + } + if (!extendedInstSets.count(operands[2])) { + return emitError(unknownLoc, "undefined set in OpExtInst"); + } + SmallVector slicedOperands; + slicedOperands.append(operands.begin(), std::next(operands.begin(), 2)); + slicedOperands.append(std::next(operands.begin(), 4), operands.end()); + return dispatchToExtensionSetAutogenDeserialization( + extendedInstSets[operands[2]], operands[3], slicedOperands); +} + namespace { template <> diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp index c31c9f3..ea50649 100644 --- a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp @@ -278,6 +278,11 @@ private: // Operations //===--------------------------------------------------------------------===// + LogicalResult encodeExtensionInstruction(Operation *op, + StringRef extensionSetName, + uint32_t opcode, + ArrayRef operands); + uint32_t findValueID(Value *val) const { return valueIDMap.lookup(val); } LogicalResult processAddressOfOp(spirv::AddressOfOp addressOfOp); @@ -345,6 +350,9 @@ private: /// Map from results of normal operations to their s. DenseMap valueIDMap; + + /// Map from extended instruction set name to s. + llvm::StringMap extendedInstSetIDMap; }; } // namespace @@ -1347,6 +1355,37 @@ LogicalResult Serializer::processBranchOp(spirv::BranchOp branchOp) { // Operation //===----------------------------------------------------------------------===// +LogicalResult Serializer::encodeExtensionInstruction( + Operation *op, StringRef extensionSetName, uint32_t extensionOpcode, + ArrayRef operands) { + // Check if the extension has been imported. + auto &setID = extendedInstSetIDMap[extensionSetName]; + if (!setID) { + setID = getNextID(); + SmallVector importOperands; + importOperands.push_back(setID); + if (failed(encodeStringLiteralInto(importOperands, extensionSetName)) || + failed(encodeInstructionInto( + extendedSets, spirv::Opcode::OpExtInstImport, importOperands))) { + return failure(); + } + } + + // The first two operands are the result type and result . The set + // and the opcode need to be insert after this. + if (operands.size() < 2) { + return op->emitError("extended instructions must have a result encoding"); + } + SmallVector extInstOperands; + extInstOperands.reserve(operands.size() + 2); + extInstOperands.append(operands.begin(), std::next(operands.begin(), 2)); + extInstOperands.push_back(setID); + extInstOperands.push_back(extensionOpcode); + extInstOperands.append(std::next(operands.begin(), 2), operands.end()); + return encodeInstructionInto(functions, spirv::Opcode::OpExtInst, + extInstOperands); +} + LogicalResult Serializer::processAddressOfOp(spirv::AddressOfOp addressOfOp) { auto varName = addressOfOp.variable(); auto variableID = findVariableID(varName); diff --git a/mlir/test/Dialect/SPIRV/Serialization/glslops.mlir b/mlir/test/Dialect/SPIRV/Serialization/glslops.mlir new file mode 100644 index 0000000..5dba6ce --- /dev/null +++ b/mlir/test/Dialect/SPIRV/Serialization/glslops.mlir @@ -0,0 +1,9 @@ +// RUN: mlir-translate -serialize-spirv %s | mlir-translate -deserialize-spirv | FileCheck %s + +spv.module "Logical" "GLSL450" { + func @fmul(%arg0 : f32) { + // CHECK: {{%.*}} = spv.GLSL.Exp {{%.*}} : f32 + %0 = spv.GLSL.Exp %arg0 : f32 + spv.Return + } +} \ No newline at end of file diff --git a/mlir/test/Dialect/SPIRV/glslops.mlir b/mlir/test/Dialect/SPIRV/glslops.mlir index 6ec900b..181f263 100644 --- a/mlir/test/Dialect/SPIRV/glslops.mlir +++ b/mlir/test/Dialect/SPIRV/glslops.mlir @@ -1,18 +1,18 @@ // RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s //===----------------------------------------------------------------------===// -// spv.glsl.Exp +// spv.GLSL.Exp //===----------------------------------------------------------------------===// func @exp(%arg0 : f32) -> () { - // CHECK: spv.glsl.Exp {{%.*}} : f32 - %2 = spv.glsl.Exp %arg0 : f32 + // CHECK: spv.GLSL.Exp {{%.*}} : f32 + %2 = spv.GLSL.Exp %arg0 : f32 return } func @expvec(%arg0 : vector<3xf16>) -> () { - // CHECK: spv.glsl.Exp {{%.*}} : vector<3xf16> - %2 = spv.glsl.Exp %arg0 : vector<3xf16> + // CHECK: spv.GLSL.Exp {{%.*}} : vector<3xf16> + %2 = spv.GLSL.Exp %arg0 : vector<3xf16> return } @@ -20,7 +20,7 @@ func @expvec(%arg0 : vector<3xf16>) -> () { func @exp(%arg0 : i32) -> () { // expected-error @+1 {{op operand #0 must be 16/32-bit float or vector of 16/32-bit float values}} - %2 = spv.glsl.Exp %arg0 : i32 + %2 = spv.GLSL.Exp %arg0 : i32 return } @@ -28,7 +28,7 @@ func @exp(%arg0 : i32) -> () { func @exp(%arg0 : vector<5xf32>) -> () { // expected-error @+1 {{op operand #0 must be 16/32-bit float or vector of 16/32-bit float values of length 2/3/4}} - %2 = spv.glsl.Exp %arg0 : vector<5xf32> + %2 = spv.GLSL.Exp %arg0 : vector<5xf32> return } @@ -36,7 +36,7 @@ func @exp(%arg0 : vector<5xf32>) -> () { func @exp(%arg0 : f32, %arg1 : f32) -> () { // expected-error @+1 {{expected ':'}} - %2 = spv.glsl.Exp %arg0, %arg1 : i32 + %2 = spv.GLSL.Exp %arg0, %arg1 : i32 return } @@ -44,6 +44,6 @@ func @exp(%arg0 : f32, %arg1 : f32) -> () { func @exp(%arg0 : i32) -> () { // expected-error @+2 {{expected non-function type}} - %2 = spv.glsl.Exp %arg0 : + %2 = spv.GLSL.Exp %arg0 : return } diff --git a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp index ca65065..b3059a9 100644 --- a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp +++ b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp @@ -28,6 +28,7 @@ #include "llvm/ADT/Sequence.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringExtras.h" +#include "llvm/ADT/StringMap.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/raw_ostream.h" @@ -41,7 +42,9 @@ using llvm::raw_ostream; using llvm::raw_string_ostream; using llvm::Record; using llvm::RecordKeeper; +using llvm::SmallVector; using llvm::SMLoc; +using llvm::StringMap; using llvm::StringRef; using llvm::Twine; using mlir::tblgen::Attribute; @@ -59,36 +62,43 @@ using mlir::tblgen::Operator; static void emitGetOpcodeFunction(const Record *record, Operator const &op, raw_ostream &os) { os << formatv("template <> constexpr inline ::mlir::spirv::Opcode " - "getOpcode<{0}>()", - op.getQualCppClassName()) - << " {\n " - << formatv("return ::mlir::spirv::Opcode::{0};\n}\n", + "getOpcode<{0}>() {{\n", + op.getQualCppClassName()); + os << formatv(" return ::mlir::spirv::Opcode::{0};\n", record->getValueAsString("spirvOpName")); + os << "}\n"; } +/// Forward declaration of function to return the SPIR-V opcode corresponding to +/// an operation. This function will be generated for all SPV_Op instances that +/// have hasOpcode = 1. static void declareOpcodeFn(raw_ostream &os) { os << "template inline constexpr ::mlir::spirv::Opcode " "getOpcode();\n"; } +/// Generates code to serialize attributes of a SPV_Op `op` into `os`. The +/// generates code extracts the attribute with name `attrName` from +/// `operandList` of `op`. static void emitAttributeSerialization(const Attribute &attr, - ArrayRef loc, llvm::StringRef op, - llvm::StringRef operandList, - llvm::StringRef attrName, - raw_ostream &os) { - os << " auto attr = " << op << ".getAttr(\"" << attrName << "\");\n"; - os << " if (attr) {\n"; + ArrayRef loc, StringRef tabs, + StringRef opVar, StringRef operandList, + StringRef attrName, raw_ostream &os) { + os << tabs << formatv("auto attr = {0}.getAttr(\"{1}\");\n", opVar, attrName); + os << tabs << "if (attr) {\n"; if (attr.getAttrDefName() == "I32ArrayAttr") { // Serialize all the elements of the array - os << " for (auto attrElem : attr.cast()) {\n"; - os << " " << operandList - << ".push_back(static_cast(attrElem.cast()." - "getValue().getZExtValue()));\n"; - os << " }\n"; + os << tabs << " for (auto attrElem : attr.cast()) {\n"; + os << tabs + << formatv(" {0}.push_back(static_cast(" + "attrElem.cast().getValue().getZExtValue()));\n", + operandList); + os << tabs << " }\n"; } else if (attr.isEnumAttr() || attr.getAttrDefName() == "I32Attr") { - os << " " << operandList - << ".push_back(static_cast(attr.cast().getValue()" - ".getZExtValue()));\n"; + os << tabs + << formatv(" {0}.push_back(static_cast(" + "attr.cast().getValue().getZExtValue()));\n", + operandList); } else { PrintFatalError( loc, @@ -96,122 +106,205 @@ static void emitAttributeSerialization(const Attribute &attr, "unhandled attribute type in SPIR-V serialization generation : '") + attr.getAttrDefName() + llvm::Twine("'")); } - os << " }\n"; + os << tabs << "}\n"; } -static void emitSerializationFunction(const Record *attrClass, - const Record *record, const Operator &op, - raw_ostream &os) { - // If the record has 'autogenSerialization' set to 0, nothing to do - if (!record->getValueAsBit("autogenSerialization")) { - return; - } - os << formatv("template <> LogicalResult\nSerializer::processOp<{0}>(\n" - " {0} op)", - op.getQualCppClassName()) - << " {\n"; - os << " SmallVector operands;\n"; - os << " SmallVector elidedAttrs;\n"; - - // Serialize result information - if (op.getNumResults() == 1) { - os << " uint32_t resultTypeID = 0;\n"; - os << " if (failed(processType(op.getLoc(), op.getType(), resultTypeID))) " - "{\n"; - os << " return failure();\n"; - os << " }\n"; - os << " operands.push_back(resultTypeID);\n"; - // Create an SSA result for the op - os << " auto resultID = getNextID();\n"; - os << " valueIDMap[op.getResult()] = resultID;\n"; - os << " operands.push_back(resultID);\n"; - } else if (op.getNumResults() != 0) { - PrintFatalError(record->getLoc(), "SPIR-V ops can only zero or one result"); - } - - // Process arguments +/// Generates code to serialize the operands of a SPV_Op `op` into `os`. The +/// generated querries the SSA-ID if operand is a SSA-Value, or serializes the +/// attributes. The `operands` vector is updated appropriately. `elidedAttrs` +/// updated as well to include the serialized attributes. +static void emitOperandSerialization(const Operator &op, ArrayRef loc, + StringRef tabs, StringRef opVar, + StringRef operands, StringRef elidedAttrs, + raw_ostream &os) { auto operandNum = 0; for (unsigned i = 0, e = op.getNumArgs(); i < e; ++i) { auto argument = op.getArg(i); - os << " {\n"; + os << tabs << "{\n"; if (argument.is()) { - os << " for (auto arg : op.getODSOperands(" << operandNum << ")) {\n"; - os << " auto argID = findValueID(arg);\n"; - os << " if (!argID) {\n"; - os << " emitError(op.getLoc(), \"operand " << operandNum - << " has a use before def\");\n"; - os << " }\n"; - os << " operands.push_back(argID);\n"; + os << tabs + << formatv(" for (auto arg : {0}.getODSOperands({1})) {{\n", opVar, + operandNum); + os << tabs << " auto argID = findValueID(arg);\n"; + os << tabs << " if (!argID) {\n"; + os << tabs + << formatv( + " emitError({0}.getLoc(), \"operand {1} has a use before " + "def\");\n", + opVar, operandNum); + os << tabs << " }\n"; + os << tabs << formatv(" {0}.push_back(argID);\n", operands); os << " }\n"; operandNum++; } else { auto attr = argument.get(); + auto newtabs = tabs.str() + " "; emitAttributeSerialization( (attr->attr.isOptional() ? attr->attr.getBaseAttr() : attr->attr), - record->getLoc(), "op", "operands", attr->name, os); - os << " elidedAttrs.push_back(\"" << attr->name << "\");\n"; + loc, newtabs, opVar, operands, attr->name, os); + os << newtabs + << formatv("{0}.push_back(\"{1}\");\n", elidedAttrs, attr->name); } - os << " }\n"; + os << tabs << "}\n"; } +} - os << formatv(" encodeInstructionInto(" - "functions, spirv::getOpcode<{0}>(), operands);\n", - op.getQualCppClassName()); +/// Generates code to serializes the result of SPV_Op `op` into `os`. The +/// generated gets the ID for the type of the result (if any), the SSA-ID of +/// the result and updates `resultID` with the SSA-ID. +static void emitResultSerialization(const Operator &op, ArrayRef loc, + StringRef tabs, StringRef opVar, + StringRef operands, StringRef resultID, + raw_ostream &os) { + if (op.getNumResults() == 1) { + StringRef resultTypeID("resultTypeID"); + os << tabs << formatv("uint32_t {0} = 0;\n", resultTypeID); + os << tabs + << formatv( + "if (failed(processType({0}.getLoc(), {0}.getType(), {1}))) {{\n", + opVar, resultTypeID); + os << tabs << " return failure();\n"; + os << tabs << "}\n"; + os << tabs << formatv("{0}.push_back({1});\n", operands, resultTypeID); + // Create an SSA result for the op + os << tabs << formatv("{0} = getNextID();\n", resultID); + os << tabs + << formatv("valueIDMap[{0}.getResult()] = {1};\n", opVar, resultID); + os << tabs << formatv("{0}.push_back({1});\n", operands, resultID); + } else if (op.getNumResults() != 0) { + PrintFatalError(loc, "SPIR-V ops can only have zero or one result"); + } +} +/// Generates code to serialize attributes of SPV_Op `op` that become +/// decorations on the `resultID` of the serialized operation `opVar` in the +/// SPIR-V binary. +static void emitDecorationSerialization(const Operator &op, StringRef tabs, + StringRef opVar, StringRef elidedAttrs, + StringRef resultID, raw_ostream &os) { if (op.getNumResults() == 1) { // All non-argument attributes translated into OpDecorate instruction - os << " for (auto attr : op.getAttrs()) {\n"; - os << " if (llvm::any_of(elidedAttrs, [&](StringRef elided) { return " - "attr.first.is(elided); })) {\n"; - os << " continue;\n"; - os << " }\n"; - os << " if (failed(processDecoration(op.getLoc(), resultID, attr))) {\n"; - os << " return failure();"; - os << " }\n"; - os << " }\n"; + os << tabs << formatv("for (auto attr : {0}.getAttrs()) {{\n", opVar); + os << tabs + << formatv(" if (llvm::any_of({0}, [&](StringRef elided)", elidedAttrs); + os << " {return attr.first.is(elided);})) {\n"; + os << tabs << " continue;\n"; + os << tabs << " }\n"; + os << tabs + << formatv( + " if (failed(processDecoration({0}.getLoc(), {1}, attr))) {{\n", + opVar, resultID); + os << tabs << " return failure();\n"; + os << tabs << " }\n"; + os << tabs << "}\n"; + } +} + +/// Generates code to serialize an SPV_Op `op` into `os`. +static void emitSerializationFunction(const Record *attrClass, + const Record *record, const Operator &op, + raw_ostream &os) { + // If the record has 'autogenSerialization' set to 0, nothing to do + if (!record->getValueAsBit("autogenSerialization")) { + return; + } + StringRef opVar("op"), operands("operands"), elidedAttrs("elidedAttrs"), + resultID("resultID"); + os << formatv( + "template <> LogicalResult\nSerializer::processOp<{0}>({0} {1}) {{\n", + op.getQualCppClassName(), opVar); + os << formatv(" SmallVector {0};\n", operands); + os << formatv(" SmallVector {0};\n", elidedAttrs); + + // Serialize result information. + if (op.getNumResults() == 1) { + os << formatv(" uint32_t {0} = 0;\n", resultID); + emitResultSerialization(op, record->getLoc(), " ", opVar, operands, + resultID, os); + } + + // Process arguments. + emitOperandSerialization(op, record->getLoc(), " ", opVar, operands, + elidedAttrs, os); + + if (record->isSubClassOf("SPV_ExtInstOp")) { + os << formatv(" encodeExtensionInstruction({0}, \"{1}\", {2}, {3});\n", + opVar, record->getValueAsString("extendedInstSetName"), + record->getValueAsInt("extendedInstOpcode"), operands); + } else { + os << formatv(" encodeInstructionInto(" + "functions, spirv::getOpcode<{0}>(), {1});\n", + op.getQualCppClassName(), operands); } + // Process decorations. + emitDecorationSerialization(op, " ", opVar, elidedAttrs, resultID, os); + os << " return success();\n"; os << "}\n\n"; } -static void initDispatchSerializationFn(raw_ostream &os) { - os << "LogicalResult Serializer::dispatchToAutogenSerialization(Operation " - "*op) {\n "; +/// Generates the prologue for the function that dispatches the serialization of +/// the operation `opVar` based on its opcode. +static void initDispatchSerializationFn(StringRef opVar, raw_ostream &os) { + os << formatv( + "LogicalResult Serializer::dispatchToAutogenSerialization(Operation " + "*{0}) {{\n ", + opVar); } -static void emitSerializationDispatch(const Operator &op, raw_ostream &os) { - os << formatv(" if (isa<{0}>(op)) ", op.getQualCppClassName()) << "{\n"; - os << " "; - os << formatv("return processOp<{0}>(cast<{0}>(op));\n", - op.getQualCppClassName()); - os << " } else"; +/// Generates the body of the dispatch function. This function generates the +/// check that if satisfied, will call the serialization function generated for +/// the `op`. +static void emitSerializationDispatch(const Operator &op, StringRef tabs, + StringRef opVar, raw_ostream &os) { + os << tabs + << formatv("if (isa<{0}>({1})) {{\n", op.getQualCppClassName(), opVar); + os << tabs + << formatv(" return processOp(cast<{0}>({1}));\n", + op.getQualCppClassName(), opVar); + os << tabs << "} else"; } -static void finalizeDispatchSerializationFn(raw_ostream &os) { +/// Generates the epilogue for the function that dispatches the serialization of +/// the operation. +static void finalizeDispatchSerializationFn(StringRef opVar, raw_ostream &os) { os << " {\n"; - os << " return op->emitError(\"unhandled operation serialization\");\n"; + os << formatv( + " return {0}->emitError(\"unhandled operation serialization\");\n", + opVar); os << " }\n"; os << " return success();\n"; os << "}\n\n"; } -static void emitAttributeDeserialization( - const Attribute &attr, ArrayRef loc, llvm::StringRef attrList, - llvm::StringRef attrName, llvm::StringRef operandsList, - llvm::StringRef wordIndex, llvm::StringRef wordCount, raw_ostream &os) { +/// Generates code to deserialize the attribute of a SPV_Op into `os`. The +/// generated code reads the `words` of the serialized instruction at +/// position `wordIndex` and adds the deserialized attribute into `attrList`. +static void emitAttributeDeserialization(const Attribute &attr, + ArrayRef loc, StringRef tabs, + StringRef attrList, StringRef attrName, + StringRef words, StringRef wordIndex, + raw_ostream &os) { if (attr.getAttrDefName() == "I32ArrayAttr") { - os << " SmallVector attrListElems;\n"; - os << " while (" << wordIndex << " < " << wordCount << ") {\n"; - os << " attrListElems.push_back(opBuilder.getI32IntegerAttr(" - << operandsList << "[" << wordIndex << "++]));\n"; - os << " }\n"; - os << " " << attrList << ".push_back(opBuilder.getNamedAttr(\"" - << attrName << "\", opBuilder.getArrayAttr(attrListElems)));\n"; + os << tabs << "SmallVector attrListElems;\n"; + os << tabs << formatv("while ({0} < {1}.size()) {{\n", wordIndex, words); + os << tabs + << formatv( + " " + "attrListElems.push_back(opBuilder.getI32IntegerAttr({0}[{1}++]))" + ";\n", + words, wordIndex); + os << tabs << "}\n"; + os << tabs + << formatv("{0}.push_back(opBuilder.getNamedAttr(\"{1}\", " + "opBuilder.getArrayAttr(attrListElems)));\n", + attrList, attrName); } else if (attr.isEnumAttr() || attr.getAttrDefName() == "I32Attr") { - os << " " << attrList << ".push_back(opBuilder.getNamedAttr(\"" - << attrName << "\", opBuilder.getI32IntegerAttr(" << operandsList << "[" - << wordIndex << "++])));\n"; + os << tabs + << formatv("{0}.push_back(opBuilder.getNamedAttr(\"{1}\", " + "opBuilder.getI32IntegerAttr({2}[{3}++])));\n", + attrList, attrName, words, wordIndex); } else { PrintFatalError( loc, llvm::Twine( @@ -220,143 +313,281 @@ static void emitAttributeDeserialization( } } -static void emitDeserializationFunction(const Record *attrClass, - const Record *record, - const Operator &op, raw_ostream &os) { - // If the record has 'autogenSerialization' set to 0, nothing to do - if (!record->getValueAsBit("autogenSerialization")) { - return; - } - os << formatv("template <> " - "LogicalResult\nDeserializer::processOp<{0}>(ArrayRef<" - "uint32_t> words)", - op.getQualCppClassName()); - os << " {\n"; - os << " SmallVector resultTypes;\n"; - os << " size_t wordIndex = 0; (void)wordIndex;\n"; - +/// Generates the code to deserialize the result of an SPV_Op `op` into +/// `os`. The generated code gets the type of the result specified at +/// `words`[`wordIndex`], the SSA ID for the result at position `wordIndex` + 1 +/// and updates the `resultType` and `valueID` with the parsed type and SSA ID, +/// respectively. +static void emitResultDeserialization(const Operator &op, ArrayRef loc, + StringRef tabs, StringRef words, + StringRef wordIndex, + StringRef resultTypes, StringRef valueID, + raw_ostream &os) { // Deserialize result information if it exists - bool hasResult = false; if (op.getNumResults() == 1) { - os << " {\n"; - os << " if (wordIndex >= words.size()) {\n"; - os << " " - << formatv("return emitError(unknownLoc, \"expected result type " - "while deserializing {0}\");\n", - op.getQualCppClassName()); - os << " }\n"; - os << " auto ty = getType(words[wordIndex]);\n"; - os << " if (!ty) {\n"; - os << " return emitError(unknownLoc, \"unknown type result : " - "\") << words[wordIndex];\n"; - os << " }\n"; - os << " resultTypes.push_back(ty);\n"; - os << " wordIndex++;\n"; - os << " }\n"; - os << " if (wordIndex >= words.size()) {\n"; - os << " " - << formatv("return emitError(unknownLoc, \"expected result while " - "deserializing {0}\");\n", - op.getQualCppClassName()); - os << " }\n"; - os << " uint32_t valueID = words[wordIndex++];\n"; - hasResult = true; + os << tabs << "{\n"; + os << tabs << formatv(" if ({0} >= {1}.size()) {{\n", wordIndex, words); + os << tabs + << formatv( + " return emitError(unknownLoc, \"expected result type " + "while deserializing {0}\");\n", + op.getQualCppClassName()); + os << tabs << " }\n"; + os << tabs << formatv(" auto ty = getType({0}[{1}]);\n", words, wordIndex); + os << tabs << " if (!ty) {\n"; + os << tabs + << formatv( + " return emitError(unknownLoc, \"unknown type result : " + "\") << {0}[{1}];\n", + words, wordIndex); + os << tabs << " }\n"; + os << tabs << formatv(" {0}.push_back(ty);\n", resultTypes); + os << tabs << formatv(" {0}++;\n", wordIndex); + os << tabs << formatv(" if ({0} >= {1}.size()) {{\n", wordIndex, words); + os << tabs + << formatv( + " return emitError(unknownLoc, \"expected result while " + "deserializing {0}\");\n", + op.getQualCppClassName()); + os << tabs << " }\n"; + os << tabs << "}\n"; + os << tabs << formatv("{0} = {1}[{2}++];\n", valueID, words, wordIndex); } else if (op.getNumResults() != 0) { - PrintFatalError(record->getLoc(), - "SPIR-V ops can have only zero or one result"); + PrintFatalError(loc, "SPIR-V ops can have only zero or one result"); } +} +/// Generates the code to deserialize the operands of an SPV_Op `op` into +/// `os`. The generated code reads the `words` of the binary instruction, from +/// position `wordIndex` to the end, and either gets the Value corresponding to +/// the ID encoded, or deserializes the attributes encoded. The parsed +/// operand(attribute) is added to the `operands` list or `attributes` list. +static void emitOperandDeserialization(const Operator &op, ArrayRef loc, + StringRef tabs, StringRef words, + StringRef wordIndex, StringRef operands, + StringRef attributes, raw_ostream &os) { // Process operands/attributes - os << " SmallVector operands;\n"; - os << " SmallVector attributes;\n"; unsigned operandNum = 0; for (unsigned i = 0, e = op.getNumArgs(); i < e; ++i) { auto argument = op.getArg(i); if (auto valueArg = argument.dyn_cast()) { if (valueArg->isVariadic()) { if (i != e - 1) { - PrintFatalError(record->getLoc(), + PrintFatalError(loc, "SPIR-V ops can have Variadic<..> argument only if " "it's the last argument"); } - os << " for (; wordIndex < words.size(); ++wordIndex)"; + os << tabs + << formatv("for (; {0} < {1}.size(); ++{0})", wordIndex, words); } else { - os << " if (wordIndex < words.size())"; + os << tabs << formatv("if ({0} < {1}.size())", wordIndex, words); } os << " {\n"; - os << " auto arg = getValue(words[wordIndex]);\n"; - os << " if (!arg) {\n"; - os << " return emitError(unknownLoc, \"unknown result : \") << " - "words[wordIndex];\n"; - os << " }\n"; - os << " operands.push_back(arg);\n"; + os << tabs + << formatv(" auto arg = getValue({0}[{1}]);\n", words, wordIndex); + os << tabs << " if (!arg) {\n"; + os << tabs + << formatv( + " return emitError(unknownLoc, \"unknown result : \") " + "<< {0}[{1}];\n", + words, wordIndex); + os << tabs << " }\n"; + os << tabs << formatv(" {0}.push_back(arg);\n", operands); if (!valueArg->isVariadic()) { - os << " wordIndex++;\n"; + os << tabs << formatv(" {0}++;\n", wordIndex); } operandNum++; - os << " }\n"; + os << tabs << "}\n"; } else { - os << " if (wordIndex < words.size()) {\n"; + os << tabs << formatv("if ({0} < {1}.size()) {{\n", wordIndex, words); auto attr = argument.get(); + auto newtabs = tabs.str() + " "; emitAttributeDeserialization( (attr->attr.isOptional() ? attr->attr.getBaseAttr() : attr->attr), - record->getLoc(), "attributes", attr->name, "words", "wordIndex", - "words.size()", os); + loc, newtabs, attributes, attr->name, words, wordIndex, os); os << " }\n"; } } - os << " if (wordIndex != words.size()) {\n"; - os << " return emitError(unknownLoc, \"found more operands than expected " - "when deserializing " - << op.getQualCppClassName() - << ", only \") << wordIndex << \" of \" << words.size() << \" " - "processed\";\n"; - os << " }\n\n"; + os << tabs << formatv("if ({0} != {1}.size()) {{\n", wordIndex, words); + os << tabs + << formatv( + " return emitError(unknownLoc, \"found more operands than " + "expected when deserializing {0}, only \") << {1} << \" of \" << " + "{2}.size() << \" processed\";\n", + op.getQualCppClassName(), wordIndex, words); + os << tabs << "}\n\n"; +} +/// Generates code to update the `attributes` vector with the attributes +/// obtained from parsing the decorations in the SPIR-V binary associated with +/// an `valueID` +static void emitDecorationDeserialization(const Operator &op, StringRef tabs, + StringRef valueID, + StringRef attributes, + raw_ostream &os) { // Import decorations parsed if (op.getNumResults() == 1) { - os << " if (decorations.count(valueID)) {\n" - << " auto attrs = decorations[valueID].getAttrs();\n" - << " attributes.append(attrs.begin(), attrs.end());\n" - << " }\n"; + os << tabs << formatv("if (decorations.count({0})) {{\n", valueID); + os << tabs + << formatv(" auto attrs = decorations[{0}].getAttrs();\n", valueID); + os << tabs + << formatv(" {0}.append(attrs.begin(), attrs.end());\n", attributes); + os << tabs << "}\n"; } +} - os << formatv(" auto op = opBuilder.create<{0}>(unknownLoc, resultTypes, " - "operands, attributes); (void)op;\n", - op.getQualCppClassName()); - if (hasResult) { - os << " valueMap[valueID] = op.getResult();\n\n"; +/// Generates code to deserialize an SPV_Op `op` into `os`. +static void emitDeserializationFunction(const Record *attrClass, + const Record *record, + const Operator &op, raw_ostream &os) { + // If the record has 'autogenSerialization' set to 0, nothing to do + if (!record->getValueAsBit("autogenSerialization")) { + return; + } + StringRef resultTypes("resultTypes"), valueID("valueID"), words("words"), + wordIndex("wordIndex"), opVar("op"), operands("operands"), + attributes("attributes"); + os << formatv("template <> " + "LogicalResult\nDeserializer::processOp<{0}>(ArrayRef<" + "uint32_t> {1}) {{\n", + op.getQualCppClassName(), words); + os << formatv(" SmallVector {0};\n", resultTypes); + os << formatv(" size_t {0} = 0; (void){0};\n", wordIndex); + os << formatv(" uint32_t {0} = 0; (void){0};\n", valueID); + + // Deserialize result information + emitResultDeserialization(op, record->getLoc(), " ", words, wordIndex, + resultTypes, valueID, os); + + os << formatv(" SmallVector {0};\n", operands); + os << formatv(" SmallVector {0};\n", attributes); + // Operand deserialization + emitOperandDeserialization(op, record->getLoc(), " ", words, wordIndex, + operands, attributes, os); + + os << formatv( + " auto {1} = opBuilder.create<{0}>(unknownLoc, {2}, {3}, {4}); " + "(void){1};\n", + op.getQualCppClassName(), opVar, resultTypes, operands, attributes); + if (op.getNumResults() == 1) { + os << formatv(" valueMap[{0}] = {1}.getResult();\n\n", valueID, opVar); } + // Decorations + emitDecorationDeserialization(op, " ", valueID, attributes, os); os << " return success();\n"; os << "}\n\n"; } -static void initDispatchDeserializationFn(raw_ostream &os) { - os << "LogicalResult " - "Deserializer::dispatchToAutogenDeserialization(spirv::Opcode " - "opcode, ArrayRef words) {\n"; - os << " switch (opcode) {\n"; +/// Generates the prologue for the function that dispatches the deserialization +/// based on the `opcode`. +static void initDispatchDeserializationFn(StringRef opcode, StringRef words, + raw_ostream &os) { + os << formatv( + "LogicalResult " + "Deserializer::dispatchToAutogenDeserialization(spirv::Opcode {0}, " + "ArrayRef {1}) {{\n", + opcode, words); + os << formatv(" switch ({0}) {{\n", opcode); } +/// Generates the body of the dispatch function, by generating the case label +/// for an opcode and the call to the method to perform the deserialization. static void emitDeserializationDispatch(const Operator &op, const Record *def, + StringRef tabs, StringRef words, raw_ostream &os) { - os << formatv(" case spirv::Opcode::{0}:\n", + os << tabs + << formatv("case spirv::Opcode::{0}:\n", def->getValueAsString("spirvOpName")); - os << formatv(" return processOp<{0}>(words);\n", - op.getQualCppClassName()); + os << tabs + << formatv(" return processOp<{0}>({1});\n", op.getQualCppClassName(), + words); } -static void finalizeDispatchDeserializationFn(raw_ostream &os) { +/// Generates the epilogue for the function that dispatches the deserialization +/// of the operation. +static void finalizeDispatchDeserializationFn(StringRef opcode, + raw_ostream &os) { os << " default:\n"; os << " ;\n"; os << " }\n"; - os << " return emitError(unknownLoc, \"unhandled deserialization of \") << " - "spirv::stringifyOpcode(opcode);\n"; + os << formatv( + " return emitError(unknownLoc, \"unhandled deserialization of \") << " + "spirv::stringifyOpcode({0});\n", + opcode); os << "}\n"; } +static void initExtendedSetDeserializationDispatch(StringRef extensionSetName, + StringRef instructionID, + StringRef words, + raw_ostream &os) { + os << formatv("LogicalResult " + "Deserializer::dispatchToExtensionSetAutogenDeserialization(" + "StringRef {0}, uint32_t {1}, ArrayRef {2}) {{\n", + extensionSetName, instructionID, words); +} + +static void +emitExtendedSetDeserializationDispatch(const RecordKeeper &recordKeeper, + raw_ostream &os) { + StringRef extensionSetName("extensionSetName"), + instructionID("instructionID"), words("words"); + + // First iterate over all ops derived from SPV_ExtensionSetOps to get all + // extensionSets. + + // For each of the extensions a separate raw_string_ostream is used to + // generate code into. These are then concatenated at the end. Since + // raw_string_ostream needs a string&, use a vector to store all the string + // that are captured by reference within raw_string_ostream. + StringMap extensionSets; + SmallVector extensionSetNames; + + initExtendedSetDeserializationDispatch(extensionSetName, instructionID, words, + os); + auto defs = recordKeeper.getAllDerivedDefinitions("SPV_ExtInstOp"); + for (const auto *def : defs) { + if (!def->getValueAsBit("autogenSerialization")) { + continue; + } + Operator op(def); + auto setName = def->getValueAsString("extendedInstSetName"); + if (!extensionSets.count(setName)) { + extensionSetNames.push_back(""); + extensionSets.try_emplace(setName, extensionSetNames.back()); + auto &setos = extensionSets.find(setName)->second; + setos << formatv(" if ({0} == \"{1}\") {{\n", extensionSetName, setName); + setos << formatv(" switch ({0}) {{\n", instructionID); + } + auto &setos = extensionSets.find(setName)->second; + setos << formatv(" case {0}:\n", + def->getValueAsInt("extendedInstOpcode")); + setos << formatv(" return processOp<{0}>({1});\n", + op.getQualCppClassName(), words); + } + + // Append the dispatch code for all the extended sets. + for (auto &extensionSet : extensionSets) { + os << extensionSet.second.str(); + os << " default:\n"; + os << formatv( + " return emitError(unknownLoc, \"unhandled deserializations of " + "\") << {0} << \" from extension set \" << {1};\n", + instructionID, extensionSetName); + os << " }\n"; + os << " }\n"; + } + + os << formatv(" return emitError(unknownLoc, \"unhandled deserialization of " + "extended instruction set {0}\");\n", + extensionSetName); + os << "}\n"; +} + +/// Emits all the autogenerated serialization/deserializations functions for the +/// SPV_Ops. static bool emitSerializationFns(const RecordKeeper &recordKeeper, raw_ostream &os) { llvm::emitSourceFileHeader("SPIR-V Serialization Utilities/Functions", os); @@ -367,23 +598,31 @@ static bool emitSerializationFns(const RecordKeeper &recordKeeper, serFn(serFnString), deserFn(deserFnString), utils(utilsString); auto attrClass = recordKeeper.getClass("Attr"); + // Emit the serialization and deserialization functions simulataneously. declareOpcodeFn(utils); - initDispatchSerializationFn(dSerFn); - initDispatchDeserializationFn(dDesFn); + StringRef opVar("op"); + StringRef opcode("opcode"), words("words"); + + // Handle the SPIR-V ops. + initDispatchSerializationFn(opVar, dSerFn); + initDispatchDeserializationFn(opcode, words, dDesFn); auto defs = recordKeeper.getAllDerivedDefinitions("SPV_Op"); for (const auto *def : defs) { - if (!def->getValueAsBit("hasOpcode")) { - continue; - } Operator op(def); - emitGetOpcodeFunction(def, op, utils); emitSerializationFunction(attrClass, def, op, serFn); - emitSerializationDispatch(op, dSerFn); emitDeserializationFunction(attrClass, def, op, deserFn); - emitDeserializationDispatch(op, def, dDesFn); + if (def->getValueAsBit("hasOpcode") || def->isSubClassOf("SPV_ExtInstOp")) { + emitSerializationDispatch(op, " ", opVar, dSerFn); + } + if (def->getValueAsBit("hasOpcode")) { + emitGetOpcodeFunction(def, op, utils); + emitDeserializationDispatch(op, def, " ", words, dDesFn); + } } - finalizeDispatchSerializationFn(dSerFn); - finalizeDispatchDeserializationFn(dDesFn); + finalizeDispatchSerializationFn(opVar, dSerFn); + finalizeDispatchDeserializationFn(opcode, dDesFn); + + emitExtendedSetDeserializationDispatch(recordKeeper, dDesFn); os << "#ifdef GET_SPIRV_SERIALIZATION_UTILS\n"; os << utils.str(); @@ -421,8 +660,8 @@ static void emitEnumGetSymbolizeFnDecl(raw_ostream &os) { static void emitEnumGetAttrNameFnDefn(const EnumAttr &enumAttr, raw_ostream &os) { auto enumName = enumAttr.getEnumClassName(); - os << formatv("template <> inline StringRef attributeName<{0}>()", enumName) - << " {\n"; + os << formatv("template <> inline StringRef attributeName<{0}>() {{\n", + enumName); os << " " << formatv("static constexpr const char attrName[] = \"{0}\";\n", mlir::convertToSnakeCase(enumName)); @@ -434,9 +673,9 @@ static void emitEnumGetSymbolizeFnDefn(const EnumAttr &enumAttr, raw_ostream &os) { auto enumName = enumAttr.getEnumClassName(); auto strToSymFnName = enumAttr.getStringToSymbolFnName(); - os << formatv("template <> inline SymbolizeFnTy<{0}> symbolizeEnum<{0}>()", - enumName) - << " {\n"; + os << formatv( + "template <> inline SymbolizeFnTy<{0}> symbolizeEnum<{0}>() {{\n", + enumName); os << " return " << strToSymFnName << ";\n"; os << "}\n"; } -- 2.7.4