[spirv] Support (de)serialization of spv.struct
authorDenis Khalikov <dennis.khalikov@gmail.com>
Tue, 20 Aug 2019 18:02:57 +0000 (11:02 -0700)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Tue, 20 Aug 2019 18:03:42 +0000 (11:03 -0700)
Support (de)serialization of spv.struct with offset decorations.

Closes tensorflow/mlir#94

PiperOrigin-RevId: 264421427

mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
mlir/test/Dialect/SPIRV/Serialization/struct.mlir [new file with mode: 0644]

index cf87bfd..67a3ae7 100644 (file)
@@ -83,6 +83,7 @@ def SPV_OC_OpTypeInt               : I32EnumAttrCase<"OpTypeInt", 21>;
 def SPV_OC_OpTypeFloat             : I32EnumAttrCase<"OpTypeFloat", 22>;
 def SPV_OC_OpTypeVector            : I32EnumAttrCase<"OpTypeVector", 23>;
 def SPV_OC_OpTypeArray             : I32EnumAttrCase<"OpTypeArray", 28>;
+def SPV_OC_OpTypeStruct            : I32EnumAttrCase<"OpTypeStruct", 30>;
 def SPV_OC_OpTypePointer           : I32EnumAttrCase<"OpTypePointer", 32>;
 def SPV_OC_OpTypeFunction          : I32EnumAttrCase<"OpTypeFunction", 33>;
 def SPV_OC_OpConstantTrue          : I32EnumAttrCase<"OpConstantTrue", 41>;
@@ -102,6 +103,7 @@ def SPV_OC_OpLoad                  : I32EnumAttrCase<"OpLoad", 61>;
 def SPV_OC_OpStore                 : I32EnumAttrCase<"OpStore", 62>;
 def SPV_OC_OpAccessChain           : I32EnumAttrCase<"OpAccessChain", 65>;
 def SPV_OC_OpDecorate              : I32EnumAttrCase<"OpDecorate", 71>;
+def SPV_OC_OpMemberDecorate        : I32EnumAttrCase<"OpMemberDecorate", 72>;
 def SPV_OC_OpCompositeExtract      : I32EnumAttrCase<"OpCompositeExtract", 81>;
 def SPV_OC_OpIAdd                  : I32EnumAttrCase<"OpIAdd", 128>;
 def SPV_OC_OpFAdd                  : I32EnumAttrCase<"OpFAdd", 129>;
@@ -135,19 +137,20 @@ def SPV_OpcodeAttr :
       SPV_OC_OpNop, SPV_OC_OpName, SPV_OC_OpMemoryModel, SPV_OC_OpEntryPoint,
       SPV_OC_OpExecutionMode, SPV_OC_OpTypeVoid, SPV_OC_OpTypeBool, SPV_OC_OpTypeInt,
       SPV_OC_OpTypeFloat, SPV_OC_OpTypeVector, SPV_OC_OpTypeArray,
-      SPV_OC_OpTypePointer, SPV_OC_OpTypeFunction, SPV_OC_OpConstantTrue,
-      SPV_OC_OpConstantFalse, SPV_OC_OpConstant, SPV_OC_OpConstantComposite,
-      SPV_OC_OpConstantNull, SPV_OC_OpSpecConstantTrue, SPV_OC_OpSpecConstantFalse,
-      SPV_OC_OpSpecConstant, SPV_OC_OpSpecConstantComposite, SPV_OC_OpFunction,
-      SPV_OC_OpFunctionParameter, SPV_OC_OpFunctionEnd, SPV_OC_OpVariable,
-      SPV_OC_OpLoad, SPV_OC_OpStore, SPV_OC_OpAccessChain, SPV_OC_OpDecorate,
-      SPV_OC_OpCompositeExtract, SPV_OC_OpIAdd, SPV_OC_OpFAdd, SPV_OC_OpISub,
-      SPV_OC_OpFSub, SPV_OC_OpIMul, SPV_OC_OpFMul, SPV_OC_OpUDiv, SPV_OC_OpSDiv,
-      SPV_OC_OpFDiv, SPV_OC_OpUMod, SPV_OC_OpSRem, SPV_OC_OpSMod, SPV_OC_OpFRem,
-      SPV_OC_OpFMod, SPV_OC_OpIEqual, SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan,
-      SPV_OC_OpSGreaterThan, SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual,
-      SPV_OC_OpULessThan, SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual,
-      SPV_OC_OpSLessThanEqual, SPV_OC_OpReturn, SPV_OC_OpReturnValue
+      SPV_OC_OpTypeStruct, SPV_OC_OpTypePointer, SPV_OC_OpTypeFunction,
+      SPV_OC_OpConstantTrue, SPV_OC_OpConstantFalse, SPV_OC_OpConstant,
+      SPV_OC_OpConstantComposite, SPV_OC_OpConstantNull, SPV_OC_OpSpecConstantTrue,
+      SPV_OC_OpSpecConstantFalse, SPV_OC_OpSpecConstant, SPV_OC_OpSpecConstantComposite,
+      SPV_OC_OpFunction, SPV_OC_OpFunctionParameter, SPV_OC_OpFunctionEnd,
+      SPV_OC_OpVariable, SPV_OC_OpLoad, SPV_OC_OpStore, SPV_OC_OpAccessChain,
+      SPV_OC_OpDecorate,SPV_OC_OpMemberDecorate, SPV_OC_OpCompositeExtract,
+      SPV_OC_OpIAdd, SPV_OC_OpFAdd, SPV_OC_OpISub, SPV_OC_OpFSub, SPV_OC_OpIMul,
+      SPV_OC_OpFMul, SPV_OC_OpUDiv, SPV_OC_OpSDiv, SPV_OC_OpFDiv, SPV_OC_OpUMod,
+      SPV_OC_OpSRem, SPV_OC_OpSMod, SPV_OC_OpFRem, SPV_OC_OpFMod, SPV_OC_OpIEqual,
+      SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan, SPV_OC_OpSGreaterThan,
+      SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual, SPV_OC_OpULessThan,
+      SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual, SPV_OC_OpSLessThanEqual,
+      SPV_OC_OpReturn, SPV_OC_OpReturnValue
       ]> {
     let returnType = "::mlir::spirv::Opcode";
     let convertFromStorage = "static_cast<::mlir::spirv::Opcode>($_self.getInt())";
index a3d71ed..412487d 100644 (file)
@@ -28,6 +28,7 @@
 #include "mlir/IR/Location.h"
 #include "mlir/Support/LogicalResult.h"
 #include "mlir/Support/StringExtras.h"
+#include "llvm/ADT/Sequence.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/bit.h"
 
@@ -84,6 +85,9 @@ private:
   /// Method to process an OpDecorate instruction.
   LogicalResult processDecoration(ArrayRef<uint32_t> words);
 
+  // Method to process an OpMemberDecorate instruction.
+  LogicalResult processMemberDecoration(ArrayRef<uint32_t> words);
+
   /// Processes the SPIR-V function at the current `offset` into `binary`.
   /// The operands to the OpFunction instruction is passed in as ``operands`.
   /// This method processes each instruction inside the function and dispatches
@@ -122,6 +126,8 @@ private:
 
   LogicalResult processFunctionType(ArrayRef<uint32_t> operands);
 
+  LogicalResult processStructType(ArrayRef<uint32_t> operands);
+
   //===--------------------------------------------------------------------===//
   // Constant
   //===--------------------------------------------------------------------===//
@@ -232,6 +238,9 @@ private:
   // Result <id> to type decorations.
   DenseMap<uint32_t, uint32_t> typeDecorations;
 
+  // Result <id> to member decorations.
+  DenseMap<uint32_t, DenseMap<uint32_t, uint32_t>> memberDecorationMap;
+
   // List of instructions that are processed in a defered fashion (after an
   // initial processing of the entire binary). Some operations like
   // OpEntryPoint, and OpExecutionMode use forward references to function
@@ -368,6 +377,23 @@ LogicalResult Deserializer::processDecoration(ArrayRef<uint32_t> words) {
   return success();
 }
 
+LogicalResult Deserializer::processMemberDecoration(ArrayRef<uint32_t> words) {
+  // The binary layout of OpMemberDecorate is different comparing to OpDecorate
+  if (words.size() != 4) {
+    return emitError(unknownLoc, "OpMemberDecorate must have 4 operands");
+  }
+
+  switch (static_cast<spirv::Decoration>(words[2])) {
+  case spirv::Decoration::Offset:
+    memberDecorationMap[words[0]][words[1]] = words[3];
+    break;
+  default:
+    return emitError(unknownLoc, "unhandled OpMemberDecoration case: ")
+           << words[2];
+  }
+  return success();
+}
+
 LogicalResult Deserializer::processFunction(ArrayRef<uint32_t> operands) {
   // Get the result type
   if (operands.size() != 4) {
@@ -653,6 +679,8 @@ LogicalResult Deserializer::processType(spirv::Opcode opcode,
     return processArrayType(operands);
   case spirv::Opcode::OpTypeFunction:
     return processFunctionType(operands);
+  case spirv::Opcode::OpTypeStruct:
+    return processStructType(operands);
   default:
     return emitError(unknownLoc, "unhandled type instruction");
   }
@@ -722,6 +750,46 @@ LogicalResult Deserializer::processFunctionType(ArrayRef<uint32_t> operands) {
   return success();
 }
 
+LogicalResult Deserializer::processStructType(ArrayRef<uint32_t> operands) {
+  // TODO(ravishankarm) : Regarding to the spec spv.struct must support zero
+  // amount of members.
+  if (operands.size() < 2) {
+    return emitError(unknownLoc, "OpTypeStruct must have at least 2 operand");
+  }
+
+  SmallVector<Type, 0> memberTypes;
+  for (auto op : llvm::drop_begin(operands, 1)) {
+    Type memberType = getType(op);
+    if (!memberType) {
+      return emitError(unknownLoc, "OpTypeStruct references undefined <id> ")
+             << op;
+    }
+    memberTypes.push_back(memberType);
+  }
+
+  SmallVector<spirv::StructType::LayoutInfo, 0> layoutInfo;
+  // Check for layoutinfo
+  auto memberDecorationIt = memberDecorationMap.find(operands[0]);
+  if (memberDecorationIt != memberDecorationMap.end()) {
+    // Each member must have an offset
+    const auto &offsetDecorationMap = memberDecorationIt->second;
+    auto offsetDecorationMapEnd = offsetDecorationMap.end();
+    for (auto memberIndex : llvm::seq<uint32_t>(0, memberTypes.size())) {
+      // Check that specific member has an offset
+      auto offsetIt = offsetDecorationMap.find(memberIndex);
+      if (offsetIt == offsetDecorationMapEnd) {
+        return emitError(unknownLoc, "OpTypeStruct with <id> ")
+               << operands[0] << " must have an offset for " << memberIndex
+               << "-th member";
+      }
+      layoutInfo.push_back(
+          static_cast<spirv::StructType::LayoutInfo>(offsetIt->second));
+    }
+  }
+  typeMap[operands[0]] = spirv::StructType::get(memberTypes, layoutInfo);
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // Constant
 //===----------------------------------------------------------------------===//
@@ -993,6 +1061,7 @@ LogicalResult Deserializer::processInstruction(spirv::Opcode opcode,
   case spirv::Opcode::OpTypeVector:
   case spirv::Opcode::OpTypeArray:
   case spirv::Opcode::OpTypeFunction:
+  case spirv::Opcode::OpTypeStruct:
   case spirv::Opcode::OpTypePointer:
     return processType(opcode, operands);
   case spirv::Opcode::OpConstant:
@@ -1015,6 +1084,8 @@ LogicalResult Deserializer::processInstruction(spirv::Opcode opcode,
     return processConstantNull(operands);
   case spirv::Opcode::OpDecorate:
     return processDecoration(operands);
+  case spirv::Opcode::OpMemberDecorate:
+    return processMemberDecoration(operands);
   case spirv::Opcode::OpFunction:
     return processFunction(operands);
   default:
index 575d995..bc0b706 100644 (file)
@@ -28,6 +28,7 @@
 #include "mlir/IR/Builders.h"
 #include "mlir/Support/LogicalResult.h"
 #include "mlir/Support/StringExtras.h"
+#include "llvm/ADT/Sequence.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/bit.h"
 #include "llvm/Support/raw_ostream.h"
@@ -148,6 +149,11 @@ private:
     return emitError(loc, "unhandled decoraion for type:") << type;
   }
 
+  /// Process member decoration
+  LogicalResult processMemberDecoration(uint32_t structID, uint32_t memberNum,
+                                        spirv::Decoration decorationType,
+                                        uint32_t value);
+
   //===--------------------------------------------------------------------===//
   // Types
   //===--------------------------------------------------------------------===//
@@ -411,6 +417,16 @@ LogicalResult Serializer::processTypeDecoration<spirv::ArrayType>(
   }
   return success();
 }
+
+LogicalResult
+Serializer::processMemberDecoration(uint32_t structID, uint32_t memberIndex,
+                                    spirv::Decoration decorationType,
+                                    uint32_t value) {
+  SmallVector<uint32_t, 4> args(
+      {structID, memberIndex, static_cast<uint32_t>(decorationType), value});
+  return encodeInstructionInto(decorations, spirv::Opcode::OpMemberDecorate,
+                               args);
+}
 } // namespace
 
 LogicalResult Serializer::processFuncOp(FuncOp op) {
@@ -618,6 +634,31 @@ Serializer::prepareBasicType(Location loc, Type type, uint32_t resultID,
     return success();
   }
 
+  if (auto structType = type.dyn_cast<spirv::StructType>()) {
+    bool hasLayout = structType.hasLayout();
+    for (auto elementIndex :
+         llvm::seq<uint32_t>(0, structType.getNumElements())) {
+      uint32_t elementTypeID = 0;
+      if (failed(processType(loc, structType.getElementType(elementIndex),
+                             elementTypeID))) {
+        return failure();
+      }
+      operands.push_back(elementTypeID);
+      if (hasLayout) {
+        // Decorate each struct member with an offset
+        if (failed(processMemberDecoration(
+                resultID, elementIndex, spirv::Decoration::Offset,
+                static_cast<uint32_t>(structType.getOffset(elementIndex))))) {
+          return emitError(loc, "cannot decorate ")
+                 << elementIndex << "-th member of : " << structType
+                 << "with its offset";
+        }
+      }
+    }
+    typeEnum = spirv::Opcode::OpTypeStruct;
+    return success();
+  }
+
   // TODO(ravishankarm) : Handle other types.
   return emitError(loc, "unhandled type in serialization: ") << type;
 }
diff --git a/mlir/test/Dialect/SPIRV/Serialization/struct.mlir b/mlir/test/Dialect/SPIRV/Serialization/struct.mlir
new file mode 100644 (file)
index 0000000..ac885fb
--- /dev/null
@@ -0,0 +1,24 @@
+// RUN: mlir-translate -serialize-spirv %s | mlir-translate -deserialize-spirv | FileCheck %s
+
+func @spirvmodule() -> () {
+  spv.module "Logical" "VulkanKHR" {
+    // CHECK: !spv.ptr<!spv.struct<!spv.array<128 x f32 [4]> [0]>, Input>
+    spv.globalVariable @var0 bind(0, 1) : !spv.ptr<!spv.struct<!spv.array<128 x f32 [4]> [0]>, Input>
+
+    // CHECK: !spv.ptr<!spv.struct<f32 [0], !spv.struct<f32 [0], !spv.array<16 x f32 [4]> [4]> [4]>, Input>
+    spv.globalVariable @var1 bind(0, 2) : !spv.ptr<!spv.struct<f32 [0], !spv.struct<f32 [0], !spv.array<16 x f32 [4]> [4]> [4]>, Input>
+
+    // CHECK: !spv.ptr<!spv.struct<f32 [0], i32 [4], f64 [8], i64 [16], f32 [24], i32 [30], f32 [34], i32 [38]>, StorageBuffer>
+    spv.globalVariable @var2 : !spv.ptr<!spv.struct<f32 [0], i32 [4], f64 [8], i64 [16], f32 [24], i32 [30], f32 [34], i32 [38]>, StorageBuffer>
+
+    // CHECK: !spv.ptr<!spv.struct<!spv.array<128 x !spv.struct<!spv.array<128 x f32 [4]> [0]> [4]> [0]>, StorageBuffer>
+    spv.globalVariable @var3 : !spv.ptr<!spv.struct<!spv.array<128 x !spv.struct<!spv.array<128 x f32 [4]> [0]> [4]> [0]>, StorageBuffer>
+
+    // CHECK: !spv.ptr<!spv.struct<!spv.array<128 x f32 [4]> [0]>, Input>,
+    // CHECK-SAME: !spv.ptr<!spv.struct<!spv.array<128 x f32 [4]> [0]>, Output>
+    func @kernel_1(%arg0: !spv.ptr<!spv.struct<!spv.array<128 x f32 [4]> [0]>, Input>, %arg1: !spv.ptr<!spv.struct<!spv.array<128 x f32 [4]> [0]>, Output>) -> () {
+      spv.Return
+    }
+  }
+  return
+}