[mlir][spirv] Add serialization control to emit symbol name
authorLei Zhang <antiagainst@google.com>
Fri, 10 Dec 2021 18:57:29 +0000 (13:57 -0500)
committerLei Zhang <antiagainst@google.com>
Sat, 11 Dec 2021 00:20:49 +0000 (19:20 -0500)
In SPIR-V, symbol names are encoded as `OpName` instructions.
They are not semantic impacting and can be omitted, which can
reduce the binary size.

Reviewed By: scotttodd

Differential Revision: https://reviews.llvm.org/D115531

mlir/include/mlir/Target/SPIRV/SPIRVBinaryUtils.h
mlir/include/mlir/Target/SPIRV/Serialization.h
mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
mlir/lib/Target/SPIRV/Serialization/Serialization.cpp
mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
mlir/lib/Target/SPIRV/Serialization/Serializer.h
mlir/lib/Target/SPIRV/TranslateRegistration.cpp
mlir/unittests/Dialect/SPIRV/SerializationTest.cpp

index dc2b8c7..a155b19 100644 (file)
@@ -13,8 +13,8 @@
 #ifndef MLIR_TARGET_SPIRV_BINARY_UTILS_H_
 #define MLIR_TARGET_SPIRV_BINARY_UTILS_H_
 
-#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
-#include "mlir/Support/LogicalResult.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
+#include "mlir/Support/LLVM.h"
 
 #include <cstdint>
 
@@ -41,6 +41,16 @@ uint32_t getPrefixedOpcode(uint32_t wordCount, spirv::Opcode opcode);
 /// Encodes an SPIR-V `literal` string into the given `binary` vector.
 LogicalResult encodeStringLiteralInto(SmallVectorImpl<uint32_t> &binary,
                                       StringRef literal);
+
+/// Decodes a string literal in `words` starting at `wordIndex`. Update the
+/// latter to point to the position in words after the string literal.
+inline StringRef decodeStringLiteral(ArrayRef<uint32_t> words,
+                                     unsigned &wordIndex) {
+  StringRef str(reinterpret_cast<const char *>(words.data() + wordIndex));
+  wordIndex += str.size() / 4 + 1;
+  return str;
+}
+
 } // namespace spirv
 } // namespace mlir
 
index 25033e2..498f390 100644 (file)
@@ -22,11 +22,18 @@ class MLIRContext;
 namespace spirv {
 class ModuleOp;
 
+struct SerializationOptions {
+  /// Whether to emit `OpName` instructions for SPIR-V symbol ops.
+  bool emitSymbolName = true;
+  /// Whether to emit `OpLine` location information for SPIR-V ops.
+  bool emitDebugInfo = false;
+};
+
 /// Serializes the given SPIR-V `module` and writes to `binary`. On failure,
 /// reports errors to the error handler registered with the MLIR context for
 /// `module`.
 LogicalResult serialize(ModuleOp module, SmallVectorImpl<uint32_t> &binary,
-                        bool emitDebugInfo = false);
+                        const SerializationOptions &options = {});
 
 } // namespace spirv
 } // namespace mlir
index c01362d..1c77dfd 100644 (file)
@@ -15,6 +15,7 @@
 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/Location.h"
+#include "mlir/Target/SPIRV/SPIRVBinaryUtils.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/Support/Debug.h"
index 17060dd..8351f3a 100644 (file)
 #include "llvm/ADT/StringRef.h"
 #include <cstdint>
 
-//===----------------------------------------------------------------------===//
-// Utility Functions
-//===----------------------------------------------------------------------===//
-
-/// Decodes a string literal in `words` starting at `wordIndex`. Update the
-/// latter to point to the position in words after the string literal.
-static inline llvm::StringRef
-decodeStringLiteral(llvm::ArrayRef<uint32_t> words, unsigned &wordIndex) {
-  llvm::StringRef str(reinterpret_cast<const char *>(words.data() + wordIndex));
-  wordIndex += str.size() / 4 + 1;
-  return str;
-}
-
 namespace mlir {
 namespace spirv {
 
index 33b886b..7d4d118 100644 (file)
 namespace mlir {
 LogicalResult spirv::serialize(spirv::ModuleOp module,
                                SmallVectorImpl<uint32_t> &binary,
-                               bool emitDebugInfo) {
+                               const SerializationOptions &options) {
   if (!module.vce_triple().hasValue())
     return module.emitError(
         "module must have 'vce_triple' attribute to be serializeable");
 
-  Serializer serializer(module, emitDebugInfo);
+  Serializer serializer(module, options);
 
   if (failed(serializer.serialize()))
     return failure();
index 6bd5ff1..bcead6e 100644 (file)
@@ -81,9 +81,9 @@ LogicalResult encodeInstructionInto(SmallVectorImpl<uint32_t> &binary,
   return success();
 }
 
-Serializer::Serializer(spirv::ModuleOp module, bool emitDebugInfo)
-    : module(module), mlirBuilder(module.getContext()),
-      emitDebugInfo(emitDebugInfo) {}
+Serializer::Serializer(spirv::ModuleOp module,
+                       const SerializationOptions &options)
+    : module(module), mlirBuilder(module.getContext()), options(options) {}
 
 LogicalResult Serializer::serialize() {
   LLVM_DEBUG(llvm::dbgs() << "+++ starting serialization +++\n");
@@ -172,7 +172,7 @@ void Serializer::processCapability() {
 }
 
 void Serializer::processDebugInfo() {
-  if (!emitDebugInfo)
+  if (!options.emitDebugInfo)
     return;
   auto fileLoc = module.getLoc().dyn_cast<FileLineColLoc>();
   auto fileName = fileLoc ? fileLoc.getFilename().strref() : "<unknown>";
@@ -254,12 +254,13 @@ LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID,
 
 LogicalResult Serializer::processName(uint32_t resultID, StringRef name) {
   assert(!name.empty() && "unexpected empty string for OpName");
+  if (!options.emitSymbolName)
+    return success();
 
   SmallVector<uint32_t, 4> nameOperands;
   nameOperands.push_back(resultID);
-  if (failed(spirv::encodeStringLiteralInto(nameOperands, name))) {
+  if (failed(spirv::encodeStringLiteralInto(nameOperands, name)))
     return failure();
-  }
   return encodeInstructionInto(names, spirv::Opcode::OpName, nameOperands);
 }
 
@@ -1170,7 +1171,7 @@ LogicalResult Serializer::emitDecoration(uint32_t target,
 
 LogicalResult Serializer::emitDebugLine(SmallVectorImpl<uint32_t> &binary,
                                         Location loc) {
-  if (!emitDebugInfo)
+  if (!options.emitDebugInfo)
     return success();
 
   if (lastProcessedWasMergeInst) {
index 9ee288a..5f4a4e9 100644 (file)
@@ -15,6 +15,7 @@
 
 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
 #include "mlir/IR/Builders.h"
+#include "mlir/Target/SPIRV/Serialization.h"
 #include "llvm/ADT/SetVector.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/Support/raw_ostream.h"
@@ -42,7 +43,8 @@ LogicalResult encodeInstructionInto(SmallVectorImpl<uint32_t> &binary,
 class Serializer {
 public:
   /// Creates a serializer for the given SPIR-V `module`.
-  explicit Serializer(spirv::ModuleOp module, bool emitDebugInfo = false);
+  explicit Serializer(spirv::ModuleOp module,
+                      const SerializationOptions &options);
 
   /// Serializes the remembered SPIR-V module.
   LogicalResult serialize();
@@ -316,8 +318,8 @@ private:
   /// An MLIR builder for getting MLIR constructs.
   mlir::Builder mlirBuilder;
 
-  /// A flag which indicates if the debuginfo should be emitted.
-  bool emitDebugInfo = false;
+  /// Serialization options.
+  SerializationOptions options;
 
   /// A flag which indicates if the last processed instruction was a merge
   /// instruction.
index 989de41..e63a68c 100644 (file)
@@ -40,7 +40,7 @@ static OwningModuleRef deserializeModule(const llvm::MemoryBuffer *input,
   context->loadDialect<spirv::SPIRVDialect>();
 
   // Make sure the input stream can be treated as a stream of SPIR-V words
-  auto start = input->getBufferStart();
+  auto *start = input->getBufferStart();
   auto size = input->getBufferSize();
   if (size % sizeof(uint32_t) != 0) {
     emitError(UnknownLoc::get(context))
@@ -94,8 +94,7 @@ static LogicalResult serializeModule(ModuleOp module, raw_ostream &output) {
   if (spirvModules.size() != 1)
     return module.emitError("found more than one 'spv.module' op");
 
-  if (failed(
-          spirv::serialize(spirvModules[0], binary, /*emitDebuginfo=*/false)))
+  if (failed(spirv::serialize(spirvModules[0], binary)))
     return failure();
 
   output.write(reinterpret_cast<char *>(binary.data()),
@@ -133,7 +132,9 @@ static LogicalResult roundTripModule(ModuleOp srcModule, bool emitDebugInfo,
   if (std::next(spirvModules.begin()) != spirvModules.end())
     return srcModule.emitError("found more than one 'spv.module' op");
 
-  if (failed(spirv::serialize(*spirvModules.begin(), binary, emitDebugInfo)))
+  spirv::SerializationOptions options;
+  options.emitDebugInfo = emitDebugInfo;
+  if (failed(spirv::serialize(*spirvModules.begin(), binary, options)))
     return failure();
 
   MLIRContext deserializationContext(context->getDialectRegistry());
index a3ae8bc..9222b0c 100644 (file)
@@ -37,10 +37,11 @@ class SerializationTest : public ::testing::Test {
 protected:
   SerializationTest() {
     context.getOrLoadDialect<mlir::spirv::SPIRVDialect>();
-    createModuleOp();
+    initModuleOp();
   }
 
-  void createModuleOp() {
+  /// Initializes an empty SPIR-V module op.
+  void initModuleOp() {
     OpBuilder builder(&context);
     OperationState state(UnknownLoc::get(&context),
                          spirv::ModuleOp::getOperationName());
@@ -58,27 +59,29 @@ protected:
     module = cast<spirv::ModuleOp>(Operation::create(state));
   }
 
-  Type getFloatStructType() {
-    OpBuilder opBuilder(module->getRegion());
-    llvm::SmallVector<Type, 1> elementTypes{opBuilder.getF32Type()};
+  /// Gets the `struct { float }` type.
+  spirv::StructType getFloatStructType() {
+    OpBuilder builder(module->getRegion());
+    llvm::SmallVector<Type, 1> elementTypes{builder.getF32Type()};
     llvm::SmallVector<spirv::StructType::OffsetInfo, 1> offsetInfo{0};
-    auto structType = spirv::StructType::get(elementTypes, offsetInfo);
-    return structType;
+    return spirv::StructType::get(elementTypes, offsetInfo);
   }
 
-  void addGlobalVar(Type type, llvm::StringRef name) {
-    OpBuilder opBuilder(module->getRegion());
+  /// Inserts a global variable of the given `type` and `name`.
+  spirv::GlobalVariableOp addGlobalVar(Type type, llvm::StringRef name) {
+    OpBuilder builder(module->getRegion());
     auto ptrType = spirv::PointerType::get(type, spirv::StorageClass::Uniform);
-    opBuilder.create<spirv::GlobalVariableOp>(
+    return builder.create<spirv::GlobalVariableOp>(
         UnknownLoc::get(&context), TypeAttr::get(ptrType),
-        opBuilder.getStringAttr(name), nullptr);
+        builder.getStringAttr(name), nullptr);
   }
 
+  /// Returns true if we can find a matching instruction in the SPIR-V blob.
   bool findInstruction(llvm::function_ref<bool(spirv::Opcode opcode,
                                                ArrayRef<uint32_t> operands)>
                            matchFn) {
     auto binarySize = binary.size();
-    auto begin = binary.begin();
+    auto *begin = binary.begin();
     auto currOffset = spirv::kHeaderWordCount;
 
     while (currOffset < binarySize) {
@@ -109,10 +112,12 @@ protected:
 // Block decoration
 //===----------------------------------------------------------------------===//
 
-TEST_F(SerializationTest, BlockDecorationTest) {
+TEST_F(SerializationTest, ContainsBlockDecoration) {
   auto structType = getFloatStructType();
   addGlobalVar(structType, "var0");
+
   ASSERT_TRUE(succeeded(spirv::serialize(module.get(), binary)));
+
   auto hasBlockDecoration = [](spirv::Opcode opcode,
                                ArrayRef<uint32_t> operands) -> bool {
     if (opcode != spirv::Opcode::OpDecorate || operands.size() != 2)
@@ -121,3 +126,35 @@ TEST_F(SerializationTest, BlockDecorationTest) {
   };
   EXPECT_TRUE(findInstruction(hasBlockDecoration));
 }
+
+TEST_F(SerializationTest, ContainsSymbolName) {
+  auto structType = getFloatStructType();
+  addGlobalVar(structType, "var0");
+
+  spirv::SerializationOptions options;
+  options.emitSymbolName = true;
+  ASSERT_TRUE(succeeded(spirv::serialize(module.get(), binary, options)));
+
+  auto hasVarName = [](spirv::Opcode opcode, ArrayRef<uint32_t> operands) {
+    unsigned index = 1; // Skip the result <id>
+    return opcode == spirv::Opcode::OpName &&
+           spirv::decodeStringLiteral(operands, index) == "var0";
+  };
+  EXPECT_TRUE(findInstruction(hasVarName));
+}
+
+TEST_F(SerializationTest, DoesNotContainSymbolName) {
+  auto structType = getFloatStructType();
+  addGlobalVar(structType, "var0");
+
+  spirv::SerializationOptions options;
+  options.emitSymbolName = false;
+  ASSERT_TRUE(succeeded(spirv::serialize(module.get(), binary, options)));
+
+  auto hasVarName = [](spirv::Opcode opcode, ArrayRef<uint32_t> operands) {
+    unsigned index = 1; // Skip the result <id>
+    return opcode == spirv::Opcode::OpName &&
+           spirv::decodeStringLiteral(operands, index) == "var0";
+  };
+  EXPECT_FALSE(findInstruction(hasVarName));
+}