[spirv] Fix the entry block to start with OpLabel
authorLei Zhang <antiagainst@google.com>
Tue, 27 Aug 2019 17:50:58 +0000 (10:50 -0700)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Tue, 27 Aug 2019 17:51:26 +0000 (10:51 -0700)
Each basic block in SPIR-V must start with an OpLabel instruction.
We don't support control flow yet, so this CL just makes sure that
the entry block follows this rule and is valid.

PiperOrigin-RevId: 265718841

mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
mlir/unittests/Dialect/SPIRV/DeserializationTest.cpp

index c9e7343..538891e 100644 (file)
@@ -131,6 +131,7 @@ def SPV_OC_OpULessThan             : I32EnumAttrCase<"OpULessThan", 176>;
 def SPV_OC_OpSLessThan             : I32EnumAttrCase<"OpSLessThan", 177>;
 def SPV_OC_OpULessThanEqual        : I32EnumAttrCase<"OpULessThanEqual", 178>;
 def SPV_OC_OpSLessThanEqual        : I32EnumAttrCase<"OpSLessThanEqual", 179>;
+def SPV_OC_OpLabel                 : I32EnumAttrCase<"OpLabel", 248>;
 def SPV_OC_OpReturn                : I32EnumAttrCase<"OpReturn", 253>;
 def SPV_OC_OpReturnValue           : I32EnumAttrCase<"OpReturnValue", 254>;
 
@@ -153,7 +154,7 @@ def SPV_OpcodeAttr :
       SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan, SPV_OC_OpSGreaterThan,
       SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual, SPV_OC_OpULessThan,
       SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual, SPV_OC_OpSLessThanEqual,
-      SPV_OC_OpReturn, SPV_OC_OpReturnValue
+      SPV_OC_OpLabel, SPV_OC_OpReturn, SPV_OC_OpReturnValue
       ]> {
     let returnType = "::mlir::spirv::Opcode";
     let convertFromStorage = "static_cast<::mlir::spirv::Opcode>($_self.getInt())";
index d300725..dc0d886 100644 (file)
@@ -1,3 +1,4 @@
+//===- Deserializer.cpp - MLIR SPIR-V Deserialization ---------------------===//
 //
 // Copyright 2019 The MLIR Authors.
 //
@@ -43,6 +44,11 @@ static inline StringRef decodeStringLiteral(ArrayRef<uint32_t> words,
   return str;
 }
 
+// Extracts the opcode from the given first word of a SPIR-V instruction.
+static inline spirv::Opcode extractOpcode(uint32_t word) {
+  return static_cast<spirv::Opcode>(word & 0xffff);
+}
+
 namespace {
 /// A SPIR-V module serializer.
 ///
@@ -177,6 +183,13 @@ private:
   LogicalResult processConstantNull(ArrayRef<uint32_t> operands);
 
   //===--------------------------------------------------------------------===//
+  // Control flow
+  //===--------------------------------------------------------------------===//
+
+  /// Processes a SPIR-V OpLabel instruction with the given `operands`.
+  LogicalResult processLabel(ArrayRef<uint32_t> operands);
+
+  //===--------------------------------------------------------------------===//
   // Instruction
   //===--------------------------------------------------------------------===//
 
@@ -195,6 +208,9 @@ private:
   sliceInstruction(spirv::Opcode &opcode, ArrayRef<uint32_t> &operands,
                    Optional<spirv::Opcode> expectedOpcode = llvm::None);
 
+  /// Returns the next instruction's opcode if exists.
+  Optional<spirv::Opcode> peekOpcode();
+
   /// Processes a SPIR-V instruction with the given `opcode` and `operands`.
   /// This method is the main entrance for handling SPIR-V instruction; it
   /// checks the instruction opcode and dispatches to the corresponding handler.
@@ -581,10 +597,18 @@ LogicalResult Deserializer::processFunction(ArrayRef<uint32_t> operands) {
     }
   }
 
-  // Create a new builder for building the body
+  // Create a new builder for building the body.
   OpBuilder funcBody(funcOp.getBody());
   std::swap(funcBody, opBuilder);
 
+  // Make sure the first basic block, if exists, starts with an OpLabel
+  // instruction.
+  if (auto nextOpcode = peekOpcode()) {
+    if (*nextOpcode != spirv::Opcode::OpFunctionEnd &&
+        *nextOpcode != spirv::Opcode::OpLabel)
+      return emitError(unknownLoc, "a basic block must start with OpLabel");
+  }
+
   spirv::Opcode opcode = spirv::Opcode::OpNop;
   ArrayRef<uint32_t> instOperands;
   while (succeeded(sliceInstruction(opcode, instOperands,
@@ -597,9 +621,12 @@ LogicalResult Deserializer::processFunction(ArrayRef<uint32_t> operands) {
   if (opcode != spirv::Opcode::OpFunctionEnd) {
     return failure();
   }
+
+  // Process OpFunctionEnd.
   if (!instOperands.empty()) {
     return emitError(unknownLoc, "unexpected operands for OpFunctionEnd");
   }
+
   std::swap(funcBody, opBuilder);
   return success();
 }
@@ -1125,6 +1152,18 @@ LogicalResult Deserializer::processConstantNull(ArrayRef<uint32_t> operands) {
 }
 
 //===----------------------------------------------------------------------===//
+// Control flow
+//===----------------------------------------------------------------------===//
+
+LogicalResult Deserializer::processLabel(ArrayRef<uint32_t> operands) {
+  if (operands.size() != 1) {
+    return emitError(unknownLoc, "OpLabel should only have result <id>");
+  }
+  // TODO(antiagainst): support basic blocks and control flow properly.
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
 // Instruction
 //===----------------------------------------------------------------------===//
 
@@ -1173,12 +1212,18 @@ Deserializer::sliceInstruction(spirv::Opcode &opcode,
   if (nextOffset > binarySize)
     return emitError(unknownLoc, "insufficient words for the last instruction");
 
-  opcode = static_cast<spirv::Opcode>(binary[curOffset] & 0xffff);
+  opcode = extractOpcode(binary[curOffset]);
   operands = binary.slice(curOffset + 1, wordCount - 1);
   curOffset = nextOffset;
   return success();
 }
 
+Optional<spirv::Opcode> Deserializer::peekOpcode() {
+  if (curOffset >= binary.size())
+    return llvm::None;
+  return extractOpcode(binary[curOffset]);
+}
+
 LogicalResult Deserializer::processInstruction(spirv::Opcode opcode,
                                                ArrayRef<uint32_t> operands,
                                                bool deferInstructions) {
@@ -1237,6 +1282,8 @@ LogicalResult Deserializer::processInstruction(spirv::Opcode opcode,
     return processMemberDecoration(operands);
   case spirv::Opcode::OpFunction:
     return processFunction(operands);
+  case spirv::Opcode::OpLabel:
+    return processLabel(operands);
   default:
     break;
   }
index 03973db..43a1d08 100644 (file)
@@ -543,6 +543,8 @@ LogicalResult Serializer::processFuncOp(FuncOp op) {
   }
 
   for (auto &b : op) {
+    // TODO(antiagainst): support basic blocks and control flow properly.
+    encodeInstructionInto(functions, spirv::Opcode::OpLabel, {getNextID()});
     for (auto &op : b) {
       if (failed(processOperation(&op))) {
         return failure();
index e4b3ee5..5262d57 100644 (file)
@@ -111,11 +111,9 @@ protected:
     return id;
   }
 
-  uint32_t addFunctionEnd() {
-    auto id = nextID++;
-    addInstruction(spirv::Opcode::OpFunctionEnd, {id});
-    return id;
-  }
+  void addFunctionEnd() { addInstruction(spirv::Opcode::OpFunctionEnd, {}); }
+
+  void addReturn() { addInstruction(spirv::Opcode::OpReturn, {}); }
 
 protected:
   SmallVector<uint32_t, 5> binary;
@@ -201,3 +199,29 @@ TEST_F(DeserializationTest, FunctionMissingParameterFailure) {
   ASSERT_EQ(llvm::None, deserialize());
   expectDiagnostic("expected OpFunctionParameter instruction");
 }
+
+TEST_F(DeserializationTest, FunctionMissingLabelForFirstBlockFailure) {
+  addHeader();
+  auto voidType = addVoidType();
+  auto fnType = addFunctionType(voidType, {});
+  addFunction(voidType, fnType);
+  // Missing OpLabel
+  addReturn();
+  addFunctionEnd();
+
+  ASSERT_EQ(llvm::None, deserialize());
+  expectDiagnostic("a basic block must start with OpLabel");
+}
+
+TEST_F(DeserializationTest, FunctionMalformedLabelFailure) {
+  addHeader();
+  auto voidType = addVoidType();
+  auto fnType = addFunctionType(voidType, {});
+  addFunction(voidType, fnType);
+  addInstruction(spirv::Opcode::OpLabel, {}); // Malformed OpLabel
+  addReturn();
+  addFunctionEnd();
+
+  ASSERT_EQ(llvm::None, deserialize());
+  expectDiagnostic("OpLabel should only have result <id>");
+}