[mlir][spirv] Enhance structure type member decoration handling
authorHazemAbdelhafez <23439402+HazemAbdelhafez@users.noreply.github.com>
Wed, 10 Jun 2020 23:15:55 +0000 (19:15 -0400)
committerLei Zhang <antiagainst@google.com>
Wed, 10 Jun 2020 23:25:03 +0000 (19:25 -0400)
Modify structure type in SPIR-V dialect to support:
1) Multiple decorations per structure member
2) Key-value based decorations (e.g., MatrixStride)

This commit kept the Offset decoration separate from members'
decorations container for easier implementation and logical clarity.
As such, all references to Structure layoutinfo are now offsetinfo,
and any member layout defining decoration (e.g., RowMajor for Matrix)
will be add to the members' decorations container along with its
value if any.

Differential Revision: https://reviews.llvm.org/D81426

mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h
mlir/include/mlir/IR/DialectImplementation.h
mlir/lib/Dialect/SPIRV/LayoutUtils.cpp
mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp
mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp
mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
mlir/lib/Parser/Parser.cpp
mlir/test/Dialect/SPIRV/Serialization/struct.mlir
mlir/test/Dialect/SPIRV/types.mlir
mlir/unittests/Dialect/SPIRV/SerializationTest.cpp

index b718039..3addfd3 100644 (file)
@@ -276,22 +276,39 @@ class StructType : public Type::TypeBase<StructType, CompositeType,
 public:
   using Base::Base;
 
-  // Layout information used for members in a struct in SPIR-V
-  //
-  // TODO(ravishankarm) : For now this only supports the offset type, so uses
-  // uint64_t value to represent the offset, with
-  // std::numeric_limit<uint64_t>::max indicating no offset. Change this to
-  // something that can hold all the information needed for different member
-  // types
-  using LayoutInfo = uint64_t;
+  // Type for specifying the offset of the struct members
+  using OffsetInfo = uint32_t;
+
+  // Type for specifying the decoration(s) on struct members
+  struct MemberDecorationInfo {
+    uint32_t memberIndex : 31;
+    uint32_t hasValue : 1;
+    Decoration decoration;
+    uint32_t decorationValue;
+
+    MemberDecorationInfo(uint32_t index, uint32_t hasValue,
+                         Decoration decoration, uint32_t decorationValue)
+        : memberIndex(index), hasValue(hasValue), decoration(decoration),
+          decorationValue(decorationValue) {}
+
+    bool operator==(const MemberDecorationInfo &other) const {
+      return (this->memberIndex == other.memberIndex) &&
+             (this->decoration == other.decoration) &&
+             (this->decorationValue == other.decorationValue);
+    }
 
-  using MemberDecorationInfo = std::pair<uint32_t, spirv::Decoration>;
+    bool operator<(const MemberDecorationInfo &other) const {
+      return this->memberIndex < other.memberIndex ||
+             (this->memberIndex == other.memberIndex &&
+              this->decoration < other.decoration);
+    }
+  };
 
   static bool kindof(unsigned kind) { return kind == TypeKind::Struct; }
 
   /// Construct a StructType with at least one member.
   static StructType get(ArrayRef<Type> memberTypes,
-                        ArrayRef<LayoutInfo> layoutInfo = {},
+                        ArrayRef<OffsetInfo> offsetInfo = {},
                         ArrayRef<MemberDecorationInfo> memberDecorations = {});
 
   /// Construct a struct with no members.
@@ -323,9 +340,9 @@ public:
 
   ElementTypeRange getElementTypes() const;
 
-  bool hasLayout() const;
+  bool hasOffset() const;
 
-  uint64_t getOffset(unsigned) const;
+  uint64_t getMemberOffset(unsigned) const;
 
   // Returns in `allMemberDecorations` the spirv::Decorations (apart from
   // Offset) associated with all members of the StructType.
@@ -334,8 +351,9 @@ public:
 
   // Returns in `memberDecorations` all the spirv::Decorations (apart from
   // Offset) associated with the `i`-th member of the StructType.
-  void getMemberDecorations(
-      unsigned i, SmallVectorImpl<spirv::Decoration> &memberDecorations) const;
+  void getMemberDecorations(unsigned i,
+                            SmallVectorImpl<StructType::MemberDecorationInfo>
+                                &memberDecorations) const;
 
   void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
                      Optional<spirv::StorageClass> storage = llvm::None);
@@ -343,6 +361,9 @@ public:
                        Optional<spirv::StorageClass> storage = llvm::None);
 };
 
+llvm::hash_code
+hash_value(const StructType::MemberDecorationInfo &memberDecorationInfo);
+
 // SPIR-V cooperative matrix type
 class CooperativeMatrixNVType
     : public Type::TypeBase<CooperativeMatrixNVType, CompositeType,
index a891c8c..e2d7e2c 100644 (file)
@@ -200,6 +200,9 @@ public:
   /// Parse a `=` token.
   virtual ParseResult parseEqual() = 0;
 
+  /// Parse a `=` token if present.
+  virtual ParseResult parseOptionalEqual() = 0;
+
   /// Parse a given keyword.
   ParseResult parseKeyword(StringRef keyword, const Twine &msg = "") {
     auto loc = getCurrentLocation();
index f953f95..e456f49 100644 (file)
@@ -32,7 +32,7 @@ VulkanLayoutUtils::decorateType(spirv::StructType structType,
   }
 
   SmallVector<Type, 4> memberTypes;
-  SmallVector<Size, 4> layoutInfo;
+  SmallVector<spirv::StructType::OffsetInfo, 4> offsetInfo;
   SmallVector<spirv::StructType::MemberDecorationInfo, 4> memberDecorations;
 
   Size structMemberOffset = 0;
@@ -46,7 +46,8 @@ VulkanLayoutUtils::decorateType(spirv::StructType structType,
         decorateType(structType.getElementType(i), memberSize, memberAlignment);
     structMemberOffset = llvm::alignTo(structMemberOffset, memberAlignment);
     memberTypes.push_back(memberType);
-    layoutInfo.push_back(structMemberOffset);
+    offsetInfo.push_back(
+        static_cast<spirv::StructType::OffsetInfo>(structMemberOffset));
     // If the member's size is the max value, it must be the last member and it
     // must be a runtime array.
     assert(memberSize != std::numeric_limits<Size>().max() ||
@@ -66,7 +67,7 @@ VulkanLayoutUtils::decorateType(spirv::StructType structType,
   size = llvm::alignTo(structMemberOffset, maxMemberAlignment);
   alignment = maxMemberAlignment;
   structType.getMemberDecorations(memberDecorations);
-  return spirv::StructType::get(memberTypes, layoutInfo, memberDecorations);
+  return spirv::StructType::get(memberTypes, offsetInfo, memberDecorations);
 }
 
 Type VulkanLayoutUtils::decorateType(Type type, VulkanLayoutUtils::Size &size,
@@ -168,7 +169,7 @@ bool VulkanLayoutUtils::isLegalType(Type type) {
   case spirv::StorageClass::StorageBuffer:
   case spirv::StorageClass::PushConstant:
   case spirv::StorageClass::PhysicalStorageBuffer:
-    return structType.hasLayout() || !structType.getNumElements();
+    return structType.hasOffset() || !structType.getNumElements();
   default:
     return true;
   }
index 455064f..43e70a1 100644 (file)
@@ -535,30 +535,31 @@ static Type parseImageType(SPIRVDialect const &dialect,
 static ParseResult parseStructMemberDecorations(
     SPIRVDialect const &dialect, DialectAsmParser &parser,
     ArrayRef<Type> memberTypes,
-    SmallVectorImpl<StructType::LayoutInfo> &layoutInfo,
+    SmallVectorImpl<StructType::OffsetInfo> &offsetInfo,
     SmallVectorImpl<StructType::MemberDecorationInfo> &memberDecorationInfo) {
 
   // Check if the first element is offset.
-  llvm::SMLoc layoutLoc = parser.getCurrentLocation();
-  StructType::LayoutInfo layout = 0;
-  OptionalParseResult layoutParseResult = parser.parseOptionalInteger(layout);
-  if (layoutParseResult.hasValue()) {
-    if (failed(*layoutParseResult))
+  llvm::SMLoc offsetLoc = parser.getCurrentLocation();
+  StructType::OffsetInfo offset = 0;
+  OptionalParseResult offsetParseResult = parser.parseOptionalInteger(offset);
+  if (offsetParseResult.hasValue()) {
+    if (failed(*offsetParseResult))
       return failure();
 
-    if (layoutInfo.size() != memberTypes.size() - 1) {
-      return parser.emitError(
-          layoutLoc, "layout specification must be given for all members");
+    if (offsetInfo.size() != memberTypes.size() - 1) {
+      return parser.emitError(offsetLoc,
+                              "offset specification must be given for "
+                              "all members");
     }
-    layoutInfo.push_back(layout);
+    offsetInfo.push_back(offset);
   }
 
   // Check for no spirv::Decorations.
   if (succeeded(parser.parseOptionalRSquare()))
     return success();
 
-  // If there was a layout, make sure to parse the comma.
-  if (layoutParseResult.hasValue() && parser.parseComma())
+  // If there was an offset, make sure to parse the comma.
+  if (offsetParseResult.hasValue() && parser.parseComma())
     return failure();
 
   // Check for spirv::Decorations.
@@ -567,9 +568,23 @@ static ParseResult parseStructMemberDecorations(
     if (!memberDecoration)
       return failure();
 
-    memberDecorationInfo.emplace_back(
-        static_cast<uint32_t>(memberTypes.size() - 1),
-        memberDecoration.getValue());
+    // Parse member decoration value if it exists.
+    if (succeeded(parser.parseOptionalEqual())) {
+      auto memberDecorationValue =
+          parseAndVerifyInteger<uint32_t>(dialect, parser);
+
+      if (!memberDecorationValue)
+        return failure();
+
+      memberDecorationInfo.emplace_back(
+          static_cast<uint32_t>(memberTypes.size() - 1), 1,
+          memberDecoration.getValue(), memberDecorationValue.getValue());
+    } else {
+      memberDecorationInfo.emplace_back(
+          static_cast<uint32_t>(memberTypes.size() - 1), 0,
+          memberDecoration.getValue(), 0);
+    }
+
   } while (succeeded(parser.parseOptionalComma()));
 
   return parser.parseRSquare();
@@ -587,7 +602,7 @@ static Type parseStructType(SPIRVDialect const &dialect,
     return StructType::getEmpty(dialect.getContext());
 
   SmallVector<Type, 4> memberTypes;
-  SmallVector<StructType::LayoutInfo, 4> layoutInfo;
+  SmallVector<StructType::OffsetInfo, 4> offsetInfo;
   SmallVector<StructType::MemberDecorationInfo, 4> memberDecorationInfo;
 
   do {
@@ -597,21 +612,21 @@ static Type parseStructType(SPIRVDialect const &dialect,
     memberTypes.push_back(memberType);
 
     if (succeeded(parser.parseOptionalLSquare())) {
-      if (parseStructMemberDecorations(dialect, parser, memberTypes, layoutInfo,
+      if (parseStructMemberDecorations(dialect, parser, memberTypes, offsetInfo,
                                        memberDecorationInfo)) {
         return Type();
       }
     }
   } while (succeeded(parser.parseOptionalComma()));
 
-  if (!layoutInfo.empty() && memberTypes.size() != layoutInfo.size()) {
+  if (!offsetInfo.empty() && memberTypes.size() != offsetInfo.size()) {
     parser.emitError(parser.getNameLoc(),
-                     "layout specification must be given for all members");
+                     "offset specification must be given for all members");
     return Type();
   }
   if (parser.parseGreater())
     return Type();
-  return StructType::get(memberTypes, layoutInfo, memberDecorationInfo);
+  return StructType::get(memberTypes, offsetInfo, memberDecorationInfo);
 }
 
 // spirv-type ::= array-type
@@ -679,17 +694,20 @@ static void print(StructType type, DialectAsmPrinter &os) {
   os << "struct<";
   auto printMember = [&](unsigned i) {
     os << type.getElementType(i);
-    SmallVector<spirv::Decoration, 0> decorations;
+    SmallVector<spirv::StructType::MemberDecorationInfo, 0> decorations;
     type.getMemberDecorations(i, decorations);
-    if (type.hasLayout() || !decorations.empty()) {
+    if (type.hasOffset() || !decorations.empty()) {
       os << " [";
-      if (type.hasLayout()) {
-        os << type.getOffset(i);
+      if (type.hasOffset()) {
+        os << type.getMemberOffset(i);
         if (!decorations.empty())
           os << ", ";
       }
-      auto eachFn = [&os](spirv::Decoration decoration) {
-        os << stringifyDecoration(decoration);
+      auto eachFn = [&os](spirv::StructType::MemberDecorationInfo decoration) {
+        os << stringifyDecoration(decoration.decoration);
+        if (decoration.hasValue) {
+          os << "=" << decoration.decorationValue;
+        }
       };
       llvm::interleaveComma(decorations, os, eachFn);
       os << "]";
index 0226f51..963f539 100644 (file)
@@ -874,17 +874,17 @@ void SPIRVType::getCapabilities(
 struct spirv::detail::StructTypeStorage : public TypeStorage {
   StructTypeStorage(
       unsigned numMembers, Type const *memberTypes,
-      StructType::LayoutInfo const *layoutInfo, unsigned numMemberDecorations,
+      StructType::OffsetInfo const *layoutInfo, unsigned numMemberDecorations,
       StructType::MemberDecorationInfo const *memberDecorationsInfo)
       : TypeStorage(numMembers), memberTypes(memberTypes),
-        layoutInfo(layoutInfo), numMemberDecorations(numMemberDecorations),
+        offsetInfo(layoutInfo), numMemberDecorations(numMemberDecorations),
         memberDecorationsInfo(memberDecorationsInfo) {}
 
-  using KeyTy = std::tuple<ArrayRef<Type>, ArrayRef<StructType::LayoutInfo>,
+  using KeyTy = std::tuple<ArrayRef<Type>, ArrayRef<StructType::OffsetInfo>,
                            ArrayRef<StructType::MemberDecorationInfo>>;
   bool operator==(const KeyTy &key) const {
     return key ==
-           KeyTy(getMemberTypes(), getLayoutInfo(), getMemberDecorationsInfo());
+           KeyTy(getMemberTypes(), getOffsetInfo(), getMemberDecorationsInfo());
   }
 
   static StructTypeStorage *construct(TypeStorageAllocator &allocator,
@@ -897,13 +897,13 @@ struct spirv::detail::StructTypeStorage : public TypeStorage {
       typesList = allocator.copyInto(keyTypes).data();
     }
 
-    const StructType::LayoutInfo *layoutInfoList = nullptr;
+    const StructType::OffsetInfo *offsetInfoList = nullptr;
     if (!std::get<1>(key).empty()) {
-      ArrayRef<StructType::LayoutInfo> keyLayoutInfo = std::get<1>(key);
-      assert(keyLayoutInfo.size() == keyTypes.size() &&
-             "size of layout information must be same as the size of number of "
+      ArrayRef<StructType::OffsetInfo> keyOffsetInfo = std::get<1>(key);
+      assert(keyOffsetInfo.size() == keyTypes.size() &&
+             "size of offset information must be same as the size of number of "
              "elements");
-      layoutInfoList = allocator.copyInto(keyLayoutInfo).data();
+      offsetInfoList = allocator.copyInto(keyOffsetInfo).data();
     }
 
     const StructType::MemberDecorationInfo *memberDecorationList = nullptr;
@@ -914,7 +914,7 @@ struct spirv::detail::StructTypeStorage : public TypeStorage {
       memberDecorationList = allocator.copyInto(keyMemberDecorations).data();
     }
     return new (allocator.allocate<StructTypeStorage>())
-        StructTypeStorage(keyTypes.size(), typesList, layoutInfoList,
+        StructTypeStorage(keyTypes.size(), typesList, offsetInfoList,
                           numMemberDecorations, memberDecorationList);
   }
 
@@ -922,9 +922,9 @@ struct spirv::detail::StructTypeStorage : public TypeStorage {
     return ArrayRef<Type>(memberTypes, getSubclassData());
   }
 
-  ArrayRef<StructType::LayoutInfo> getLayoutInfo() const {
-    if (layoutInfo) {
-      return ArrayRef<StructType::LayoutInfo>(layoutInfo, getSubclassData());
+  ArrayRef<StructType::OffsetInfo> getOffsetInfo() const {
+    if (offsetInfo) {
+      return ArrayRef<StructType::OffsetInfo>(offsetInfo, getSubclassData());
     }
     return {};
   }
@@ -938,14 +938,14 @@ struct spirv::detail::StructTypeStorage : public TypeStorage {
   }
 
   Type const *memberTypes;
-  StructType::LayoutInfo const *layoutInfo;
+  StructType::OffsetInfo const *offsetInfo;
   unsigned numMemberDecorations;
   StructType::MemberDecorationInfo const *memberDecorationsInfo;
 };
 
 StructType
 StructType::get(ArrayRef<Type> memberTypes,
-                ArrayRef<StructType::LayoutInfo> layoutInfo,
+                ArrayRef<StructType::OffsetInfo> offsetInfo,
                 ArrayRef<StructType::MemberDecorationInfo> memberDecorations) {
   assert(!memberTypes.empty() && "Struct needs at least one member type");
   // Sort the decorations.
@@ -953,12 +953,12 @@ StructType::get(ArrayRef<Type> memberTypes,
       memberDecorations.begin(), memberDecorations.end());
   llvm::array_pod_sort(sortedDecorations.begin(), sortedDecorations.end());
   return Base::get(memberTypes.vec().front().getContext(), TypeKind::Struct,
-                   memberTypes, layoutInfo, sortedDecorations);
+                   memberTypes, offsetInfo, sortedDecorations);
 }
 
 StructType StructType::getEmpty(MLIRContext *context) {
   return Base::get(context, TypeKind::Struct, ArrayRef<Type>(),
-                   ArrayRef<StructType::LayoutInfo>(),
+                   ArrayRef<StructType::OffsetInfo>(),
                    ArrayRef<StructType::MemberDecorationInfo>());
 }
 
@@ -975,11 +975,11 @@ StructType::ElementTypeRange StructType::getElementTypes() const {
   return ElementTypeRange(getImpl()->memberTypes, getNumElements());
 }
 
-bool StructType::hasLayout() const { return getImpl()->layoutInfo; }
+bool StructType::hasOffset() const { return getImpl()->offsetInfo; }
 
-uint64_t StructType::getOffset(unsigned index) const {
+uint64_t StructType::getMemberOffset(unsigned index) const {
   assert(getNumElements() > index && "member index out of range");
-  return getImpl()->layoutInfo[index];
+  return getImpl()->offsetInfo[index];
 }
 
 void StructType::getMemberDecorations(
@@ -992,15 +992,16 @@ void StructType::getMemberDecorations(
 }
 
 void StructType::getMemberDecorations(
-    unsigned index, SmallVectorImpl<spirv::Decoration> &decorations) const {
+    unsigned index,
+    SmallVectorImpl<StructType::MemberDecorationInfo> &decorationsInfo) const {
   assert(getNumElements() > index && "member index out of range");
   auto memberDecorations = getImpl()->getMemberDecorationsInfo();
-  decorations.clear();
-  for (auto &memberDecoration : memberDecorations) {
-    if (memberDecoration.first == index) {
-      decorations.push_back(memberDecoration.second);
+  decorationsInfo.clear();
+  for (const auto &memberDecoration : memberDecorations) {
+    if (memberDecoration.memberIndex == index) {
+      decorationsInfo.push_back(memberDecoration);
     }
-    if (memberDecoration.first > index) {
+    if (memberDecoration.memberIndex > index) {
       // Early exit since the decorations are stored sorted.
       return;
     }
@@ -1020,6 +1021,12 @@ void StructType::getCapabilities(
     elementType.cast<SPIRVType>().getCapabilities(capabilities, storage);
 }
 
+llvm::hash_code spirv::hash_value(
+    const StructType::MemberDecorationInfo &memberDecorationInfo) {
+  return llvm::hash_combine(memberDecorationInfo.memberIndex,
+                            memberDecorationInfo.decoration);
+}
+
 //===----------------------------------------------------------------------===//
 // MatrixType
 //===----------------------------------------------------------------------===//
index ecd79d7..4bb9e6d 100644 (file)
@@ -1305,7 +1305,7 @@ LogicalResult Deserializer::processStructType(ArrayRef<uint32_t> operands) {
     memberTypes.push_back(memberType);
   }
 
-  SmallVector<spirv::StructType::LayoutInfo, 0> layoutInfo;
+  SmallVector<spirv::StructType::OffsetInfo, 0> offsetInfo;
   SmallVector<spirv::StructType::MemberDecorationInfo, 0> memberDecorationsInfo;
   if (memberDecorationMap.count(operands[0])) {
     auto &allMemberDecorations = memberDecorationMap[operands[0]];
@@ -1314,27 +1314,27 @@ LogicalResult Deserializer::processStructType(ArrayRef<uint32_t> operands) {
         for (auto &memberDecoration : allMemberDecorations[memberIndex]) {
           // Check for offset.
           if (memberDecoration.first == spirv::Decoration::Offset) {
-            // If layoutInfo is empty, resize to the number of members;
-            if (layoutInfo.empty()) {
-              layoutInfo.resize(memberTypes.size());
+            // If offset info is empty, resize to the number of members;
+            if (offsetInfo.empty()) {
+              offsetInfo.resize(memberTypes.size());
             }
-            layoutInfo[memberIndex] = memberDecoration.second[0];
+            offsetInfo[memberIndex] = memberDecoration.second[0];
           } else {
             if (!memberDecoration.second.empty()) {
-              return emitError(unknownLoc,
-                               "unhandled OpMemberDecoration with decoration ")
-                     << stringifyDecoration(memberDecoration.first)
-                     << " which has additional operands";
+              memberDecorationsInfo.emplace_back(memberIndex, /*hasValue=*/1,
+                                                 memberDecoration.first,
+                                                 memberDecoration.second[0]);
+            } else {
+              memberDecorationsInfo.emplace_back(memberIndex, /*hasValue=*/0,
+                                                 memberDecoration.first, 0);
             }
-            memberDecorationsInfo.emplace_back(memberIndex,
-                                               memberDecoration.first);
           }
         }
       }
     }
   }
   typeMap[operands[0]] =
-      spirv::StructType::get(memberTypes, layoutInfo, memberDecorationsInfo);
+      spirv::StructType::get(memberTypes, offsetInfo, memberDecorationsInfo);
   // TODO(ravishankarm): Update StructType to have member name as attribute as
   // well.
   return success();
index 81f8732..f864187 100644 (file)
@@ -227,9 +227,9 @@ private:
   }
 
   /// Process member decoration
-  LogicalResult processMemberDecoration(uint32_t structID, uint32_t memberIndex,
-                                        spirv::Decoration decorationType,
-                                        ArrayRef<uint32_t> values = {});
+  LogicalResult processMemberDecoration(
+      uint32_t structID,
+      const spirv::StructType::MemberDecorationInfo &memberDecorationInfo);
 
   //===--------------------------------------------------------------------===//
   // Types
@@ -736,14 +736,14 @@ LogicalResult Serializer::processTypeDecoration<spirv::RuntimeArrayType>(
   return success();
 }
 
-LogicalResult
-Serializer::processMemberDecoration(uint32_t structID, uint32_t memberIndex,
-                                    spirv::Decoration decorationType,
-                                    ArrayRef<uint32_t> values) {
+LogicalResult Serializer::processMemberDecoration(
+    uint32_t structID,
+    const spirv::StructType::MemberDecorationInfo &memberDecoration) {
   SmallVector<uint32_t, 4> args(
-      {structID, memberIndex, static_cast<uint32_t>(decorationType)});
-  if (!values.empty()) {
-    args.append(values.begin(), values.end());
+      {structID, memberDecoration.memberIndex,
+       static_cast<uint32_t>(memberDecoration.decoration)});
+  if (memberDecoration.hasValue) {
+    args.push_back(memberDecoration.decorationValue);
   }
   return encodeInstructionInto(decorations, spirv::Opcode::OpMemberDecorate,
                                args);
@@ -1070,7 +1070,7 @@ Serializer::prepareBasicType(Location loc, Type type, uint32_t resultID,
   }
 
   if (auto structType = type.dyn_cast<spirv::StructType>()) {
-    bool hasLayout = structType.hasLayout();
+    bool hasOffset = structType.hasOffset();
     for (auto elementIndex :
          llvm::seq<uint32_t>(0, structType.getNumElements())) {
       uint32_t elementTypeID = 0;
@@ -1079,11 +1079,12 @@ Serializer::prepareBasicType(Location loc, Type type, uint32_t resultID,
         return failure();
       }
       operands.push_back(elementTypeID);
-      if (hasLayout) {
+      if (hasOffset) {
         // Decorate each struct member with an offset
-        if (failed(processMemberDecoration(
-                resultID, elementIndex, spirv::Decoration::Offset,
-                static_cast<uint32_t>(structType.getOffset(elementIndex))))) {
+        spirv::StructType::MemberDecorationInfo offsetDecoration{
+            elementIndex, /*hasValue=*/1, spirv::Decoration::Offset,
+            static_cast<uint32_t>(structType.getMemberOffset(elementIndex))};
+        if (failed(processMemberDecoration(resultID, offsetDecoration))) {
           return emitError(loc, "cannot decorate ")
                  << elementIndex << "-th member of " << structType
                  << " with its offset";
@@ -1093,11 +1094,11 @@ Serializer::prepareBasicType(Location loc, Type type, uint32_t resultID,
     SmallVector<spirv::StructType::MemberDecorationInfo, 4> memberDecorations;
     structType.getMemberDecorations(memberDecorations);
     for (auto &memberDecoration : memberDecorations) {
-      if (failed(processMemberDecoration(resultID, memberDecoration.first,
-                                         memberDecoration.second))) {
+      if (failed(processMemberDecoration(resultID, memberDecoration))) {
         return emitError(loc, "cannot decorate ")
-               << memberDecoration.first << "-th member of " << structType
-               << " with " << stringifyDecoration(memberDecoration.second);
+               << static_cast<uint32_t>(memberDecoration.memberIndex)
+               << "-th member of " << structType << " with "
+               << stringifyDecoration(memberDecoration.decoration);
       }
     }
     typeEnum = spirv::Opcode::OpTypeStruct;
index 88e576c..be96a39 100644 (file)
@@ -544,6 +544,11 @@ public:
     return parser.parseToken(Token::equal, "expected '='");
   }
 
+  /// Parse a `=` token if present.
+  ParseResult parseOptionalEqual() override {
+    return success(parser.consumeIf(Token::equal));
+  }
+
   /// Parse a '<' token.
   ParseResult parseLess() override {
     return parser.parseToken(Token::less, "expected '<'");
index 3066462..fff591d 100644 (file)
@@ -22,6 +22,9 @@ spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
   // CHECK: !spv.ptr<!spv.struct<f32 [0, NonWritable], i32 [4, NonWritable, NonReadable]>, StorageBuffer>
   spv.globalVariable @var6 : !spv.ptr<!spv.struct<f32 [0, NonWritable], i32 [4, NonWritable, NonReadable]>, StorageBuffer>
 
+  // CHECK: !spv.ptr<!spv.struct<!spv.matrix<3 x vector<3xf32>> [0, ColMajor, MatrixStride=16]>, StorageBuffer>
+  spv.globalVariable @var7 : !spv.ptr<!spv.struct<!spv.matrix<3 x vector<3xf32>> [0, ColMajor, MatrixStride=16]>, StorageBuffer>
+
   // CHECK: !spv.ptr<!spv.struct<>, StorageBuffer>
   spv.globalVariable @empty : !spv.ptr<!spv.struct<>, StorageBuffer>
 
index 1d1a186..d5eb073 100644 (file)
@@ -275,17 +275,23 @@ func @struct_type_with_decoration7(!spv.struct<f32 [0], !spv.struct<i32, f32 [No
 // CHECK: func @struct_type_with_decoration8(!spv.struct<f32, !spv.struct<i32 [0], f32 [4, NonReadable]>>)
 func @struct_type_with_decoration8(!spv.struct<f32, !spv.struct<i32 [0], f32 [4, NonReadable]>>)
 
+// CHECK: func @struct_type_with_matrix_1(!spv.struct<!spv.matrix<3 x vector<3xf32>> [0, ColMajor, MatrixStride=16]>)
+func @struct_type_with_matrix_1(!spv.struct<!spv.matrix<3 x vector<3xf32>> [0, ColMajor, MatrixStride=16]>)
+
+// CHECK: func @struct_type_with_matrix_2(!spv.struct<!spv.matrix<3 x vector<3xf32>> [0, RowMajor, MatrixStride=16]>)
+func @struct_type_with_matrix_2(!spv.struct<!spv.matrix<3 x vector<3xf32>> [0, RowMajor, MatrixStride=16]>)
+
 // CHECK: func @struct_empty(!spv.struct<>)
 func @struct_empty(!spv.struct<>)
 
 // -----
 
-// expected-error @+1 {{layout specification must be given for all members}}
+// expected-error @+1 {{offset specification must be given for all members}}
 func @struct_type_missing_offset1((!spv.struct<f32, i32 [4]>) -> ()
 
 // -----
 
-// expected-error @+1 {{layout specification must be given for all members}}
+// expected-error @+1 {{offset specification must be given for all members}}
 func @struct_type_missing_offset2(!spv.struct<f32 [3], i32>) -> ()
 
 // -----
@@ -330,6 +336,16 @@ func @struct_type_missing_comma(!spv.struct<f32 [0, NonWritable NonReadable], i3
 
 // -----
 
+// expected-error @+1 {{expected ']'}}
+func @struct_type_missing_comma(!spv.struct<!spv.matrix<3 x vector<3xf32>> [0, RowMajor MatrixStride=16]>)
+
+// -----
+
+// expected-error @+1 {{expected integer value}}
+func @struct_missing_member_decorator_value(!spv.struct<!spv.matrix<3 x vector<3xf32>> [0, RowMajor, MatrixStride=]>)
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // CooperativeMatrix
 //===----------------------------------------------------------------------===//
index 06c417c..340bfd9 100644 (file)
@@ -58,8 +58,8 @@ protected:
   Type getFloatStructType() {
     OpBuilder opBuilder(module.body());
     llvm::SmallVector<Type, 1> elementTypes{opBuilder.getF32Type()};
-    llvm::SmallVector<spirv::StructType::LayoutInfo, 1> layoutInfo{0};
-    auto structType = spirv::StructType::get(elementTypes, layoutInfo);
+    llvm::SmallVector<spirv::StructType::OffsetInfo, 1> offsetInfo{0};
+    auto structType = spirv::StructType::get(elementTypes, offsetInfo);
     return structType;
   }