Autogenerate (de)serialization for Extended Instruction Sets
authorMahesh Ravishankar <ravishankarm@google.com>
Tue, 17 Sep 2019 00:11:50 +0000 (17:11 -0700)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Tue, 17 Sep 2019 00:12:33 +0000 (17:12 -0700)
A generic mechanism for (de)serialization of extended instruction sets
is added with this CL. To facilitate this, a new class
"SPV_ExtendedInstSetOp" is added which is a base class for all
operations corresponding to extended instruction sets. The methods to
(de)serialization such ops as well as its dispatch is generated
automatically.

The behavior controlled by autogenSerialization and hasOpcode is also
slightly modified to enable this. They are now decoupled.
1) Setting hasOpcode=1 means the operation has a corresponding
   opcode in SPIR-V binary format, and its dispatch for
   (de)serialization is automatically generated.
2) Setting autogenSerialization=1 generates the function for
   (de)serialization automatically.
So now it is possible to have hasOpcode=0 and autogenSerialization=1
(for example SPV_ExtendedInstSetOp).

Since the dispatch functions is also auto-generated, the input file
needs to contain all operations. To this effect, SPIRVGLSLOps.td is
included into SPIRVOps.td. This makes the previously added
SPIRVGLSLOps.h and SPIRVGLSLOps.cpp unnecessary, and are deleted.

The SPIRVUtilsGen.cpp is also changed to make better use of
formatv,making the code more readable.

PiperOrigin-RevId: 269456263

16 files changed:
mlir/include/mlir/Dialect/SPIRV/CMakeLists.txt
mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td
mlir/include/mlir/Dialect/SPIRV/SPIRVGLSLOps.h [deleted file]
mlir/include/mlir/Dialect/SPIRV/SPIRVGLSLOps.td
mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td
mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td
mlir/lib/Dialect/SPIRV/CMakeLists.txt
mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp
mlir/lib/Dialect/SPIRV/SPIRVGLSLOps.cpp [deleted file]
mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
mlir/test/Dialect/SPIRV/Serialization/glslops.mlir [new file with mode: 0644]
mlir/test/Dialect/SPIRV/glslops.mlir
mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp

index 0c847f0..f1d6803 100644 (file)
@@ -3,11 +3,6 @@ mlir_tablegen(SPIRVOps.h.inc -gen-op-decls)
 mlir_tablegen(SPIRVOps.cpp.inc -gen-op-defs)
 add_public_tablegen_target(MLIRSPIRVOpsIncGen)
 
-set(LLVM_TARGET_DEFINITIONS SPIRVGLSLOps.td)
-mlir_tablegen(SPIRVGLSLOps.h.inc -gen-op-decls)
-mlir_tablegen(SPIRVGLSLOps.cpp.inc -gen-op-defs)
-add_public_tablegen_target(MLIRSPIRVGLSLOpsIncGen)
-
 set(LLVM_TARGET_DEFINITIONS SPIRVBase.td)
 mlir_tablegen(SPIRVIntEnums.h.inc -gen-enum-decls)
 mlir_tablegen(SPIRVIntEnums.cpp.inc -gen-enum-defs)
index b2b844d..d1abf9f 100644 (file)
@@ -75,6 +75,8 @@ class SPV_OpCode<string name, int val> {
 def SPV_OC_OpNop                    : I32EnumAttrCase<"OpNop", 0>;
 def SPV_OC_OpName                   : I32EnumAttrCase<"OpName", 5>;
 def SPV_OC_OpExtension              : I32EnumAttrCase<"OpExtension", 10>;
+def SPV_OC_OpExtInstImport          : I32EnumAttrCase<"OpExtInstImport", 11>;
+def SPV_OC_OpExtInst                : I32EnumAttrCase<"OpExtInst", 12>;
 def SPV_OC_OpMemoryModel            : I32EnumAttrCase<"OpMemoryModel", 14>;
 def SPV_OC_OpEntryPoint             : I32EnumAttrCase<"OpEntryPoint", 15>;
 def SPV_OC_OpExecutionMode          : I32EnumAttrCase<"OpExecutionMode", 16>;
@@ -154,21 +156,22 @@ def SPV_OC_OpReturnValue            : I32EnumAttrCase<"OpReturnValue", 254>;
 
 def SPV_OpcodeAttr :
     I32EnumAttr<"Opcode", "valid SPIR-V instructions", [
-      SPV_OC_OpNop, SPV_OC_OpName, SPV_OC_OpExtension, SPV_OC_OpMemoryModel,
-      SPV_OC_OpEntryPoint, SPV_OC_OpExecutionMode, SPV_OC_OpCapability,
-      SPV_OC_OpTypeVoid, SPV_OC_OpTypeBool, SPV_OC_OpTypeInt, SPV_OC_OpTypeFloat,
-      SPV_OC_OpTypeVector, SPV_OC_OpTypeArray, SPV_OC_OpTypeStruct,
-      SPV_OC_OpTypePointer, SPV_OC_OpTypeFunction, SPV_OC_OpConstantTrue,
-      SPV_OC_OpConstantFalse, SPV_OC_OpConstant, SPV_OC_OpConstantComposite,
-      SPV_OC_OpConstantNull, SPV_OC_OpSpecConstantTrue, SPV_OC_OpSpecConstantFalse,
-      SPV_OC_OpSpecConstant, SPV_OC_OpSpecConstantComposite, SPV_OC_OpFunction,
-      SPV_OC_OpFunctionParameter, SPV_OC_OpFunctionEnd, SPV_OC_OpFunctionCall,
-      SPV_OC_OpVariable, SPV_OC_OpLoad, SPV_OC_OpStore, SPV_OC_OpAccessChain,
-      SPV_OC_OpDecorate, SPV_OC_OpMemberDecorate, SPV_OC_OpCompositeExtract,
-      SPV_OC_OpIAdd, SPV_OC_OpFAdd, SPV_OC_OpISub, SPV_OC_OpFSub, SPV_OC_OpIMul,
-      SPV_OC_OpFMul, SPV_OC_OpUDiv, SPV_OC_OpSDiv, SPV_OC_OpFDiv, SPV_OC_OpUMod,
-      SPV_OC_OpSRem, SPV_OC_OpSMod, SPV_OC_OpFRem, SPV_OC_OpFMod, SPV_OC_OpSelect,
-      SPV_OC_OpIEqual, SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan, SPV_OC_OpSGreaterThan,
+      SPV_OC_OpNop, SPV_OC_OpName, SPV_OC_OpExtension, SPV_OC_OpExtInstImport,
+      SPV_OC_OpExtInst, SPV_OC_OpMemoryModel, SPV_OC_OpEntryPoint,
+      SPV_OC_OpExecutionMode, SPV_OC_OpCapability, SPV_OC_OpTypeVoid,
+      SPV_OC_OpTypeBool, SPV_OC_OpTypeInt, SPV_OC_OpTypeFloat, SPV_OC_OpTypeVector,
+      SPV_OC_OpTypeArray, SPV_OC_OpTypeStruct, SPV_OC_OpTypePointer,
+      SPV_OC_OpTypeFunction, SPV_OC_OpConstantTrue, SPV_OC_OpConstantFalse,
+      SPV_OC_OpConstant, SPV_OC_OpConstantComposite, SPV_OC_OpConstantNull,
+      SPV_OC_OpSpecConstantTrue, SPV_OC_OpSpecConstantFalse, SPV_OC_OpSpecConstant,
+      SPV_OC_OpSpecConstantComposite, SPV_OC_OpFunction, SPV_OC_OpFunctionParameter,
+      SPV_OC_OpFunctionEnd, SPV_OC_OpFunctionCall, SPV_OC_OpVariable, SPV_OC_OpLoad,
+      SPV_OC_OpStore, SPV_OC_OpAccessChain, SPV_OC_OpDecorate,
+      SPV_OC_OpMemberDecorate, SPV_OC_OpCompositeExtract, SPV_OC_OpIAdd,
+      SPV_OC_OpFAdd, SPV_OC_OpISub, SPV_OC_OpFSub, SPV_OC_OpIMul, SPV_OC_OpFMul,
+      SPV_OC_OpUDiv, SPV_OC_OpSDiv, SPV_OC_OpFDiv, SPV_OC_OpUMod, SPV_OC_OpSRem,
+      SPV_OC_OpSMod, SPV_OC_OpFRem, SPV_OC_OpFMod, SPV_OC_OpSelect, SPV_OC_OpIEqual,
+      SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan, SPV_OC_OpSGreaterThan,
       SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual, SPV_OC_OpULessThan,
       SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual, SPV_OC_OpSLessThanEqual,
       SPV_OC_OpFOrdEqual, SPV_OC_OpFUnordEqual, SPV_OC_OpFOrdNotEqual,
@@ -1143,8 +1146,7 @@ class SPV_Op<string mnemonic, list<OpTrait> traits = []> :
   // Specifies whether this op has a direct corresponding SPIR-V binary
   // instruction opcode. The (de)serializer use this field to determine whether
   // to auto-generate an entry in the (de)serialization dispatch table for this
-  // op. If set, this field also futher enables `autogenSerialization` (see
-  // below for details).
+  // op.
   bit hasOpcode = 1;
 
   // Name of the corresponding SPIR-V op. Only valid to use when hasOpcode is 1.
@@ -1162,15 +1164,14 @@ class SPV_Op<string mnemonic, list<OpTrait> traits = []> :
   // these methods is required.
   //
   // Note:
-  //
   // 1) If hasOpcode is set but autogenSerialization is not set, the
   //    (de)serializer dispatch method still calls the above method for
   //    (de)serializing this op.
-  //
-  // 2) If hasOpcode is not set, then this field is not interpreted; this op's
-  //    (de)serialization method will not be auto-generated regardless. Neither
-  //    does the handling in the (de)serialization dispatch table. Both
-  //    (de)serializing this op and its dispatch should be handled manually.
+  // 2) If hasOpcode is not set, but autogenSerialization is set, the
+  //    above methods for (de)serialization are generated, but there is no
+  //    entry added in the dispatch tables to invoke these methods. The
+  //    dispatch needs to be handled manually. SPV_ExtInstOps are an
+  //    example of this.
   bit autogenSerialization = 1;
 }
 
@@ -1190,4 +1191,24 @@ class SPV_BinaryOp<string mnemonic, Type resultType, Type operandsType,
   let verifier = [{ return success(); }];
 }
 
+class SPV_ExtInstOp<string mnemonic, string setPrefix, string setName,
+                    int opcode, list<OpTrait> traits = []> :
+  SPV_Op<setPrefix # "." # mnemonic, traits> {
+
+  // Extended instruction sets have no direct opcode (they share the
+  // same `OpExtInst` instruction). So the hasOpcode field is set to
+  // false. So no entry corresponding to these ops are added in the
+  // dispatch functions for (de)serialization. The methods for
+  // (de)serialization are still automatically generated (since
+  // autogenSerialization remains 1). A separate method is generated
+  // for dispatching extended instruction set ops.
+  let hasOpcode = 0;
+
+  // Opcode within extended instruction set.
+  int extendedInstOpcode = opcode;
+
+  // Name used to import the extended instruction set.
+  string extendedInstSetName = setName;
+}
+
 #endif // SPIRV_BASE
index b9f8604..1d3796c 100644 (file)
@@ -250,6 +250,8 @@ def SPV_LoopOp : SPV_Op<"loop"> {
   }];
 
   let hasOpcode = 0;
+
+  let autogenSerialization = 0;
 }
 
 // -----
@@ -273,6 +275,8 @@ def SPV_MergeOp : SPV_Op<"_merge", [HasParent<"LoopOp">, Terminator]> {
   let printer = [{ printNoIOOp(getOperation(), p); }];
 
   let hasOpcode = 0;
+
+  let autogenSerialization = 0;
 }
 
 // -----
diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVGLSLOps.h b/mlir/include/mlir/Dialect/SPIRV/SPIRVGLSLOps.h
deleted file mode 100644 (file)
index b20b81c..0000000
+++ /dev/null
@@ -1,37 +0,0 @@
-//===- SPIRVGLSLOps.h - MLIR SPIR-V extended ops for GLSL --------*- C++-*-===//
-//
-// Copyright 2019 The MLIR Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-//   http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
-//
-// This file declares the extended operations for GLSL in the SPIR-V dialect.
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_DIALECT_SPIRV_SPIRVGLSLOPS_H_
-#define MLIR_DIALECT_SPIRV_SPIRVGLSLOPS_H_
-
-#include "mlir/Dialect/SPIRV/SPIRVTypes.h"
-#include "mlir/IR/OpDefinition.h"
-
-namespace mlir {
-namespace spirv {
-
-#define GET_OP_CLASSES
-#include "mlir/Dialect/SPIRV/SPIRVGLSLOps.h.inc"
-
-} // namespace spirv
-} // namespace mlir
-
-#endif // MLIR_DIALECT_SPIRV_SPIRVGLSLOPS_H_
index e32a81d..61aac88 100644 (file)
@@ -34,17 +34,7 @@ include "mlir/Dialect/SPIRV/SPIRVBase.td"
 
 // Base class for all GLSL ops.
 class SPV_GLSLOp<string mnemonic, int opcode, list<OpTrait> traits = []> :
-  SPV_Op<"glsl." # mnemonic, traits> {
-
-  // Do not use the default auto-generation serializer/deserializer.
-  let hasOpcode = 0;
-
-  // Opcode within the extended instruction set.
-  int glslOpcode = opcode;
-
-  // Name used to refer to the extended instruction set.
-  string extensionSetName = "GLSL.std.450";
-}
+  SPV_ExtInstOp<mnemonic, "GLSL", "GLSL.std.450", opcode, traits>;
 
 // Base class for GLSL unary ops.
 class SPV_GLSLUnaryOp<string mnemonic, Type resultType, Type operandType,
index 8d1a19a..e351f47 100644 (file)
@@ -56,6 +56,12 @@ include "mlir/Dialect/SPIRV/SPIRVLogicalOps.td"
 include "mlir/Dialect/SPIRV/SPIRVStructureOps.td"
 #endif // SPIRV_STRUCTURE_OPS
 
+#ifdef SPIRV_GLSL_OPS
+#else
+// Pull in ops for extended instruction set for GLSL
+include "mlir/Dialect/SPIRV/SPIRVGLSLOps.td"
+#endif // SPIRV_GLSL_OPS
+
 // -----
 
 def SPV_AccessChainOp : SPV_Op<"AccessChain", [NoSideEffect]> {
index 8b2eb16..37e79d5 100644 (file)
@@ -65,6 +65,8 @@ def SPV_AddressOfOp : SPV_Op<"_address_of", [InFunctionScope, NoSideEffect]> {
   );
 
   let hasOpcode = 0;
+
+  let autogenSerialization = 0;
 }
 
 def SPV_ConstantOp : SPV_Op<"constant", [NoSideEffect]> {
@@ -120,6 +122,8 @@ def SPV_ConstantOp : SPV_Op<"constant", [NoSideEffect]> {
   }];
 
   let hasOpcode = 0;
+
+  let autogenSerialization = 0;
 }
 
 def SPV_EntryPointOp : SPV_Op<"EntryPoint", [InModuleScope]> {
@@ -174,6 +178,7 @@ def SPV_EntryPointOp : SPV_Op<"EntryPoint", [InModuleScope]> {
   );
 
   let results = (outs);
+
   let autogenSerialization = 0;
 }
 
@@ -233,6 +238,8 @@ def SPV_GlobalVariableOp : SPV_Op<"globalVariable", [InModuleScope]> {
 
   let hasOpcode = 0;
 
+  let autogenSerialization = 0;
+
   let extraClassDeclaration = [{
     ::mlir::spirv::StorageClass storageClass() {
       return this->type().cast<::mlir::spirv::PointerType>().getStorageClass();
@@ -313,6 +320,8 @@ def SPV_ModuleOp : SPV_Op<"module",
 
   let hasOpcode = 0;
 
+  let autogenSerialization = 0;
+
   let extraClassDeclaration = [{
     Block& getBlock() {
       return this->getOperation()->getRegion(0).front();
@@ -340,6 +349,8 @@ def SPV_ModuleEndOp : SPV_Op<"_module_end", [InModuleScope, Terminator]> {
   let verifier = [{ return success(); }];
 
   let hasOpcode = 0;
+
+  let autogenSerialization = 0;
 }
 
 def SPV_ReferenceOfOp : SPV_Op<"_reference_of", [NoSideEffect]> {
@@ -376,6 +387,8 @@ def SPV_ReferenceOfOp : SPV_Op<"_reference_of", [NoSideEffect]> {
   );
 
   let hasOpcode = 0;
+
+  let autogenSerialization = 0;
 }
 
 def SPV_SpecConstantOp : SPV_Op<"specConstant", [InModuleScope]> {
@@ -417,6 +430,8 @@ def SPV_SpecConstantOp : SPV_Op<"specConstant", [InModuleScope]> {
   let results = (outs);
 
   let hasOpcode = 0;
+
+  let autogenSerialization = 0;
 }
 
 #endif // SPIRV_STRUCTURE_OPS
index f4e89d1..05f09d2 100644 (file)
@@ -1,7 +1,6 @@
 add_llvm_library(MLIRSPIRV
   DialectRegistration.cpp
   SPIRVDialect.cpp
-  SPIRVGLSLOps.cpp
   SPIRVOps.cpp
   SPIRVTypes.cpp
 
index c1c214f..4660aa8 100644 (file)
@@ -11,7 +11,6 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/SPIRV/SPIRVDialect.h"
-#include "mlir/Dialect/SPIRV/SPIRVGLSLOps.h"
 #include "mlir/Dialect/SPIRV/SPIRVOps.h"
 #include "mlir/Dialect/SPIRV/SPIRVTypes.h"
 #include "mlir/IR/Builders.h"
@@ -48,12 +47,6 @@ SPIRVDialect::SPIRVDialect(MLIRContext *context)
 #include "mlir/Dialect/SPIRV/SPIRVOps.cpp.inc"
       >();
 
-  // Add SPIR-V extension ops of GLSL.
-  addOperations<
-#define GET_OP_LIST
-#include "mlir/Dialect/SPIRV/SPIRVGLSLOps.cpp.inc"
-      >();
-
   // Allow unknown operations because SPIR-V is extensible.
   allowUnknownOperations();
 }
diff --git a/mlir/lib/Dialect/SPIRV/SPIRVGLSLOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVGLSLOps.cpp
deleted file mode 100644 (file)
index b007aaf..0000000
+++ /dev/null
@@ -1,58 +0,0 @@
-//===- SPIRVGLSLOps.cpp - MLIR SPIR-V GLSL extended operations ------------===//
-//
-// Copyright 2019 The MLIR Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-//   http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
-//
-// This file defines the operations in the SPIR-V extended instructions set for
-// GLSL
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/SPIRV/SPIRVGLSLOps.h"
-#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
-#include "mlir/Dialect/SPIRV/SPIRVTypes.h"
-#include "mlir/IR/OpImplementation.h"
-
-using namespace mlir;
-
-//===----------------------------------------------------------------------===//
-// spv.glsl.UnaryOp
-//===----------------------------------------------------------------------===//
-
-static ParseResult parseGLSLUnaryOp(OpAsmParser *parser,
-                                    OperationState *state) {
-  OpAsmParser::OperandType operandInfo;
-  Type type;
-  if (parser->parseOperand(operandInfo) || parser->parseColonType(type) ||
-      parser->resolveOperands(operandInfo, type, state->operands)) {
-    return failure();
-  }
-  state->addTypes(type);
-  return success();
-}
-
-static void printGLSLUnaryOp(Operation *unaryOp, OpAsmPrinter *printer) {
-  *printer << unaryOp->getName() << ' ' << *unaryOp->getOperand(0) << " : "
-           << unaryOp->getOperand(0)->getType();
-}
-
-namespace mlir {
-namespace spirv {
-
-#define GET_OP_CLASSES
-#include "mlir/Dialect/SPIRV/SPIRVGLSLOps.cpp.inc"
-
-} // namespace spirv
-} // namespace mlir
index 9766d6c..c8a8078 100644 (file)
@@ -914,7 +914,7 @@ static void print(spirv::ExecutionModeOp execModeOp, OpAsmPrinter *printer) {
 }
 
 //===----------------------------------------------------------------------===//
-// spv.FuncionCall
+// spv.FunctionCall
 //===----------------------------------------------------------------------===//
 
 static ParseResult parseFunctionCallOp(OpAsmParser *parser,
@@ -1016,6 +1016,27 @@ static LogicalResult verify(spirv::FunctionCallOp functionCallOp) {
 }
 
 //===----------------------------------------------------------------------===//
+// spv.GLSL.UnaryOp
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseGLSLUnaryOp(OpAsmParser *parser,
+                                    OperationState *state) {
+  OpAsmParser::OperandType operandInfo;
+  Type type;
+  if (parser->parseOperand(operandInfo) || parser->parseColonType(type) ||
+      parser->resolveOperands(operandInfo, type, state->operands)) {
+    return failure();
+  }
+  state->addTypes(type);
+  return success();
+}
+
+static void printGLSLUnaryOp(Operation *unaryOp, OpAsmPrinter *printer) {
+  *printer << unaryOp->getName() << ' ' << *unaryOp->getOperand(0) << " : "
+           << unaryOp->getOperand(0)->getType();
+}
+
+//===----------------------------------------------------------------------===//
 // spv.globalVariable
 //===----------------------------------------------------------------------===//
 
index e5f4e06..23cd60e 100644 (file)
@@ -97,7 +97,11 @@ private:
 
   /// Processes the SPIR-V OpExtension with `operands` and updates bookkeeping
   /// in the deserializer.
-  LogicalResult processExtension(ArrayRef<uint32_t> operands);
+  LogicalResult processExtension(ArrayRef<uint32_t> words);
+
+  /// Processes the SPIR-V OpExtInstImport with `operands` and updates
+  /// bookkeeping in the deserializer.
+  LogicalResult processExtInstImport(ArrayRef<uint32_t> words);
 
   /// Attaches all collected extensions to `module` as an attribute.
   void attachExtensions();
@@ -300,6 +304,20 @@ private:
   LogicalResult dispatchToAutogenDeserialization(spirv::Opcode opcode,
                                                  ArrayRef<uint32_t> words);
 
+  /// Processes a SPIR-V OpExtInst with given `operands`. This slices the
+  /// entries of `operands` that specify the extended instruction set <id> and
+  /// the instruction opcode. The op deserializer is then invoked using the
+  /// other entries.
+  LogicalResult processExtInst(ArrayRef<uint32_t> operands);
+
+  /// Dispatches the deserialization of extended instruction set operation based
+  /// on the extended instruction set name, and instruction opcode. This is
+  /// autogenerated from ODS.
+  LogicalResult
+  dispatchToExtensionSetAutogenDeserialization(StringRef extensionSetName,
+                                               uint32_t instructionID,
+                                               ArrayRef<uint32_t> words);
+
   /// Method to deserialize an operation in the SPIR-V dialect that is a mirror
   /// of an instruction in the SPIR-V spec. This is auto generated if hasOpcode
   /// == 1 and autogenSerialization == 1 in ODS.
@@ -381,6 +399,9 @@ private:
   // Result <id> to member decorations.
   DenseMap<uint32_t, DenseMap<uint32_t, uint32_t>> memberDecorationMap;
 
+  // Result <id> to extended instruction set name.
+  DenseMap<uint32_t, StringRef> extendedInstSets;
+
   // List of instructions that are processed in a defered fashion (after an
   // initial processing of the entire binary). Some operations like
   // OpEntryPoint, and OpExecutionMode use forward references to function
@@ -487,16 +508,16 @@ void Deserializer::attachCapabilities() {
   module->setAttr("capabilities", opBuilder.getStrArrayAttr(caps));
 }
 
-LogicalResult Deserializer::processExtension(ArrayRef<uint32_t> operands) {
-  if (operands.empty()) {
+LogicalResult Deserializer::processExtension(ArrayRef<uint32_t> words) {
+  if (words.empty()) {
     return emitError(
         unknownLoc,
         "OpExtension must have a literal string for the extension name");
   }
 
   unsigned wordIndex = 0;
-  StringRef extName = decodeStringLiteral(operands, wordIndex);
-  if (wordIndex != operands.size()) {
+  StringRef extName = decodeStringLiteral(words, wordIndex);
+  if (wordIndex != words.size()) {
     return emitError(unknownLoc,
                      "unexpected trailing words in OpExtension instruction");
   }
@@ -505,6 +526,22 @@ LogicalResult Deserializer::processExtension(ArrayRef<uint32_t> operands) {
   return success();
 }
 
+LogicalResult Deserializer::processExtInstImport(ArrayRef<uint32_t> words) {
+  if (words.size() < 2) {
+    return emitError(unknownLoc,
+                     "OpExtInstImport must have a result <id> and a literal "
+                     "string for the extensed instruction set name");
+  }
+
+  unsigned wordIndex = 1;
+  extendedInstSets[words[0]] = decodeStringLiteral(words, wordIndex);
+  if (wordIndex != words.size()) {
+    return emitError(unknownLoc,
+                     "unexpected trailing words in OpExtInstImport");
+  }
+  return success();
+}
+
 void Deserializer::attachExtensions() {
   if (extensions.empty())
     return;
@@ -1652,6 +1689,10 @@ LogicalResult Deserializer::processInstruction(spirv::Opcode opcode,
     return processCapability(operands);
   case spirv::Opcode::OpExtension:
     return processExtension(operands);
+  case spirv::Opcode::OpExtInst:
+    return processExtInst(operands);
+  case spirv::Opcode::OpExtInstImport:
+    return processExtInstImport(operands);
   case spirv::Opcode::OpMemoryModel:
     return processMemoryModel(operands);
   case spirv::Opcode::OpEntryPoint:
@@ -1714,6 +1755,22 @@ LogicalResult Deserializer::processInstruction(spirv::Opcode opcode,
   return dispatchToAutogenDeserialization(opcode, operands);
 }
 
+LogicalResult Deserializer::processExtInst(ArrayRef<uint32_t> operands) {
+  if (operands.size() < 4) {
+    return emitError(unknownLoc,
+                     "OpExtInst must have at least 4 operands, result type "
+                     "<id>, result <id>, set <id> and instruction opcode");
+  }
+  if (!extendedInstSets.count(operands[2])) {
+    return emitError(unknownLoc, "undefined set <id> in OpExtInst");
+  }
+  SmallVector<uint32_t, 4> slicedOperands;
+  slicedOperands.append(operands.begin(), std::next(operands.begin(), 2));
+  slicedOperands.append(std::next(operands.begin(), 4), operands.end());
+  return dispatchToExtensionSetAutogenDeserialization(
+      extendedInstSets[operands[2]], operands[3], slicedOperands);
+}
+
 namespace {
 
 template <>
index c31c9f3..ea50649 100644 (file)
@@ -278,6 +278,11 @@ private:
   // Operations
   //===--------------------------------------------------------------------===//
 
+  LogicalResult encodeExtensionInstruction(Operation *op,
+                                           StringRef extensionSetName,
+                                           uint32_t opcode,
+                                           ArrayRef<uint32_t> operands);
+
   uint32_t findValueID(Value *val) const { return valueIDMap.lookup(val); }
 
   LogicalResult processAddressOfOp(spirv::AddressOfOp addressOfOp);
@@ -345,6 +350,9 @@ private:
 
   /// Map from results of normal operations to their <id>s.
   DenseMap<Value *, uint32_t> valueIDMap;
+
+  /// Map from extended instruction set name to <id>s.
+  llvm::StringMap<uint32_t> extendedInstSetIDMap;
 };
 } // namespace
 
@@ -1347,6 +1355,37 @@ LogicalResult Serializer::processBranchOp(spirv::BranchOp branchOp) {
 // Operation
 //===----------------------------------------------------------------------===//
 
+LogicalResult Serializer::encodeExtensionInstruction(
+    Operation *op, StringRef extensionSetName, uint32_t extensionOpcode,
+    ArrayRef<uint32_t> operands) {
+  // Check if the extension has been imported.
+  auto &setID = extendedInstSetIDMap[extensionSetName];
+  if (!setID) {
+    setID = getNextID();
+    SmallVector<uint32_t, 16> importOperands;
+    importOperands.push_back(setID);
+    if (failed(encodeStringLiteralInto(importOperands, extensionSetName)) ||
+        failed(encodeInstructionInto(
+            extendedSets, spirv::Opcode::OpExtInstImport, importOperands))) {
+      return failure();
+    }
+  }
+
+  // The first two operands are the result type <id> and result <id>. The set
+  // <id> and the opcode need to be insert after this.
+  if (operands.size() < 2) {
+    return op->emitError("extended instructions must have a result encoding");
+  }
+  SmallVector<uint32_t, 8> extInstOperands;
+  extInstOperands.reserve(operands.size() + 2);
+  extInstOperands.append(operands.begin(), std::next(operands.begin(), 2));
+  extInstOperands.push_back(setID);
+  extInstOperands.push_back(extensionOpcode);
+  extInstOperands.append(std::next(operands.begin(), 2), operands.end());
+  return encodeInstructionInto(functions, spirv::Opcode::OpExtInst,
+                               extInstOperands);
+}
+
 LogicalResult Serializer::processAddressOfOp(spirv::AddressOfOp addressOfOp) {
   auto varName = addressOfOp.variable();
   auto variableID = findVariableID(varName);
diff --git a/mlir/test/Dialect/SPIRV/Serialization/glslops.mlir b/mlir/test/Dialect/SPIRV/Serialization/glslops.mlir
new file mode 100644 (file)
index 0000000..5dba6ce
--- /dev/null
@@ -0,0 +1,9 @@
+// RUN: mlir-translate -serialize-spirv %s | mlir-translate -deserialize-spirv | FileCheck %s
+
+spv.module "Logical" "GLSL450" {
+  func @fmul(%arg0 : f32) {
+    // CHECK: {{%.*}} = spv.GLSL.Exp {{%.*}} : f32
+    %0 = spv.GLSL.Exp %arg0 : f32
+    spv.Return
+  }
+}
\ No newline at end of file
index 6ec900b..181f263 100644 (file)
@@ -1,18 +1,18 @@
 // RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s
 
 //===----------------------------------------------------------------------===//
-// spv.glsl.Exp
+// spv.GLSL.Exp
 //===----------------------------------------------------------------------===//
 
 func @exp(%arg0 : f32) -> () {
-  // CHECK: spv.glsl.Exp {{%.*}} : f32
-  %2 = spv.glsl.Exp %arg0 : f32
+  // CHECK: spv.GLSL.Exp {{%.*}} : f32
+  %2 = spv.GLSL.Exp %arg0 : f32
   return
 }
 
 func @expvec(%arg0 : vector<3xf16>) -> () {
-  // CHECK: spv.glsl.Exp {{%.*}} : vector<3xf16>
-  %2 = spv.glsl.Exp %arg0 : vector<3xf16>
+  // CHECK: spv.GLSL.Exp {{%.*}} : vector<3xf16>
+  %2 = spv.GLSL.Exp %arg0 : vector<3xf16>
   return
 }
 
@@ -20,7 +20,7 @@ func @expvec(%arg0 : vector<3xf16>) -> () {
 
 func @exp(%arg0 : i32) -> () {
   // expected-error @+1 {{op operand #0 must be 16/32-bit float or vector of 16/32-bit float values}}
-  %2 = spv.glsl.Exp %arg0 : i32
+  %2 = spv.GLSL.Exp %arg0 : i32
   return
 }
 
@@ -28,7 +28,7 @@ func @exp(%arg0 : i32) -> () {
 
 func @exp(%arg0 : vector<5xf32>) -> () {
   // expected-error @+1 {{op operand #0 must be 16/32-bit float or vector of 16/32-bit float values of length 2/3/4}}
-  %2 = spv.glsl.Exp %arg0 : vector<5xf32>
+  %2 = spv.GLSL.Exp %arg0 : vector<5xf32>
   return
 }
 
@@ -36,7 +36,7 @@ func @exp(%arg0 : vector<5xf32>) -> () {
 
 func @exp(%arg0 : f32, %arg1 : f32) -> () {
   // expected-error @+1 {{expected ':'}}
-  %2 = spv.glsl.Exp %arg0, %arg1 : i32
+  %2 = spv.GLSL.Exp %arg0, %arg1 : i32
   return
 }
 
@@ -44,6 +44,6 @@ func @exp(%arg0 : f32, %arg1 : f32) -> () {
 
 func @exp(%arg0 : i32) -> () {
   // expected-error @+2 {{expected non-function type}}
-  %2 = spv.glsl.Exp %arg0 :
+  %2 = spv.GLSL.Exp %arg0 :
   return
 }
index ca65065..b3059a9 100644 (file)
@@ -28,6 +28,7 @@
 #include "llvm/ADT/Sequence.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/StringExtras.h"
+#include "llvm/ADT/StringMap.h"
 #include "llvm/ADT/StringRef.h"
 #include "llvm/Support/FormatVariadic.h"
 #include "llvm/Support/raw_ostream.h"
@@ -41,7 +42,9 @@ using llvm::raw_ostream;
 using llvm::raw_string_ostream;
 using llvm::Record;
 using llvm::RecordKeeper;
+using llvm::SmallVector;
 using llvm::SMLoc;
+using llvm::StringMap;
 using llvm::StringRef;
 using llvm::Twine;
 using mlir::tblgen::Attribute;
@@ -59,36 +62,43 @@ using mlir::tblgen::Operator;
 static void emitGetOpcodeFunction(const Record *record, Operator const &op,
                                   raw_ostream &os) {
   os << formatv("template <> constexpr inline ::mlir::spirv::Opcode "
-                "getOpcode<{0}>()",
-                op.getQualCppClassName())
-     << " {\n  "
-     << formatv("return ::mlir::spirv::Opcode::{0};\n}\n",
+                "getOpcode<{0}>() {{\n",
+                op.getQualCppClassName());
+  os << formatv("  return ::mlir::spirv::Opcode::{0};\n",
                 record->getValueAsString("spirvOpName"));
+  os << "}\n";
 }
 
+/// Forward declaration of function to return the SPIR-V opcode corresponding to
+/// an operation. This function will be generated for all SPV_Op instances that
+/// have hasOpcode = 1.
 static void declareOpcodeFn(raw_ostream &os) {
   os << "template <typename OpClass> inline constexpr ::mlir::spirv::Opcode "
         "getOpcode();\n";
 }
 
+/// Generates code to serialize attributes of a SPV_Op `op` into `os`. The
+/// generates code extracts the attribute with name `attrName` from
+/// `operandList` of `op`.
 static void emitAttributeSerialization(const Attribute &attr,
-                                       ArrayRef<SMLoc> loc, llvm::StringRef op,
-                                       llvm::StringRef operandList,
-                                       llvm::StringRef attrName,
-                                       raw_ostream &os) {
-  os << "    auto attr = " << op << ".getAttr(\"" << attrName << "\");\n";
-  os << "    if (attr) {\n";
+                                       ArrayRef<SMLoc> loc, StringRef tabs,
+                                       StringRef opVar, StringRef operandList,
+                                       StringRef attrName, raw_ostream &os) {
+  os << tabs << formatv("auto attr = {0}.getAttr(\"{1}\");\n", opVar, attrName);
+  os << tabs << "if (attr) {\n";
   if (attr.getAttrDefName() == "I32ArrayAttr") {
     // Serialize all the elements of the array
-    os << "      for (auto attrElem : attr.cast<ArrayAttr>()) {\n";
-    os << "        " << operandList
-       << ".push_back(static_cast<uint32_t>(attrElem.cast<IntegerAttr>()."
-          "getValue().getZExtValue()));\n";
-    os << "      }\n";
+    os << tabs << "  for (auto attrElem : attr.cast<ArrayAttr>()) {\n";
+    os << tabs
+       << formatv("    {0}.push_back(static_cast<uint32_t>("
+                  "attrElem.cast<IntegerAttr>().getValue().getZExtValue()));\n",
+                  operandList);
+    os << tabs << "  }\n";
   } else if (attr.isEnumAttr() || attr.getAttrDefName() == "I32Attr") {
-    os << "      " << operandList
-       << ".push_back(static_cast<uint32_t>(attr.cast<IntegerAttr>().getValue()"
-          ".getZExtValue()));\n";
+    os << tabs
+       << formatv("  {0}.push_back(static_cast<uint32_t>("
+                  "attr.cast<IntegerAttr>().getValue().getZExtValue()));\n",
+                  operandList);
   } else {
     PrintFatalError(
         loc,
@@ -96,122 +106,205 @@ static void emitAttributeSerialization(const Attribute &attr,
             "unhandled attribute type in SPIR-V serialization generation : '") +
             attr.getAttrDefName() + llvm::Twine("'"));
   }
-  os << "    }\n";
+  os << tabs << "}\n";
 }
 
-static void emitSerializationFunction(const Record *attrClass,
-                                      const Record *record, const Operator &op,
-                                      raw_ostream &os) {
-  // If the record has 'autogenSerialization' set to 0, nothing to do
-  if (!record->getValueAsBit("autogenSerialization")) {
-    return;
-  }
-  os << formatv("template <> LogicalResult\nSerializer::processOp<{0}>(\n"
-                "  {0} op)",
-                op.getQualCppClassName())
-     << " {\n";
-  os << "  SmallVector<uint32_t, 4> operands;\n";
-  os << "  SmallVector<StringRef, 2> elidedAttrs;\n";
-
-  // Serialize result information
-  if (op.getNumResults() == 1) {
-    os << "  uint32_t resultTypeID = 0;\n";
-    os << "  if (failed(processType(op.getLoc(), op.getType(), resultTypeID))) "
-          "{\n";
-    os << "    return failure();\n";
-    os << "  }\n";
-    os << "  operands.push_back(resultTypeID);\n";
-    // Create an SSA result <id> for the op
-    os << "  auto resultID = getNextID();\n";
-    os << "  valueIDMap[op.getResult()] = resultID;\n";
-    os << "  operands.push_back(resultID);\n";
-  } else if (op.getNumResults() != 0) {
-    PrintFatalError(record->getLoc(), "SPIR-V ops can only zero or one result");
-  }
-
-  // Process arguments
+/// Generates code to serialize the operands of a SPV_Op `op` into `os`. The
+/// generated querries the SSA-ID if operand is a SSA-Value, or serializes the
+/// attributes. The `operands` vector is updated appropriately. `elidedAttrs`
+/// updated as well to include the serialized attributes.
+static void emitOperandSerialization(const Operator &op, ArrayRef<SMLoc> loc,
+                                     StringRef tabs, StringRef opVar,
+                                     StringRef operands, StringRef elidedAttrs,
+                                     raw_ostream &os) {
   auto operandNum = 0;
   for (unsigned i = 0, e = op.getNumArgs(); i < e; ++i) {
     auto argument = op.getArg(i);
-    os << "  {\n";
+    os << tabs << "{\n";
     if (argument.is<NamedTypeConstraint *>()) {
-      os << "    for (auto arg : op.getODSOperands(" << operandNum << ")) {\n";
-      os << "      auto argID = findValueID(arg);\n";
-      os << "      if (!argID) {\n";
-      os << "        emitError(op.getLoc(), \"operand " << operandNum
-         << " has a use before def\");\n";
-      os << "      }\n";
-      os << "      operands.push_back(argID);\n";
+      os << tabs
+         << formatv("  for (auto arg : {0}.getODSOperands({1})) {{\n", opVar,
+                    operandNum);
+      os << tabs << "    auto argID = findValueID(arg);\n";
+      os << tabs << "    if (!argID) {\n";
+      os << tabs
+         << formatv(
+                "      emitError({0}.getLoc(), \"operand {1} has a use before "
+                "def\");\n",
+                opVar, operandNum);
+      os << tabs << "    }\n";
+      os << tabs << formatv("    {0}.push_back(argID);\n", operands);
       os << "    }\n";
       operandNum++;
     } else {
       auto attr = argument.get<NamedAttribute *>();
+      auto newtabs = tabs.str() + "  ";
       emitAttributeSerialization(
           (attr->attr.isOptional() ? attr->attr.getBaseAttr() : attr->attr),
-          record->getLoc(), "op", "operands", attr->name, os);
-      os << "    elidedAttrs.push_back(\"" << attr->name << "\");\n";
+          loc, newtabs, opVar, operands, attr->name, os);
+      os << newtabs
+         << formatv("{0}.push_back(\"{1}\");\n", elidedAttrs, attr->name);
     }
-    os << "  }\n";
+    os << tabs << "}\n";
   }
+}
 
-  os << formatv("  encodeInstructionInto("
-                "functions, spirv::getOpcode<{0}>(), operands);\n",
-                op.getQualCppClassName());
+/// Generates code to serializes the result of SPV_Op `op` into `os`. The
+/// generated gets the ID for the type of the result (if any), the SSA-ID of
+/// the result and updates `resultID` with the SSA-ID.
+static void emitResultSerialization(const Operator &op, ArrayRef<SMLoc> loc,
+                                    StringRef tabs, StringRef opVar,
+                                    StringRef operands, StringRef resultID,
+                                    raw_ostream &os) {
+  if (op.getNumResults() == 1) {
+    StringRef resultTypeID("resultTypeID");
+    os << tabs << formatv("uint32_t {0} = 0;\n", resultTypeID);
+    os << tabs
+       << formatv(
+              "if (failed(processType({0}.getLoc(), {0}.getType(), {1}))) {{\n",
+              opVar, resultTypeID);
+    os << tabs << "  return failure();\n";
+    os << tabs << "}\n";
+    os << tabs << formatv("{0}.push_back({1});\n", operands, resultTypeID);
+    // Create an SSA result <id> for the op
+    os << tabs << formatv("{0} = getNextID();\n", resultID);
+    os << tabs
+       << formatv("valueIDMap[{0}.getResult()] = {1};\n", opVar, resultID);
+    os << tabs << formatv("{0}.push_back({1});\n", operands, resultID);
+  } else if (op.getNumResults() != 0) {
+    PrintFatalError(loc, "SPIR-V ops can only have zero or one result");
+  }
+}
 
+/// Generates code to serialize attributes of SPV_Op `op` that become
+/// decorations on the `resultID` of the serialized operation `opVar` in the
+/// SPIR-V binary.
+static void emitDecorationSerialization(const Operator &op, StringRef tabs,
+                                        StringRef opVar, StringRef elidedAttrs,
+                                        StringRef resultID, raw_ostream &os) {
   if (op.getNumResults() == 1) {
     // All non-argument attributes translated into OpDecorate instruction
-    os << "  for (auto attr : op.getAttrs()) {\n";
-    os << "    if (llvm::any_of(elidedAttrs, [&](StringRef elided) { return "
-          "attr.first.is(elided); })) {\n";
-    os << "      continue;\n";
-    os << "    }\n";
-    os << "    if (failed(processDecoration(op.getLoc(), resultID, attr))) {\n";
-    os << "      return failure();";
-    os << "    }\n";
-    os << "  }\n";
+    os << tabs << formatv("for (auto attr : {0}.getAttrs()) {{\n", opVar);
+    os << tabs
+       << formatv("  if (llvm::any_of({0}, [&](StringRef elided)", elidedAttrs);
+    os << " {return attr.first.is(elided);})) {\n";
+    os << tabs << "    continue;\n";
+    os << tabs << "  }\n";
+    os << tabs
+       << formatv(
+              "  if (failed(processDecoration({0}.getLoc(), {1}, attr))) {{\n",
+              opVar, resultID);
+    os << tabs << "    return failure();\n";
+    os << tabs << "  }\n";
+    os << tabs << "}\n";
+  }
+}
+
+/// Generates code to serialize an SPV_Op `op` into `os`.
+static void emitSerializationFunction(const Record *attrClass,
+                                      const Record *record, const Operator &op,
+                                      raw_ostream &os) {
+  // If the record has 'autogenSerialization' set to 0, nothing to do
+  if (!record->getValueAsBit("autogenSerialization")) {
+    return;
+  }
+  StringRef opVar("op"), operands("operands"), elidedAttrs("elidedAttrs"),
+      resultID("resultID");
+  os << formatv(
+      "template <> LogicalResult\nSerializer::processOp<{0}>({0} {1}) {{\n",
+      op.getQualCppClassName(), opVar);
+  os << formatv("  SmallVector<uint32_t, 4> {0};\n", operands);
+  os << formatv("  SmallVector<StringRef, 2> {0};\n", elidedAttrs);
+
+  // Serialize result information.
+  if (op.getNumResults() == 1) {
+    os << formatv("  uint32_t {0} = 0;\n", resultID);
+    emitResultSerialization(op, record->getLoc(), "  ", opVar, operands,
+                            resultID, os);
+  }
+
+  // Process arguments.
+  emitOperandSerialization(op, record->getLoc(), "  ", opVar, operands,
+                           elidedAttrs, os);
+
+  if (record->isSubClassOf("SPV_ExtInstOp")) {
+    os << formatv("  encodeExtensionInstruction({0}, \"{1}\", {2}, {3});\n",
+                  opVar, record->getValueAsString("extendedInstSetName"),
+                  record->getValueAsInt("extendedInstOpcode"), operands);
+  } else {
+    os << formatv("  encodeInstructionInto("
+                  "functions, spirv::getOpcode<{0}>(), {1});\n",
+                  op.getQualCppClassName(), operands);
   }
 
+  // Process decorations.
+  emitDecorationSerialization(op, "  ", opVar, elidedAttrs, resultID, os);
+
   os << "  return success();\n";
   os << "}\n\n";
 }
 
-static void initDispatchSerializationFn(raw_ostream &os) {
-  os << "LogicalResult Serializer::dispatchToAutogenSerialization(Operation "
-        "*op) {\n ";
+/// Generates the prologue for the function that dispatches the serialization of
+/// the operation `opVar` based on its opcode.
+static void initDispatchSerializationFn(StringRef opVar, raw_ostream &os) {
+  os << formatv(
+      "LogicalResult Serializer::dispatchToAutogenSerialization(Operation "
+      "*{0}) {{\n ",
+      opVar);
 }
 
-static void emitSerializationDispatch(const Operator &op, raw_ostream &os) {
-  os << formatv(" if (isa<{0}>(op)) ", op.getQualCppClassName()) << "{\n";
-  os << "    ";
-  os << formatv("return processOp<{0}>(cast<{0}>(op));\n",
-                op.getQualCppClassName());
-  os << "  } else";
+/// Generates the body of the dispatch function. This function generates the
+/// check that if satisfied, will call the serialization function generated for
+/// the `op`.
+static void emitSerializationDispatch(const Operator &op, StringRef tabs,
+                                      StringRef opVar, raw_ostream &os) {
+  os << tabs
+     << formatv("if (isa<{0}>({1})) {{\n", op.getQualCppClassName(), opVar);
+  os << tabs
+     << formatv("  return processOp(cast<{0}>({1}));\n",
+                op.getQualCppClassName(), opVar);
+  os << tabs << "} else";
 }
 
-static void finalizeDispatchSerializationFn(raw_ostream &os) {
+/// Generates the epilogue for the function that dispatches the serialization of
+/// the operation.
+static void finalizeDispatchSerializationFn(StringRef opVar, raw_ostream &os) {
   os << " {\n";
-  os << "    return op->emitError(\"unhandled operation serialization\");\n";
+  os << formatv(
+      "    return {0}->emitError(\"unhandled operation serialization\");\n",
+      opVar);
   os << "  }\n";
   os << "  return success();\n";
   os << "}\n\n";
 }
 
-static void emitAttributeDeserialization(
-    const Attribute &attr, ArrayRef<SMLoc> loc, llvm::StringRef attrList,
-    llvm::StringRef attrName, llvm::StringRef operandsList,
-    llvm::StringRef wordIndex, llvm::StringRef wordCount, raw_ostream &os) {
+/// Generates code to deserialize the attribute of a SPV_Op into `os`. The
+/// generated code reads the `words` of the serialized instruction at
+/// position `wordIndex` and adds the deserialized attribute into `attrList`.
+static void emitAttributeDeserialization(const Attribute &attr,
+                                         ArrayRef<SMLoc> loc, StringRef tabs,
+                                         StringRef attrList, StringRef attrName,
+                                         StringRef words, StringRef wordIndex,
+                                         raw_ostream &os) {
   if (attr.getAttrDefName() == "I32ArrayAttr") {
-    os << "    SmallVector<Attribute, 4> attrListElems;\n";
-    os << "    while (" << wordIndex << " < " << wordCount << ") {\n";
-    os << "      attrListElems.push_back(opBuilder.getI32IntegerAttr("
-       << operandsList << "[" << wordIndex << "++]));\n";
-    os << "    }\n";
-    os << "    " << attrList << ".push_back(opBuilder.getNamedAttr(\""
-       << attrName << "\", opBuilder.getArrayAttr(attrListElems)));\n";
+    os << tabs << "SmallVector<Attribute, 4> attrListElems;\n";
+    os << tabs << formatv("while ({0} < {1}.size()) {{\n", wordIndex, words);
+    os << tabs
+       << formatv(
+              "  "
+              "attrListElems.push_back(opBuilder.getI32IntegerAttr({0}[{1}++]))"
+              ";\n",
+              words, wordIndex);
+    os << tabs << "}\n";
+    os << tabs
+       << formatv("{0}.push_back(opBuilder.getNamedAttr(\"{1}\", "
+                  "opBuilder.getArrayAttr(attrListElems)));\n",
+                  attrList, attrName);
   } else if (attr.isEnumAttr() || attr.getAttrDefName() == "I32Attr") {
-    os << "    " << attrList << ".push_back(opBuilder.getNamedAttr(\""
-       << attrName << "\", opBuilder.getI32IntegerAttr(" << operandsList << "["
-       << wordIndex << "++])));\n";
+    os << tabs
+       << formatv("{0}.push_back(opBuilder.getNamedAttr(\"{1}\", "
+                  "opBuilder.getI32IntegerAttr({2}[{3}++])));\n",
+                  attrList, attrName, words, wordIndex);
   } else {
     PrintFatalError(
         loc, llvm::Twine(
@@ -220,143 +313,281 @@ static void emitAttributeDeserialization(
   }
 }
 
-static void emitDeserializationFunction(const Record *attrClass,
-                                        const Record *record,
-                                        const Operator &op, raw_ostream &os) {
-  // If the record has 'autogenSerialization' set to 0, nothing to do
-  if (!record->getValueAsBit("autogenSerialization")) {
-    return;
-  }
-  os << formatv("template <> "
-                "LogicalResult\nDeserializer::processOp<{0}>(ArrayRef<"
-                "uint32_t> words)",
-                op.getQualCppClassName());
-  os << " {\n";
-  os << "  SmallVector<Type, 1> resultTypes;\n";
-  os << "  size_t wordIndex = 0; (void)wordIndex;\n";
-
+/// Generates the code to deserialize the result of an SPV_Op `op` into
+/// `os`. The generated code gets the type of the result specified at
+/// `words`[`wordIndex`], the SSA ID for the result at position `wordIndex` + 1
+/// and updates the `resultType` and `valueID` with the parsed type and SSA ID,
+/// respectively.
+static void emitResultDeserialization(const Operator &op, ArrayRef<SMLoc> loc,
+                                      StringRef tabs, StringRef words,
+                                      StringRef wordIndex,
+                                      StringRef resultTypes, StringRef valueID,
+                                      raw_ostream &os) {
   // Deserialize result information if it exists
-  bool hasResult = false;
   if (op.getNumResults() == 1) {
-    os << "  {\n";
-    os << "    if (wordIndex >= words.size()) {\n";
-    os << "      "
-       << formatv("return emitError(unknownLoc, \"expected result type <id> "
-                  "while deserializing {0}\");\n",
-                  op.getQualCppClassName());
-    os << "    }\n";
-    os << "    auto ty = getType(words[wordIndex]);\n";
-    os << "    if (!ty) {\n";
-    os << "      return emitError(unknownLoc, \"unknown type result <id> : "
-          "\") << words[wordIndex];\n";
-    os << "    }\n";
-    os << "    resultTypes.push_back(ty);\n";
-    os << "    wordIndex++;\n";
-    os << "  }\n";
-    os << "  if (wordIndex >= words.size()) {\n";
-    os << "    "
-       << formatv("return emitError(unknownLoc, \"expected result <id> while "
-                  "deserializing {0}\");\n",
-                  op.getQualCppClassName());
-    os << "  }\n";
-    os << "  uint32_t valueID = words[wordIndex++];\n";
-    hasResult = true;
+    os << tabs << "{\n";
+    os << tabs << formatv("  if ({0} >= {1}.size()) {{\n", wordIndex, words);
+    os << tabs
+       << formatv(
+              "    return emitError(unknownLoc, \"expected result type <id> "
+              "while deserializing {0}\");\n",
+              op.getQualCppClassName());
+    os << tabs << "  }\n";
+    os << tabs << formatv("  auto ty = getType({0}[{1}]);\n", words, wordIndex);
+    os << tabs << "  if (!ty) {\n";
+    os << tabs
+       << formatv(
+              "    return emitError(unknownLoc, \"unknown type result <id> : "
+              "\") << {0}[{1}];\n",
+              words, wordIndex);
+    os << tabs << "  }\n";
+    os << tabs << formatv("  {0}.push_back(ty);\n", resultTypes);
+    os << tabs << formatv("  {0}++;\n", wordIndex);
+    os << tabs << formatv("  if ({0} >= {1}.size()) {{\n", wordIndex, words);
+    os << tabs
+       << formatv(
+              "    return emitError(unknownLoc, \"expected result <id> while "
+              "deserializing {0}\");\n",
+              op.getQualCppClassName());
+    os << tabs << "  }\n";
+    os << tabs << "}\n";
+    os << tabs << formatv("{0} = {1}[{2}++];\n", valueID, words, wordIndex);
   } else if (op.getNumResults() != 0) {
-    PrintFatalError(record->getLoc(),
-                    "SPIR-V ops can have only zero or one result");
+    PrintFatalError(loc, "SPIR-V ops can have only zero or one result");
   }
+}
 
+/// Generates the code to deserialize the operands of an SPV_Op `op` into
+/// `os`. The generated code reads the `words` of the binary instruction, from
+/// position `wordIndex` to the end, and either gets the Value corresponding to
+/// the ID encoded, or deserializes the attributes encoded. The parsed
+/// operand(attribute) is added to the `operands` list or `attributes` list.
+static void emitOperandDeserialization(const Operator &op, ArrayRef<SMLoc> loc,
+                                       StringRef tabs, StringRef words,
+                                       StringRef wordIndex, StringRef operands,
+                                       StringRef attributes, raw_ostream &os) {
   // Process operands/attributes
-  os << "  SmallVector<Value *, 4> operands;\n";
-  os << "  SmallVector<NamedAttribute, 4> attributes;\n";
   unsigned operandNum = 0;
   for (unsigned i = 0, e = op.getNumArgs(); i < e; ++i) {
     auto argument = op.getArg(i);
     if (auto valueArg = argument.dyn_cast<NamedTypeConstraint *>()) {
       if (valueArg->isVariadic()) {
         if (i != e - 1) {
-          PrintFatalError(record->getLoc(),
+          PrintFatalError(loc,
                           "SPIR-V ops can have Variadic<..> argument only if "
                           "it's the last argument");
         }
-        os << "  for (; wordIndex < words.size(); ++wordIndex)";
+        os << tabs
+           << formatv("for (; {0} < {1}.size(); ++{0})", wordIndex, words);
       } else {
-        os << "  if (wordIndex < words.size())";
+        os << tabs << formatv("if ({0} < {1}.size())", wordIndex, words);
       }
       os << " {\n";
-      os << "    auto arg = getValue(words[wordIndex]);\n";
-      os << "    if (!arg) {\n";
-      os << "      return emitError(unknownLoc, \"unknown result <id> : \") << "
-            "words[wordIndex];\n";
-      os << "    }\n";
-      os << "    operands.push_back(arg);\n";
+      os << tabs
+         << formatv("  auto arg = getValue({0}[{1}]);\n", words, wordIndex);
+      os << tabs << "  if (!arg) {\n";
+      os << tabs
+         << formatv(
+                "    return emitError(unknownLoc, \"unknown result <id> : \") "
+                "<< {0}[{1}];\n",
+                words, wordIndex);
+      os << tabs << "  }\n";
+      os << tabs << formatv("  {0}.push_back(arg);\n", operands);
       if (!valueArg->isVariadic()) {
-        os << "    wordIndex++;\n";
+        os << tabs << formatv("  {0}++;\n", wordIndex);
       }
       operandNum++;
-      os << "  }\n";
+      os << tabs << "}\n";
     } else {
-      os << "  if (wordIndex < words.size()) {\n";
+      os << tabs << formatv("if ({0} < {1}.size()) {{\n", wordIndex, words);
       auto attr = argument.get<NamedAttribute *>();
+      auto newtabs = tabs.str() + "  ";
       emitAttributeDeserialization(
           (attr->attr.isOptional() ? attr->attr.getBaseAttr() : attr->attr),
-          record->getLoc(), "attributes", attr->name, "words", "wordIndex",
-          "words.size()", os);
+          loc, newtabs, attributes, attr->name, words, wordIndex, os);
       os << "  }\n";
     }
   }
 
-  os << "  if (wordIndex != words.size()) {\n";
-  os << "    return emitError(unknownLoc, \"found more operands than expected "
-        "when deserializing "
-     << op.getQualCppClassName()
-     << ", only \") << wordIndex << \" of \" << words.size() << \" "
-        "processed\";\n";
-  os << "  }\n\n";
+  os << tabs << formatv("if ({0} != {1}.size()) {{\n", wordIndex, words);
+  os << tabs
+     << formatv(
+            "  return emitError(unknownLoc, \"found more operands than "
+            "expected when deserializing {0}, only \") << {1} << \" of \" << "
+            "{2}.size() << \" processed\";\n",
+            op.getQualCppClassName(), wordIndex, words);
+  os << tabs << "}\n\n";
+}
 
+/// Generates code to update the `attributes` vector with the attributes
+/// obtained from parsing the decorations in the SPIR-V binary associated with
+/// an <id> `valueID`
+static void emitDecorationDeserialization(const Operator &op, StringRef tabs,
+                                          StringRef valueID,
+                                          StringRef attributes,
+                                          raw_ostream &os) {
   // Import decorations parsed
   if (op.getNumResults() == 1) {
-    os << "  if (decorations.count(valueID)) {\n"
-       << "    auto attrs = decorations[valueID].getAttrs();\n"
-       << "    attributes.append(attrs.begin(), attrs.end());\n"
-       << "  }\n";
+    os << tabs << formatv("if (decorations.count({0})) {{\n", valueID);
+    os << tabs
+       << formatv("  auto attrs = decorations[{0}].getAttrs();\n", valueID);
+    os << tabs
+       << formatv("  {0}.append(attrs.begin(), attrs.end());\n", attributes);
+    os << tabs << "}\n";
   }
+}
 
-  os << formatv("  auto op = opBuilder.create<{0}>(unknownLoc, resultTypes, "
-                "operands, attributes); (void)op;\n",
-                op.getQualCppClassName());
-  if (hasResult) {
-    os << "  valueMap[valueID] = op.getResult();\n\n";
+/// Generates code to deserialize an SPV_Op `op` into `os`.
+static void emitDeserializationFunction(const Record *attrClass,
+                                        const Record *record,
+                                        const Operator &op, raw_ostream &os) {
+  // If the record has 'autogenSerialization' set to 0, nothing to do
+  if (!record->getValueAsBit("autogenSerialization")) {
+    return;
+  }
+  StringRef resultTypes("resultTypes"), valueID("valueID"), words("words"),
+      wordIndex("wordIndex"), opVar("op"), operands("operands"),
+      attributes("attributes");
+  os << formatv("template <> "
+                "LogicalResult\nDeserializer::processOp<{0}>(ArrayRef<"
+                "uint32_t> {1}) {{\n",
+                op.getQualCppClassName(), words);
+  os << formatv("  SmallVector<Type, 1> {0};\n", resultTypes);
+  os << formatv("  size_t {0} = 0; (void){0};\n", wordIndex);
+  os << formatv("  uint32_t {0} = 0; (void){0};\n", valueID);
+
+  // Deserialize result information
+  emitResultDeserialization(op, record->getLoc(), "  ", words, wordIndex,
+                            resultTypes, valueID, os);
+
+  os << formatv("  SmallVector<Value *, 4> {0};\n", operands);
+  os << formatv("  SmallVector<NamedAttribute, 4> {0};\n", attributes);
+  // Operand deserialization
+  emitOperandDeserialization(op, record->getLoc(), "  ", words, wordIndex,
+                             operands, attributes, os);
+
+  os << formatv(
+      "  auto {1} = opBuilder.create<{0}>(unknownLoc, {2}, {3}, {4}); "
+      "(void){1};\n",
+      op.getQualCppClassName(), opVar, resultTypes, operands, attributes);
+  if (op.getNumResults() == 1) {
+    os << formatv("  valueMap[{0}] = {1}.getResult();\n\n", valueID, opVar);
   }
 
+  // Decorations
+  emitDecorationDeserialization(op, "  ", valueID, attributes, os);
   os << "  return success();\n";
   os << "}\n\n";
 }
 
-static void initDispatchDeserializationFn(raw_ostream &os) {
-  os << "LogicalResult "
-        "Deserializer::dispatchToAutogenDeserialization(spirv::Opcode "
-        "opcode, ArrayRef<uint32_t> words) {\n";
-  os << "  switch (opcode) {\n";
+/// Generates the prologue for the function that dispatches the deserialization
+/// based on the `opcode`.
+static void initDispatchDeserializationFn(StringRef opcode, StringRef words,
+                                          raw_ostream &os) {
+  os << formatv(
+      "LogicalResult "
+      "Deserializer::dispatchToAutogenDeserialization(spirv::Opcode {0}, "
+      "ArrayRef<uint32_t> {1}) {{\n",
+      opcode, words);
+  os << formatv("  switch ({0}) {{\n", opcode);
 }
 
+/// Generates the body of the dispatch function, by generating the case label
+/// for an opcode and the call to the method to perform the deserialization.
 static void emitDeserializationDispatch(const Operator &op, const Record *def,
+                                        StringRef tabs, StringRef words,
                                         raw_ostream &os) {
-  os << formatv("  case spirv::Opcode::{0}:\n",
+  os << tabs
+     << formatv("case spirv::Opcode::{0}:\n",
                 def->getValueAsString("spirvOpName"));
-  os << formatv("    return processOp<{0}>(words);\n",
-                op.getQualCppClassName());
+  os << tabs
+     << formatv("  return processOp<{0}>({1});\n", op.getQualCppClassName(),
+                words);
 }
 
-static void finalizeDispatchDeserializationFn(raw_ostream &os) {
+/// Generates the epilogue for the function that dispatches the deserialization
+/// of the operation.
+static void finalizeDispatchDeserializationFn(StringRef opcode,
+                                              raw_ostream &os) {
   os << "  default:\n";
   os << "    ;\n";
   os << "  }\n";
-  os << "  return emitError(unknownLoc, \"unhandled deserialization of \") << "
-        "spirv::stringifyOpcode(opcode);\n";
+  os << formatv(
+      "  return emitError(unknownLoc, \"unhandled deserialization of \") << "
+      "spirv::stringifyOpcode({0});\n",
+      opcode);
   os << "}\n";
 }
 
+static void initExtendedSetDeserializationDispatch(StringRef extensionSetName,
+                                                   StringRef instructionID,
+                                                   StringRef words,
+                                                   raw_ostream &os) {
+  os << formatv("LogicalResult "
+                "Deserializer::dispatchToExtensionSetAutogenDeserialization("
+                "StringRef {0}, uint32_t {1}, ArrayRef<uint32_t> {2}) {{\n",
+                extensionSetName, instructionID, words);
+}
+
+static void
+emitExtendedSetDeserializationDispatch(const RecordKeeper &recordKeeper,
+                                       raw_ostream &os) {
+  StringRef extensionSetName("extensionSetName"),
+      instructionID("instructionID"), words("words");
+
+  // First iterate over all ops derived from SPV_ExtensionSetOps to get all
+  // extensionSets.
+
+  // For each of the extensions a separate raw_string_ostream is used to
+  // generate code into. These are then concatenated at the end. Since
+  // raw_string_ostream needs a string&, use a vector to store all the string
+  // that are captured by reference within raw_string_ostream.
+  StringMap<raw_string_ostream> extensionSets;
+  SmallVector<std::string, 1> extensionSetNames;
+
+  initExtendedSetDeserializationDispatch(extensionSetName, instructionID, words,
+                                         os);
+  auto defs = recordKeeper.getAllDerivedDefinitions("SPV_ExtInstOp");
+  for (const auto *def : defs) {
+    if (!def->getValueAsBit("autogenSerialization")) {
+      continue;
+    }
+    Operator op(def);
+    auto setName = def->getValueAsString("extendedInstSetName");
+    if (!extensionSets.count(setName)) {
+      extensionSetNames.push_back("");
+      extensionSets.try_emplace(setName, extensionSetNames.back());
+      auto &setos = extensionSets.find(setName)->second;
+      setos << formatv("  if ({0} == \"{1}\") {{\n", extensionSetName, setName);
+      setos << formatv("    switch ({0}) {{\n", instructionID);
+    }
+    auto &setos = extensionSets.find(setName)->second;
+    setos << formatv("    case {0}:\n",
+                     def->getValueAsInt("extendedInstOpcode"));
+    setos << formatv("      return processOp<{0}>({1});\n",
+                     op.getQualCppClassName(), words);
+  }
+
+  // Append the dispatch code for all the extended sets.
+  for (auto &extensionSet : extensionSets) {
+    os << extensionSet.second.str();
+    os << "    default:\n";
+    os << formatv(
+        "      return emitError(unknownLoc, \"unhandled deserializations of "
+        "\") << {0} << \" from extension set \" << {1};\n",
+        instructionID, extensionSetName);
+    os << "    }\n";
+    os << "  }\n";
+  }
+
+  os << formatv("  return emitError(unknownLoc, \"unhandled deserialization of "
+                "extended instruction set {0}\");\n",
+                extensionSetName);
+  os << "}\n";
+}
+
+/// Emits all the autogenerated serialization/deserializations functions for the
+/// SPV_Ops.
 static bool emitSerializationFns(const RecordKeeper &recordKeeper,
                                  raw_ostream &os) {
   llvm::emitSourceFileHeader("SPIR-V Serialization Utilities/Functions", os);
@@ -367,23 +598,31 @@ static bool emitSerializationFns(const RecordKeeper &recordKeeper,
       serFn(serFnString), deserFn(deserFnString), utils(utilsString);
   auto attrClass = recordKeeper.getClass("Attr");
 
+  // Emit the serialization and deserialization functions simulataneously.
   declareOpcodeFn(utils);
-  initDispatchSerializationFn(dSerFn);
-  initDispatchDeserializationFn(dDesFn);
+  StringRef opVar("op");
+  StringRef opcode("opcode"), words("words");
+
+  // Handle the SPIR-V ops.
+  initDispatchSerializationFn(opVar, dSerFn);
+  initDispatchDeserializationFn(opcode, words, dDesFn);
   auto defs = recordKeeper.getAllDerivedDefinitions("SPV_Op");
   for (const auto *def : defs) {
-    if (!def->getValueAsBit("hasOpcode")) {
-      continue;
-    }
     Operator op(def);
-    emitGetOpcodeFunction(def, op, utils);
     emitSerializationFunction(attrClass, def, op, serFn);
-    emitSerializationDispatch(op, dSerFn);
     emitDeserializationFunction(attrClass, def, op, deserFn);
-    emitDeserializationDispatch(op, def, dDesFn);
+    if (def->getValueAsBit("hasOpcode") || def->isSubClassOf("SPV_ExtInstOp")) {
+      emitSerializationDispatch(op, "  ", opVar, dSerFn);
+    }
+    if (def->getValueAsBit("hasOpcode")) {
+      emitGetOpcodeFunction(def, op, utils);
+      emitDeserializationDispatch(op, def, "  ", words, dDesFn);
+    }
   }
-  finalizeDispatchSerializationFn(dSerFn);
-  finalizeDispatchDeserializationFn(dDesFn);
+  finalizeDispatchSerializationFn(opVar, dSerFn);
+  finalizeDispatchDeserializationFn(opcode, dDesFn);
+
+  emitExtendedSetDeserializationDispatch(recordKeeper, dDesFn);
 
   os << "#ifdef GET_SPIRV_SERIALIZATION_UTILS\n";
   os << utils.str();
@@ -421,8 +660,8 @@ static void emitEnumGetSymbolizeFnDecl(raw_ostream &os) {
 static void emitEnumGetAttrNameFnDefn(const EnumAttr &enumAttr,
                                       raw_ostream &os) {
   auto enumName = enumAttr.getEnumClassName();
-  os << formatv("template <> inline StringRef attributeName<{0}>()", enumName)
-     << " {\n";
+  os << formatv("template <> inline StringRef attributeName<{0}>() {{\n",
+                enumName);
   os << "  "
      << formatv("static constexpr const char attrName[] = \"{0}\";\n",
                 mlir::convertToSnakeCase(enumName));
@@ -434,9 +673,9 @@ static void emitEnumGetSymbolizeFnDefn(const EnumAttr &enumAttr,
                                        raw_ostream &os) {
   auto enumName = enumAttr.getEnumClassName();
   auto strToSymFnName = enumAttr.getStringToSymbolFnName();
-  os << formatv("template <> inline SymbolizeFnTy<{0}> symbolizeEnum<{0}>()",
-                enumName)
-     << " {\n";
+  os << formatv(
+      "template <> inline SymbolizeFnTy<{0}> symbolizeEnum<{0}>() {{\n",
+      enumName);
   os << "  return " << strToSymFnName << ";\n";
   os << "}\n";
 }