Adding validation code for OpSwitch limits
authorEhsan Nasiri <ehsann@google.com>
Fri, 25 Nov 2016 14:26:26 +0000 (09:26 -0500)
committerDavid Neto <dneto@google.com>
Wed, 30 Nov 2016 20:36:05 +0000 (15:36 -0500)
The number of (literal, label) pairs passed to OpSwitch may not exceed
16,383. Added code to validate this and added unit tests for it.

Also fixed a typo in another validor error message.

source/validate_cfg.cpp
source/validate_instruction.cpp
test/val/val_cfg_test.cpp
test/val/val_limits_test.cpp

index b01b88d..ab866e3 100644 (file)
@@ -198,7 +198,7 @@ void printDominatorList(const BasicBlock& b) {
 spv_result_t FirstBlockAssert(ValidationState_t& _, uint32_t target) {
   if (_.current_function().IsFirstBlock(target)) {
     return _.diag(SPV_ERROR_INVALID_CFG)
-           << "First block " << _.getIdName(target) << " of funciton "
+           << "First block " << _.getIdName(target) << " of function "
            << _.getIdName(_.current_function().id()) << " is targeted by block "
            << _.getIdName(_.current_function().current_block()->id());
   }
index b9f3f2e..3fdd362 100644 (file)
@@ -156,6 +156,26 @@ spv_result_t LimitCheckStruct(ValidationState_t& _,
   return SPV_SUCCESS;
 }
 
+// Checks that the number of (literal, label) pairs in OpSwitch is within the
+// limit.
+spv_result_t LimitCheckSwitch(ValidationState_t& _,
+                              const spv_parsed_instruction_t* inst) {
+  if (SpvOpSwitch == inst->opcode) {
+    // The instruction syntax is as follows:
+    // OpSwitch <selector ID> <Default ID> literal label literal label ...
+    // literal,label pairs come after the first 2 operands.
+    // It is guaranteed at this point that num_operands is an even numner.
+    unsigned int num_pairs = (inst->num_operands - 2) / 2;
+    const unsigned int num_pairs_limit = 16383;
+    if (num_pairs > num_pairs_limit) {
+      return _.diag(SPV_ERROR_INVALID_BINARY)
+             << "Number of (literal, label) pairs in OpSwitch (" << num_pairs
+             << ") exceeds the limit (" << num_pairs_limit << ").";
+    }
+  }
+  return SPV_SUCCESS;
+}
+
 spv_result_t InstructionPass(ValidationState_t& _,
                              const spv_parsed_instruction_t* inst) {
   const SpvOp opcode = static_cast<SpvOp>(inst->opcode);
@@ -198,6 +218,7 @@ spv_result_t InstructionPass(ValidationState_t& _,
   if (auto error = CapCheck(_, inst)) return error;
   if (auto error = LimitCheckIdBound(_, inst)) return error;
   if (auto error = LimitCheckStruct(_, inst)) return error;
+  if (auto error = LimitCheckSwitch(_, inst)) return error;
 
   // All instruction checks have passed.
   return SPV_SUCCESS;
index eb30827..5082760 100644 (file)
 
 #include "gmock/gmock.h"
 
+#include "source/diagnostic.h"
+#include "source/validate.h"
 #include "test_fixture.h"
 #include "unit_spirv.h"
 #include "val_fixtures.h"
-#include "source/diagnostic.h"
-#include "source/validate.h"
 
 using std::array;
 using std::make_pair;
@@ -80,12 +80,12 @@ class Block {
 
   /// Sets the instructions which will appear in the body of the block
   Block& SetBody(std::string body) {
-      body_ = body;
+    body_ = body;
     return *this;
   }
 
   Block& AppendBody(std::string body) {
-      body_ += body;
+    body_ += body;
     return *this;
   }
 
@@ -465,7 +465,7 @@ TEST_P(ValidateCFG, BranchTargetFirstBlockBad) {
   CompileSuccessfully(str);
   ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions());
   EXPECT_THAT(getDiagnosticString(),
-              MatchesRegex("First block .\\[entry\\] of funciton .\\[Main\\] "
+              MatchesRegex("First block .\\[entry\\] of function .\\[Main\\] "
                            "is targeted by block .\\[bad\\]"));
 }
 
@@ -489,7 +489,7 @@ TEST_P(ValidateCFG, BranchConditionalTrueTargetFirstBlockBad) {
   CompileSuccessfully(str);
   ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions());
   EXPECT_THAT(getDiagnosticString(),
-              MatchesRegex("First block .\\[entry\\] of funciton .\\[Main\\] "
+              MatchesRegex("First block .\\[entry\\] of function .\\[Main\\] "
                            "is targeted by block .\\[bad\\]"));
 }
 
@@ -516,7 +516,7 @@ TEST_P(ValidateCFG, BranchConditionalFalseTargetFirstBlockBad) {
   CompileSuccessfully(str);
   ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions());
   EXPECT_THAT(getDiagnosticString(),
-              MatchesRegex("First block .\\[entry\\] of funciton .\\[Main\\] "
+              MatchesRegex("First block .\\[entry\\] of function .\\[Main\\] "
                            "is targeted by block .\\[bad\\]"));
 }
 
@@ -550,7 +550,7 @@ TEST_P(ValidateCFG, SwitchTargetFirstBlockBad) {
   CompileSuccessfully(str);
   ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions());
   EXPECT_THAT(getDiagnosticString(),
-              MatchesRegex("First block .\\[entry\\] of funciton .\\[Main\\] "
+              MatchesRegex("First block .\\[entry\\] of function .\\[Main\\] "
                            "is targeted by block .\\[bad\\]"));
 }
 
@@ -1019,7 +1019,8 @@ TEST_P(ValidateCFG, BranchOutOfConstructToMergeBad) {
     EXPECT_THAT(getDiagnosticString(),
                 MatchesRegex("The continue construct with the continue target "
                              ".\\[loop\\] is not post dominated by the "
-                             "back-edge block .\\[cont\\]")) << str;
+                             "back-edge block .\\[cont\\]"))
+        << str;
   } else {
     ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
   }
@@ -1254,7 +1255,7 @@ TEST_P(ValidateCFG, SingleLatchBlockMultipleBranchesToLoopHeader) {
 
   str += entry >> loop;
   str += loop >> vector<Block>({latch, merge});
-  str += latch >> vector<Block>({loop, loop}); // This is the key
+  str += latch >> vector<Block>({loop, loop});  // This is the key
   str += merge;
   str += "OpFunctionEnd";
 
index 4ebb9f5..5188ca3 100644 (file)
@@ -103,3 +103,62 @@ TEST_F(ValidateLimits, structNumMembersExceededBad) {
                         "the limit (16383)."));
 }
 
+// Valid: Switch statement has 16,383 branches.
+TEST_F(ValidateLimits, switchNumBranchesGood) {
+  std::ostringstream spirv;
+  spirv << header << R"(
+%1 = OpTypeVoid
+%2 = OpTypeFunction %1
+%3 = OpTypeInt 32 0
+%4 = OpConstant %3 1234
+%5 = OpFunction %1 None %2
+%7 = OpLabel
+%8 = OpIAdd %3 %4 %4
+%9 = OpSwitch %4 %10)";
+
+  // Now add the (literal, label) pairs
+  for (int i = 0; i < 16383; ++i) {
+    spirv << " 1 %10";
+  }
+
+  spirv << R"(
+%10 = OpLabel
+OpReturn
+OpFunctionEnd
+  )";
+
+  CompileSuccessfully(spirv.str());
+  ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
+}
+
+// Invalid: Switch statement has 16,384 branches.
+TEST_F(ValidateLimits, switchNumBranchesBad) {
+  std::ostringstream spirv;
+  spirv << header << R"(
+%1 = OpTypeVoid
+%2 = OpTypeFunction %1
+%3 = OpTypeInt 32 0
+%4 = OpConstant %3 1234
+%5 = OpFunction %1 None %2
+%7 = OpLabel
+%8 = OpIAdd %3 %4 %4
+%9 = OpSwitch %4 %10)";
+
+  // Now add the (literal, label) pairs
+  for (int i = 0; i < 16384; ++i) {
+    spirv << " 1 %10";
+  }
+
+  spirv << R"(
+%10 = OpLabel
+OpReturn
+OpFunctionEnd
+  )";
+
+  CompileSuccessfully(spirv.str());
+  ASSERT_EQ(SPV_ERROR_INVALID_BINARY, ValidateInstructions());
+  EXPECT_THAT(getDiagnosticString(),
+              HasSubstr("Number of (literal, label) pairs in OpSwitch (16384) "
+                        "exceeds the limit (16383)."));
+}
+