[mlir][spirv] Avoid duplicated Block decoration during serialization
authorLei Zhang <antiagainst@google.com>
Fri, 10 Dec 2021 18:57:46 +0000 (13:57 -0500)
committerLei Zhang <antiagainst@google.com>
Sat, 11 Dec 2021 00:20:49 +0000 (19:20 -0500)
It's legal per the Vulkan / SPIR-V spec; still it's better to avoid
such duplication to have cleaner blob and reduce the binary size.

Reviewed By: scotttodd

Differential Revision: https://reviews.llvm.org/D115532

mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
mlir/unittests/Dialect/SPIRV/SerializationTest.cpp

index d877c1b..ef558aa 100644 (file)
@@ -316,18 +316,6 @@ Serializer::processGlobalVariableOp(spirv::GlobalVariableOp varOp) {
     return failure();
   }
 
-  if (isInterfaceStructPtrType(varOp.type())) {
-    auto structType = varOp.type()
-                          .cast<spirv::PointerType>()
-                          .getPointeeType()
-                          .cast<spirv::StructType>();
-    if (failed(
-            emitDecoration(getTypeID(structType), spirv::Decoration::Block))) {
-      return varOp.emitError("cannot decorate ")
-             << structType << " with Block decoration";
-    }
-  }
-
   elidedAttrs.push_back("type");
   SmallVector<uint32_t, 4> operands;
   operands.push_back(resultTypeID);
index bcead6e..bd618ec 100644 (file)
@@ -331,9 +331,9 @@ LogicalResult
 Serializer::processTypeImpl(Location loc, Type type, uint32_t &typeID,
                             SetVector<StringRef> &serializationCtx) {
   typeID = getTypeID(type);
-  if (typeID) {
+  if (typeID)
     return success();
-  }
+
   typeID = getNextID();
   SmallVector<uint32_t, 4> operands;
 
@@ -499,6 +499,14 @@ LogicalResult Serializer::prepareBasicType(
     typeEnum = spirv::Opcode::OpTypePointer;
     operands.push_back(static_cast<uint32_t>(ptrType.getStorageClass()));
     operands.push_back(pointeeTypeID);
+
+    if (isInterfaceStructPtrType(ptrType)) {
+      if (failed(emitDecoration(getTypeID(pointeeStruct),
+                                spirv::Decoration::Block)))
+        return emitError(loc, "cannot decorate ")
+               << pointeeStruct << " with Block decoration";
+    }
+
     return success();
   }
 
index 9222b0c..f17bc53 100644 (file)
@@ -76,27 +76,29 @@ protected:
         builder.getStringAttr(name), nullptr);
   }
 
+  /// Handles a SPIR-V instruction with the given `opcode` and `operand`.
+  /// Returns true to interrupt.
+  using HandleFn = llvm::function_ref<bool(spirv::Opcode opcode,
+                                           ArrayRef<uint32_t> operands)>;
+
   /// Returns true if we can find a matching instruction in the SPIR-V blob.
-  bool findInstruction(llvm::function_ref<bool(spirv::Opcode opcode,
-                                               ArrayRef<uint32_t> operands)>
-                           matchFn) {
+  bool scanInstruction(HandleFn handleFn) {
     auto binarySize = binary.size();
     auto *begin = binary.begin();
     auto currOffset = spirv::kHeaderWordCount;
 
     while (currOffset < binarySize) {
       auto wordCount = binary[currOffset] >> 16;
-      if (!wordCount || (currOffset + wordCount > binarySize)) {
+      if (!wordCount || (currOffset + wordCount > binarySize))
         return false;
-      }
+
       spirv::Opcode opcode =
           static_cast<spirv::Opcode>(binary[currOffset] & 0xffff);
-
-      if (matchFn(opcode,
-                  llvm::ArrayRef<uint32_t>(begin + currOffset + 1,
-                                           begin + currOffset + wordCount))) {
+      llvm::ArrayRef<uint32_t> operands(begin + currOffset + 1,
+                                        begin + currOffset + wordCount);
+      if (handleFn(opcode, operands))
         return true;
-      }
+
       currOffset += wordCount;
     }
     return false;
@@ -119,12 +121,32 @@ TEST_F(SerializationTest, ContainsBlockDecoration) {
   ASSERT_TRUE(succeeded(spirv::serialize(module.get(), binary)));
 
   auto hasBlockDecoration = [](spirv::Opcode opcode,
-                               ArrayRef<uint32_t> operands) -> bool {
-    if (opcode != spirv::Opcode::OpDecorate || operands.size() != 2)
-      return false;
-    return operands[1] == static_cast<uint32_t>(spirv::Decoration::Block);
+                               ArrayRef<uint32_t> operands) {
+    return opcode == spirv::Opcode::OpDecorate && operands.size() == 2 &&
+           operands[1] == static_cast<uint32_t>(spirv::Decoration::Block);
+  };
+  EXPECT_TRUE(scanInstruction(hasBlockDecoration));
+}
+
+TEST_F(SerializationTest, ContainsNoDuplicatedBlockDecoration) {
+  auto structType = getFloatStructType();
+  // Two global variables using the same type should not decorate the type with
+  // duplicated `Block` decorations.
+  addGlobalVar(structType, "var0");
+  addGlobalVar(structType, "var1");
+
+  ASSERT_TRUE(succeeded(spirv::serialize(module.get(), binary)));
+
+  unsigned count = 0;
+  auto countBlockDecoration = [&count](spirv::Opcode opcode,
+                                       ArrayRef<uint32_t> operands) {
+    if (opcode == spirv::Opcode::OpDecorate && operands.size() == 2 &&
+        operands[1] == static_cast<uint32_t>(spirv::Decoration::Block))
+      ++count;
+    return false;
   };
-  EXPECT_TRUE(findInstruction(hasBlockDecoration));
+  ASSERT_FALSE(scanInstruction(countBlockDecoration));
+  EXPECT_EQ(count, 1u);
 }
 
 TEST_F(SerializationTest, ContainsSymbolName) {
@@ -140,7 +162,7 @@ TEST_F(SerializationTest, ContainsSymbolName) {
     return opcode == spirv::Opcode::OpName &&
            spirv::decodeStringLiteral(operands, index) == "var0";
   };
-  EXPECT_TRUE(findInstruction(hasVarName));
+  EXPECT_TRUE(scanInstruction(hasVarName));
 }
 
 TEST_F(SerializationTest, DoesNotContainSymbolName) {
@@ -156,5 +178,5 @@ TEST_F(SerializationTest, DoesNotContainSymbolName) {
     return opcode == spirv::Opcode::OpName &&
            spirv::decodeStringLiteral(operands, index) == "var0";
   };
-  EXPECT_FALSE(findInstruction(hasVarName));
+  EXPECT_FALSE(scanInstruction(hasVarName));
 }