Update the gen_spirv_dialect.py script to add opcodes from the SPIR-V
authorMahesh Ravishankar <ravishankarm@google.com>
Thu, 11 Jul 2019 00:33:28 +0000 (17:33 -0700)
committerjpienaar <jpienaar@google.com>
Fri, 12 Jul 2019 15:43:09 +0000 (08:43 -0700)
JSON spec into the SPIRBase.td file. This is done incrementally to
only import those opcodes that are needed, through use of the script
define_opcode.sh added.

PiperOrigin-RevId: 257517343

mlir/include/mlir/SPIRV/SPIRVBase.td
mlir/include/mlir/SPIRV/SPIRVOps.h
mlir/include/mlir/SPIRV/SPIRVOps.td
mlir/include/mlir/SPIRV/SPIRVStructureOps.td
mlir/include/mlir/SPIRV/SPIRVTypes.h
mlir/lib/SPIRV/Serialization/Deserializer.cpp
mlir/lib/SPIRV/Serialization/SPIRVBinaryUtils.h
mlir/lib/SPIRV/Serialization/Serializer.cpp
mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
mlir/utils/spirv/define_opcodes.sh [new file with mode: 0755]
mlir/utils/spirv/gen_spirv_dialect.py

index b862f71c30757ea3ae04b19b27cc3d9716b4825b..8eb45aab64d9a0d2952049c54c23e9a8165ea8ad 100644 (file)
@@ -58,6 +58,45 @@ def SPV_Dialect : Dialect {
   let cppNamespace = "spirv";
 }
 
+//===----------------------------------------------------------------------===//
+// SPIR-V opcode specification
+//===----------------------------------------------------------------------===//
+
+class SPV_OpCode<string name, int val> {
+  // Name used as reference to retrieve the opcode
+  string opname = name;
+
+  // Opcode associated with the name
+  int opcode = val;
+}
+
+// Begin opcode section. Generated from SPIR-V spec; DO NOT MODIFY!
+
+def SPV_OC_OpMemoryModel   : I32EnumAttrCase<"OpMemoryModel", 14>;
+def SPV_OC_OpEntryPoint    : I32EnumAttrCase<"OpEntryPoint", 15>;
+def SPV_OC_OpExecutionMode : I32EnumAttrCase<"OpExecutionMode", 16>;
+def SPV_OC_OpTypeVoid      : I32EnumAttrCase<"OpTypeVoid", 19>;
+def SPV_OC_OpTypeFunction  : I32EnumAttrCase<"OpTypeFunction", 33>;
+def SPV_OC_OpVariable      : I32EnumAttrCase<"OpVariable", 59>;
+def SPV_OC_OpLoad          : I32EnumAttrCase<"OpLoad", 61>;
+def SPV_OC_OpStore         : I32EnumAttrCase<"OpStore", 62>;
+def SPV_OC_OpFMul          : I32EnumAttrCase<"OpFMul", 133>;
+def SPV_OC_OpReturn        : I32EnumAttrCase<"OpReturn", 253>;
+
+def SPV_OpcodeAttr :
+    I32EnumAttr<"Opcode", "valid SPIR-V instructions", [
+      SPV_OC_OpMemoryModel, SPV_OC_OpEntryPoint, SPV_OC_OpExecutionMode,
+      SPV_OC_OpTypeVoid, SPV_OC_OpTypeFunction, SPV_OC_OpVariable, SPV_OC_OpLoad,
+      SPV_OC_OpStore, SPV_OC_OpFMul, SPV_OC_OpReturn
+      ]> {
+    let returnType = "::mlir::spirv::Opcode";
+    let convertFromStorage = "static_cast<::mlir::spirv::Opcode>($_self.getInt())";
+    let cppNamespace = "::mlir::spirv";
+}
+
+// End opcode section. Generated from SPIR-V spec; DO NOT MODIFY!
+
+
 //===----------------------------------------------------------------------===//
 // SPIR-V type definitions
 //===----------------------------------------------------------------------===//
@@ -438,11 +477,8 @@ def ModuleOnly :
 // Base class for all SPIR-V ops.
 class SPV_Op<string mnemonic, list<OpTrait> traits = []> :
     Op<SPV_Dialect, mnemonic, traits> {
-  // Opcode for the binary format.
-  // See https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html for
-  // the opcode for each operation. Ops that cannot be directly serialized will
-  // leave this field as unset.
-  int opcode = ?;
+
+  string spirvOpName = "Op" # mnemonic;
 
   // For each SPIR-V op, the following static functions need to be defined
   // in SPVOps.cpp:
@@ -454,6 +490,14 @@ class SPV_Op<string mnemonic, list<OpTrait> traits = []> :
   let parser = [{ return ::parse$cppClass(parser, result); }];
   let printer = [{ return ::print(*this, p); }];
   let verifier = [{ return ::verify(*this); }];
+
+  // By default the opcode to use for (de)serialization is obtained
+  // automatically from the SPIR-V spec. It assume the SPIR-V op being defined
+  // is ('Op' # mnemonic). The opcode value can be obtained by calling
+  // getOpcode<OpClass>().  If invoking this method is invalid or custom
+  // processing is needed for the op, set hasOpcode = 0 and specialize the
+  // getOpcode method.
+  int hasOpcode = 1;
 }
 
 #endif // SPIRV_BASE
index e0345fbc654de09d1b9c6b8b1e141fb28a2e7470..f3720a86b455c73aed944cf23b1fd545c41b3209 100644 (file)
@@ -22,7 +22,7 @@
 #ifndef MLIR_SPIRV_SPIRVOPS_H_
 #define MLIR_SPIRV_SPIRVOPS_H_
 
-#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/Function.h"
 #include "mlir/SPIRV/SPIRVTypes.h"
 
 namespace mlir {
index afab62ab4e41eb41b7184bd6b05b24d6de1cf9c6..b57e00a4629f1c8161965ce576f941a34fca8c7e 100644 (file)
@@ -88,8 +88,6 @@ def SPV_EntryPointOp : SPV_Op<"EntryPoint", [ModuleOnly]> {
   );
 
   let results = (outs SPV_EntryPoint:$id);
-
-  let opcode = 15;
 }
 
 def SPV_ExecutionModeOp : SPV_Op<"ExecutionMode", [ModuleOnly]> {
@@ -130,8 +128,6 @@ def SPV_ExecutionModeOp : SPV_Op<"ExecutionMode", [ModuleOnly]> {
   );
 
   let verifier = [{ return success(); }];
-
-  let opcode = 16;
 }
 
 def SPV_FMulOp : SPV_Op<"FMul", [NoSideEffect, SameOperandsAndResultType]> {
@@ -159,8 +155,6 @@ def SPV_FMulOp : SPV_Op<"FMul", [NoSideEffect, SameOperandsAndResultType]> {
 
   // No additional verification needed in addition to the ODS-generated ones.
   let verifier = [{ return success(); }];
-
-  let opcode = 133;
 }
 
 def SPV_LoadOp : SPV_Op<"Load"> {
@@ -207,8 +201,6 @@ def SPV_LoadOp : SPV_Op<"Load"> {
   let results = (outs
     SPV_Type:$value
   );
-
-  let opcode = 61;
 }
 
 def SPV_ReturnOp : SPV_Op<"Return", [Terminator]> {
@@ -226,8 +218,6 @@ def SPV_ReturnOp : SPV_Op<"Return", [Terminator]> {
   let printer = [{ printNoIOOp(getOperation(), p); }];
 
   let verifier = [{ return verifyReturn(*this); }];
-
-  let opcode = 253;
 }
 
 def SPV_StoreOp : SPV_Op<"Store"> {
@@ -267,8 +257,6 @@ def SPV_StoreOp : SPV_Op<"Store"> {
     OptionalAttr<SPV_MemoryAccessAttr>:$memory_access,
     OptionalAttr<APIntAttr>:$alignment
   );
-
-  let opcode = 62;
 }
 
 def SPV_VariableOp : SPV_Op<"Variable"> {
@@ -320,8 +308,6 @@ def SPV_VariableOp : SPV_Op<"Variable"> {
   let results = (outs
     SPV_AnyPtr:$pointer
   );
-
-  let opcode = 59;
 }
 
 #endif // SPIRV_OPS
index 3b32dcd012a2957f26e7d039e208b869b5546aff..65ffeb73969998170c06c8787e71a075e5594a64 100644 (file)
@@ -88,6 +88,14 @@ def SPV_ModuleOp : SPV_Op<"module", []> {
   let regions = (region SizedRegion<1>:$body);
 
   let builders = [OpBuilder<"Builder *, OperationState *state">];
+
+  let hasOpcode = 0;
+
+  let extraClassDeclaration = [{
+    Block& getBlock() {
+      return this->getOperation()->getRegion(0).front();
+    }
+  }];
 }
 
 def SPV_ModuleEndOp : SPV_Op<"_module_end", [Terminator, ModuleOnly]> {
@@ -108,6 +116,8 @@ def SPV_ModuleEndOp : SPV_Op<"_module_end", [Terminator, ModuleOnly]> {
   let printer = [{ printNoIOOp(getOperation(), p); }];
 
   let verifier = [{ return success(); }];
+
+  let hasOpcode = 0;
 }
 
 def SPV_ConstantOp : SPV_Op<"constant", [NoSideEffect]> {
@@ -152,6 +162,8 @@ def SPV_ConstantOp : SPV_Op<"constant", [NoSideEffect]> {
   let results = (outs
     SPV_Type:$constant
   );
+
+  let hasOpcode = 0;
 }
 
 #endif // SPIRV_STRUCTURE_OPS
index 0bdb9f030234866ab8e8af7c8ff9470a1159b553..a389439f1bf8c87b588bfb9dc9eeb49d88134261 100644 (file)
@@ -22,7 +22,7 @@
 #ifndef MLIR_SPIRV_SPIRVTYPES_H_
 #define MLIR_SPIRV_SPIRVTYPES_H_
 
-#include "mlir/IR/TypeSupport.h"
+#include "mlir/IR/StandardTypes.h"
 #include "mlir/IR/Types.h"
 
 // Pull in all enum type definitions and utility function declarations
index e7eaa99ef3c173e37fb0669b925536cf80722974..ab0d11e29235d7aadbbe0b22d03cc21d3203e38a 100644 (file)
@@ -67,7 +67,7 @@ private:
   LogicalResult processHeader();
 
   /// Processes a SPIR-V instruction with the given `opcode` and `operands`.
-  LogicalResult processInstruction(uint32_t opcode,
+  LogicalResult processInstruction(spirv::Opcode opcode,
                                    ArrayRef<uint32_t> operands);
 
   LogicalResult processMemoryModel(ArrayRef<uint32_t> operands);
@@ -123,7 +123,8 @@ LogicalResult Deserializer::deserialize() {
                        "insufficient words for the last instruction");
 
     auto operands = binary.slice(curOffset + 1, wordCount - 1);
-    if (failed(processInstruction(opcode, operands)))
+    if (failed(
+            processInstruction(static_cast<spirv::Opcode>(opcode), operands)))
       return failure();
 
     curOffset = nextOffset;
@@ -146,15 +147,16 @@ LogicalResult Deserializer::processHeader() {
   return success();
 }
 
-LogicalResult Deserializer::processInstruction(uint32_t opcode,
+LogicalResult Deserializer::processInstruction(spirv::Opcode opcode,
                                                ArrayRef<uint32_t> operands) {
   switch (opcode) {
-  case spirv::kOpMemoryModelOpcode:
+  case spirv::Opcode::OpMemoryModel:
     return processMemoryModel(operands);
   default:
     break;
   }
-  return emitError(unknownLoc, "NYI: opcode ") << opcode;
+  return emitError(unknownLoc, "NYI: opcode ")
+         << spirv::stringifyOpcode(opcode);
 }
 
 LogicalResult Deserializer::processMemoryModel(ArrayRef<uint32_t> operands) {
index 157b68c7399c74ae95ff6d09206296bfcdba746d..80220956477d0d0e3aedce80e7a081a5d2d41be8 100644 (file)
@@ -22,6 +22,8 @@
 #ifndef MLIR_SPIRV_SERIALIZATION_SPIRV_BINARY_UTILS_H_
 #define MLIR_SPIRV_SERIALIZATION_SPIRV_BINARY_UTILS_H_
 
+#include "mlir/SPIRV/SPIRVOps.h"
+
 #include <cstdint>
 
 namespace mlir {
@@ -33,8 +35,7 @@ constexpr unsigned kHeaderWordCount = 5;
 /// SPIR-V magic number
 constexpr uint32_t kMagicNumber = 0x07230203;
 
-/// Opcode for SPIR-V OpMemoryModel
-constexpr uint32_t kOpMemoryModelOpcode = 14;
+#include "mlir/SPIRV/SPIRVSerialization.inc"
 
 } // end namespace spirv
 } // end namespace mlir
index 1e6780f81f5c471ab19030c2bf026856a19fc785..ed083e00b8052a756c47d613f9d1a86852e29750 100644 (file)
 
 using namespace mlir;
 
-static inline uint32_t getPrefixedOpcode(uint32_t wordCount, uint32_t opcode) {
+static inline uint32_t getPrefixedOpcode(uint32_t wordCount,
+                                         spirv::Opcode opcode) {
   assert(((wordCount >> 16) == 0) && "word count out of range!");
-  return (wordCount << 16) | opcode;
+  return (wordCount << 16) | static_cast<uint32_t>(opcode);
+}
+
+static inline void buildInstruction(spirv::Opcode op,
+                                    ArrayRef<uint32_t> operands,
+                                    SmallVectorImpl<uint32_t> &binary) {
+  uint32_t wordCount = 1 + operands.size();
+  binary.push_back(getPrefixedOpcode(wordCount, op));
+  binary.append(operands.begin(), operands.end());
 }
 
 namespace {
+
 /// A SPIR-V module serializer.
 ///
 /// A SPIR-V binary module is a single linear stream of instructions; each
@@ -70,7 +80,7 @@ private:
   spirv::ModuleOp module;
 
   /// The next available result <id>.
-  uint32_t nextID = 0;
+  uint32_t nextID = 1;
 
   // The following are for different SPIR-V instruction sections. They follow
   // the logical layout of a SPIR-V module.
@@ -88,10 +98,6 @@ private:
 };
 } // namespace
 
-namespace {
-#include "mlir/SPIRV/SPIRVSerialization.inc"
-}
-
 LogicalResult Serializer::serialize() {
   if (failed(module.verify()))
     return failure();
@@ -160,11 +166,7 @@ void Serializer::processMemoryModel() {
   uint32_t mm = module.getAttrOfType<IntegerAttr>("memory_model").getInt();
   uint32_t am = module.getAttrOfType<IntegerAttr>("addressing_model").getInt();
 
-  constexpr uint32_t kNumWords = 3;
-
-  memoryModel.reserve(kNumWords);
-  memoryModel.assign(
-      {getPrefixedOpcode(kNumWords, spirv::kOpMemoryModelOpcode), am, mm});
+  buildInstruction(spirv::Opcode::OpMemoryModel, {am, mm}, memoryModel);
 }
 
 LogicalResult spirv::serialize(spirv::ModuleOp module,
index f63c928bb4acf514416f6325e3ea533eed281445..b3288162dc088ad2e62cd0081653a4ef7b88493c 100644 (file)
@@ -25,6 +25,7 @@
 #include "mlir/TableGen/Operator.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/StringExtras.h"
+#include "llvm/ADT/StringRef.h"
 #include "llvm/Support/FormatVariadic.h"
 #include "llvm/Support/raw_ostream.h"
 #include "llvm/TableGen/Error.h"
@@ -41,18 +42,21 @@ using mlir::tblgen::Operator;
 // Writes the following function to `os`:
 //   inline uint32_t getOpcode(<op-class-name>) { return <opcode>; }
 static void emitGetOpcodeFunction(const llvm::Record &record,
-                                  const Operator &op, raw_ostream &os) {
-  if (llvm::isa<llvm::UnsetInit>(record.getValueInit("opcode")))
-    return;
-
-  os << formatv("inline uint32_t getOpcode({0}) {{ return {1}u; }\n",
-                op.getQualCppClassName(), record.getValueAsInt("opcode"));
+                                  Operator const &op, raw_ostream &os) {
+  if (record.getValueAsInt("hasOpcode")) {
+    os << formatv("template <> constexpr inline uint32_t getOpcode<{0}>()",
+                  op.getQualCppClassName())
+       << " {\n  return static_cast<uint32_t>("
+       << formatv("Opcode::Op{0});\n}\n", record.getValueAsString("opName"));
+  }
 }
 
 static bool emitSerializationUtils(const RecordKeeper &recordKeeper,
                                    raw_ostream &os) {
   llvm::emitSourceFileHeader("SPIR-V Serialization Utilities", os);
 
+  /// Define the function to get the opcode
+  os << "template <typename OpClass> inline constexpr uint32_t getOpcode();\n";
   auto defs = recordKeeper.getAllDerivedDefinitions("SPV_Op");
   for (const auto *def : defs) {
     Operator op(def);
diff --git a/mlir/utils/spirv/define_opcodes.sh b/mlir/utils/spirv/define_opcodes.sh
new file mode 100755 (executable)
index 0000000..e98d55c
--- /dev/null
@@ -0,0 +1,39 @@
+#!/bin/bash
+
+# Copyright 2019 The MLIR Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Script for defining map for opname to opcode using SPIR-V spec from the
+# Internet
+#
+# Run as:
+# ./define_opcode.sh (<op-name>)*
+#
+# For example:
+# ./define_opcode.sh OpTypeVoid OpTypeFunction
+#
+# If no op-name is specified, the existing opcodes are updated
+#
+# The 'instructions' list of spirv.core.grammar.json contains all instructions
+# in SPIR-V
+
+set -e
+set -x
+
+current_file="$(readlink -f "$0")"
+current_dir="$(dirname "$current_file")"
+
+python3 ${current_dir}/gen_spirv_dialect.py \
+  --base-td-path ${current_dir}/../../include/mlir/SPIRV/SPIRVBase.td \
+  --new-opcode $@
index 88f640c6ef853542d72b2e9085c8454cb21ec43a..8279f70765baee1c6f44d24ceb6400af348ce14a 100755 (executable)
@@ -32,6 +32,8 @@ SPIRV_HTML_SPEC_URL = 'https://www.khronos.org/registry/spir-v/specs/unified1/SP
 SPIRV_JSON_SPEC_URL = 'https://raw.githubusercontent.com/KhronosGroup/SPIRV-Headers/master/include/spirv/unified1/spirv.core.grammar.json'
 
 AUTOGEN_ENUM_SECTION_MARKER = 'enum section. Generated from SPIR-V spec; DO NOT MODIFY!'
+AUTOGEN_INSTRUCTION_OPCODE_SECTION_MARKER = ('opcode section. Generated from '
+                                             'SPIR-V spec; DO NOT MODIFY!')
 
 
 def get_spirv_grammar_from_json_spec():
@@ -129,6 +131,76 @@ def gen_operand_kind_enum_attr(operand_kind):
   return kind_name, case_defs + '\n\n' + enum_attr
 
 
+def gen_opcode(instructions):
+  """ Generates the TableGen definition to map opname to opcode
+
+   Returns:
+       - A string containing the TableGen SPV_OpCode definition
+   """
+
+  max_len = max([len(inst['opname']) for inst in instructions])
+  def_fmt_str = 'def SPV_OC_{name} {colon:>{offset}} '\
+            'I32EnumAttrCase<"{name}", {value}>;'
+  opcode_defs = [
+      def_fmt_str.format(
+          name=inst['opname'],
+          value=inst['opcode'],
+          colon=':',
+          offset=(max_len + 1 - len(inst['opname']))) for inst in instructions
+  ]
+  opcode_str = '\n'.join(opcode_defs)
+
+  decl_fmt_str = 'SPV_OC_{name}'
+  opcode_list = [
+      decl_fmt_str.format(name=inst['opname']) for inst in instructions
+  ]
+  opcode_list = split_list_into_sublists(opcode_list, 6)
+  opcode_list = [
+      '{:6}'.format('') + ', '.join(sublist) for sublist in opcode_list
+  ]
+  opcode_list = ',\n'.join(opcode_list)
+  enum_attr = 'def SPV_OpcodeAttr :\n'\
+              '    I32EnumAttr<"{name}", "valid SPIR-V instructions", [\n'\
+              '{lst}\n'\
+              '      ]> {{\n'\
+              '    let returnType = "::mlir::spirv::{name}";\n'\
+              '    let convertFromStorage = '\
+              '"static_cast<::mlir::spirv::{name}>($_self.getInt())";\n'\
+              '    let cppNamespace = "::mlir::spirv";\n}}'.format(
+                  name='Opcode', lst=opcode_list)
+  return opcode_str + '\n\n' + enum_attr
+
+
+def update_td_opcodes(path, instructions, filter_list):
+
+  with open(path, 'r') as f:
+    content = f.read()
+
+  content = content.split(AUTOGEN_INSTRUCTION_OPCODE_SECTION_MARKER)
+  assert len(content) == 3
+
+  # Extend opcode list with existing list
+  import re
+  existing_opcodes = [k[11:] for k in re.findall('def SPV_OC_\w+', content[1])]
+  filter_list.extend(existing_opcodes)
+  filter_list = list(set(filter_list))
+
+  # Generate the opcode for all instructions in SPIR-V
+  filter_instrs = list(
+      filter(lambda inst: (inst['opname'] in filter_list), instructions))
+  # Sort instruction based on opcode
+  filter_instrs.sort(key=lambda inst: inst['opcode'])
+  opcode = gen_opcode(filter_instrs)
+
+  # Substitute the opcode
+  content = content[0] + AUTOGEN_INSTRUCTION_OPCODE_SECTION_MARKER + '\n\n' + \
+        opcode + '\n\n// End ' + AUTOGEN_INSTRUCTION_OPCODE_SECTION_MARKER \
+        + content[2]
+
+  with open(path, 'w') as f:
+    f.write(content)
+
+
 def update_td_enum_attrs(path, operand_kinds, filter_list):
   """Updates SPIRBase.td with new generated enum definitions.
 
@@ -136,7 +208,7 @@ def update_td_enum_attrs(path, operand_kinds, filter_list):
         - path: the path to SPIRBase.td
         - operand_kinds: a list containing all operand kinds' grammar
         - filter_list: a list containing new enums to add
-    """
+  """
   with open(path, 'r') as f:
     content = f.read()
 
@@ -175,8 +247,16 @@ if __name__ == '__main__':
                           help='Path to SPIRVBase.td')
   cli_parser.add_argument('--new-enum', dest='new_enum', type=str,
                           help='SPIR-V enum to be added to SPIRVBase.td')
+  cli_parser.add_argument(
+      '--new-opcodes',
+      dest='new_opcodes',
+      type=str,
+      nargs='*',
+      help='update SPIR-V opcodes in SPIRVBase.td')
   args = cli_parser.parse_args()
 
-  operand_kinds, _ = get_spirv_grammar_from_json_spec()
+  operand_kinds, instructions = get_spirv_grammar_from_json_spec()
 
   update_td_enum_attrs(args.base_td_path, operand_kinds, [args.new_enum])
+
+  update_td_opcodes(args.base_td_path, instructions, args.new_opcodes)