#include "mlir/IR/TypeSupport.h"
#include "mlir/IR/Types.h"
-#include <bits/stdint-uintn.h>
#include <tuple>
// Forward declare enum classes related to op availability. Their definitions
public:
using Base::Base;
- // Type for specifying the offset of the struct members
- using OffsetInfo = uint32_t;
-
- // Type for specifying the decoration(s) on struct members
- struct MemberDecorationInfo {
- uint32_t memberIndex : 31;
- uint32_t hasValue : 1;
- Decoration decoration;
- uint32_t decorationValue;
-
- MemberDecorationInfo(uint32_t index, uint32_t hasValue,
- Decoration decoration, uint32_t decorationValue)
- : memberIndex(index), hasValue(hasValue), decoration(decoration),
- decorationValue(decorationValue) {}
-
- bool operator==(const MemberDecorationInfo &other) const {
- return (this->memberIndex == other.memberIndex) &&
- (this->decoration == other.decoration) &&
- (this->decorationValue == other.decorationValue);
- }
+ // Layout information used for members in a struct in SPIR-V
+ //
+ // TODO(ravishankarm) : For now this only supports the offset type, so uses
+ // uint64_t value to represent the offset, with
+ // std::numeric_limit<uint64_t>::max indicating no offset. Change this to
+ // something that can hold all the information needed for different member
+ // types
+ using LayoutInfo = uint64_t;
- bool operator<(const MemberDecorationInfo &other) const {
- return this->memberIndex < other.memberIndex ||
- (this->memberIndex == other.memberIndex &&
- static_cast<uint32_t>(this->decoration) <
- static_cast<uint32_t>(other.decoration));
- }
- };
+ using MemberDecorationInfo = std::pair<uint32_t, spirv::Decoration>;
static bool kindof(unsigned kind) { return kind == TypeKind::Struct; }
/// Construct a StructType with at least one member.
static StructType get(ArrayRef<Type> memberTypes,
- ArrayRef<OffsetInfo> offsetInfo = {},
+ ArrayRef<LayoutInfo> layoutInfo = {},
ArrayRef<MemberDecorationInfo> memberDecorations = {});
/// Construct a struct with no members.
ElementTypeRange getElementTypes() const;
- bool hasOffset() const;
+ bool hasLayout() const;
- uint64_t getMemberOffset(unsigned) const;
+ uint64_t getOffset(unsigned) const;
// Returns in `allMemberDecorations` the spirv::Decorations (apart from
// Offset) associated with all members of the StructType.
// Returns in `memberDecorations` all the spirv::Decorations (apart from
// Offset) associated with the `i`-th member of the StructType.
- void getMemberDecorations(unsigned i,
- SmallVectorImpl<StructType::MemberDecorationInfo>
- &memberDecorations) const;
+ void getMemberDecorations(
+ unsigned i, SmallVectorImpl<spirv::Decoration> &memberDecorations) const;
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
Optional<spirv::StorageClass> storage = llvm::None);
Optional<spirv::StorageClass> storage = llvm::None);
};
-llvm::hash_code
-hash_value(const StructType::MemberDecorationInfo &memberDecorationInfo);
-
// SPIR-V cooperative matrix type
class CooperativeMatrixNVType
: public Type::TypeBase<CooperativeMatrixNVType, CompositeType,
}
SmallVector<Type, 4> memberTypes;
- SmallVector<spirv::StructType::OffsetInfo, 4> offsetInfo;
+ SmallVector<Size, 4> layoutInfo;
SmallVector<spirv::StructType::MemberDecorationInfo, 4> memberDecorations;
Size structMemberOffset = 0;
decorateType(structType.getElementType(i), memberSize, memberAlignment);
structMemberOffset = llvm::alignTo(structMemberOffset, memberAlignment);
memberTypes.push_back(memberType);
- offsetInfo.push_back(
- static_cast<spirv::StructType::OffsetInfo>(structMemberOffset));
+ layoutInfo.push_back(structMemberOffset);
// If the member's size is the max value, it must be the last member and it
// must be a runtime array.
assert(memberSize != std::numeric_limits<Size>().max() ||
size = llvm::alignTo(structMemberOffset, maxMemberAlignment);
alignment = maxMemberAlignment;
structType.getMemberDecorations(memberDecorations);
- return spirv::StructType::get(memberTypes, offsetInfo, memberDecorations);
+ return spirv::StructType::get(memberTypes, layoutInfo, memberDecorations);
}
Type VulkanLayoutUtils::decorateType(Type type, VulkanLayoutUtils::Size &size,
case spirv::StorageClass::StorageBuffer:
case spirv::StorageClass::PushConstant:
case spirv::StorageClass::PhysicalStorageBuffer:
- return structType.hasOffset() || !structType.getNumElements();
+ return structType.hasLayout() || !structType.getNumElements();
default:
return true;
}
static ParseResult parseStructMemberDecorations(
SPIRVDialect const &dialect, DialectAsmParser &parser,
ArrayRef<Type> memberTypes,
- SmallVectorImpl<StructType::OffsetInfo> &offsetInfo,
+ SmallVectorImpl<StructType::LayoutInfo> &layoutInfo,
SmallVectorImpl<StructType::MemberDecorationInfo> &memberDecorationInfo) {
// Check if the first element is offset.
- llvm::SMLoc offsetLoc = parser.getCurrentLocation();
- StructType::OffsetInfo offset = 0;
- OptionalParseResult offsetParseResult = parser.parseOptionalInteger(offset);
- if (offsetParseResult.hasValue()) {
- if (failed(*offsetParseResult))
+ llvm::SMLoc layoutLoc = parser.getCurrentLocation();
+ StructType::LayoutInfo layout = 0;
+ OptionalParseResult layoutParseResult = parser.parseOptionalInteger(layout);
+ if (layoutParseResult.hasValue()) {
+ if (failed(*layoutParseResult))
return failure();
- if (offsetInfo.size() != memberTypes.size() - 1) {
- return parser.emitError(offsetLoc,
- "offset specification must be given for "
- "all members");
+ if (layoutInfo.size() != memberTypes.size() - 1) {
+ return parser.emitError(
+ layoutLoc, "layout specification must be given for all members");
}
- offsetInfo.push_back(offset);
+ layoutInfo.push_back(layout);
}
// Check for no spirv::Decorations.
if (succeeded(parser.parseOptionalRSquare()))
return success();
- // If there was an offset, make sure to parse the comma.
- if (offsetParseResult.hasValue() && parser.parseComma())
+ // If there was a layout, make sure to parse the comma.
+ if (layoutParseResult.hasValue() && parser.parseComma())
return failure();
// Check for spirv::Decorations.
if (!memberDecoration)
return failure();
- // Parse member decoration value if it exists.
- if (succeeded(parser.parseOptionalEqual())) {
- auto memberDecorationValue =
- parseAndVerifyInteger<uint32_t>(dialect, parser);
-
- if (!memberDecorationValue)
- return failure();
-
- memberDecorationInfo.emplace_back(
- static_cast<uint32_t>(memberTypes.size() - 1), 1,
- memberDecoration.getValue(), memberDecorationValue.getValue());
- } else {
- memberDecorationInfo.emplace_back(
- static_cast<uint32_t>(memberTypes.size() - 1), 0,
- memberDecoration.getValue(), 0);
- }
-
+ memberDecorationInfo.emplace_back(
+ static_cast<uint32_t>(memberTypes.size() - 1),
+ memberDecoration.getValue());
} while (succeeded(parser.parseOptionalComma()));
return parser.parseRSquare();
return StructType::getEmpty(dialect.getContext());
SmallVector<Type, 4> memberTypes;
- SmallVector<StructType::OffsetInfo, 4> offsetInfo;
+ SmallVector<StructType::LayoutInfo, 4> layoutInfo;
SmallVector<StructType::MemberDecorationInfo, 4> memberDecorationInfo;
do {
memberTypes.push_back(memberType);
if (succeeded(parser.parseOptionalLSquare())) {
- if (parseStructMemberDecorations(dialect, parser, memberTypes, offsetInfo,
+ if (parseStructMemberDecorations(dialect, parser, memberTypes, layoutInfo,
memberDecorationInfo)) {
return Type();
}
}
} while (succeeded(parser.parseOptionalComma()));
- if (!offsetInfo.empty() && memberTypes.size() != offsetInfo.size()) {
+ if (!layoutInfo.empty() && memberTypes.size() != layoutInfo.size()) {
parser.emitError(parser.getNameLoc(),
- "offset specification must be given for all members");
+ "layout specification must be given for all members");
return Type();
}
if (parser.parseGreater())
return Type();
- return StructType::get(memberTypes, offsetInfo, memberDecorationInfo);
+ return StructType::get(memberTypes, layoutInfo, memberDecorationInfo);
}
// spirv-type ::= array-type
os << "struct<";
auto printMember = [&](unsigned i) {
os << type.getElementType(i);
- SmallVector<spirv::StructType::MemberDecorationInfo, 0> decorations;
+ SmallVector<spirv::Decoration, 0> decorations;
type.getMemberDecorations(i, decorations);
- if (type.hasOffset() || !decorations.empty()) {
+ if (type.hasLayout() || !decorations.empty()) {
os << " [";
- if (type.hasOffset()) {
- os << type.getMemberOffset(i);
+ if (type.hasLayout()) {
+ os << type.getOffset(i);
if (!decorations.empty())
os << ", ";
}
- auto eachFn = [&os](spirv::StructType::MemberDecorationInfo decoration) {
- os << stringifyDecoration(decoration.decoration);
- if (decoration.hasValue) {
- os << "=" << decoration.decorationValue;
- }
+ auto eachFn = [&os](spirv::Decoration decoration) {
+ os << stringifyDecoration(decoration);
};
llvm::interleaveComma(decorations, os, eachFn);
os << "]";
struct spirv::detail::StructTypeStorage : public TypeStorage {
StructTypeStorage(
unsigned numMembers, Type const *memberTypes,
- StructType::OffsetInfo const *layoutInfo, unsigned numMemberDecorations,
+ StructType::LayoutInfo const *layoutInfo, unsigned numMemberDecorations,
StructType::MemberDecorationInfo const *memberDecorationsInfo)
: TypeStorage(numMembers), memberTypes(memberTypes),
- offsetInfo(layoutInfo), numMemberDecorations(numMemberDecorations),
+ layoutInfo(layoutInfo), numMemberDecorations(numMemberDecorations),
memberDecorationsInfo(memberDecorationsInfo) {}
- using KeyTy = std::tuple<ArrayRef<Type>, ArrayRef<StructType::OffsetInfo>,
+ using KeyTy = std::tuple<ArrayRef<Type>, ArrayRef<StructType::LayoutInfo>,
ArrayRef<StructType::MemberDecorationInfo>>;
bool operator==(const KeyTy &key) const {
return key ==
- KeyTy(getMemberTypes(), getOffsetInfo(), getMemberDecorationsInfo());
+ KeyTy(getMemberTypes(), getLayoutInfo(), getMemberDecorationsInfo());
}
static StructTypeStorage *construct(TypeStorageAllocator &allocator,
typesList = allocator.copyInto(keyTypes).data();
}
- const StructType::OffsetInfo *offsetInfoList = nullptr;
+ const StructType::LayoutInfo *layoutInfoList = nullptr;
if (!std::get<1>(key).empty()) {
- ArrayRef<StructType::OffsetInfo> keyOffsetInfo = std::get<1>(key);
- assert(keyOffsetInfo.size() == keyTypes.size() &&
- "size of offset information must be same as the size of number of "
+ ArrayRef<StructType::LayoutInfo> keyLayoutInfo = std::get<1>(key);
+ assert(keyLayoutInfo.size() == keyTypes.size() &&
+ "size of layout information must be same as the size of number of "
"elements");
- offsetInfoList = allocator.copyInto(keyOffsetInfo).data();
+ layoutInfoList = allocator.copyInto(keyLayoutInfo).data();
}
const StructType::MemberDecorationInfo *memberDecorationList = nullptr;
memberDecorationList = allocator.copyInto(keyMemberDecorations).data();
}
return new (allocator.allocate<StructTypeStorage>())
- StructTypeStorage(keyTypes.size(), typesList, offsetInfoList,
+ StructTypeStorage(keyTypes.size(), typesList, layoutInfoList,
numMemberDecorations, memberDecorationList);
}
return ArrayRef<Type>(memberTypes, getSubclassData());
}
- ArrayRef<StructType::OffsetInfo> getOffsetInfo() const {
- if (offsetInfo) {
- return ArrayRef<StructType::OffsetInfo>(offsetInfo, getSubclassData());
+ ArrayRef<StructType::LayoutInfo> getLayoutInfo() const {
+ if (layoutInfo) {
+ return ArrayRef<StructType::LayoutInfo>(layoutInfo, getSubclassData());
}
return {};
}
}
Type const *memberTypes;
- StructType::OffsetInfo const *offsetInfo;
+ StructType::LayoutInfo const *layoutInfo;
unsigned numMemberDecorations;
StructType::MemberDecorationInfo const *memberDecorationsInfo;
};
StructType
StructType::get(ArrayRef<Type> memberTypes,
- ArrayRef<StructType::OffsetInfo> offsetInfo,
+ ArrayRef<StructType::LayoutInfo> layoutInfo,
ArrayRef<StructType::MemberDecorationInfo> memberDecorations) {
assert(!memberTypes.empty() && "Struct needs at least one member type");
// Sort the decorations.
memberDecorations.begin(), memberDecorations.end());
llvm::array_pod_sort(sortedDecorations.begin(), sortedDecorations.end());
return Base::get(memberTypes.vec().front().getContext(), TypeKind::Struct,
- memberTypes, offsetInfo, sortedDecorations);
+ memberTypes, layoutInfo, sortedDecorations);
}
StructType StructType::getEmpty(MLIRContext *context) {
return Base::get(context, TypeKind::Struct, ArrayRef<Type>(),
- ArrayRef<StructType::OffsetInfo>(),
+ ArrayRef<StructType::LayoutInfo>(),
ArrayRef<StructType::MemberDecorationInfo>());
}
return ElementTypeRange(getImpl()->memberTypes, getNumElements());
}
-bool StructType::hasOffset() const { return getImpl()->offsetInfo; }
+bool StructType::hasLayout() const { return getImpl()->layoutInfo; }
-uint64_t StructType::getMemberOffset(unsigned index) const {
+uint64_t StructType::getOffset(unsigned index) const {
assert(getNumElements() > index && "member index out of range");
- return getImpl()->offsetInfo[index];
+ return getImpl()->layoutInfo[index];
}
void StructType::getMemberDecorations(
}
void StructType::getMemberDecorations(
- unsigned index,
- SmallVectorImpl<StructType::MemberDecorationInfo> &decorationsInfo) const {
+ unsigned index, SmallVectorImpl<spirv::Decoration> &decorations) const {
assert(getNumElements() > index && "member index out of range");
auto memberDecorations = getImpl()->getMemberDecorationsInfo();
- decorationsInfo.clear();
- for (const auto &memberDecoration : memberDecorations) {
- if (memberDecoration.memberIndex == index) {
- decorationsInfo.push_back(memberDecoration);
+ decorations.clear();
+ for (auto &memberDecoration : memberDecorations) {
+ if (memberDecoration.first == index) {
+ decorations.push_back(memberDecoration.second);
}
- if (memberDecoration.memberIndex > index) {
+ if (memberDecoration.first > index) {
// Early exit since the decorations are stored sorted.
return;
}
elementType.cast<SPIRVType>().getCapabilities(capabilities, storage);
}
-llvm::hash_code spirv::hash_value(
- const StructType::MemberDecorationInfo &memberDecorationInfo) {
- return llvm::hash_combine(memberDecorationInfo.memberIndex,
- memberDecorationInfo.decoration);
-}
-
//===----------------------------------------------------------------------===//
// MatrixType
//===----------------------------------------------------------------------===//
memberTypes.push_back(memberType);
}
- SmallVector<spirv::StructType::OffsetInfo, 0> offsetInfo;
+ SmallVector<spirv::StructType::LayoutInfo, 0> layoutInfo;
SmallVector<spirv::StructType::MemberDecorationInfo, 0> memberDecorationsInfo;
if (memberDecorationMap.count(operands[0])) {
auto &allMemberDecorations = memberDecorationMap[operands[0]];
for (auto &memberDecoration : allMemberDecorations[memberIndex]) {
// Check for offset.
if (memberDecoration.first == spirv::Decoration::Offset) {
- // If offset info is empty, resize to the number of members;
- if (offsetInfo.empty()) {
- offsetInfo.resize(memberTypes.size());
+ // If layoutInfo is empty, resize to the number of members;
+ if (layoutInfo.empty()) {
+ layoutInfo.resize(memberTypes.size());
}
- offsetInfo[memberIndex] = memberDecoration.second[0];
+ layoutInfo[memberIndex] = memberDecoration.second[0];
} else {
if (!memberDecoration.second.empty()) {
- memberDecorationsInfo.emplace_back(memberIndex, /*hasValue=*/1,
- memberDecoration.first,
- memberDecoration.second[0]);
- } else {
- memberDecorationsInfo.emplace_back(memberIndex, /*hasValue=*/0,
- memberDecoration.first, 0);
+ return emitError(unknownLoc,
+ "unhandled OpMemberDecoration with decoration ")
+ << stringifyDecoration(memberDecoration.first)
+ << " which has additional operands";
}
+ memberDecorationsInfo.emplace_back(memberIndex,
+ memberDecoration.first);
}
}
}
}
}
typeMap[operands[0]] =
- spirv::StructType::get(memberTypes, offsetInfo, memberDecorationsInfo);
+ spirv::StructType::get(memberTypes, layoutInfo, memberDecorationsInfo);
// TODO(ravishankarm): Update StructType to have member name as attribute as
// well.
return success();
}
/// Process member decoration
- LogicalResult processMemberDecoration(
- uint32_t structID,
- const spirv::StructType::MemberDecorationInfo &memberDecorationInfo);
+ LogicalResult processMemberDecoration(uint32_t structID, uint32_t memberIndex,
+ spirv::Decoration decorationType,
+ ArrayRef<uint32_t> values = {});
//===--------------------------------------------------------------------===//
// Types
return success();
}
-LogicalResult Serializer::processMemberDecoration(
- uint32_t structID,
- const spirv::StructType::MemberDecorationInfo &memberDecoration) {
+LogicalResult
+Serializer::processMemberDecoration(uint32_t structID, uint32_t memberIndex,
+ spirv::Decoration decorationType,
+ ArrayRef<uint32_t> values) {
SmallVector<uint32_t, 4> args(
- {structID, memberDecoration.memberIndex,
- static_cast<uint32_t>(memberDecoration.decoration)});
- if (memberDecoration.hasValue) {
- args.push_back(memberDecoration.decorationValue);
+ {structID, memberIndex, static_cast<uint32_t>(decorationType)});
+ if (!values.empty()) {
+ args.append(values.begin(), values.end());
}
return encodeInstructionInto(decorations, spirv::Opcode::OpMemberDecorate,
args);
}
if (auto structType = type.dyn_cast<spirv::StructType>()) {
- bool hasOffset = structType.hasOffset();
+ bool hasLayout = structType.hasLayout();
for (auto elementIndex :
llvm::seq<uint32_t>(0, structType.getNumElements())) {
uint32_t elementTypeID = 0;
return failure();
}
operands.push_back(elementTypeID);
- if (hasOffset) {
+ if (hasLayout) {
// Decorate each struct member with an offset
- spirv::StructType::MemberDecorationInfo offsetDecoration{
- elementIndex, /*hasValue=*/1, spirv::Decoration::Offset,
- static_cast<uint32_t>(structType.getMemberOffset(elementIndex))};
- if (failed(processMemberDecoration(resultID, offsetDecoration))) {
+ if (failed(processMemberDecoration(
+ resultID, elementIndex, spirv::Decoration::Offset,
+ static_cast<uint32_t>(structType.getOffset(elementIndex))))) {
return emitError(loc, "cannot decorate ")
<< elementIndex << "-th member of " << structType
<< " with its offset";
SmallVector<spirv::StructType::MemberDecorationInfo, 4> memberDecorations;
structType.getMemberDecorations(memberDecorations);
for (auto &memberDecoration : memberDecorations) {
- if (failed(processMemberDecoration(resultID, memberDecoration))) {
+ if (failed(processMemberDecoration(resultID, memberDecoration.first,
+ memberDecoration.second))) {
return emitError(loc, "cannot decorate ")
- << static_cast<uint32_t>(memberDecoration.memberIndex)
- << "-th member of " << structType << " with "
- << stringifyDecoration(memberDecoration.decoration);
+ << memberDecoration.first << "-th member of " << structType
+ << " with " << stringifyDecoration(memberDecoration.second);
}
}
typeEnum = spirv::Opcode::OpTypeStruct;
// CHECK: !spv.ptr<!spv.struct<f32 [0, NonWritable], i32 [4, NonWritable, NonReadable]>, StorageBuffer>
spv.globalVariable @var6 : !spv.ptr<!spv.struct<f32 [0, NonWritable], i32 [4, NonWritable, NonReadable]>, StorageBuffer>
- // CHECK: !spv.ptr<!spv.struct<!spv.matrix<3 x vector<3xf32>> [0, ColMajor, MatrixStride=16]>, StorageBuffer>
- spv.globalVariable @var7 : !spv.ptr<!spv.struct<!spv.matrix<3 x vector<3xf32>> [0, ColMajor, MatrixStride=16]>, StorageBuffer>
-
// CHECK: !spv.ptr<!spv.struct<>, StorageBuffer>
spv.globalVariable @empty : !spv.ptr<!spv.struct<>, StorageBuffer>
// CHECK: func @struct_type_with_decoration8(!spv.struct<f32, !spv.struct<i32 [0], f32 [4, NonReadable]>>)
func @struct_type_with_decoration8(!spv.struct<f32, !spv.struct<i32 [0], f32 [4, NonReadable]>>)
-// CHECK: func @struct_type_with_matrix_1(!spv.struct<!spv.matrix<3 x vector<3xf32>> [0, ColMajor, MatrixStride=16]>)
-func @struct_type_with_matrix_1(!spv.struct<!spv.matrix<3 x vector<3xf32>> [0, ColMajor, MatrixStride=16]>)
-
-// CHECK: func @struct_type_with_matrix_2(!spv.struct<!spv.matrix<3 x vector<3xf32>> [0, RowMajor, MatrixStride=16]>)
-func @struct_type_with_matrix_2(!spv.struct<!spv.matrix<3 x vector<3xf32>> [0, RowMajor, MatrixStride=16]>)
-
// CHECK: func @struct_empty(!spv.struct<>)
func @struct_empty(!spv.struct<>)
// -----
-// expected-error @+1 {{offset specification must be given for all members}}
+// expected-error @+1 {{layout specification must be given for all members}}
func @struct_type_missing_offset1((!spv.struct<f32, i32 [4]>) -> ()
// -----
-// expected-error @+1 {{offset specification must be given for all members}}
+// expected-error @+1 {{layout specification must be given for all members}}
func @struct_type_missing_offset2(!spv.struct<f32 [3], i32>) -> ()
// -----
// -----
-// expected-error @+1 {{expected ']'}}
-func @struct_type_missing_comma(!spv.struct<!spv.matrix<3 x vector<3xf32>> [0, RowMajor MatrixStride=16]>)
-
-// -----
-
-// expected-error @+1 {{expected integer value}}
-func @struct_missing_member_decorator_value(!spv.struct<!spv.matrix<3 x vector<3xf32>> [0, RowMajor, MatrixStride=]>)
-
-// -----
-
//===----------------------------------------------------------------------===//
// CooperativeMatrix
//===----------------------------------------------------------------------===//
Type getFloatStructType() {
OpBuilder opBuilder(module.body());
llvm::SmallVector<Type, 1> elementTypes{opBuilder.getF32Type()};
- llvm::SmallVector<spirv::StructType::OffsetInfo, 1> offsetInfo{0};
- auto structType = spirv::StructType::get(elementTypes, offsetInfo);
+ llvm::SmallVector<spirv::StructType::LayoutInfo, 1> layoutInfo{0};
+ auto structType = spirv::StructType::get(elementTypes, layoutInfo);
return structType;
}