[spirv] Add array and run-time array types
authorLei Zhang <antiagainst@google.com>
Mon, 10 Jun 2019 19:21:44 +0000 (12:21 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Tue, 11 Jun 2019 17:12:53 +0000 (10:12 -0700)
PiperOrigin-RevId: 252458108

mlir/include/mlir/SPIRV/SPIRVDialect.h
mlir/include/mlir/SPIRV/SPIRVTypes.h [new file with mode: 0644]
mlir/lib/SPIRV/CMakeLists.txt
mlir/lib/SPIRV/SPIRVDialect.cpp
mlir/lib/SPIRV/SPIRVTypes.cpp [new file with mode: 0644]
mlir/test/SPIRV/types.mlir [new file with mode: 0644]

index 14a576b..cf3d0af 100644 (file)
@@ -32,6 +32,12 @@ namespace spirv {
 class SPIRVDialect : public Dialect {
 public:
   explicit SPIRVDialect(MLIRContext *context);
+
+  /// Parses a type registered to this dialect.
+  Type parseType(llvm::StringRef spec, Location loc) const override;
+
+  /// Prints a type registered to this dialect.
+  void printType(Type type, llvm::raw_ostream &os) const override;
 };
 
 } // end namespace spirv
diff --git a/mlir/include/mlir/SPIRV/SPIRVTypes.h b/mlir/include/mlir/SPIRV/SPIRVTypes.h
new file mode 100644 (file)
index 0000000..2e2c819
--- /dev/null
@@ -0,0 +1,74 @@
+//===- SPIRVTypes.h - MLIR SPIR-V Types -------------------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file declares the types in the SPIR-V dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_SPIRV_SPIRVTYPES_H_
+#define MLIR_SPIRV_SPIRVTYPES_H_
+
+#include "mlir/IR/Types.h"
+
+namespace mlir {
+namespace spirv {
+
+namespace detail {
+struct ArrayTypeStorage;
+struct RuntimeArrayTypeStorage;
+} // namespace detail
+
+namespace TypeKind {
+enum Kind {
+  Array = Type::FIRST_SPIRV_TYPE,
+  RuntimeArray,
+};
+}
+
+// SPIR-V array type
+class ArrayType
+    : public Type::TypeBase<ArrayType, Type, detail::ArrayTypeStorage> {
+public:
+  using Base::Base;
+
+  static bool kindof(unsigned kind) { return kind == TypeKind::Array; }
+
+  static ArrayType get(Type elementType, int64_t elementCount);
+
+  Type getElementType();
+
+  int64_t getElementCount();
+};
+
+// SPIR-V run-time array type
+class RuntimeArrayType
+    : public Type::TypeBase<RuntimeArrayType, Type,
+                            detail::RuntimeArrayTypeStorage> {
+public:
+  using Base::Base;
+
+  static bool kindof(unsigned kind) { return kind == TypeKind::RuntimeArray; }
+
+  static RuntimeArrayType get(Type elementType);
+
+  Type getElementType();
+};
+
+} // end namespace spirv
+} // end namespace mlir
+
+#endif // MLIR_SPIRV_SPIRVTYPES_H_
index 622a7b4..c539016 100644 (file)
@@ -2,6 +2,7 @@ add_llvm_library(MLIRSPIRV
   DialectRegistration.cpp
   SPIRVDialect.cpp
   SPIRVOps.cpp
+  SPIRVTypes.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/SPIRV
index 77b9301..15dce6d 100644 (file)
 
 #include "mlir/SPIRV/SPIRVDialect.h"
 
+#include "mlir/IR/Builders.h"
 #include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/StandardTypes.h"
 #include "mlir/SPIRV/SPIRVOps.h"
+#include "mlir/SPIRV/SPIRVTypes.h"
+#include "llvm/ADT/StringExtras.h"
+#include "llvm/ADT/StringSwitch.h"
+#include "llvm/Support/raw_ostream.h"
 
-namespace mlir {
-namespace spirv {
+using namespace mlir;
+using namespace mlir::spirv;
+
+//===----------------------------------------------------------------------===//
+// SPIR-V Dialect
+//===----------------------------------------------------------------------===//
 
 SPIRVDialect::SPIRVDialect(MLIRContext *context) : Dialect("spv", context) {
+  addTypes<ArrayType, RuntimeArrayType>();
+
   addOperations<
 #define GET_OP_LIST
 #include "mlir/SPIRV/SPIRVOps.cpp.inc"
@@ -41,5 +53,129 @@ SPIRVDialect::SPIRVDialect(MLIRContext *context) : Dialect("spv", context) {
   allowUnknownOperations();
 }
 
-} // namespace spirv
-} // namespace mlir
+//===----------------------------------------------------------------------===//
+// Type Parsing
+//===----------------------------------------------------------------------===//
+
+// TODO(b/133530217): The following implements some type parsing logic. It is
+// intended to be short-lived and used just before the main parser logic gets
+// exposed to dialects. So there is little type checking inside.
+
+static Type parseScalarType(StringRef spec, Builder builder) {
+  return llvm::StringSwitch<Type>(spec)
+      .Case("f32", builder.getF32Type())
+      .Case("i32", builder.getIntegerType(32))
+      .Case("f16", builder.getF16Type())
+      .Case("i16", builder.getIntegerType(16))
+      .Default(Type());
+}
+
+// Parses "<number> x" from the beginning of `spec`.
+static bool parseNumberX(StringRef &spec, int64_t &number) {
+  spec = spec.ltrim();
+  if (spec.empty() || !llvm::isDigit(spec.front()))
+    return false;
+
+  number = 0;
+  do {
+    number = number * 10 + spec.front() - '0';
+    spec = spec.drop_front();
+  } while (!spec.empty() && llvm::isDigit(spec.front()));
+
+  spec = spec.ltrim();
+  if (!spec.consume_front("x"))
+    return false;
+
+  return true;
+}
+
+static Type parseVectorType(StringRef spec, Builder builder) {
+  if (!spec.consume_front("vector<") || !spec.consume_back(">"))
+    return Type();
+
+  int64_t count = 0;
+  if (!parseNumberX(spec, count))
+    return Type();
+
+  spec = spec.trim();
+  auto scalarType = parseScalarType(spec, builder);
+  if (!scalarType)
+    return Type();
+
+  return VectorType::get({count}, scalarType);
+}
+
+static Type parseArrayType(StringRef spec, Builder builder) {
+  if (!spec.consume_front("array<") || !spec.consume_back(">"))
+    return Type();
+
+  Type elementType;
+  int64_t count = 0;
+
+  spec = spec.trim();
+  if (!parseNumberX(spec, count))
+    return Type();
+
+  spec = spec.ltrim();
+  if (spec.startswith("vector")) {
+    elementType = parseVectorType(spec, builder);
+  } else {
+    elementType = parseScalarType(spec, builder);
+  }
+  if (!elementType)
+    return Type();
+
+  return ArrayType::get(elementType, count);
+}
+
+static Type parseRuntimeArrayType(StringRef spec, Builder builder) {
+  if (!spec.consume_front("rtarray<") || !spec.consume_back(">"))
+    return Type();
+
+  Type elementType;
+  spec = spec.trim();
+  if (spec.startswith("vector")) {
+    elementType = parseVectorType(spec, builder);
+  } else {
+    elementType = parseScalarType(spec, builder);
+  }
+  if (!elementType)
+    return Type();
+
+  return RuntimeArrayType::get(elementType);
+}
+
+Type SPIRVDialect::parseType(StringRef spec, Location loc) const {
+  Builder builder(getContext());
+
+  if (auto type = parseArrayType(spec, builder))
+    return type;
+  if (auto type = parseRuntimeArrayType(spec, builder))
+    return type;
+
+  getContext()->emitError(loc, "unknown SPIR-V type: ") << spec;
+  return Type();
+}
+
+//===----------------------------------------------------------------------===//
+// Type Printing
+//===----------------------------------------------------------------------===//
+
+static void print(ArrayType type, llvm::raw_ostream &os) {
+  os << "array<" << type.getElementCount() << " x " << type.getElementType()
+     << ">";
+}
+
+static void print(RuntimeArrayType type, llvm::raw_ostream &os) {
+  os << "rtarray<" << type.getElementType() << ">";
+}
+
+void SPIRVDialect::printType(Type type, llvm::raw_ostream &os) const {
+  if (auto t = type.dyn_cast<ArrayType>()) {
+    print(t, os);
+  } else if (auto t = type.dyn_cast<RuntimeArrayType>()) {
+    print(t, os);
+  } else {
+    llvm_unreachable("unhandled SPIR-V type");
+  }
+}
diff --git a/mlir/lib/SPIRV/SPIRVTypes.cpp b/mlir/lib/SPIRV/SPIRVTypes.cpp
new file mode 100644 (file)
index 0000000..1e24675
--- /dev/null
@@ -0,0 +1,84 @@
+//===- SPIRVTypes.cpp - MLIR SPIR-V Types ---------------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file defines the types in the SPIR-V dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/SPIRV/SPIRVTypes.h"
+
+using namespace mlir;
+using namespace mlir::spirv;
+
+//===----------------------------------------------------------------------===//
+// ArrayType
+//===----------------------------------------------------------------------===//
+
+struct spirv::detail::ArrayTypeStorage : public TypeStorage {
+  using KeyTy = std::pair<Type, int64_t>;
+
+  static ArrayTypeStorage *construct(TypeStorageAllocator &allocator,
+                                     const KeyTy &key) {
+    return new (allocator.allocate<ArrayTypeStorage>()) ArrayTypeStorage(key);
+  }
+
+  bool operator==(const KeyTy &key) const {
+    return key == KeyTy(elementType, elementCount);
+  }
+
+  ArrayTypeStorage(const KeyTy &key)
+      : elementType(key.first), elementCount(key.second) {}
+
+  Type elementType;
+  int64_t elementCount;
+};
+
+ArrayType ArrayType::get(Type elementType, int64_t elementCount) {
+  return Base::get(elementType.getContext(), TypeKind::Array, elementType,
+                   elementCount);
+}
+
+Type ArrayType::getElementType() { return getImpl()->elementType; }
+
+int64_t ArrayType::getElementCount() { return getImpl()->elementCount; }
+
+//===----------------------------------------------------------------------===//
+// RuntimeArrayType
+//===----------------------------------------------------------------------===//
+
+struct spirv::detail::RuntimeArrayTypeStorage : public TypeStorage {
+  using KeyTy = Type;
+
+  static RuntimeArrayTypeStorage *construct(TypeStorageAllocator &allocator,
+                                            const KeyTy &key) {
+    return new (allocator.allocate<RuntimeArrayTypeStorage>())
+        RuntimeArrayTypeStorage(key);
+  }
+
+  bool operator==(const KeyTy &key) const { return elementType == key; }
+
+  RuntimeArrayTypeStorage(const KeyTy &key) : elementType(key) {}
+
+  Type elementType;
+};
+
+RuntimeArrayType RuntimeArrayType::get(Type elementType) {
+  return Base::get(elementType.getContext(), TypeKind::RuntimeArray,
+                   elementType);
+}
+
+Type RuntimeArrayType::getElementType() { return getImpl()->elementType; }
diff --git a/mlir/test/SPIRV/types.mlir b/mlir/test/SPIRV/types.mlir
new file mode 100644 (file)
index 0000000..9a62f9d
--- /dev/null
@@ -0,0 +1,45 @@
+// RUN: mlir-opt -split-input-file -verify %s | FileCheck %s
+
+// TODO(b/133530217): Add more tests after switching to the generic parser.
+
+//===----------------------------------------------------------------------===//
+// ArrayType
+//===----------------------------------------------------------------------===//
+
+// CHECK: func @scalar_array_type(!spv.array<16 x f32>, !spv.array<8 x i32>)
+func @scalar_array_type(!spv.array<16xf32>, !spv.array<8 x i32>) -> ()
+
+// CHECK: func @vector_array_type(!spv.array<32 x vector<4xf32>>)
+func @vector_array_type(!spv.array< 32 x vector<4xf32> >) -> ()
+
+// -----
+
+// expected-error @+1 {{unknown SPIR-V type}}
+func @missing_count(!spv.array<f32>) -> ()
+
+// -----
+
+// expected-error @+1 {{unknown SPIR-V type}}
+func @missing_x(!spv.array<4 f32>) -> ()
+
+// -----
+
+// expected-error @+1 {{unknown SPIR-V type}}
+func @more_than_one_dim(!spv.array<4x3xf32>) -> ()
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// RuntimeArrayType
+//===----------------------------------------------------------------------===//
+
+// CHECK: func @scalar_runtime_array_type(!spv.rtarray<f32>, !spv.rtarray<i32>)
+func @scalar_runtime_array_type(!spv.rtarray<f32>, !spv.rtarray<i32>) -> ()
+
+// CHECK: func @vector_runtime_array_type(!spv.rtarray<vector<4xf32>>)
+func @vector_runtime_array_type(!spv.rtarray< vector<4xf32> >) -> ()
+
+// -----
+
+// expected-error @+1 {{unknown SPIR-V type}}
+func @redundant_count(!spv.rtarray<4xf32>) -> ()