class SPIRVDialect : public Dialect {
public:
explicit SPIRVDialect(MLIRContext *context);
+
+ /// Parses a type registered to this dialect.
+ Type parseType(llvm::StringRef spec, Location loc) const override;
+
+ /// Prints a type registered to this dialect.
+ void printType(Type type, llvm::raw_ostream &os) const override;
};
} // end namespace spirv
--- /dev/null
+//===- SPIRVTypes.h - MLIR SPIR-V Types -------------------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file declares the types in the SPIR-V dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_SPIRV_SPIRVTYPES_H_
+#define MLIR_SPIRV_SPIRVTYPES_H_
+
+#include "mlir/IR/Types.h"
+
+namespace mlir {
+namespace spirv {
+
+namespace detail {
+struct ArrayTypeStorage;
+struct RuntimeArrayTypeStorage;
+} // namespace detail
+
+namespace TypeKind {
+enum Kind {
+ Array = Type::FIRST_SPIRV_TYPE,
+ RuntimeArray,
+};
+}
+
+// SPIR-V array type
+class ArrayType
+ : public Type::TypeBase<ArrayType, Type, detail::ArrayTypeStorage> {
+public:
+ using Base::Base;
+
+ static bool kindof(unsigned kind) { return kind == TypeKind::Array; }
+
+ static ArrayType get(Type elementType, int64_t elementCount);
+
+ Type getElementType();
+
+ int64_t getElementCount();
+};
+
+// SPIR-V run-time array type
+class RuntimeArrayType
+ : public Type::TypeBase<RuntimeArrayType, Type,
+ detail::RuntimeArrayTypeStorage> {
+public:
+ using Base::Base;
+
+ static bool kindof(unsigned kind) { return kind == TypeKind::RuntimeArray; }
+
+ static RuntimeArrayType get(Type elementType);
+
+ Type getElementType();
+};
+
+} // end namespace spirv
+} // end namespace mlir
+
+#endif // MLIR_SPIRV_SPIRVTYPES_H_
DialectRegistration.cpp
SPIRVDialect.cpp
SPIRVOps.cpp
+ SPIRVTypes.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/SPIRV
#include "mlir/SPIRV/SPIRVDialect.h"
+#include "mlir/IR/Builders.h"
#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/StandardTypes.h"
#include "mlir/SPIRV/SPIRVOps.h"
+#include "mlir/SPIRV/SPIRVTypes.h"
+#include "llvm/ADT/StringExtras.h"
+#include "llvm/ADT/StringSwitch.h"
+#include "llvm/Support/raw_ostream.h"
-namespace mlir {
-namespace spirv {
+using namespace mlir;
+using namespace mlir::spirv;
+
+//===----------------------------------------------------------------------===//
+// SPIR-V Dialect
+//===----------------------------------------------------------------------===//
SPIRVDialect::SPIRVDialect(MLIRContext *context) : Dialect("spv", context) {
+ addTypes<ArrayType, RuntimeArrayType>();
+
addOperations<
#define GET_OP_LIST
#include "mlir/SPIRV/SPIRVOps.cpp.inc"
allowUnknownOperations();
}
-} // namespace spirv
-} // namespace mlir
+//===----------------------------------------------------------------------===//
+// Type Parsing
+//===----------------------------------------------------------------------===//
+
+// TODO(b/133530217): The following implements some type parsing logic. It is
+// intended to be short-lived and used just before the main parser logic gets
+// exposed to dialects. So there is little type checking inside.
+
+static Type parseScalarType(StringRef spec, Builder builder) {
+ return llvm::StringSwitch<Type>(spec)
+ .Case("f32", builder.getF32Type())
+ .Case("i32", builder.getIntegerType(32))
+ .Case("f16", builder.getF16Type())
+ .Case("i16", builder.getIntegerType(16))
+ .Default(Type());
+}
+
+// Parses "<number> x" from the beginning of `spec`.
+static bool parseNumberX(StringRef &spec, int64_t &number) {
+ spec = spec.ltrim();
+ if (spec.empty() || !llvm::isDigit(spec.front()))
+ return false;
+
+ number = 0;
+ do {
+ number = number * 10 + spec.front() - '0';
+ spec = spec.drop_front();
+ } while (!spec.empty() && llvm::isDigit(spec.front()));
+
+ spec = spec.ltrim();
+ if (!spec.consume_front("x"))
+ return false;
+
+ return true;
+}
+
+static Type parseVectorType(StringRef spec, Builder builder) {
+ if (!spec.consume_front("vector<") || !spec.consume_back(">"))
+ return Type();
+
+ int64_t count = 0;
+ if (!parseNumberX(spec, count))
+ return Type();
+
+ spec = spec.trim();
+ auto scalarType = parseScalarType(spec, builder);
+ if (!scalarType)
+ return Type();
+
+ return VectorType::get({count}, scalarType);
+}
+
+static Type parseArrayType(StringRef spec, Builder builder) {
+ if (!spec.consume_front("array<") || !spec.consume_back(">"))
+ return Type();
+
+ Type elementType;
+ int64_t count = 0;
+
+ spec = spec.trim();
+ if (!parseNumberX(spec, count))
+ return Type();
+
+ spec = spec.ltrim();
+ if (spec.startswith("vector")) {
+ elementType = parseVectorType(spec, builder);
+ } else {
+ elementType = parseScalarType(spec, builder);
+ }
+ if (!elementType)
+ return Type();
+
+ return ArrayType::get(elementType, count);
+}
+
+static Type parseRuntimeArrayType(StringRef spec, Builder builder) {
+ if (!spec.consume_front("rtarray<") || !spec.consume_back(">"))
+ return Type();
+
+ Type elementType;
+ spec = spec.trim();
+ if (spec.startswith("vector")) {
+ elementType = parseVectorType(spec, builder);
+ } else {
+ elementType = parseScalarType(spec, builder);
+ }
+ if (!elementType)
+ return Type();
+
+ return RuntimeArrayType::get(elementType);
+}
+
+Type SPIRVDialect::parseType(StringRef spec, Location loc) const {
+ Builder builder(getContext());
+
+ if (auto type = parseArrayType(spec, builder))
+ return type;
+ if (auto type = parseRuntimeArrayType(spec, builder))
+ return type;
+
+ getContext()->emitError(loc, "unknown SPIR-V type: ") << spec;
+ return Type();
+}
+
+//===----------------------------------------------------------------------===//
+// Type Printing
+//===----------------------------------------------------------------------===//
+
+static void print(ArrayType type, llvm::raw_ostream &os) {
+ os << "array<" << type.getElementCount() << " x " << type.getElementType()
+ << ">";
+}
+
+static void print(RuntimeArrayType type, llvm::raw_ostream &os) {
+ os << "rtarray<" << type.getElementType() << ">";
+}
+
+void SPIRVDialect::printType(Type type, llvm::raw_ostream &os) const {
+ if (auto t = type.dyn_cast<ArrayType>()) {
+ print(t, os);
+ } else if (auto t = type.dyn_cast<RuntimeArrayType>()) {
+ print(t, os);
+ } else {
+ llvm_unreachable("unhandled SPIR-V type");
+ }
+}
--- /dev/null
+//===- SPIRVTypes.cpp - MLIR SPIR-V Types ---------------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file defines the types in the SPIR-V dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/SPIRV/SPIRVTypes.h"
+
+using namespace mlir;
+using namespace mlir::spirv;
+
+//===----------------------------------------------------------------------===//
+// ArrayType
+//===----------------------------------------------------------------------===//
+
+struct spirv::detail::ArrayTypeStorage : public TypeStorage {
+ using KeyTy = std::pair<Type, int64_t>;
+
+ static ArrayTypeStorage *construct(TypeStorageAllocator &allocator,
+ const KeyTy &key) {
+ return new (allocator.allocate<ArrayTypeStorage>()) ArrayTypeStorage(key);
+ }
+
+ bool operator==(const KeyTy &key) const {
+ return key == KeyTy(elementType, elementCount);
+ }
+
+ ArrayTypeStorage(const KeyTy &key)
+ : elementType(key.first), elementCount(key.second) {}
+
+ Type elementType;
+ int64_t elementCount;
+};
+
+ArrayType ArrayType::get(Type elementType, int64_t elementCount) {
+ return Base::get(elementType.getContext(), TypeKind::Array, elementType,
+ elementCount);
+}
+
+Type ArrayType::getElementType() { return getImpl()->elementType; }
+
+int64_t ArrayType::getElementCount() { return getImpl()->elementCount; }
+
+//===----------------------------------------------------------------------===//
+// RuntimeArrayType
+//===----------------------------------------------------------------------===//
+
+struct spirv::detail::RuntimeArrayTypeStorage : public TypeStorage {
+ using KeyTy = Type;
+
+ static RuntimeArrayTypeStorage *construct(TypeStorageAllocator &allocator,
+ const KeyTy &key) {
+ return new (allocator.allocate<RuntimeArrayTypeStorage>())
+ RuntimeArrayTypeStorage(key);
+ }
+
+ bool operator==(const KeyTy &key) const { return elementType == key; }
+
+ RuntimeArrayTypeStorage(const KeyTy &key) : elementType(key) {}
+
+ Type elementType;
+};
+
+RuntimeArrayType RuntimeArrayType::get(Type elementType) {
+ return Base::get(elementType.getContext(), TypeKind::RuntimeArray,
+ elementType);
+}
+
+Type RuntimeArrayType::getElementType() { return getImpl()->elementType; }
--- /dev/null
+// RUN: mlir-opt -split-input-file -verify %s | FileCheck %s
+
+// TODO(b/133530217): Add more tests after switching to the generic parser.
+
+//===----------------------------------------------------------------------===//
+// ArrayType
+//===----------------------------------------------------------------------===//
+
+// CHECK: func @scalar_array_type(!spv.array<16 x f32>, !spv.array<8 x i32>)
+func @scalar_array_type(!spv.array<16xf32>, !spv.array<8 x i32>) -> ()
+
+// CHECK: func @vector_array_type(!spv.array<32 x vector<4xf32>>)
+func @vector_array_type(!spv.array< 32 x vector<4xf32> >) -> ()
+
+// -----
+
+// expected-error @+1 {{unknown SPIR-V type}}
+func @missing_count(!spv.array<f32>) -> ()
+
+// -----
+
+// expected-error @+1 {{unknown SPIR-V type}}
+func @missing_x(!spv.array<4 f32>) -> ()
+
+// -----
+
+// expected-error @+1 {{unknown SPIR-V type}}
+func @more_than_one_dim(!spv.array<4x3xf32>) -> ()
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// RuntimeArrayType
+//===----------------------------------------------------------------------===//
+
+// CHECK: func @scalar_runtime_array_type(!spv.rtarray<f32>, !spv.rtarray<i32>)
+func @scalar_runtime_array_type(!spv.rtarray<f32>, !spv.rtarray<i32>) -> ()
+
+// CHECK: func @vector_runtime_array_type(!spv.rtarray<vector<4xf32>>)
+func @vector_runtime_array_type(!spv.rtarray< vector<4xf32> >) -> ()
+
+// -----
+
+// expected-error @+1 {{unknown SPIR-V type}}
+func @redundant_count(!spv.rtarray<4xf32>) -> ()