Add validation for OpBranchConditional
authorNuno Subtil <nsubtil@nvidia.com>
Thu, 26 Oct 2017 20:20:32 +0000 (13:20 -0700)
committerDavid Neto <dneto@google.com>
Tue, 31 Oct 2017 16:05:20 +0000 (12:05 -0400)
source/validate_id.cpp
test/val/val_id_test.cpp

index 3b21143..842c2cf 100644 (file)
@@ -2038,11 +2038,47 @@ bool idUsage::isValid<OpBranch>(const spv_instruction_t *inst,
                                 const spv_opcode_desc opcodeEntry) {}
 #endif
 
-#if 0
 template <>
-bool idUsage::isValid<OpBranchConditional>(
-    const spv_instruction_t *inst, const spv_opcode_desc opcodeEntry) {}
-#endif
+bool idUsage::isValid<SpvOpBranchConditional>(
+    const spv_instruction_t *inst, const spv_opcode_desc) {
+  const size_t numOperands = inst->words.size() - 1;
+  const size_t condOperandIndex = 1;
+  const size_t targetTrueIndex = 2;
+  const size_t targetFalseIndex = 3;
+
+  // num_operands is either 3 or 5 --- if 5, the last two need to be literal integers
+  if (numOperands != 3 &&
+      numOperands != 5) {
+    DIAG(0) << "OpBranchConditional requires either 3 or 5 parameters";
+    return false;
+  }
+
+  bool ret = true;
+
+  // grab the condition operand and check that it is a bool
+  const auto condOp = module_.FindDef(inst->words[condOperandIndex]);
+  if (!condOp || !module_.IsBoolScalarType(condOp->type_id())) {
+    DIAG(0) << "Condition operand for OpBranchConditional must be of boolean type";
+    ret = false;
+  }
+
+  // target operands must be OpLabel
+  // note that we don't need to check that the target labels are in the same function,
+  // PerformCfgChecks already checks for that
+  const auto targetOpTrue = module_.FindDef(inst->words[targetTrueIndex]);
+  if (!targetOpTrue || SpvOpLabel != targetOpTrue->opcode()) {
+    DIAG(0) << "The 'True Label' operand for OpBranchConditional must be the ID of an OpLabel instruction";
+    ret = false;
+  }
+
+  const auto targetOpFalse = module_.FindDef(inst->words[targetFalseIndex]);
+  if (!targetOpFalse || SpvOpLabel != targetOpFalse->opcode()) {
+    DIAG(0) << "The 'False Label' operand for OpBranchConditional must be the ID of an OpLabel instruction";
+    ret = false;
+  }
+
+  return ret;
+}
 
 #if 0
 template <>
@@ -2558,7 +2594,7 @@ bool idUsage::isValid(const spv_instruction_t* inst) {
     TODO(OpLoopMerge)
     TODO(OpSelectionMerge)
     TODO(OpBranch)
-    TODO(OpBranchConditional)
+    CASE(OpBranchConditional)
     TODO(OpSwitch)
     CASE(OpReturnValue)
     TODO(OpLifetimeStart)
index 4ad0421..a9074b7 100644 (file)
@@ -94,6 +94,52 @@ string sampledImageSetup = R"(
             %sampler_inst = OpLoad %sampler_type %s
 )";
 
+string BranchConditionalSetup = R"(
+               OpCapability Shader
+          %1 = OpExtInstImport "GLSL.std.450"
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Fragment %main "main"
+               OpExecutionMode %main OriginUpperLeft
+               OpSource GLSL 140
+               OpName %main "main"
+
+             ; type definitions
+       %bool = OpTypeBool
+       %uint = OpTypeInt 32 0
+        %int = OpTypeInt 32 1
+      %float = OpTypeFloat 32
+    %v4float = OpTypeVector %float 4
+
+             ; constants
+         %i0 = OpConstant %int 0
+         %i1 = OpConstant %int 1
+         %f0 = OpConstant %float 0
+         %f1 = OpConstant %float 1
+
+
+             ; main function header
+       %void = OpTypeVoid
+   %voidfunc = OpTypeFunction %void
+       %main = OpFunction %void None %voidfunc
+      %lmain = OpLabel
+
+               OpSelectionMerge %end None
+)";
+
+string BranchConditionalTail = R"(
+   %target_t = OpLabel
+               OpNop
+               OpBranch %end
+   %target_f = OpLabel
+               OpNop
+               OpBranch %end
+
+        %end = OpLabel
+
+               OpReturn
+               OpFunctionEnd
+)";
+
 // TODO: OpUndef
 
 TEST_F(ValidateIdWithMessage, OpName) {
@@ -4161,7 +4207,101 @@ TEST_F(ValidateIdWithMessage, OpVectorShuffleLiterals) {
 // TODO: OpLoopMerge
 // TODO: OpSelectionMerge
 // TODO: OpBranch
-// TODO: OpBranchConditional
+
+TEST_F(ValidateIdWithMessage, OpBranchConditionalGood) {
+  string spirv = BranchConditionalSetup + R"(
+    %branch_cond = OpINotEqual %bool %i0 %i1
+                   OpBranchConditional %branch_cond %target_t %target_f
+  )" + BranchConditionalTail;
+
+  CompileSuccessfully(spirv.c_str());
+  EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState());
+}
+
+TEST_F(ValidateIdWithMessage, OpBranchConditionalWithWeightsGood) {
+  string spirv = BranchConditionalSetup + R"(
+    %branch_cond = OpINotEqual %bool %i0 %i1
+                   OpBranchConditional %branch_cond %target_t %target_f 1 1
+  )" + BranchConditionalTail;
+
+  CompileSuccessfully(spirv.c_str());
+  EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState());
+}
+
+TEST_F(ValidateIdWithMessage, OpBranchConditional_CondIsScalarInt) {
+  string spirv = BranchConditionalSetup + R"(
+    OpBranchConditional %i0 %target_t %target_f
+  )" + BranchConditionalTail;
+
+  CompileSuccessfully(spirv.c_str());
+  EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
+  EXPECT_THAT(
+      getDiagnosticString(),
+      HasSubstr("Condition operand for OpBranchConditional must be of boolean type"));
+}
+
+TEST_F(ValidateIdWithMessage, OpBranchConditional_TrueTargetIsNotLabel) {
+  string spirv = BranchConditionalSetup + R"(
+                   OpBranchConditional %i0 %i0 %target_f
+  )" + BranchConditionalTail;
+
+  CompileSuccessfully(spirv.c_str());
+  // EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
+  // EXPECT_THAT(
+  //     getDiagnosticString(),
+  //     HasSubstr("The 'True Label' operand for OpBranchConditional must be the ID of an OpLabel instruction"));
+
+  // xxxnsubtil: this is actually caught by the ID validation instead
+  EXPECT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions());
+  EXPECT_THAT(
+      getDiagnosticString(),
+      HasSubstr("are referenced but not defined in function"));
+}
+
+TEST_F(ValidateIdWithMessage, OpBranchConditional_FalseTargetIsNotLabel) {
+  string spirv = BranchConditionalSetup + R"(
+    OpBranchConditional %i0 %target_t %i0
+  )" + BranchConditionalTail;
+
+  CompileSuccessfully(spirv.c_str());
+  // EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
+  // EXPECT_THAT(
+  //     getDiagnosticString(),
+  //     HasSubstr("The 'False Label' operand for OpBranchConditional must be the ID of an OpLabel instruction"));
+
+  // xxxnsubtil: this is actually caught by the ID validation
+  EXPECT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions());
+  EXPECT_THAT(
+      getDiagnosticString(),
+      HasSubstr("are referenced but not defined in function"));
+}
+
+TEST_F(ValidateIdWithMessage, OpBranchConditional_NotEnoughWeights) {
+  string spirv = BranchConditionalSetup + R"(
+    %branch_cond = OpINotEqual %bool %i0 %i1
+                   OpBranchConditional %branch_cond %target_t %target_f 1
+  )" + BranchConditionalTail;
+
+  CompileSuccessfully(spirv.c_str());
+  EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
+  EXPECT_THAT(
+      getDiagnosticString(),
+      HasSubstr("OpBranchConditional requires either 3 or 5 parameters"));
+}
+
+TEST_F(ValidateIdWithMessage, OpBranchConditional_TooManyWeights) {
+  string spirv = BranchConditionalSetup + R"(
+    %branch_cond = OpINotEqual %bool %i0 %i1
+                   OpBranchConditional %branch_cond %target_t %target_f 1 2 3
+  )" + BranchConditionalTail;
+
+  CompileSuccessfully(spirv.c_str());
+  EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
+  EXPECT_THAT(
+      getDiagnosticString(),
+      HasSubstr("OpBranchConditional requires either 3 or 5 parameters"));
+}
+
 // TODO: OpSwitch
 
 TEST_F(ValidateIdWithMessage, OpReturnValueConstantGood) {