detail::ArrayTypeStorage> {
public:
using Base::Base;
+ // Zero layout specifies that is no layout
+ using LayoutInfo = uint64_t;
static bool kindof(unsigned kind) { return kind == TypeKind::Array; }
static ArrayType get(Type elementType, unsigned elementCount);
+ static ArrayType get(Type elementType, unsigned elementCount,
+ LayoutInfo layoutInfo);
+
unsigned getNumElements() const;
Type getElementType() const;
+
+ bool hasLayout() const;
+
+ uint64_t getArrayStride() const;
};
// SPIR-V image type
// Type Parsing
//===----------------------------------------------------------------------===//
+// Forward declarations.
+template <typename ValTy>
+static Optional<ValTy> parseAndVerify(SPIRVDialect const &dialect, Location loc,
+ StringRef spec);
+template <>
+Optional<Type> parseAndVerify<Type>(SPIRVDialect const &dialect, Location loc,
+ StringRef spec);
+
+template <>
+Optional<uint64_t> parseAndVerify(SPIRVDialect const &dialect, Location loc,
+ StringRef spec);
+
// Parses "<number> x" from the beginning of `spec`.
static bool parseNumberX(StringRef &spec, int64_t &number) {
spec = spec.ltrim();
// | vector-type
// | spirv-type
//
-// array-type ::= `!spv.array<` integer-literal `x` element-type `>`
+// array-type ::= `!spv.array<` integer-literal `x` element-type
+// (`[` integer-literal `]`)? `>`
static Type parseArrayType(SPIRVDialect const &dialect, StringRef spec,
Location loc) {
if (!spec.consume_front("array<") || !spec.consume_back(">")) {
return Type();
}
+ ArrayType::LayoutInfo layoutInfo = 0;
+ size_t lastLSquare;
+
+ // Handle case when element type is not a trivial type
+ auto lastRDelimiter = spec.rfind('>');
+ if (lastRDelimiter != StringRef::npos) {
+ lastLSquare = spec.find('[', lastRDelimiter);
+ } else {
+ lastLSquare = spec.rfind('[');
+ }
+
+ if (lastLSquare != StringRef::npos) {
+ auto layoutSpec = spec.substr(lastLSquare);
+ auto layout =
+ parseAndVerify<ArrayType::LayoutInfo>(dialect, loc, layoutSpec);
+ if (!layout) {
+ return Type();
+ }
+
+ if (!(layoutInfo = layout.getValue())) {
+ emitError(loc, "ArrayStride must be greater than zero");
+ return Type();
+ }
+ spec = spec.substr(0, lastLSquare);
+ }
+
Type elementType = parseAndVerifyType(dialect, spec, loc);
if (!elementType)
return Type();
- return ArrayType::get(elementType, count);
+ return ArrayType::get(elementType, count, layoutInfo);
}
// TODO(ravishankarm) : Reorder methods to be utilities first and parse*Type
}
template <>
-Optional<spirv::StructType::LayoutInfo>
-parseAndVerify(SPIRVDialect const &dialect, Location loc, StringRef spec) {
+Optional<uint64_t> parseAndVerify(SPIRVDialect const &dialect, Location loc,
+ StringRef spec) {
uint64_t offsetVal = std::numeric_limits<uint64_t>::max();
if (!spec.consume_front("[")) {
emitError(loc, "expected '[' while parsing layout specification in '")
<< spec << "'";
return llvm::None;
}
+ spec = spec.trim();
if (spec.consumeInteger(10, offsetVal)) {
- emitError(
- loc,
- "expected unsigned integer to specify offset of member in struct: '")
+ emitError(loc, "expected unsigned integer to specify layout information: '")
<< spec << "'";
return llvm::None;
}
<< spec << "'";
return llvm::None;
}
- return spirv::StructType::LayoutInfo{offsetVal};
+ return offsetVal;
}
// Functor object to parse a comma separated list of specs. The function
//===----------------------------------------------------------------------===//
static void print(ArrayType type, llvm::raw_ostream &os) {
- os << "array<" << type.getNumElements() << " x " << type.getElementType()
- << ">";
+ os << "array<" << type.getNumElements() << " x " << type.getElementType();
+ if (type.hasLayout()) {
+ os << " [" << type.getArrayStride() << "]";
+ }
+ os << ">";
}
static void print(RuntimeArrayType type, llvm::raw_ostream &os) {
//===----------------------------------------------------------------------===//
struct spirv::detail::ArrayTypeStorage : public TypeStorage {
- using KeyTy = std::pair<Type, unsigned>;
+ using KeyTy = std::tuple<Type, unsigned, ArrayType::LayoutInfo>;
static ArrayTypeStorage *construct(TypeStorageAllocator &allocator,
const KeyTy &key) {
}
bool operator==(const KeyTy &key) const {
- return key == KeyTy(elementType, getSubclassData());
+ return key == KeyTy(elementType, getSubclassData(), layoutInfo);
}
ArrayTypeStorage(const KeyTy &key)
- : TypeStorage(key.second), elementType(key.first) {}
+ : TypeStorage(std::get<1>(key)), elementType(std::get<0>(key)),
+ layoutInfo(std::get<2>(key)) {}
Type elementType;
+ ArrayType::LayoutInfo layoutInfo;
};
ArrayType ArrayType::get(Type elementType, unsigned elementCount) {
return Base::get(elementType.getContext(), TypeKind::Array, elementType,
- elementCount);
+ elementCount, 0);
+}
+
+ArrayType ArrayType::get(Type elementType, unsigned elementCount,
+ ArrayType::LayoutInfo layoutInfo) {
+ return Base::get(elementType.getContext(), TypeKind::Array, elementType,
+ elementCount, layoutInfo);
}
unsigned ArrayType::getNumElements() const {
Type ArrayType::getElementType() const { return getImpl()->elementType; }
+// ArrayStride must be greater than zero
+bool ArrayType::hasLayout() const { return getImpl()->layoutInfo; }
+
+uint64_t ArrayType::getArrayStride() const { return getImpl()->layoutInfo; }
+
//===----------------------------------------------------------------------===//
// CompositeType
//===----------------------------------------------------------------------===//
// Result <id> to decorations mapping.
DenseMap<uint32_t, NamedAttributeList> decorations;
+ // Result <id> to type decorations.
+ DenseMap<uint32_t, uint32_t> typeDecorations;
+
// List of instructions that are processed in a defered fashion (after an
// initial processing of the entire binary). Some operations like
// OpEntryPoint, and OpExecutionMode use forward references to function
opBuilder.getStringAttr(stringifyBuiltIn(
static_cast<spirv::BuiltIn>(words[2]))));
break;
+ case spirv::Decoration::ArrayStride:
+ if (words.size() != 3) {
+ return emitError(unknownLoc, "OpDecorate with ")
+ << decorationName << " needs a single integer literal";
+ }
+ typeDecorations[words[0]] = static_cast<uint32_t>(words[2]);
+ break;
default:
return emitError(unknownLoc, "unhandled Decoration : '") << decorationName;
}
<< defOp->getName();
}
- typeMap[operands[0]] = spirv::ArrayType::get(elementTy, count);
+ typeMap[operands[0]] = spirv::ArrayType::get(
+ elementTy, count, typeDecorations.lookup(operands[0]));
return success();
}
LogicalResult processDecoration(Location loc, uint32_t resultID,
NamedAttribute attr);
+ template <typename DType>
+ LogicalResult processTypeDecoration(Location loc, DType type,
+ uint32_t resultId) {
+ return emitError(loc, "unhandled decoraion for type:") << type;
+ }
+
//===--------------------------------------------------------------------===//
// Types
//===--------------------------------------------------------------------===//
/// Method for preparing basic SPIR-V type serialization. Returns the type's
/// opcode and operands for the instruction via `typeEnum` and `operands`.
- LogicalResult prepareBasicType(Location loc, Type type,
+ LogicalResult prepareBasicType(Location loc, Type type, uint32_t resultID,
spirv::Opcode &typeEnum,
SmallVectorImpl<uint32_t> &operands);
return encodeInstructionInto(decorations, spirv::Opcode::OpDecorate, args);
}
+namespace {
+template <>
+LogicalResult Serializer::processTypeDecoration<spirv::ArrayType>(
+ Location loc, spirv::ArrayType type, uint32_t resultID) {
+ if (type.hasLayout()) {
+ // OpDecorate %arrayTypeSSA ArrayStride strideLiteral
+ SmallVector<uint32_t, 3> args;
+ args.push_back(resultID);
+ args.push_back(static_cast<uint32_t>(spirv::Decoration::ArrayStride));
+ args.push_back(type.getArrayStride());
+ return encodeInstructionInto(decorations, spirv::Opcode::OpDecorate, args);
+ }
+ return success();
+}
+} // namespace
+
LogicalResult Serializer::processFuncOp(FuncOp op) {
uint32_t fnTypeID = 0;
// Generate type of the function.
if ((type.isa<FunctionType>() &&
succeeded(prepareFunctionType(loc, type.cast<FunctionType>(), typeEnum,
operands))) ||
- succeeded(prepareBasicType(loc, type, typeEnum, operands))) {
+ succeeded(prepareBasicType(loc, type, typeID, typeEnum, operands))) {
typeIDMap[type] = typeID;
return encodeInstructionInto(typesGlobalValues, typeEnum, operands);
}
}
LogicalResult
-Serializer::prepareBasicType(Location loc, Type type, spirv::Opcode &typeEnum,
+Serializer::prepareBasicType(Location loc, Type type, uint32_t resultID,
+ spirv::Opcode &typeEnum,
SmallVectorImpl<uint32_t> &operands) {
if (isVoidType(type)) {
typeEnum = spirv::Opcode::OpTypeVoid;
loc, mlirBuilder.getI32IntegerAttr(arrayType.getNumElements()),
/*isSpec=*/false)) {
operands.push_back(elementCountID);
- return success();
}
- return failure();
+ return processTypeDecoration(loc, arrayType, resultID);
}
if (auto ptrType = type.dyn_cast<spirv::PointerType>()) {
--- /dev/null
+// RUN: mlir-translate -serialize-spirv %s | mlir-translate -deserialize-spirv | FileCheck %s
+
+func @spirvmodule() {
+ spv.module "Logical" "VulkanKHR" {
+ func @array_stride(%arg0 : !spv.ptr<!spv.array<4x!spv.array<4xf32 [4]> [128]>, StorageBuffer>,
+ %arg1 : i32, %arg2 : i32) {
+ // CHECK: {{%.*}} = spv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spv.ptr<!spv.array<4 x !spv.array<4 x f32 [4]> [128]>, StorageBuffer>
+ %2 = spv.AccessChain %arg0[%arg1, %arg2] : !spv.ptr<!spv.array<4x!spv.array<4xf32 [4]> [128]>, StorageBuffer>
+ spv.Return
+ }
+ }
+ return
+}
// CHECK: func @vector_array_type(!spv.array<32 x vector<4xf32>>)
func @vector_array_type(!spv.array< 32 x vector<4xf32> >) -> ()
+// CHECK: func @array_type_stride(!spv.array<4 x !spv.array<4 x f32 [4]> [128]>)
+func @array_type_stride(!spv.array< 4 x !spv.array<4 x f32 [4]> [128]>) -> ()
+
// -----
// expected-error @+1 {{spv.array delimiter <...> mismatch}}
// -----
+// expected-error @+1 {{ArrayStride must be greater than zero}}
+func @array_type_zero_stide(!spv.array<4xi32 [0]>) -> ()
+
+// -----
+
//===----------------------------------------------------------------------===//
// PointerType
//===----------------------------------------------------------------------===//
// -----
-// expected-error @+1 {{expected unsigned integer to specify offset of member in struct}}
+// expected-error @+1 {{expected unsigned integer to specify layout info}}
func @struct_type_neg_offset(!spv.struct<f32 [-1]>) -> ()