From 801efec9e64b6a490a5e6dd465872b67a2b79df3 Mon Sep 17 00:00:00 2001 From: Mahesh Ravishankar Date: Wed, 10 Jul 2019 17:33:28 -0700 Subject: [PATCH] Update the gen_spirv_dialect.py script to add opcodes from the SPIR-V JSON spec into the SPIRBase.td file. This is done incrementally to only import those opcodes that are needed, through use of the script define_opcode.sh added. PiperOrigin-RevId: 257517343 --- mlir/include/mlir/SPIRV/SPIRVBase.td | 54 ++++++++++-- mlir/include/mlir/SPIRV/SPIRVOps.h | 2 +- mlir/include/mlir/SPIRV/SPIRVOps.td | 14 ---- mlir/include/mlir/SPIRV/SPIRVStructureOps.td | 12 +++ mlir/include/mlir/SPIRV/SPIRVTypes.h | 2 +- mlir/lib/SPIRV/Serialization/Deserializer.cpp | 12 +-- .../SPIRV/Serialization/SPIRVBinaryUtils.h | 5 +- mlir/lib/SPIRV/Serialization/Serializer.cpp | 26 +++--- mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp | 16 ++-- mlir/utils/spirv/define_opcodes.sh | 39 +++++++++ mlir/utils/spirv/gen_spirv_dialect.py | 84 ++++++++++++++++++- 11 files changed, 218 insertions(+), 48 deletions(-) create mode 100755 mlir/utils/spirv/define_opcodes.sh diff --git a/mlir/include/mlir/SPIRV/SPIRVBase.td b/mlir/include/mlir/SPIRV/SPIRVBase.td index b862f71c3075..8eb45aab64d9 100644 --- a/mlir/include/mlir/SPIRV/SPIRVBase.td +++ b/mlir/include/mlir/SPIRV/SPIRVBase.td @@ -58,6 +58,45 @@ def SPV_Dialect : Dialect { let cppNamespace = "spirv"; } +//===----------------------------------------------------------------------===// +// SPIR-V opcode specification +//===----------------------------------------------------------------------===// + +class SPV_OpCode { + // Name used as reference to retrieve the opcode + string opname = name; + + // Opcode associated with the name + int opcode = val; +} + +// Begin opcode section. Generated from SPIR-V spec; DO NOT MODIFY! + +def SPV_OC_OpMemoryModel : I32EnumAttrCase<"OpMemoryModel", 14>; +def SPV_OC_OpEntryPoint : I32EnumAttrCase<"OpEntryPoint", 15>; +def SPV_OC_OpExecutionMode : I32EnumAttrCase<"OpExecutionMode", 16>; +def SPV_OC_OpTypeVoid : I32EnumAttrCase<"OpTypeVoid", 19>; +def SPV_OC_OpTypeFunction : I32EnumAttrCase<"OpTypeFunction", 33>; +def SPV_OC_OpVariable : I32EnumAttrCase<"OpVariable", 59>; +def SPV_OC_OpLoad : I32EnumAttrCase<"OpLoad", 61>; +def SPV_OC_OpStore : I32EnumAttrCase<"OpStore", 62>; +def SPV_OC_OpFMul : I32EnumAttrCase<"OpFMul", 133>; +def SPV_OC_OpReturn : I32EnumAttrCase<"OpReturn", 253>; + +def SPV_OpcodeAttr : + I32EnumAttr<"Opcode", "valid SPIR-V instructions", [ + SPV_OC_OpMemoryModel, SPV_OC_OpEntryPoint, SPV_OC_OpExecutionMode, + SPV_OC_OpTypeVoid, SPV_OC_OpTypeFunction, SPV_OC_OpVariable, SPV_OC_OpLoad, + SPV_OC_OpStore, SPV_OC_OpFMul, SPV_OC_OpReturn + ]> { + let returnType = "::mlir::spirv::Opcode"; + let convertFromStorage = "static_cast<::mlir::spirv::Opcode>($_self.getInt())"; + let cppNamespace = "::mlir::spirv"; +} + +// End opcode section. Generated from SPIR-V spec; DO NOT MODIFY! + + //===----------------------------------------------------------------------===// // SPIR-V type definitions //===----------------------------------------------------------------------===// @@ -438,11 +477,8 @@ def ModuleOnly : // Base class for all SPIR-V ops. class SPV_Op traits = []> : Op { - // Opcode for the binary format. - // See https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html for - // the opcode for each operation. Ops that cannot be directly serialized will - // leave this field as unset. - int opcode = ?; + + string spirvOpName = "Op" # mnemonic; // For each SPIR-V op, the following static functions need to be defined // in SPVOps.cpp: @@ -454,6 +490,14 @@ class SPV_Op traits = []> : let parser = [{ return ::parse$cppClass(parser, result); }]; let printer = [{ return ::print(*this, p); }]; let verifier = [{ return ::verify(*this); }]; + + // By default the opcode to use for (de)serialization is obtained + // automatically from the SPIR-V spec. It assume the SPIR-V op being defined + // is ('Op' # mnemonic). The opcode value can be obtained by calling + // getOpcode(). If invoking this method is invalid or custom + // processing is needed for the op, set hasOpcode = 0 and specialize the + // getOpcode method. + int hasOpcode = 1; } #endif // SPIRV_BASE diff --git a/mlir/include/mlir/SPIRV/SPIRVOps.h b/mlir/include/mlir/SPIRV/SPIRVOps.h index e0345fbc654d..f3720a86b455 100644 --- a/mlir/include/mlir/SPIRV/SPIRVOps.h +++ b/mlir/include/mlir/SPIRV/SPIRVOps.h @@ -22,7 +22,7 @@ #ifndef MLIR_SPIRV_SPIRVOPS_H_ #define MLIR_SPIRV_SPIRVOPS_H_ -#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/Function.h" #include "mlir/SPIRV/SPIRVTypes.h" namespace mlir { diff --git a/mlir/include/mlir/SPIRV/SPIRVOps.td b/mlir/include/mlir/SPIRV/SPIRVOps.td index afab62ab4e41..b57e00a4629f 100644 --- a/mlir/include/mlir/SPIRV/SPIRVOps.td +++ b/mlir/include/mlir/SPIRV/SPIRVOps.td @@ -88,8 +88,6 @@ def SPV_EntryPointOp : SPV_Op<"EntryPoint", [ModuleOnly]> { ); let results = (outs SPV_EntryPoint:$id); - - let opcode = 15; } def SPV_ExecutionModeOp : SPV_Op<"ExecutionMode", [ModuleOnly]> { @@ -130,8 +128,6 @@ def SPV_ExecutionModeOp : SPV_Op<"ExecutionMode", [ModuleOnly]> { ); let verifier = [{ return success(); }]; - - let opcode = 16; } def SPV_FMulOp : SPV_Op<"FMul", [NoSideEffect, SameOperandsAndResultType]> { @@ -159,8 +155,6 @@ def SPV_FMulOp : SPV_Op<"FMul", [NoSideEffect, SameOperandsAndResultType]> { // No additional verification needed in addition to the ODS-generated ones. let verifier = [{ return success(); }]; - - let opcode = 133; } def SPV_LoadOp : SPV_Op<"Load"> { @@ -207,8 +201,6 @@ def SPV_LoadOp : SPV_Op<"Load"> { let results = (outs SPV_Type:$value ); - - let opcode = 61; } def SPV_ReturnOp : SPV_Op<"Return", [Terminator]> { @@ -226,8 +218,6 @@ def SPV_ReturnOp : SPV_Op<"Return", [Terminator]> { let printer = [{ printNoIOOp(getOperation(), p); }]; let verifier = [{ return verifyReturn(*this); }]; - - let opcode = 253; } def SPV_StoreOp : SPV_Op<"Store"> { @@ -267,8 +257,6 @@ def SPV_StoreOp : SPV_Op<"Store"> { OptionalAttr:$memory_access, OptionalAttr:$alignment ); - - let opcode = 62; } def SPV_VariableOp : SPV_Op<"Variable"> { @@ -320,8 +308,6 @@ def SPV_VariableOp : SPV_Op<"Variable"> { let results = (outs SPV_AnyPtr:$pointer ); - - let opcode = 59; } #endif // SPIRV_OPS diff --git a/mlir/include/mlir/SPIRV/SPIRVStructureOps.td b/mlir/include/mlir/SPIRV/SPIRVStructureOps.td index 3b32dcd012a2..65ffeb739699 100644 --- a/mlir/include/mlir/SPIRV/SPIRVStructureOps.td +++ b/mlir/include/mlir/SPIRV/SPIRVStructureOps.td @@ -88,6 +88,14 @@ def SPV_ModuleOp : SPV_Op<"module", []> { let regions = (region SizedRegion<1>:$body); let builders = [OpBuilder<"Builder *, OperationState *state">]; + + let hasOpcode = 0; + + let extraClassDeclaration = [{ + Block& getBlock() { + return this->getOperation()->getRegion(0).front(); + } + }]; } def SPV_ModuleEndOp : SPV_Op<"_module_end", [Terminator, ModuleOnly]> { @@ -108,6 +116,8 @@ def SPV_ModuleEndOp : SPV_Op<"_module_end", [Terminator, ModuleOnly]> { let printer = [{ printNoIOOp(getOperation(), p); }]; let verifier = [{ return success(); }]; + + let hasOpcode = 0; } def SPV_ConstantOp : SPV_Op<"constant", [NoSideEffect]> { @@ -152,6 +162,8 @@ def SPV_ConstantOp : SPV_Op<"constant", [NoSideEffect]> { let results = (outs SPV_Type:$constant ); + + let hasOpcode = 0; } #endif // SPIRV_STRUCTURE_OPS diff --git a/mlir/include/mlir/SPIRV/SPIRVTypes.h b/mlir/include/mlir/SPIRV/SPIRVTypes.h index 0bdb9f030234..a389439f1bf8 100644 --- a/mlir/include/mlir/SPIRV/SPIRVTypes.h +++ b/mlir/include/mlir/SPIRV/SPIRVTypes.h @@ -22,7 +22,7 @@ #ifndef MLIR_SPIRV_SPIRVTYPES_H_ #define MLIR_SPIRV_SPIRVTYPES_H_ -#include "mlir/IR/TypeSupport.h" +#include "mlir/IR/StandardTypes.h" #include "mlir/IR/Types.h" // Pull in all enum type definitions and utility function declarations diff --git a/mlir/lib/SPIRV/Serialization/Deserializer.cpp b/mlir/lib/SPIRV/Serialization/Deserializer.cpp index e7eaa99ef3c1..ab0d11e29235 100644 --- a/mlir/lib/SPIRV/Serialization/Deserializer.cpp +++ b/mlir/lib/SPIRV/Serialization/Deserializer.cpp @@ -67,7 +67,7 @@ private: LogicalResult processHeader(); /// Processes a SPIR-V instruction with the given `opcode` and `operands`. - LogicalResult processInstruction(uint32_t opcode, + LogicalResult processInstruction(spirv::Opcode opcode, ArrayRef operands); LogicalResult processMemoryModel(ArrayRef operands); @@ -123,7 +123,8 @@ LogicalResult Deserializer::deserialize() { "insufficient words for the last instruction"); auto operands = binary.slice(curOffset + 1, wordCount - 1); - if (failed(processInstruction(opcode, operands))) + if (failed( + processInstruction(static_cast(opcode), operands))) return failure(); curOffset = nextOffset; @@ -146,15 +147,16 @@ LogicalResult Deserializer::processHeader() { return success(); } -LogicalResult Deserializer::processInstruction(uint32_t opcode, +LogicalResult Deserializer::processInstruction(spirv::Opcode opcode, ArrayRef operands) { switch (opcode) { - case spirv::kOpMemoryModelOpcode: + case spirv::Opcode::OpMemoryModel: return processMemoryModel(operands); default: break; } - return emitError(unknownLoc, "NYI: opcode ") << opcode; + return emitError(unknownLoc, "NYI: opcode ") + << spirv::stringifyOpcode(opcode); } LogicalResult Deserializer::processMemoryModel(ArrayRef operands) { diff --git a/mlir/lib/SPIRV/Serialization/SPIRVBinaryUtils.h b/mlir/lib/SPIRV/Serialization/SPIRVBinaryUtils.h index 157b68c7399c..80220956477d 100644 --- a/mlir/lib/SPIRV/Serialization/SPIRVBinaryUtils.h +++ b/mlir/lib/SPIRV/Serialization/SPIRVBinaryUtils.h @@ -22,6 +22,8 @@ #ifndef MLIR_SPIRV_SERIALIZATION_SPIRV_BINARY_UTILS_H_ #define MLIR_SPIRV_SERIALIZATION_SPIRV_BINARY_UTILS_H_ +#include "mlir/SPIRV/SPIRVOps.h" + #include namespace mlir { @@ -33,8 +35,7 @@ constexpr unsigned kHeaderWordCount = 5; /// SPIR-V magic number constexpr uint32_t kMagicNumber = 0x07230203; -/// Opcode for SPIR-V OpMemoryModel -constexpr uint32_t kOpMemoryModelOpcode = 14; +#include "mlir/SPIRV/SPIRVSerialization.inc" } // end namespace spirv } // end namespace mlir diff --git a/mlir/lib/SPIRV/Serialization/Serializer.cpp b/mlir/lib/SPIRV/Serialization/Serializer.cpp index 1e6780f81f5c..ed083e00b805 100644 --- a/mlir/lib/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/SPIRV/Serialization/Serializer.cpp @@ -29,12 +29,22 @@ using namespace mlir; -static inline uint32_t getPrefixedOpcode(uint32_t wordCount, uint32_t opcode) { +static inline uint32_t getPrefixedOpcode(uint32_t wordCount, + spirv::Opcode opcode) { assert(((wordCount >> 16) == 0) && "word count out of range!"); - return (wordCount << 16) | opcode; + return (wordCount << 16) | static_cast(opcode); +} + +static inline void buildInstruction(spirv::Opcode op, + ArrayRef operands, + SmallVectorImpl &binary) { + uint32_t wordCount = 1 + operands.size(); + binary.push_back(getPrefixedOpcode(wordCount, op)); + binary.append(operands.begin(), operands.end()); } namespace { + /// A SPIR-V module serializer. /// /// A SPIR-V binary module is a single linear stream of instructions; each @@ -70,7 +80,7 @@ private: spirv::ModuleOp module; /// The next available result . - uint32_t nextID = 0; + uint32_t nextID = 1; // The following are for different SPIR-V instruction sections. They follow // the logical layout of a SPIR-V module. @@ -88,10 +98,6 @@ private: }; } // namespace -namespace { -#include "mlir/SPIRV/SPIRVSerialization.inc" -} - LogicalResult Serializer::serialize() { if (failed(module.verify())) return failure(); @@ -160,11 +166,7 @@ void Serializer::processMemoryModel() { uint32_t mm = module.getAttrOfType("memory_model").getInt(); uint32_t am = module.getAttrOfType("addressing_model").getInt(); - constexpr uint32_t kNumWords = 3; - - memoryModel.reserve(kNumWords); - memoryModel.assign( - {getPrefixedOpcode(kNumWords, spirv::kOpMemoryModelOpcode), am, mm}); + buildInstruction(spirv::Opcode::OpMemoryModel, {am, mm}, memoryModel); } LogicalResult spirv::serialize(spirv::ModuleOp module, diff --git a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp index f63c928bb4ac..b3288162dc08 100644 --- a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp +++ b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp @@ -25,6 +25,7 @@ #include "mlir/TableGen/Operator.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringExtras.h" +#include "llvm/ADT/StringRef.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/raw_ostream.h" #include "llvm/TableGen/Error.h" @@ -41,18 +42,21 @@ using mlir::tblgen::Operator; // Writes the following function to `os`: // inline uint32_t getOpcode() { return ; } static void emitGetOpcodeFunction(const llvm::Record &record, - const Operator &op, raw_ostream &os) { - if (llvm::isa(record.getValueInit("opcode"))) - return; - - os << formatv("inline uint32_t getOpcode({0}) {{ return {1}u; }\n", - op.getQualCppClassName(), record.getValueAsInt("opcode")); + Operator const &op, raw_ostream &os) { + if (record.getValueAsInt("hasOpcode")) { + os << formatv("template <> constexpr inline uint32_t getOpcode<{0}>()", + op.getQualCppClassName()) + << " {\n return static_cast(" + << formatv("Opcode::Op{0});\n}\n", record.getValueAsString("opName")); + } } static bool emitSerializationUtils(const RecordKeeper &recordKeeper, raw_ostream &os) { llvm::emitSourceFileHeader("SPIR-V Serialization Utilities", os); + /// Define the function to get the opcode + os << "template inline constexpr uint32_t getOpcode();\n"; auto defs = recordKeeper.getAllDerivedDefinitions("SPV_Op"); for (const auto *def : defs) { Operator op(def); diff --git a/mlir/utils/spirv/define_opcodes.sh b/mlir/utils/spirv/define_opcodes.sh new file mode 100755 index 000000000000..e98d55c71285 --- /dev/null +++ b/mlir/utils/spirv/define_opcodes.sh @@ -0,0 +1,39 @@ +#!/bin/bash + +# Copyright 2019 The MLIR Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Script for defining map for opname to opcode using SPIR-V spec from the +# Internet +# +# Run as: +# ./define_opcode.sh ()* +# +# For example: +# ./define_opcode.sh OpTypeVoid OpTypeFunction +# +# If no op-name is specified, the existing opcodes are updated +# +# The 'instructions' list of spirv.core.grammar.json contains all instructions +# in SPIR-V + +set -e +set -x + +current_file="$(readlink -f "$0")" +current_dir="$(dirname "$current_file")" + +python3 ${current_dir}/gen_spirv_dialect.py \ + --base-td-path ${current_dir}/../../include/mlir/SPIRV/SPIRVBase.td \ + --new-opcode $@ diff --git a/mlir/utils/spirv/gen_spirv_dialect.py b/mlir/utils/spirv/gen_spirv_dialect.py index 88f640c6ef85..8279f70765ba 100755 --- a/mlir/utils/spirv/gen_spirv_dialect.py +++ b/mlir/utils/spirv/gen_spirv_dialect.py @@ -32,6 +32,8 @@ SPIRV_HTML_SPEC_URL = 'https://www.khronos.org/registry/spir-v/specs/unified1/SP SPIRV_JSON_SPEC_URL = 'https://raw.githubusercontent.com/KhronosGroup/SPIRV-Headers/master/include/spirv/unified1/spirv.core.grammar.json' AUTOGEN_ENUM_SECTION_MARKER = 'enum section. Generated from SPIR-V spec; DO NOT MODIFY!' +AUTOGEN_INSTRUCTION_OPCODE_SECTION_MARKER = ('opcode section. Generated from ' + 'SPIR-V spec; DO NOT MODIFY!') def get_spirv_grammar_from_json_spec(): @@ -129,6 +131,76 @@ def gen_operand_kind_enum_attr(operand_kind): return kind_name, case_defs + '\n\n' + enum_attr +def gen_opcode(instructions): + """ Generates the TableGen definition to map opname to opcode + + Returns: + - A string containing the TableGen SPV_OpCode definition + """ + + max_len = max([len(inst['opname']) for inst in instructions]) + def_fmt_str = 'def SPV_OC_{name} {colon:>{offset}} '\ + 'I32EnumAttrCase<"{name}", {value}>;' + opcode_defs = [ + def_fmt_str.format( + name=inst['opname'], + value=inst['opcode'], + colon=':', + offset=(max_len + 1 - len(inst['opname']))) for inst in instructions + ] + opcode_str = '\n'.join(opcode_defs) + + decl_fmt_str = 'SPV_OC_{name}' + opcode_list = [ + decl_fmt_str.format(name=inst['opname']) for inst in instructions + ] + opcode_list = split_list_into_sublists(opcode_list, 6) + opcode_list = [ + '{:6}'.format('') + ', '.join(sublist) for sublist in opcode_list + ] + opcode_list = ',\n'.join(opcode_list) + enum_attr = 'def SPV_OpcodeAttr :\n'\ + ' I32EnumAttr<"{name}", "valid SPIR-V instructions", [\n'\ + '{lst}\n'\ + ' ]> {{\n'\ + ' let returnType = "::mlir::spirv::{name}";\n'\ + ' let convertFromStorage = '\ + '"static_cast<::mlir::spirv::{name}>($_self.getInt())";\n'\ + ' let cppNamespace = "::mlir::spirv";\n}}'.format( + name='Opcode', lst=opcode_list) + return opcode_str + '\n\n' + enum_attr + + +def update_td_opcodes(path, instructions, filter_list): + + with open(path, 'r') as f: + content = f.read() + + content = content.split(AUTOGEN_INSTRUCTION_OPCODE_SECTION_MARKER) + assert len(content) == 3 + + # Extend opcode list with existing list + import re + existing_opcodes = [k[11:] for k in re.findall('def SPV_OC_\w+', content[1])] + filter_list.extend(existing_opcodes) + filter_list = list(set(filter_list)) + + # Generate the opcode for all instructions in SPIR-V + filter_instrs = list( + filter(lambda inst: (inst['opname'] in filter_list), instructions)) + # Sort instruction based on opcode + filter_instrs.sort(key=lambda inst: inst['opcode']) + opcode = gen_opcode(filter_instrs) + + # Substitute the opcode + content = content[0] + AUTOGEN_INSTRUCTION_OPCODE_SECTION_MARKER + '\n\n' + \ + opcode + '\n\n// End ' + AUTOGEN_INSTRUCTION_OPCODE_SECTION_MARKER \ + + content[2] + + with open(path, 'w') as f: + f.write(content) + + def update_td_enum_attrs(path, operand_kinds, filter_list): """Updates SPIRBase.td with new generated enum definitions. @@ -136,7 +208,7 @@ def update_td_enum_attrs(path, operand_kinds, filter_list): - path: the path to SPIRBase.td - operand_kinds: a list containing all operand kinds' grammar - filter_list: a list containing new enums to add - """ + """ with open(path, 'r') as f: content = f.read() @@ -175,8 +247,16 @@ if __name__ == '__main__': help='Path to SPIRVBase.td') cli_parser.add_argument('--new-enum', dest='new_enum', type=str, help='SPIR-V enum to be added to SPIRVBase.td') + cli_parser.add_argument( + '--new-opcodes', + dest='new_opcodes', + type=str, + nargs='*', + help='update SPIR-V opcodes in SPIRVBase.td') args = cli_parser.parse_args() - operand_kinds, _ = get_spirv_grammar_from_json_spec() + operand_kinds, instructions = get_spirv_grammar_from_json_spec() update_td_enum_attrs(args.base_td_path, operand_kinds, [args.new_enum]) + + update_td_opcodes(args.base_td_path, instructions, args.new_opcodes) -- 2.34.1