From: Mahesh Ravishankar Date: Tue, 2 Jul 2019 19:30:34 +0000 (-0700) Subject: Add support for SPIR-V Struct Types. Current support is limited to X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=c73edeec139fdf3695e0156cac5261e2e30dfbc2;p=platform%2Fupstream%2Fllvm.git Add support for SPIR-V Struct Types. Current support is limited to supporting only Offset decorations PiperOrigin-RevId: 256216704 --- diff --git a/mlir/g3doc/Dialects/SPIR-V.md b/mlir/g3doc/Dialects/SPIR-V.md index 6101d3b..e149540 100644 --- a/mlir/g3doc/Dialects/SPIR-V.md +++ b/mlir/g3doc/Dialects/SPIR-V.md @@ -152,6 +152,24 @@ For example, !spv.rtarray> ``` +### Struct type + +This corresponds to SPIR-V [struct type][StructType]. Its syntax is + +``` {.ebnf} +struct-type ::= `!spv.struct<` spirv-type (` [` integer-literal `]` )? + (`, ` spirv-type ( ` [` integer-literal `] ` )? )* `>` +``` + +For Example, + +``` {.mlir} +!spv.struct +!spv.struct +!spv.struct> +!spv.struct +``` + ## Serialization The serialization library provides two entry points, `mlir::spirv::serialize()` @@ -168,7 +186,8 @@ for now). For the latter, please use the assembler/disassembler in the [SPIR-V]: https://www.khronos.org/registry/spir-v/ [ArrayType]: https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#OpTypeArray +[ImageType]: https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#OpTypeImage [PointerType]: https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#OpTypePointer [RuntimeArrayType]: https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#OpTypeRuntimeArray -[ImageType]: https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#OpTypeImage +[StructType]: https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#Structure [SPIRV-Tools]: https://github.com/KhronosGroup/SPIRV-Tools diff --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h index adaa7d9..48c7cb3 100644 --- a/mlir/include/mlir/IR/Types.h +++ b/mlir/include/mlir/IR/Types.h @@ -145,10 +145,10 @@ public: unsigned getKind() const; /// Return the LLVMContext in which this type was uniqued. - MLIRContext *getContext(); + MLIRContext *getContext() const; /// Get the dialect this type is registered to. - Dialect &getDialect(); + Dialect &getDialect() const; // Convenience predicates. This is only for floating point types, // derived types should use isa/dyn_cast. diff --git a/mlir/include/mlir/SPIRV/SPIRVDialect.h b/mlir/include/mlir/SPIRV/SPIRVDialect.h index 4272a72..66345ee 100644 --- a/mlir/include/mlir/SPIRV/SPIRVDialect.h +++ b/mlir/include/mlir/SPIRV/SPIRVDialect.h @@ -38,22 +38,6 @@ public: /// Prints a type registered to this dialect. void printType(Type type, llvm::raw_ostream &os) const override; - -private: - /// Parses `spec` as a type and verifies it can be used in SPIR-V types. - Type parseAndVerifyType(StringRef spec, Location loc) const; - - /// Parses `spec` as a SPIR-V array type. - Type parseArrayType(StringRef spec, Location loc) const; - - /// Parses `spec` as a SPIR-V pointer type. - Type parsePointerType(StringRef spec, Location loc) const; - - /// Parses `spec` as a SPIR-V run-time array type. - Type parseRuntimeArrayType(StringRef spec, Location loc) const; - - /// Parses `spec` as a SPIR-V image type - Type parseImageType(StringRef spec, Location loc) const; }; } // end namespace spirv diff --git a/mlir/include/mlir/SPIRV/SPIRVOps.td b/mlir/include/mlir/SPIRV/SPIRVOps.td index 3ce4f64..7951bf8 100644 --- a/mlir/include/mlir/SPIRV/SPIRVOps.td +++ b/mlir/include/mlir/SPIRV/SPIRVOps.td @@ -83,7 +83,7 @@ def SPV_LoadOp : SPV_Op<"Load"> { ### Custom assembly form ``` {.ebnf} - memory-access ::= `"None"` | `"Volatile"` | `"Aligned"` integer-literal + memory-access ::= `"None"` | `"Volatile"` | `"Aligned", ` integer-literal | `"NonTemporal"` load-op ::= ssa-id ` = spv.Load ` storage-class ssa-use @@ -118,6 +118,8 @@ def SPV_LoadOp : SPV_Op<"Load"> { return "alignment"; } }]; + + let opcode = 61; } def SPV_ReturnOp : SPV_Op<"Return", [Terminator]> { @@ -157,7 +159,7 @@ def SPV_StoreOp : SPV_Op<"Store"> { ``` {.ebnf} store-op ::= `spv.Store ` storage-class ssa-use `, ` ssa-use `, ` - (memory-access)? : spirv-element-type + (`[` memory-access `]`)? `:` spirv-element-type ``` For example: @@ -185,6 +187,8 @@ def SPV_StoreOp : SPV_Op<"Store"> { return "alignment"; } }]; + + let opcode = 62; } def SPV_VariableOp : SPV_Op<"Variable"> { diff --git a/mlir/include/mlir/SPIRV/SPIRVTypes.h b/mlir/include/mlir/SPIRV/SPIRVTypes.h index 6ade59b..fbb0ce0 100644 --- a/mlir/include/mlir/SPIRV/SPIRVTypes.h +++ b/mlir/include/mlir/SPIRV/SPIRVTypes.h @@ -37,14 +37,16 @@ struct ArrayTypeStorage; struct ImageTypeStorage; struct PointerTypeStorage; struct RuntimeArrayTypeStorage; +struct StructTypeStorage; } // namespace detail namespace TypeKind { enum Kind { Array = Type::FIRST_SPIRV_TYPE, - ImageType, + Image, Pointer, RuntimeArray, + Struct, }; } @@ -58,9 +60,9 @@ public: static ArrayType get(Type elementType, int64_t elementCount); - Type getElementType(); + Type getElementType() const; - int64_t getElementCount(); + int64_t getElementCount() const; }; // SPIR-V pointer type @@ -73,9 +75,10 @@ public: static PointerType get(Type pointeeType, StorageClass storageClass); - Type getPointeeType(); + Type getPointeeType() const; - StorageClass getStorageClass(); + StorageClass getStorageClass() const; + StringRef getStorageClassStr() const; }; // SPIR-V run-time array type @@ -89,16 +92,17 @@ public: static RuntimeArrayType get(Type elementType); - Type getElementType(); + Type getElementType() const; }; // SPIR-V image type +// TODO(ravishankarm) : Move this in alphabetical order class ImageType : public Type::TypeBase { public: using Base::Base; - static bool kindof(unsigned kind) { return kind == TypeKind::ImageType; } + static bool kindof(unsigned kind) { return kind == TypeKind::Image; } static ImageType get(Type elementType, Dim dim, @@ -118,16 +122,45 @@ public: get(std::tuple); - Type getElementType(); - Dim getDim(); - ImageDepthInfo getDepthInfo(); - ImageArrayedInfo getArrayedInfo(); - ImageSamplingInfo getSamplingInfo(); - ImageSamplerUseInfo getSamplerUseInfo(); - ImageFormat getImageFormat(); + Type getElementType() const; + Dim getDim() const; + ImageDepthInfo getDepthInfo() const; + ImageArrayedInfo getArrayedInfo() const; + ImageSamplingInfo getSamplingInfo() const; + ImageSamplerUseInfo getSamplerUseInfo() const; + ImageFormat getImageFormat() const; // TODO(ravishankarm): Add support for Access qualifier }; +// SPIR-V struct type +class StructType + : public Type::TypeBase { + +public: + using Base::Base; + + // Layout information used for members in a struct in SPIR-V + // + // TODO(ravishankarm) : For now this only supports the offset type, so uses + // uint64_t value to represent the offset, with + // std::numeric_limit::max indicating no offset. Change this to + // something that can hold all the information needed for different member + // types + using LayoutInfo = uint64_t; + + static bool kindof(unsigned kind) { return kind == TypeKind::Struct; } + + static StructType get(ArrayRef memberTypes); + + static StructType get(ArrayRef memberTypes, + ArrayRef layoutInfo); + + size_t getNumMembers() const; + Type getMemberType(size_t) const; + bool hasLayout() const; + uint64_t getOffset(size_t) const; +}; + } // end namespace spirv } // end namespace mlir diff --git a/mlir/lib/IR/Types.cpp b/mlir/lib/IR/Types.cpp index 78bfc47..cd75176a 100644 --- a/mlir/lib/IR/Types.cpp +++ b/mlir/lib/IR/Types.cpp @@ -27,9 +27,9 @@ using namespace mlir::detail; unsigned Type::getKind() const { return impl->getKind(); } /// Get the dialect this type is registered to. -Dialect &Type::getDialect() { return impl->getDialect(); } +Dialect &Type::getDialect() const { return impl->getDialect(); } -MLIRContext *Type::getContext() { return getDialect().getContext(); } +MLIRContext *Type::getContext() const { return getDialect().getContext(); } unsigned Type::getSubclassData() const { return impl->getSubclassData(); } void Type::setSubclassData(unsigned val) { impl->setSubclassData(val); } diff --git a/mlir/lib/SPIRV/SPIRVDialect.cpp b/mlir/lib/SPIRV/SPIRVDialect.cpp index 816d673..67dd549 100644 --- a/mlir/lib/SPIRV/SPIRVDialect.cpp +++ b/mlir/lib/SPIRV/SPIRVDialect.cpp @@ -1,26 +1,16 @@ //===- LLVMDialect.cpp - MLIR SPIR-V dialect ------------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file defines the SPIR-V dialect in MLIR. // //===----------------------------------------------------------------------===// #include "mlir/SPIRV/SPIRVDialect.h" - #include "mlir/IR/MLIRContext.h" #include "mlir/IR/StandardTypes.h" #include "mlir/Parser.h" @@ -32,8 +22,6 @@ #include "llvm/ADT/StringSwitch.h" #include "llvm/Support/raw_ostream.h" -#include - using namespace mlir; using namespace mlir::spirv; @@ -43,7 +31,7 @@ using namespace mlir::spirv; SPIRVDialect::SPIRVDialect(MLIRContext *context) : Dialect(getDialectNamespace(), context) { - addTypes(); + addTypes(); addOperations< #define GET_OP_LIST @@ -77,8 +65,9 @@ static bool parseNumberX(StringRef &spec, int64_t &number) { return true; } -static Type parseAndVerifyTypeImpl(SPIRVDialect const &dialect, Location loc, - StringRef spec) { +static Type parseAndVerifyType(SPIRVDialect const &dialect, StringRef spec, + Location loc) { + spec = spec.trim(); auto *context = dialect.getContext(); auto type = mlir::parseType(spec.trim(), context); if (!type) { @@ -116,17 +105,14 @@ static Type parseAndVerifyTypeImpl(SPIRVDialect const &dialect, Location loc, return type; } -Type SPIRVDialect::parseAndVerifyType(StringRef spec, Location loc) const { - return parseAndVerifyTypeImpl(*this, loc, spec); -} - // element-type ::= integer-type // | floating-point-type // | vector-type // | spirv-type // // array-type ::= `!spv.array<` integer-literal `x` element-type `>` -Type SPIRVDialect::parseArrayType(StringRef spec, Location loc) const { +static Type parseArrayType(SPIRVDialect const &dialect, StringRef spec, + Location loc) { if (!spec.consume_front("array<") || !spec.consume_back(">")) { emitError(loc, "spv.array delimiter <...> mismatch"); return Type(); @@ -145,20 +131,24 @@ Type SPIRVDialect::parseArrayType(StringRef spec, Location loc) const { return Type(); } - Type elementType = parseAndVerifyType(spec, loc); + Type elementType = parseAndVerifyType(dialect, spec, loc); if (!elementType) return Type(); return ArrayType::get(elementType, count); } +// TODO(ravishankarm) : Reorder methods to be utilities first and parse*Type +// methods in alphabetical order +// // storage-class ::= `UniformConstant` // | `Uniform` // | `Workgroup` // | // // pointer-type ::= `!spv.ptr<` element-type `,` storage-class `>` -Type SPIRVDialect::parsePointerType(StringRef spec, Location loc) const { +static Type parsePointerType(SPIRVDialect const &dialect, StringRef spec, + Location loc) { if (!spec.consume_front("ptr<") || !spec.consume_back(">")) { emitError(loc, "spv.ptr delimiter <...> mismatch"); return Type(); @@ -186,7 +176,7 @@ Type SPIRVDialect::parsePointerType(StringRef spec, Location loc) const { return Type(); } - auto pointeeType = parseAndVerifyType(ptSpec, loc); + auto pointeeType = parseAndVerifyType(dialect, ptSpec, loc); if (!pointeeType) return Type(); @@ -194,7 +184,8 @@ Type SPIRVDialect::parsePointerType(StringRef spec, Location loc) const { } // runtime-array-type ::= `!spv.rtarray<` element-type `>` -Type SPIRVDialect::parseRuntimeArrayType(StringRef spec, Location loc) const { +static Type parseRuntimeArrayType(SPIRVDialect const &dialect, StringRef spec, + Location loc) { if (!spec.consume_front("rtarray<") || !spec.consume_back(">")) { emitError(loc, "spv.rtarray delimiter <...> mismatch"); return Type(); @@ -205,7 +196,7 @@ Type SPIRVDialect::parseRuntimeArrayType(StringRef spec, Location loc) const { return Type(); } - Type elementType = parseAndVerifyType(spec, loc); + Type elementType = parseAndVerifyType(dialect, spec, loc); if (!elementType) return Type(); @@ -215,8 +206,8 @@ Type SPIRVDialect::parseRuntimeArrayType(StringRef spec, Location loc) const { // Specialize this function to parse each of the parameters that define an // ImageType template -Optional parseAndVerify(SPIRVDialect const &dialect, Location loc, - StringRef spec) { +static Optional parseAndVerify(SPIRVDialect const &dialect, Location loc, + StringRef spec) { emitError(loc, "unexpected parameter while parsing '") << spec << "'"; return llvm::None; } @@ -225,15 +216,20 @@ template <> Optional parseAndVerify(SPIRVDialect const &dialect, Location loc, StringRef spec) { // TODO(ravishankarm): Further verify that the element type can be sampled - return parseAndVerifyTypeImpl(dialect, loc, spec); + auto ty = parseAndVerifyType(dialect, spec, loc); + if (!ty) { + return llvm::None; + } + return ty; } template <> Optional parseAndVerify(SPIRVDialect const &dialect, Location loc, StringRef spec) { auto dim = symbolizeDim(spec); - if (!dim) + if (!dim) { emitError(loc, "unknown Dim in Image type: '") << spec << "'"; + } return dim; } @@ -242,8 +238,9 @@ Optional parseAndVerify(SPIRVDialect const &dialect, Location loc, StringRef spec) { auto depth = symbolizeImageDepthInfo(spec); - if (!depth) + if (!depth) { emitError(loc, "unknown ImageDepthInfo in Image type: '") << spec << "'"; + } return depth; } @@ -252,8 +249,9 @@ Optional parseAndVerify(SPIRVDialect const &dialect, Location loc, StringRef spec) { auto arrayedInfo = symbolizeImageArrayedInfo(spec); - if (!arrayedInfo) + if (!arrayedInfo) { emitError(loc, "unknown ImageArrayedInfo in Image type: '") << spec << "'"; + } return arrayedInfo; } @@ -262,8 +260,9 @@ Optional parseAndVerify(SPIRVDialect const &dialect, Location loc, StringRef spec) { auto samplingInfo = symbolizeImageSamplingInfo(spec); - if (!samplingInfo) + if (!samplingInfo) { emitError(loc, "unknown ImageSamplingInfo in Image type: '") << spec << "'"; + } return samplingInfo; } @@ -272,9 +271,10 @@ Optional parseAndVerify(SPIRVDialect const &dialect, Location loc, StringRef spec) { auto samplerUseInfo = symbolizeImageSamplerUseInfo(spec); - if (!samplerUseInfo) + if (!samplerUseInfo) { emitError(loc, "unknown ImageSamplerUseInfo in Image type: '") << spec << "'"; + } return samplerUseInfo; } @@ -283,11 +283,41 @@ Optional parseAndVerify(SPIRVDialect const &dialect, Location loc, StringRef spec) { auto format = symbolizeImageFormat(spec); - if (!format) + if (!format) { emitError(loc, "unknown ImageFormat in Image type: '") << spec << "'"; + } return format; } +template <> +Optional +parseAndVerify(SPIRVDialect const &dialect, Location loc, StringRef spec) { + uint64_t offsetVal = std::numeric_limits::max(); + if (!spec.consume_front("[")) { + emitError(loc, "expected '[' while parsing layout specification in '") + << spec << "'"; + return llvm::None; + } + if (spec.consumeInteger(10, offsetVal)) { + emitError( + loc, + "expected unsigned integer to specify offset of member in struct: '") + << spec << "'"; + return llvm::None; + } + spec = spec.trim(); + if (!spec.consume_front("]")) { + emitError(loc, "missing ']' in decorations spec: '") << spec << "'"; + return llvm::None; + } + if (spec != "") { + emitError(loc, "unexpected extra tokens in layout information: '") + << spec << "'"; + return llvm::None; + } + return spirv::StructType::LayoutInfo{offsetVal}; +} + // Functor object to parse a comma separated list of specs. The function // parseAndVerify does the actual parsing and verification of individual // elements. This is a functor since parsing the last element of the list @@ -350,7 +380,8 @@ template struct parseCommaSeparatedList { // image-type ::= `!spv.image<` element-type `,` dim `,` depth-info `,` // arrayed-info `,` sampling-info `,` // sampler-use-info `,` format `>` -Type SPIRVDialect::parseImageType(StringRef spec, Location loc) const { +static Type parseImageType(SPIRVDialect const &dialect, StringRef spec, + Location loc) { if (!spec.consume_front("image<") || !spec.consume_back(">")) { emitError(loc, "spv.image delimiter <...> mismatch"); return Type(); @@ -359,7 +390,7 @@ Type SPIRVDialect::parseImageType(StringRef spec, Location loc) const { auto value = parseCommaSeparatedList{}(*this, loc, spec); + ImageFormat>{}(dialect, loc, spec); if (!value) { return Type(); } @@ -367,15 +398,151 @@ Type SPIRVDialect::parseImageType(StringRef spec, Location loc) const { return ImageType::get(value.getValue()); } +// Method to parse one member of a struct (including Layout information) +static ParseResult +parseStructElement(SPIRVDialect const &dialect, StringRef spec, Location loc, + SmallVectorImpl &memberTypes, + SmallVectorImpl &layoutInfo) { + // Check for a '[' ']' + auto lastLSquare = spec.rfind('['); + auto typeSpec = spec.substr(0, lastLSquare); + auto layoutSpec = (lastLSquare == StringRef::npos ? StringRef("") + : spec.substr(lastLSquare)); + auto type = parseAndVerify(dialect, loc, typeSpec); + if (!type) { + return failure(); + } + memberTypes.push_back(type.getValue()); + if (layoutSpec.empty()) { + return success(); + } + if (layoutInfo.size() != memberTypes.size() - 1) { + emitError(loc, "layout specification must be given for all members"); + return failure(); + } + auto layout = + parseAndVerify(dialect, loc, layoutSpec); + if (!layout) { + return failure(); + } + layoutInfo.push_back(layout.getValue()); + return success(); +} + +// Helper method to record the position of the corresponding '>' for every '<' +// encountered when parsing the string left to right. The relative position of +// '>' w.r.t to the '<' is recorded. +static bool +computeMatchingRAngles(Location loc, StringRef const &spec, + SmallVectorImpl &matchingRAngleOffset) { + SmallVector openBrackets; + for (size_t i = 0, e = spec.size(); i != e; ++i) { + if (spec[i] == '<') { + openBrackets.push_back(i); + } else if (spec[i] == '>') { + if (openBrackets.empty()) { + emitError(loc, "unbalanced '<' in '") << spec << "'"; + return false; + } + matchingRAngleOffset.push_back(i - openBrackets.pop_back_val()); + } + } + return true; +} + +static ParseResult +parseStructHelper(SPIRVDialect const &dialect, StringRef spec, Location loc, + ArrayRef matchingRAngleOffset, + SmallVectorImpl &memberTypes, + SmallVectorImpl &layoutInfo) { + // Check if the occurrence of ',' or '<' is before. If former, split using + // ','. If latter, split using matching '>' to get the entire type + // description + auto firstComma = spec.find(','); + auto firstLAngle = spec.find('<'); + if (firstLAngle == StringRef::npos && firstComma == StringRef::npos) { + return parseStructElement(dialect, spec, loc, memberTypes, layoutInfo); + } + if (firstLAngle == StringRef::npos || firstComma < firstLAngle) { + // Parse the type before the ',' + if (parseStructElement(dialect, spec.substr(0, firstComma), loc, + memberTypes, layoutInfo)) { + return failure(); + } + return parseStructHelper(dialect, spec.substr(firstComma + 1).ltrim(), loc, + matchingRAngleOffset, memberTypes, layoutInfo); + } + auto matchingRAngle = matchingRAngleOffset.front() + firstLAngle; + // Find the next ',' or '>' + auto endLoc = std::min(spec.find(',', matchingRAngle + 1), spec.size()); + if (parseStructElement(dialect, spec.substr(0, endLoc), loc, memberTypes, + layoutInfo)) { + return failure(); + } + auto rest = spec.substr(endLoc + 1).ltrim(); + if (rest.empty()) { + return success(); + } + if (rest.front() == ',') { + return parseStructHelper( + dialect, rest.drop_front().trim(), loc, + ArrayRef(std::next(matchingRAngleOffset.begin()), + matchingRAngleOffset.end()), + memberTypes, layoutInfo); + } + emitError(loc, "unexpected string : '") << rest << "'"; + return failure(); +} + +// struct-type ::= `!spv.struct<` spirv-type (` [` integer-literal `]`)? +// (`, ` spirv-type ( ` [` integer-literal `] ` )? )* +static Type parseStructType(SPIRVDialect const &dialect, StringRef spec, + Location loc) { + if (!spec.consume_front("struct<") || !spec.consume_back(">")) { + emitError(loc, "spv.struct delimiter <...> mismatch"); + return Type(); + } + + if (spec.trim().empty()) { + emitError(loc, "expected SPIR-V type"); + return Type(); + } + + SmallVector memberTypes; + SmallVector layoutInfo; + SmallVector matchingRAngleOffset; + if (!computeMatchingRAngles(loc, spec, matchingRAngleOffset) || + parseStructHelper(dialect, spec, loc, matchingRAngleOffset, memberTypes, + layoutInfo)) { + return Type(); + } + if (layoutInfo.empty()) { + return StructType::get(memberTypes); + } + if (memberTypes.size() != layoutInfo.size()) { + emitError(loc, "layout specification must be given for all members"); + return Type(); + } + return StructType::get(memberTypes, layoutInfo); +} + +// spirv-type ::= array-type +// | element-type +// | image-type +// | pointer-type +// | runtime-array-type +// | struct-type Type SPIRVDialect::parseType(StringRef spec, Location loc) const { if (spec.startswith("array")) - return parseArrayType(spec, loc); + return parseArrayType(*this, spec, loc); if (spec.startswith("image")) - return parseImageType(spec, loc); + return parseImageType(*this, spec, loc); if (spec.startswith("ptr")) - return parsePointerType(spec, loc); + return parsePointerType(*this, spec, loc); if (spec.startswith("rtarray")) - return parseRuntimeArrayType(spec, loc); + return parseRuntimeArrayType(*this, spec, loc); + if (spec.startswith("struct")) + return parseStructType(*this, spec, loc); emitError(loc, "unknown SPIR-V type: ") << spec; return Type(); @@ -408,6 +575,19 @@ static void print(ImageType type, llvm::raw_ostream &os) { << stringifyImageFormat(type.getImageFormat()) << ">"; } +static void print(StructType type, llvm::raw_ostream &os) { + os << "struct<"; + std::string sep = ""; + for (size_t i = 0, e = type.getNumMembers(); i != e; ++i) { + os << sep << type.getMemberType(i); + if (type.hasLayout()) { + os << " [" << type.getOffset(i) << "]"; + } + sep = ", "; + } + os << ">"; +} + void SPIRVDialect::printType(Type type, llvm::raw_ostream &os) const { switch (type.getKind()) { case TypeKind::Array: @@ -419,9 +599,12 @@ void SPIRVDialect::printType(Type type, llvm::raw_ostream &os) const { case TypeKind::RuntimeArray: print(type.cast(), os); return; - case TypeKind::ImageType: + case TypeKind::Image: print(type.cast(), os); return; + case TypeKind::Struct: + print(type.cast(), os); + return; default: llvm_unreachable("unhandled SPIR-V type"); } diff --git a/mlir/lib/SPIRV/SPIRVTypes.cpp b/mlir/lib/SPIRV/SPIRVTypes.cpp index 23acd65..f62f3e1 100644 --- a/mlir/lib/SPIRV/SPIRVTypes.cpp +++ b/mlir/lib/SPIRV/SPIRVTypes.cpp @@ -56,9 +56,9 @@ ArrayType ArrayType::get(Type elementType, int64_t elementCount) { elementCount); } -Type ArrayType::getElementType() { return getImpl()->elementType; } +Type ArrayType::getElementType() const { return getImpl()->elementType; } -int64_t ArrayType::getElementCount() { return getImpl()->elementCount; } +int64_t ArrayType::getElementCount() const { return getImpl()->elementCount; } //===----------------------------------------------------------------------===// // ImageType @@ -216,28 +216,32 @@ ImageType ImageType::get(std::tuple value) { - return Base::get(std::get<0>(value).getContext(), TypeKind::ImageType, value); + return Base::get(std::get<0>(value).getContext(), TypeKind::Image, value); } -Type ImageType::getElementType() { return getImpl()->elementType; } +Type ImageType::getElementType() const { return getImpl()->elementType; } -Dim ImageType::getDim() { return getImpl()->getDim(); } +Dim ImageType::getDim() const { return getImpl()->getDim(); } -ImageDepthInfo ImageType::getDepthInfo() { return getImpl()->getDepthInfo(); } +ImageDepthInfo ImageType::getDepthInfo() const { + return getImpl()->getDepthInfo(); +} -ImageArrayedInfo ImageType::getArrayedInfo() { +ImageArrayedInfo ImageType::getArrayedInfo() const { return getImpl()->getArrayedInfo(); } -ImageSamplingInfo ImageType::getSamplingInfo() { +ImageSamplingInfo ImageType::getSamplingInfo() const { return getImpl()->getSamplingInfo(); } -ImageSamplerUseInfo ImageType::getSamplerUseInfo() { +ImageSamplerUseInfo ImageType::getSamplerUseInfo() const { return getImpl()->getSamplerUseInfo(); } -ImageFormat ImageType::getImageFormat() { return getImpl()->getImageFormat(); } +ImageFormat ImageType::getImageFormat() const { + return getImpl()->getImageFormat(); +} //===----------------------------------------------------------------------===// // PointerType @@ -274,12 +278,16 @@ PointerType PointerType::get(Type pointeeType, StorageClass storageClass) { storageClass); } -Type PointerType::getPointeeType() { return getImpl()->pointeeType; } +Type PointerType::getPointeeType() const { return getImpl()->pointeeType; } -StorageClass PointerType::getStorageClass() { +StorageClass PointerType::getStorageClass() const { return getImpl()->getStorageClass(); } +StringRef PointerType::getStorageClassStr() const { + return stringifyStorageClass(getStorageClass()); +} + //===----------------------------------------------------------------------===// // RuntimeArrayType //===----------------------------------------------------------------------===// @@ -305,4 +313,88 @@ RuntimeArrayType RuntimeArrayType::get(Type elementType) { elementType); } -Type RuntimeArrayType::getElementType() { return getImpl()->elementType; } +Type RuntimeArrayType::getElementType() const { return getImpl()->elementType; } + +//===----------------------------------------------------------------------===// +// StructType +//===----------------------------------------------------------------------===// + +struct spirv::detail::StructTypeStorage : public TypeStorage { + StructTypeStorage(unsigned numMembers, Type const *memberTypes, + StructType::LayoutInfo const *layoutInfo) + : TypeStorage(numMembers), memberTypes(memberTypes), + layoutInfo(layoutInfo) {} + + using KeyTy = std::pair, ArrayRef>; + bool operator==(const KeyTy &key) const { + return key == KeyTy(getMemberTypes(), getLayoutInfo()); + } + + static StructTypeStorage *construct(TypeStorageAllocator &allocator, + const KeyTy &key) { + ArrayRef keyTypes = key.first; + + // Copy the member type and layout information into the bump pointer + auto typesList = allocator.copyInto(keyTypes).data(); + + const StructType::LayoutInfo *layoutInfoList = nullptr; + if (!key.second.empty()) { + ArrayRef keyLayoutInfo = key.second; + assert(keyLayoutInfo.size() == keyTypes.size() && + "size of layout information must be same as the size of number of " + "elements"); + layoutInfoList = allocator.copyInto(keyLayoutInfo).data(); + } + + return new (allocator.allocate()) + StructTypeStorage(keyTypes.size(), typesList, layoutInfoList); + } + + ArrayRef getMemberTypes() const { + return ArrayRef(memberTypes, getSubclassData()); + } + + ArrayRef getLayoutInfo() const { + if (layoutInfo) { + return ArrayRef(layoutInfo, getSubclassData()); + } + return ArrayRef(nullptr, size_t(0)); + } + + Type const *memberTypes; + StructType::LayoutInfo const *layoutInfo; +}; + +StructType StructType::get(ArrayRef memberTypes) { + assert(!memberTypes.empty() && "Struct needs at least one member type"); + ArrayRef noLayout(nullptr, size_t(0)); + return Base::get(memberTypes[0].getContext(), TypeKind::Struct, memberTypes, + noLayout); +} + +StructType StructType::get(ArrayRef memberTypes, + ArrayRef layoutInfo) { + assert(!memberTypes.empty() && "Struct needs at least one member type"); + return Base::get(memberTypes.vec().front().getContext(), TypeKind::Struct, + memberTypes, layoutInfo); +} + +size_t StructType::getNumMembers() const { + return getImpl()->getSubclassData(); +} + +Type StructType::getMemberType(size_t i) const { + assert( + getNumMembers() > i && + "element index is more than number of members of the SPIR-V StructType"); + return getImpl()->memberTypes[i]; +} + +bool StructType::hasLayout() const { return getImpl()->layoutInfo; } + +uint64_t StructType::getOffset(size_t i) const { + assert( + getNumMembers() > i && + "element index is more than number of members of the SPIR-V StructType"); + return getImpl()->layoutInfo[i]; +} diff --git a/mlir/test/SPIRV/types.mlir b/mlir/test/SPIRV/types.mlir index 857871a..ffd0fb3 100644 --- a/mlir/test/SPIRV/types.mlir +++ b/mlir/test/SPIRV/types.mlir @@ -200,3 +200,51 @@ func @image_parameters_nocomma_4(!spv.image) -> () +// ----- + +//===----------------------------------------------------------------------===// +// StructType +//===----------------------------------------------------------------------===// + +// CHECK: func @struct_type(!spv.struct) +func @struct_type(!spv.struct) -> () + +// CHECK: func @struct_type2(!spv.struct) +func @struct_type2(!spv.struct) -> () + +// CHECK: func @struct_type_simple(!spv.struct>) +func @struct_type_simple(!spv.struct>) -> () + +// CHECK: func @struct_type_with_offset(!spv.struct) +func @struct_type_with_offset(!spv.struct) -> () + +// CHECK: func @nested_struct(!spv.struct>) +func @nested_struct(!spv.struct>) + +// CHECK: func @nested_struct_with_offset(!spv.struct [4]>) +func @nested_struct_with_offset(!spv.struct [4]>) + +// ----- + +// expected-error @+1 {{layout specification must be given for all members}} +func @struct_type_missing_offset1((!spv.struct) -> () + +// ----- + +// expected-error @+1 {{layout specification must be given for all members}} +func @struct_type_missing_offset2(!spv.struct) -> () + +// ----- + +// expected-error @+1 {{cannot parse type: f32 i32}} +func @struct_type_missing_comma1(!spv.struct) -> () + +// ----- + +// expected-error @+1 {{unexpected extra tokens in layout information: ' i32'}} +func @struct_type_missing_comma2(!spv.struct) -> () + +// ----- + +// expected-error @+1 {{expected unsigned integer to specify offset of member in struct}} +func @struct_type_neg_offset(!spv.struct) -> ()