[spirv] Extend spv.array with Layoutinfo
authorDenis Khalikov <dennis.khalikov@gmail.com>
Fri, 16 Aug 2019 17:17:47 +0000 (10:17 -0700)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 16 Aug 2019 17:18:14 +0000 (10:18 -0700)
Extend spv.array with Layoutinfo to support (de)serialization.

Closes tensorflow/mlir#80

PiperOrigin-RevId: 263795304

mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h
mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp
mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp
mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
mlir/test/Dialect/SPIRV/Serialization/array_stride.mlir [new file with mode: 0644]
mlir/test/Dialect/SPIRV/types.mlir

index 264fed3c5aeb2b0b8064e8bd7b58776a9e7d46a5..b25c7a3091747a25892ca4b0dd9335541b05b256 100644 (file)
@@ -73,14 +73,23 @@ class ArrayType : public Type::TypeBase<ArrayType, CompositeType,
                                         detail::ArrayTypeStorage> {
 public:
   using Base::Base;
+  // Zero layout specifies that is no layout
+  using LayoutInfo = uint64_t;
 
   static bool kindof(unsigned kind) { return kind == TypeKind::Array; }
 
   static ArrayType get(Type elementType, unsigned elementCount);
 
+  static ArrayType get(Type elementType, unsigned elementCount,
+                       LayoutInfo layoutInfo);
+
   unsigned getNumElements() const;
 
   Type getElementType() const;
+
+  bool hasLayout() const;
+
+  uint64_t getArrayStride() const;
 };
 
 // SPIR-V image type
index 622bb221b3fcb96f3ee293144adaec1ceea0f578..40d877a7225ab7551d7dd2c42aadbd63c7124ea5 100644 (file)
@@ -53,6 +53,18 @@ SPIRVDialect::SPIRVDialect(MLIRContext *context)
 // Type Parsing
 //===----------------------------------------------------------------------===//
 
+// Forward declarations.
+template <typename ValTy>
+static Optional<ValTy> parseAndVerify(SPIRVDialect const &dialect, Location loc,
+                                      StringRef spec);
+template <>
+Optional<Type> parseAndVerify<Type>(SPIRVDialect const &dialect, Location loc,
+                                    StringRef spec);
+
+template <>
+Optional<uint64_t> parseAndVerify(SPIRVDialect const &dialect, Location loc,
+                                  StringRef spec);
+
 // Parses "<number> x" from the beginning of `spec`.
 static bool parseNumberX(StringRef &spec, int64_t &number) {
   spec = spec.ltrim();
@@ -150,7 +162,8 @@ static Type parseAndVerifyType(SPIRVDialect const &dialect, StringRef spec,
 //                | vector-type
 //                | spirv-type
 //
-// array-type ::= `!spv.array<` integer-literal `x` element-type `>`
+// array-type ::= `!spv.array<` integer-literal `x` element-type
+//                (`[` integer-literal `]`)? `>`
 static Type parseArrayType(SPIRVDialect const &dialect, StringRef spec,
                            Location loc) {
   if (!spec.consume_front("array<") || !spec.consume_back(">")) {
@@ -171,11 +184,37 @@ static Type parseArrayType(SPIRVDialect const &dialect, StringRef spec,
     return Type();
   }
 
+  ArrayType::LayoutInfo layoutInfo = 0;
+  size_t lastLSquare;
+
+  // Handle case when element type is not a trivial type
+  auto lastRDelimiter = spec.rfind('>');
+  if (lastRDelimiter != StringRef::npos) {
+    lastLSquare = spec.find('[', lastRDelimiter);
+  } else {
+    lastLSquare = spec.rfind('[');
+  }
+
+  if (lastLSquare != StringRef::npos) {
+    auto layoutSpec = spec.substr(lastLSquare);
+    auto layout =
+        parseAndVerify<ArrayType::LayoutInfo>(dialect, loc, layoutSpec);
+    if (!layout) {
+      return Type();
+    }
+
+    if (!(layoutInfo = layout.getValue())) {
+      emitError(loc, "ArrayStride must be greater than zero");
+      return Type();
+    }
+    spec = spec.substr(0, lastLSquare);
+  }
+
   Type elementType = parseAndVerifyType(dialect, spec, loc);
   if (!elementType)
     return Type();
 
-  return ArrayType::get(elementType, count);
+  return ArrayType::get(elementType, count, layoutInfo);
 }
 
 // TODO(ravishankarm) : Reorder methods to be utilities first and parse*Type
@@ -267,18 +306,17 @@ Optional<Type> parseAndVerify<Type>(SPIRVDialect const &dialect, Location loc,
 }
 
 template <>
-Optional<spirv::StructType::LayoutInfo>
-parseAndVerify(SPIRVDialect const &dialect, Location loc, StringRef spec) {
+Optional<uint64_t> parseAndVerify(SPIRVDialect const &dialect, Location loc,
+                                  StringRef spec) {
   uint64_t offsetVal = std::numeric_limits<uint64_t>::max();
   if (!spec.consume_front("[")) {
     emitError(loc, "expected '[' while parsing layout specification in '")
         << spec << "'";
     return llvm::None;
   }
+  spec = spec.trim();
   if (spec.consumeInteger(10, offsetVal)) {
-    emitError(
-        loc,
-        "expected unsigned integer to specify offset of member in struct: '")
+    emitError(loc, "expected unsigned integer to specify layout information: '")
         << spec << "'";
     return llvm::None;
   }
@@ -292,7 +330,7 @@ parseAndVerify(SPIRVDialect const &dialect, Location loc, StringRef spec) {
         << spec << "'";
     return llvm::None;
   }
-  return spirv::StructType::LayoutInfo{offsetVal};
+  return offsetVal;
 }
 
 // Functor object to parse a comma separated list of specs. The function
@@ -530,8 +568,11 @@ Type SPIRVDialect::parseType(StringRef spec, Location loc) const {
 //===----------------------------------------------------------------------===//
 
 static void print(ArrayType type, llvm::raw_ostream &os) {
-  os << "array<" << type.getNumElements() << " x " << type.getElementType()
-     << ">";
+  os << "array<" << type.getNumElements() << " x " << type.getElementType();
+  if (type.hasLayout()) {
+    os << " [" << type.getArrayStride() << "]";
+  }
+  os << ">";
 }
 
 static void print(RuntimeArrayType type, llvm::raw_ostream &os) {
index 345d13d42aaed5e9fa02708bbc650e1d279472a1..f79db01998f400d7f50c5bfeb3a82653c2be35d9 100644 (file)
@@ -34,7 +34,7 @@ using namespace mlir::spirv;
 //===----------------------------------------------------------------------===//
 
 struct spirv::detail::ArrayTypeStorage : public TypeStorage {
-  using KeyTy = std::pair<Type, unsigned>;
+  using KeyTy = std::tuple<Type, unsigned, ArrayType::LayoutInfo>;
 
   static ArrayTypeStorage *construct(TypeStorageAllocator &allocator,
                                      const KeyTy &key) {
@@ -42,18 +42,26 @@ struct spirv::detail::ArrayTypeStorage : public TypeStorage {
   }
 
   bool operator==(const KeyTy &key) const {
-    return key == KeyTy(elementType, getSubclassData());
+    return key == KeyTy(elementType, getSubclassData(), layoutInfo);
   }
 
   ArrayTypeStorage(const KeyTy &key)
-      : TypeStorage(key.second), elementType(key.first) {}
+      : TypeStorage(std::get<1>(key)), elementType(std::get<0>(key)),
+        layoutInfo(std::get<2>(key)) {}
 
   Type elementType;
+  ArrayType::LayoutInfo layoutInfo;
 };
 
 ArrayType ArrayType::get(Type elementType, unsigned elementCount) {
   return Base::get(elementType.getContext(), TypeKind::Array, elementType,
-                   elementCount);
+                   elementCount, 0);
+}
+
+ArrayType ArrayType::get(Type elementType, unsigned elementCount,
+                         ArrayType::LayoutInfo layoutInfo) {
+  return Base::get(elementType.getContext(), TypeKind::Array, elementType,
+                   elementCount, layoutInfo);
 }
 
 unsigned ArrayType::getNumElements() const {
@@ -62,6 +70,11 @@ unsigned ArrayType::getNumElements() const {
 
 Type ArrayType::getElementType() const { return getImpl()->elementType; }
 
+// ArrayStride must be greater than zero
+bool ArrayType::hasLayout() const { return getImpl()->layoutInfo; }
+
+uint64_t ArrayType::getArrayStride() const { return getImpl()->layoutInfo; }
+
 //===----------------------------------------------------------------------===//
 // CompositeType
 //===----------------------------------------------------------------------===//
index 217f9b190dd6158ec4511984e289631c1bfe66d0..1aad7173dc6cf6fb9551700e722f0f2359e85c58 100644 (file)
@@ -207,6 +207,9 @@ private:
   // Result <id> to decorations mapping.
   DenseMap<uint32_t, NamedAttributeList> decorations;
 
+  // Result <id> to type decorations.
+  DenseMap<uint32_t, uint32_t> typeDecorations;
+
   // List of instructions that are processed in a defered fashion (after an
   // initial processing of the entire binary). Some operations like
   // OpEntryPoint, and OpExecutionMode use forward references to function
@@ -330,6 +333,13 @@ LogicalResult Deserializer::processDecoration(ArrayRef<uint32_t> words) {
                               opBuilder.getStringAttr(stringifyBuiltIn(
                                   static_cast<spirv::BuiltIn>(words[2]))));
     break;
+  case spirv::Decoration::ArrayStride:
+    if (words.size() != 3) {
+      return emitError(unknownLoc, "OpDecorate with ")
+             << decorationName << " needs a single integer literal";
+    }
+    typeDecorations[words[0]] = static_cast<uint32_t>(words[2]);
+    break;
   default:
     return emitError(unknownLoc, "unhandled Decoration : '") << decorationName;
   }
@@ -590,7 +600,8 @@ LogicalResult Deserializer::processArrayType(ArrayRef<uint32_t> operands) {
            << defOp->getName();
   }
 
-  typeMap[operands[0]] = spirv::ArrayType::get(elementTy, count);
+  typeMap[operands[0]] = spirv::ArrayType::get(
+      elementTy, count, typeDecorations.lookup(operands[0]));
   return success();
 }
 
index 8b55873c5c0c143dcf1d1d452aadf470a7cb15c7..d06363a1a8c51104db7698cdeb17594306f3de80 100644 (file)
@@ -132,6 +132,12 @@ private:
   LogicalResult processDecoration(Location loc, uint32_t resultID,
                                   NamedAttribute attr);
 
+  template <typename DType>
+  LogicalResult processTypeDecoration(Location loc, DType type,
+                                      uint32_t resultId) {
+    return emitError(loc, "unhandled decoraion for type:") << type;
+  }
+
   //===--------------------------------------------------------------------===//
   // Types
   //===--------------------------------------------------------------------===//
@@ -148,7 +154,7 @@ private:
 
   /// Method for preparing basic SPIR-V type serialization. Returns the type's
   /// opcode and operands for the instruction via `typeEnum` and `operands`.
-  LogicalResult prepareBasicType(Location loc, Type type,
+  LogicalResult prepareBasicType(Location loc, Type type, uint32_t resultID,
                                  spirv::Opcode &typeEnum,
                                  SmallVectorImpl<uint32_t> &operands);
 
@@ -366,6 +372,22 @@ LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID,
   return encodeInstructionInto(decorations, spirv::Opcode::OpDecorate, args);
 }
 
+namespace {
+template <>
+LogicalResult Serializer::processTypeDecoration<spirv::ArrayType>(
+    Location loc, spirv::ArrayType type, uint32_t resultID) {
+  if (type.hasLayout()) {
+    // OpDecorate %arrayTypeSSA ArrayStride strideLiteral
+    SmallVector<uint32_t, 3> args;
+    args.push_back(resultID);
+    args.push_back(static_cast<uint32_t>(spirv::Decoration::ArrayStride));
+    args.push_back(type.getArrayStride());
+    return encodeInstructionInto(decorations, spirv::Opcode::OpDecorate, args);
+  }
+  return success();
+}
+} // namespace
+
 LogicalResult Serializer::processFuncOp(FuncOp op) {
   uint32_t fnTypeID = 0;
   // Generate type of the function.
@@ -445,7 +467,7 @@ LogicalResult Serializer::processType(Location loc, Type type,
   if ((type.isa<FunctionType>() &&
        succeeded(prepareFunctionType(loc, type.cast<FunctionType>(), typeEnum,
                                      operands))) ||
-      succeeded(prepareBasicType(loc, type, typeEnum, operands))) {
+      succeeded(prepareBasicType(loc, type, typeID, typeEnum, operands))) {
     typeIDMap[type] = typeID;
     return encodeInstructionInto(typesGlobalValues, typeEnum, operands);
   }
@@ -453,7 +475,8 @@ LogicalResult Serializer::processType(Location loc, Type type,
 }
 
 LogicalResult
-Serializer::prepareBasicType(Location loc, Type type, spirv::Opcode &typeEnum,
+Serializer::prepareBasicType(Location loc, Type type, uint32_t resultID,
+                             spirv::Opcode &typeEnum,
                              SmallVectorImpl<uint32_t> &operands) {
   if (isVoidType(type)) {
     typeEnum = spirv::Opcode::OpTypeVoid;
@@ -501,9 +524,8 @@ Serializer::prepareBasicType(Location loc, Type type, spirv::Opcode &typeEnum,
             loc, mlirBuilder.getI32IntegerAttr(arrayType.getNumElements()),
             /*isSpec=*/false)) {
       operands.push_back(elementCountID);
-      return success();
     }
-    return failure();
+    return processTypeDecoration(loc, arrayType, resultID);
   }
 
   if (auto ptrType = type.dyn_cast<spirv::PointerType>()) {
diff --git a/mlir/test/Dialect/SPIRV/Serialization/array_stride.mlir b/mlir/test/Dialect/SPIRV/Serialization/array_stride.mlir
new file mode 100644 (file)
index 0000000..b7229e8
--- /dev/null
@@ -0,0 +1,13 @@
+// RUN: mlir-translate -serialize-spirv %s | mlir-translate -deserialize-spirv | FileCheck %s
+
+func @spirvmodule() {
+  spv.module "Logical" "VulkanKHR" {
+    func @array_stride(%arg0 : !spv.ptr<!spv.array<4x!spv.array<4xf32 [4]> [128]>, StorageBuffer>,
+                       %arg1 : i32, %arg2 : i32) {
+      // CHECK: {{%.*}} = spv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spv.ptr<!spv.array<4 x !spv.array<4 x f32 [4]> [128]>, StorageBuffer>
+      %2 = spv.AccessChain %arg0[%arg1, %arg2] : !spv.ptr<!spv.array<4x!spv.array<4xf32 [4]> [128]>, StorageBuffer>
+      spv.Return
+    }
+  }
+  return
+}
index 58d16cf887e4b8654e1fd9e64067f9ac2596cc20..2bfadae6b7376520266370a493950fc4e1c68106 100644 (file)
@@ -12,6 +12,9 @@ func @scalar_array_type(!spv.array<16xf32>, !spv.array<8 x i32>) -> ()
 // CHECK: func @vector_array_type(!spv.array<32 x vector<4xf32>>)
 func @vector_array_type(!spv.array< 32 x vector<4xf32> >) -> ()
 
+// CHECK: func @array_type_stride(!spv.array<4 x !spv.array<4 x f32 [4]> [128]>)
+func @array_type_stride(!spv.array< 4 x !spv.array<4 x f32 [4]> [128]>) -> ()
+
 // -----
 
 // expected-error @+1 {{spv.array delimiter <...> mismatch}}
@@ -74,6 +77,11 @@ func @llvm_type(!spv.array<4x!llvm.i32>) -> ()
 
 // -----
 
+// expected-error @+1 {{ArrayStride must be greater than zero}}
+func @array_type_zero_stide(!spv.array<4xi32 [0]>) -> ()
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // PointerType
 //===----------------------------------------------------------------------===//
@@ -246,5 +254,5 @@ func @struct_type_missing_comma2(!spv.struct<f32 [0] i32>) -> ()
 
 // -----
 
-//  expected-error @+1 {{expected unsigned integer to specify offset of member in struct}}
+//  expected-error @+1 {{expected unsigned integer to specify layout info}}
 func @struct_type_neg_offset(!spv.struct<f32 [-1]>) -> ()