[spirv] NFC: adjust `encode*` function signatures in Serializer
authorLei Zhang <antiagainst@google.com>
Mon, 22 Jul 2019 13:00:47 +0000 (06:00 -0700)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Mon, 22 Jul 2019 13:01:19 +0000 (06:01 -0700)
* Let them return `LogicalResult` so we can chain them together
  with other functions returning `LogicalResult`.
* Added "Into" as the suffix to the function name and made the
  `binary` as the first parameter so that it reads more naturally.

PiperOrigin-RevId: 259311636

mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp

index d29e717..3a48124 100644 (file)
 
 using namespace mlir;
 
+/// Returns the word-count-prefixed opcode for an SPIR-V instruction.
 static inline uint32_t getPrefixedOpcode(uint32_t wordCount,
                                          spirv::Opcode opcode) {
   assert(((wordCount >> 16) == 0) && "word count out of range!");
   return (wordCount << 16) | static_cast<uint32_t>(opcode);
 }
 
-static inline void buildInstruction(spirv::Opcode op,
-                                    ArrayRef<uint32_t> operands,
-                                    SmallVectorImpl<uint32_t> &binary) {
+/// Encodes an SPIR-V instruction with the given `opcode` and `operands` into
+/// the given `binary` vector.
+static LogicalResult encodeInstructionInto(SmallVectorImpl<uint32_t> &binary,
+                                           spirv::Opcode op,
+                                           ArrayRef<uint32_t> operands) {
   uint32_t wordCount = 1 + operands.size();
   binary.push_back(getPrefixedOpcode(wordCount, op));
   if (!operands.empty()) {
     binary.append(operands.begin(), operands.end());
   }
+  return success();
 }
 
-static inline void encodeStringLiteral(StringRef literal,
-                                       SmallVectorImpl<uint32_t> &buffer) {
-  // Encoding is the literal + null termination
+/// Encodes an SPIR-V `literal` string into the given `binary` vector.
+static LogicalResult encodeStringLiteralInto(SmallVectorImpl<uint32_t> &binary,
+                                             StringRef literal) {
+  // We need to encode the literal and the null termination.
   auto encodingSize = literal.size() / 4 + 1;
-  auto bufferStartSize = buffer.size();
-  buffer.resize(bufferStartSize + encodingSize, 0);
-  std::memcpy(buffer.data() + bufferStartSize, literal.data(), literal.size());
+  auto bufferStartSize = binary.size();
+  binary.resize(bufferStartSize + encodingSize, 0);
+  std::memcpy(binary.data() + bufferStartSize, literal.data(), literal.size());
+  return success();
 }
 
 namespace {
@@ -267,8 +273,8 @@ LogicalResult Serializer::processMemoryModel() {
   uint32_t mm = module.getAttrOfType<IntegerAttr>("memory_model").getInt();
   uint32_t am = module.getAttrOfType<IntegerAttr>("addressing_model").getInt();
 
-  buildInstruction(spirv::Opcode::OpMemoryModel, {am, mm}, memoryModel);
-  return success();
+  return encodeInstructionInto(memoryModel, spirv::Opcode::OpMemoryModel,
+                               {am, mm});
 }
 
 LogicalResult Serializer::processFuncOp(FuncOp op) {
@@ -296,13 +302,13 @@ LogicalResult Serializer::processFuncOp(FuncOp op) {
   // TODO : Support other function control options.
   operands.push_back(static_cast<uint32_t>(spirv::FunctionControl::None));
   operands.push_back(fnTypeID);
-  buildInstruction(spirv::Opcode::OpFunction, operands, functions);
+  encodeInstructionInto(functions, spirv::Opcode::OpFunction, operands);
 
   // Add function name.
   SmallVector<uint32_t, 4> nameOperands;
   nameOperands.push_back(funcID);
-  encodeStringLiteral(op.getName(), nameOperands);
-  buildInstruction(spirv::Opcode::OpName, nameOperands, names);
+  encodeStringLiteralInto(nameOperands, op.getName());
+  encodeInstructionInto(names, spirv::Opcode::OpName, nameOperands);
 
   // Declare the parameters.
   for (auto arg : op.getArguments()) {
@@ -312,8 +318,8 @@ LogicalResult Serializer::processFuncOp(FuncOp op) {
     }
     auto argValueID = getNextID();
     valueIDMap[arg] = argValueID;
-    buildInstruction(spirv::Opcode::OpFunctionParameter,
-                     {argTypeID, argValueID}, functions);
+    encodeInstructionInto(functions, spirv::Opcode::OpFunctionParameter,
+                          {argTypeID, argValueID});
   }
 
   // Process the body.
@@ -330,9 +336,7 @@ LogicalResult Serializer::processFuncOp(FuncOp op) {
   }
 
   // Insert Function End.
-  buildInstruction(spirv::Opcode::OpFunctionEnd, {}, functions);
-
-  return success();
+  return encodeInstructionInto(functions, spirv::Opcode::OpFunctionEnd, {});
 }
 
 //===----------------------------------------------------------------------===//
@@ -353,9 +357,8 @@ LogicalResult Serializer::processType(Location loc, Type type,
        succeeded(processFunctionType(loc, type.cast<FunctionType>(), typeEnum,
                                      operands))) ||
       succeeded(processBasicType(loc, type, typeEnum, operands))) {
-    buildInstruction(typeEnum, operands, typesGlobalValues);
     typeIDMap[type] = typeID;
-    return success();
+    return encodeInstructionInto(typesGlobalValues, typeEnum, operands);
   }
   return failure();
 }
@@ -441,7 +444,7 @@ Serializer::processOp<spirv::EntryPointOp>(spirv::EntryPointOp op) {
   }
   operands.push_back(funcID);
   // Add the name of the function.
-  encodeStringLiteral(op.fn(), operands);
+  encodeStringLiteralInto(operands, op.fn());
 
   // Add the interface values.
   for (auto val : op.interface()) {
@@ -453,8 +456,8 @@ Serializer::processOp<spirv::EntryPointOp>(spirv::EntryPointOp op) {
     }
     operands.push_back(id);
   }
-  buildInstruction(spirv::Opcode::OpEntryPoint, operands, entryPoints);
-  return success();
+  return encodeInstructionInto(entryPoints, spirv::Opcode::OpEntryPoint,
+                               operands);
 }
 
 template <>
@@ -481,8 +484,8 @@ Serializer::processOp<spirv::ExecutionModeOp>(spirv::ExecutionModeOp op) {
           intVal.cast<IntegerAttr>().getValue().getZExtValue()));
     }
   }
-  buildInstruction(spirv::Opcode::OpExecutionMode, operands, executionModes);
-  return success();
+  return encodeInstructionInto(executionModes, spirv::Opcode::OpExecutionMode,
+                               operands);
 }
 
 // Pull in auto-generated Serializer::dispatchToAutogenSerialization() and
index d9f229b..4595f25 100644 (file)
@@ -146,9 +146,9 @@ static void emitSerializationFunction(const Record *record, const Operator &op,
     os << "  }\n";
   }
 
-  os << formatv(
-      "  buildInstruction(spirv::getOpcode<{0}>(), operands, functions);\n",
-      op.getQualCppClassName());
+  os << formatv("  encodeInstructionInto("
+                "functions, spirv::getOpcode<{0}>(), operands);\n",
+                op.getQualCppClassName());
   os << "  return success();\n";
   os << "}\n\n";
 }