[spirv] Add Block decoration for spv.struct.
authorDenis Khalikov <dennis.khalikov@gmail.com>
Tue, 27 Aug 2019 17:41:07 +0000 (10:41 -0700)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Tue, 27 Aug 2019 17:41:42 +0000 (10:41 -0700)
Add Block decoration for top-level spv.struct.

Closes tensorflow/mlir#102

PiperOrigin-RevId: 265716241

mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
mlir/unittests/Dialect/SPIRV/CMakeLists.txt
mlir/unittests/Dialect/SPIRV/SerializationTest.cpp [new file with mode: 0644]

index 48044e9..d300725 100644 (file)
@@ -468,6 +468,13 @@ LogicalResult Deserializer::processDecoration(ArrayRef<uint32_t> words) {
     }
     typeDecorations[words[0]] = static_cast<uint32_t>(words[2]);
     break;
+  case spirv::Decoration::Block:
+    if (words.size() != 2) {
+      return emitError(unknownLoc, "OpDecoration with ")
+             << decorationName << "needs a single target <id>";
+    }
+    // Block decoration does not affect spv.struct type.
+    break;
   default:
     return emitError(unknownLoc, "unhandled Decoration : '") << decorationName;
   }
index 3f1b013..03973db 100644 (file)
@@ -174,6 +174,10 @@ private:
 
   bool isVoidType(Type type) const { return type.isa<NoneType>(); }
 
+  /// Returns true if the given type is a pointer type to a struct in Uniform or
+  /// StorageBuffer storage class.
+  bool isInterfaceStructPtrType(Type type) const;
+
   /// Main dispatch method for serializing a type. The result <id> of the
   /// serialized type will be returned as `typeID`.
   LogicalResult processType(Location loc, Type type, uint32_t &typeID);
@@ -558,6 +562,22 @@ Serializer::processGlobalVariableOp(spirv::GlobalVariableOp varOp) {
   if (failed(processType(varOp.getLoc(), varOp.type(), resultTypeID))) {
     return failure();
   }
+
+  if (isInterfaceStructPtrType(varOp.type())) {
+    auto structType = varOp.type()
+                          .cast<spirv::PointerType>()
+                          .getPointeeType()
+                          .cast<spirv::StructType>();
+    SmallVector<uint32_t, 2> args{
+        findTypeID(structType),
+        static_cast<uint32_t>(spirv::Decoration::Block)};
+    if (failed(encodeInstructionInto(decorations, spirv::Opcode::OpDecorate,
+                                     args))) {
+      return varOp.emitError("cannot decorate ")
+             << structType << " with Block decoration";
+    }
+  }
+
   elidedAttrs.push_back("type");
   SmallVector<uint32_t, 4> operands;
   operands.push_back(resultTypeID);
@@ -609,6 +629,17 @@ Serializer::processGlobalVariableOp(spirv::GlobalVariableOp varOp) {
 // Type
 //===----------------------------------------------------------------------===//
 
+bool Serializer::isInterfaceStructPtrType(Type type) const {
+  if (auto ptrType = type.dyn_cast<spirv::PointerType>()) {
+    auto storageClass = ptrType.getStorageClass();
+    if (storageClass == spirv::StorageClass::Uniform ||
+        storageClass == spirv::StorageClass::StorageBuffer) {
+      return ptrType.getPointeeType().isa<spirv::StructType>();
+    }
+  }
+  return false;
+}
+
 LogicalResult Serializer::processType(Location loc, Type type,
                                       uint32_t &typeID) {
   typeID = findTypeID(type);
index 4e85160..b444b5c 100644 (file)
@@ -1,8 +1,11 @@
 add_mlir_unittest(MLIRSPIRVTests
   DeserializationTest.cpp
+  SerializationTest.cpp
 )
 target_link_libraries(MLIRSPIRVTests
   PRIVATE
   MLIRSPIRV
   MLIRSPIRVSerialization)
 
+whole_archive_link(MLIRSPIRVTests MLIRSPIRV)
+
diff --git a/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp b/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp
new file mode 100644 (file)
index 0000000..65758a7
--- /dev/null
@@ -0,0 +1,124 @@
+//===- SerializationTest.cpp - SPIR-V Seserialization Tests -------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file contains corner case tests for the SPIR-V serializer that are not
+// covered by normal serialization and deserialization roundtripping.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SPIRV/Serialization.h"
+#include "mlir/Dialect/SPIRV/SPIRVBinaryUtils.h"
+#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
+#include "mlir/Dialect/SPIRV/SPIRVOps.h"
+#include "mlir/Dialect/SPIRV/SPIRVTypes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/MLIRContext.h"
+#include "llvm/ADT/DenseSet.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/Sequence.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/StringRef.h"
+#include "gmock/gmock.h"
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// Test Fixture
+//===----------------------------------------------------------------------===//
+
+class SerializationTest : public ::testing::Test {
+protected:
+  SerializationTest() { createModuleOp(); }
+
+  void createModuleOp() {
+    Builder builder(&context);
+    OperationState state(UnknownLoc::get(&context),
+                         spirv::ModuleOp::getOperationName());
+    state.addAttribute("addressing_model",
+                       builder.getI32IntegerAttr(static_cast<uint32_t>(
+                           spirv::AddressingModel::Logical)));
+    state.addAttribute("memory_model",
+                       builder.getI32IntegerAttr(
+                           static_cast<uint32_t>(spirv::MemoryModel::GLSL450)));
+    spirv::ModuleOp::build(&builder, &state);
+    module = cast<spirv::ModuleOp>(Operation::create(state));
+  }
+
+  Type getFloatStructType() {
+    OpBuilder opBuilder(module.body());
+    llvm::SmallVector<Type, 1> elementTypes{opBuilder.getF32Type()};
+    llvm::SmallVector<spirv::StructType::LayoutInfo, 1> layoutInfo{0};
+    auto structType = spirv::StructType::get(elementTypes, layoutInfo);
+    return structType;
+  }
+
+  void addGlobalVar(Type type, llvm::StringRef name) {
+    OpBuilder opBuilder(module.body());
+    auto ptrType = spirv::PointerType::get(type, spirv::StorageClass::Uniform);
+    opBuilder.create<spirv::GlobalVariableOp>(
+        UnknownLoc::get(&context), opBuilder.getTypeAttr(ptrType),
+        opBuilder.getStringAttr(name), nullptr);
+  }
+
+  bool findInstruction(llvm::function_ref<bool(spirv::Opcode opcode,
+                                               ArrayRef<uint32_t> operands)>
+                           matchFn) {
+    auto binarySize = binary.size();
+    auto begin = binary.begin();
+    auto currOffset = spirv::kHeaderWordCount;
+
+    while (currOffset < binarySize) {
+      auto wordCount = binary[currOffset] >> 16;
+      if (!wordCount || (currOffset + wordCount > binarySize)) {
+        return false;
+      }
+      spirv::Opcode opcode =
+          static_cast<spirv::Opcode>(binary[currOffset] & 0xffff);
+
+      if (matchFn(opcode,
+                  llvm::ArrayRef<uint32_t>(begin + currOffset + 1,
+                                           begin + currOffset + wordCount))) {
+        return true;
+      }
+      currOffset += wordCount;
+    }
+    return false;
+  }
+
+protected:
+  MLIRContext context;
+  spirv::ModuleOp module;
+  SmallVector<uint32_t, 0> binary;
+};
+
+//===----------------------------------------------------------------------===//
+// Block decoration
+//===----------------------------------------------------------------------===//
+
+TEST_F(SerializationTest, BlockDecorationTest) {
+  auto structType = getFloatStructType();
+  addGlobalVar(structType, "var0");
+  ASSERT_TRUE(succeeded(spirv::serialize(module, binary)));
+  auto hasBlockDecoration = [](spirv::Opcode opcode,
+                               ArrayRef<uint32_t> operands) -> bool {
+    if (opcode != spirv::Opcode::OpDecorate || operands.size() != 2)
+      return false;
+    return operands[1] == static_cast<uint32_t>(spirv::Decoration::Block);
+  };
+  EXPECT_TRUE(findInstruction(hasBlockDecoration));
+}