From 82cf6051ee7157a2883210baab191345cbd075bc Mon Sep 17 00:00:00 2001 From: Denis Khalikov Date: Tue, 20 Aug 2019 11:02:57 -0700 Subject: [PATCH] [spirv] Support (de)serialization of spv.struct Support (de)serialization of spv.struct with offset decorations. Closes tensorflow/mlir#94 PiperOrigin-RevId: 264421427 --- mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td | 29 +++++---- .../Dialect/SPIRV/Serialization/Deserializer.cpp | 71 ++++++++++++++++++++++ .../lib/Dialect/SPIRV/Serialization/Serializer.cpp | 41 +++++++++++++ mlir/test/Dialect/SPIRV/Serialization/struct.mlir | 24 ++++++++ 4 files changed, 152 insertions(+), 13 deletions(-) create mode 100644 mlir/test/Dialect/SPIRV/Serialization/struct.mlir diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td index cf87bfd..67a3ae7 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td @@ -83,6 +83,7 @@ def SPV_OC_OpTypeInt : I32EnumAttrCase<"OpTypeInt", 21>; def SPV_OC_OpTypeFloat : I32EnumAttrCase<"OpTypeFloat", 22>; def SPV_OC_OpTypeVector : I32EnumAttrCase<"OpTypeVector", 23>; def SPV_OC_OpTypeArray : I32EnumAttrCase<"OpTypeArray", 28>; +def SPV_OC_OpTypeStruct : I32EnumAttrCase<"OpTypeStruct", 30>; def SPV_OC_OpTypePointer : I32EnumAttrCase<"OpTypePointer", 32>; def SPV_OC_OpTypeFunction : I32EnumAttrCase<"OpTypeFunction", 33>; def SPV_OC_OpConstantTrue : I32EnumAttrCase<"OpConstantTrue", 41>; @@ -102,6 +103,7 @@ def SPV_OC_OpLoad : I32EnumAttrCase<"OpLoad", 61>; def SPV_OC_OpStore : I32EnumAttrCase<"OpStore", 62>; def SPV_OC_OpAccessChain : I32EnumAttrCase<"OpAccessChain", 65>; def SPV_OC_OpDecorate : I32EnumAttrCase<"OpDecorate", 71>; +def SPV_OC_OpMemberDecorate : I32EnumAttrCase<"OpMemberDecorate", 72>; def SPV_OC_OpCompositeExtract : I32EnumAttrCase<"OpCompositeExtract", 81>; def SPV_OC_OpIAdd : I32EnumAttrCase<"OpIAdd", 128>; def SPV_OC_OpFAdd : I32EnumAttrCase<"OpFAdd", 129>; @@ -135,19 +137,20 @@ def SPV_OpcodeAttr : SPV_OC_OpNop, SPV_OC_OpName, SPV_OC_OpMemoryModel, SPV_OC_OpEntryPoint, SPV_OC_OpExecutionMode, SPV_OC_OpTypeVoid, SPV_OC_OpTypeBool, SPV_OC_OpTypeInt, SPV_OC_OpTypeFloat, SPV_OC_OpTypeVector, SPV_OC_OpTypeArray, - SPV_OC_OpTypePointer, SPV_OC_OpTypeFunction, SPV_OC_OpConstantTrue, - SPV_OC_OpConstantFalse, SPV_OC_OpConstant, SPV_OC_OpConstantComposite, - SPV_OC_OpConstantNull, SPV_OC_OpSpecConstantTrue, SPV_OC_OpSpecConstantFalse, - SPV_OC_OpSpecConstant, SPV_OC_OpSpecConstantComposite, SPV_OC_OpFunction, - SPV_OC_OpFunctionParameter, SPV_OC_OpFunctionEnd, SPV_OC_OpVariable, - SPV_OC_OpLoad, SPV_OC_OpStore, SPV_OC_OpAccessChain, SPV_OC_OpDecorate, - SPV_OC_OpCompositeExtract, SPV_OC_OpIAdd, SPV_OC_OpFAdd, SPV_OC_OpISub, - SPV_OC_OpFSub, SPV_OC_OpIMul, SPV_OC_OpFMul, SPV_OC_OpUDiv, SPV_OC_OpSDiv, - SPV_OC_OpFDiv, SPV_OC_OpUMod, SPV_OC_OpSRem, SPV_OC_OpSMod, SPV_OC_OpFRem, - SPV_OC_OpFMod, SPV_OC_OpIEqual, SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan, - SPV_OC_OpSGreaterThan, SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual, - SPV_OC_OpULessThan, SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual, - SPV_OC_OpSLessThanEqual, SPV_OC_OpReturn, SPV_OC_OpReturnValue + SPV_OC_OpTypeStruct, SPV_OC_OpTypePointer, SPV_OC_OpTypeFunction, + SPV_OC_OpConstantTrue, SPV_OC_OpConstantFalse, SPV_OC_OpConstant, + SPV_OC_OpConstantComposite, SPV_OC_OpConstantNull, SPV_OC_OpSpecConstantTrue, + SPV_OC_OpSpecConstantFalse, SPV_OC_OpSpecConstant, SPV_OC_OpSpecConstantComposite, + SPV_OC_OpFunction, SPV_OC_OpFunctionParameter, SPV_OC_OpFunctionEnd, + SPV_OC_OpVariable, SPV_OC_OpLoad, SPV_OC_OpStore, SPV_OC_OpAccessChain, + SPV_OC_OpDecorate,SPV_OC_OpMemberDecorate, SPV_OC_OpCompositeExtract, + SPV_OC_OpIAdd, SPV_OC_OpFAdd, SPV_OC_OpISub, SPV_OC_OpFSub, SPV_OC_OpIMul, + SPV_OC_OpFMul, SPV_OC_OpUDiv, SPV_OC_OpSDiv, SPV_OC_OpFDiv, SPV_OC_OpUMod, + SPV_OC_OpSRem, SPV_OC_OpSMod, SPV_OC_OpFRem, SPV_OC_OpFMod, SPV_OC_OpIEqual, + SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan, SPV_OC_OpSGreaterThan, + SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual, SPV_OC_OpULessThan, + SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual, SPV_OC_OpSLessThanEqual, + SPV_OC_OpReturn, SPV_OC_OpReturnValue ]> { let returnType = "::mlir::spirv::Opcode"; let convertFromStorage = "static_cast<::mlir::spirv::Opcode>($_self.getInt())"; diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp index a3d71ed..412487d 100644 --- a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp @@ -28,6 +28,7 @@ #include "mlir/IR/Location.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Support/StringExtras.h" +#include "llvm/ADT/Sequence.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/bit.h" @@ -84,6 +85,9 @@ private: /// Method to process an OpDecorate instruction. LogicalResult processDecoration(ArrayRef words); + // Method to process an OpMemberDecorate instruction. + LogicalResult processMemberDecoration(ArrayRef words); + /// Processes the SPIR-V function at the current `offset` into `binary`. /// The operands to the OpFunction instruction is passed in as ``operands`. /// This method processes each instruction inside the function and dispatches @@ -122,6 +126,8 @@ private: LogicalResult processFunctionType(ArrayRef operands); + LogicalResult processStructType(ArrayRef operands); + //===--------------------------------------------------------------------===// // Constant //===--------------------------------------------------------------------===// @@ -232,6 +238,9 @@ private: // Result to type decorations. DenseMap typeDecorations; + // Result to member decorations. + DenseMap> memberDecorationMap; + // List of instructions that are processed in a defered fashion (after an // initial processing of the entire binary). Some operations like // OpEntryPoint, and OpExecutionMode use forward references to function @@ -368,6 +377,23 @@ LogicalResult Deserializer::processDecoration(ArrayRef words) { return success(); } +LogicalResult Deserializer::processMemberDecoration(ArrayRef words) { + // The binary layout of OpMemberDecorate is different comparing to OpDecorate + if (words.size() != 4) { + return emitError(unknownLoc, "OpMemberDecorate must have 4 operands"); + } + + switch (static_cast(words[2])) { + case spirv::Decoration::Offset: + memberDecorationMap[words[0]][words[1]] = words[3]; + break; + default: + return emitError(unknownLoc, "unhandled OpMemberDecoration case: ") + << words[2]; + } + return success(); +} + LogicalResult Deserializer::processFunction(ArrayRef operands) { // Get the result type if (operands.size() != 4) { @@ -653,6 +679,8 @@ LogicalResult Deserializer::processType(spirv::Opcode opcode, return processArrayType(operands); case spirv::Opcode::OpTypeFunction: return processFunctionType(operands); + case spirv::Opcode::OpTypeStruct: + return processStructType(operands); default: return emitError(unknownLoc, "unhandled type instruction"); } @@ -722,6 +750,46 @@ LogicalResult Deserializer::processFunctionType(ArrayRef operands) { return success(); } +LogicalResult Deserializer::processStructType(ArrayRef operands) { + // TODO(ravishankarm) : Regarding to the spec spv.struct must support zero + // amount of members. + if (operands.size() < 2) { + return emitError(unknownLoc, "OpTypeStruct must have at least 2 operand"); + } + + SmallVector memberTypes; + for (auto op : llvm::drop_begin(operands, 1)) { + Type memberType = getType(op); + if (!memberType) { + return emitError(unknownLoc, "OpTypeStruct references undefined ") + << op; + } + memberTypes.push_back(memberType); + } + + SmallVector layoutInfo; + // Check for layoutinfo + auto memberDecorationIt = memberDecorationMap.find(operands[0]); + if (memberDecorationIt != memberDecorationMap.end()) { + // Each member must have an offset + const auto &offsetDecorationMap = memberDecorationIt->second; + auto offsetDecorationMapEnd = offsetDecorationMap.end(); + for (auto memberIndex : llvm::seq(0, memberTypes.size())) { + // Check that specific member has an offset + auto offsetIt = offsetDecorationMap.find(memberIndex); + if (offsetIt == offsetDecorationMapEnd) { + return emitError(unknownLoc, "OpTypeStruct with ") + << operands[0] << " must have an offset for " << memberIndex + << "-th member"; + } + layoutInfo.push_back( + static_cast(offsetIt->second)); + } + } + typeMap[operands[0]] = spirv::StructType::get(memberTypes, layoutInfo); + return success(); +} + //===----------------------------------------------------------------------===// // Constant //===----------------------------------------------------------------------===// @@ -993,6 +1061,7 @@ LogicalResult Deserializer::processInstruction(spirv::Opcode opcode, case spirv::Opcode::OpTypeVector: case spirv::Opcode::OpTypeArray: case spirv::Opcode::OpTypeFunction: + case spirv::Opcode::OpTypeStruct: case spirv::Opcode::OpTypePointer: return processType(opcode, operands); case spirv::Opcode::OpConstant: @@ -1015,6 +1084,8 @@ LogicalResult Deserializer::processInstruction(spirv::Opcode opcode, return processConstantNull(operands); case spirv::Opcode::OpDecorate: return processDecoration(operands); + case spirv::Opcode::OpMemberDecorate: + return processMemberDecoration(operands); case spirv::Opcode::OpFunction: return processFunction(operands); default: diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp index 575d995..bc0b706 100644 --- a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp @@ -28,6 +28,7 @@ #include "mlir/IR/Builders.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Support/StringExtras.h" +#include "llvm/ADT/Sequence.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/bit.h" #include "llvm/Support/raw_ostream.h" @@ -148,6 +149,11 @@ private: return emitError(loc, "unhandled decoraion for type:") << type; } + /// Process member decoration + LogicalResult processMemberDecoration(uint32_t structID, uint32_t memberNum, + spirv::Decoration decorationType, + uint32_t value); + //===--------------------------------------------------------------------===// // Types //===--------------------------------------------------------------------===// @@ -411,6 +417,16 @@ LogicalResult Serializer::processTypeDecoration( } return success(); } + +LogicalResult +Serializer::processMemberDecoration(uint32_t structID, uint32_t memberIndex, + spirv::Decoration decorationType, + uint32_t value) { + SmallVector args( + {structID, memberIndex, static_cast(decorationType), value}); + return encodeInstructionInto(decorations, spirv::Opcode::OpMemberDecorate, + args); +} } // namespace LogicalResult Serializer::processFuncOp(FuncOp op) { @@ -618,6 +634,31 @@ Serializer::prepareBasicType(Location loc, Type type, uint32_t resultID, return success(); } + if (auto structType = type.dyn_cast()) { + bool hasLayout = structType.hasLayout(); + for (auto elementIndex : + llvm::seq(0, structType.getNumElements())) { + uint32_t elementTypeID = 0; + if (failed(processType(loc, structType.getElementType(elementIndex), + elementTypeID))) { + return failure(); + } + operands.push_back(elementTypeID); + if (hasLayout) { + // Decorate each struct member with an offset + if (failed(processMemberDecoration( + resultID, elementIndex, spirv::Decoration::Offset, + static_cast(structType.getOffset(elementIndex))))) { + return emitError(loc, "cannot decorate ") + << elementIndex << "-th member of : " << structType + << "with its offset"; + } + } + } + typeEnum = spirv::Opcode::OpTypeStruct; + return success(); + } + // TODO(ravishankarm) : Handle other types. return emitError(loc, "unhandled type in serialization: ") << type; } diff --git a/mlir/test/Dialect/SPIRV/Serialization/struct.mlir b/mlir/test/Dialect/SPIRV/Serialization/struct.mlir new file mode 100644 index 0000000..ac885fb --- /dev/null +++ b/mlir/test/Dialect/SPIRV/Serialization/struct.mlir @@ -0,0 +1,24 @@ +// RUN: mlir-translate -serialize-spirv %s | mlir-translate -deserialize-spirv | FileCheck %s + +func @spirvmodule() -> () { + spv.module "Logical" "VulkanKHR" { + // CHECK: !spv.ptr [0]>, Input> + spv.globalVariable @var0 bind(0, 1) : !spv.ptr [0]>, Input> + + // CHECK: !spv.ptr [4]> [4]>, Input> + spv.globalVariable @var1 bind(0, 2) : !spv.ptr [4]> [4]>, Input> + + // CHECK: !spv.ptr, StorageBuffer> + spv.globalVariable @var2 : !spv.ptr, StorageBuffer> + + // CHECK: !spv.ptr [0]> [4]> [0]>, StorageBuffer> + spv.globalVariable @var3 : !spv.ptr [0]> [4]> [0]>, StorageBuffer> + + // CHECK: !spv.ptr [0]>, Input>, + // CHECK-SAME: !spv.ptr [0]>, Output> + func @kernel_1(%arg0: !spv.ptr [0]>, Input>, %arg1: !spv.ptr [0]>, Output>) -> () { + spv.Return + } + } + return +} -- 2.7.4