Fix validation of group ops in SPV_AMD_shader_ballot
authorDavid Neto <dneto@google.com>
Fri, 24 Nov 2017 19:18:17 +0000 (14:18 -0500)
committerDavid Neto <dneto@google.com>
Thu, 30 Nov 2017 15:26:04 +0000 (10:26 -0500)
This needs custom code since the rules from the extension
are not encoded in the grammar.

Changes are:
- The new group instructions don't require Group capability
  when the extension is declared.
- The Reduce, InclusiveScan, ExclusiveScan normally require the Kernel
  capability, but don't when the extension is declared.

Fixes https://github.com/KhronosGroup/SPIRV-Tools/issues/991

source/val/validation_state.cpp
source/val/validation_state.h
source/validate_instruction.cpp
test/val/val_extensions_test.cpp

index 4e8db28..b76620e 100644 (file)
@@ -292,6 +292,9 @@ void ValidationState_t::RegisterCapability(SpvCapability cap) {
   }
 
   switch (cap) {
+    case SpvCapabilityKernel:
+      features_.group_ops_reduce_and_scans = true;
+      break;
     case SpvCapabilityInt16:
       features_.declare_int16_type = true;
       break;
@@ -323,6 +326,18 @@ void ValidationState_t::RegisterExtension(Extension ext) {
   if (module_extensions_.Contains(ext)) return;
 
   module_extensions_.Add(ext);
+
+  switch (ext) {
+    case kSPV_AMD_shader_ballot:
+      // The grammar doesn't encode the fact that SPV_AMD_shader_ballot
+      // enables the use of group operations Reduce, InclusiveScan,
+      // and ExclusiveScan.  Enable it manually.
+      // https://github.com/KhronosGroup/SPIRV-Tools/issues/991
+      features_.group_ops_reduce_and_scans = true;
+      break;
+    default:
+      break;
+  }
 }
 
 bool ValidationState_t::HasAnyOfCapabilities(
index 48bc3be..6005605 100644 (file)
@@ -69,6 +69,9 @@ class ValidationState_t {
     // Allow functionalities enabled by VariablePointersStorageBuffer
     // capability.
     bool variable_pointers_storage_buffer = false;
+
+    // Permit group oerations Reduce, InclusiveScan, ExclusiveScan
+    bool group_ops_reduce_and_scans = false;
   };
 
   ValidationState_t(const spv_const_context context,
index 0181507..dc013cb 100644 (file)
@@ -67,6 +67,39 @@ spv_result_t CapabilityError(ValidationState_t& _, int which_operand,
          << " requires one of these capabilities: " << required_capabilities;
 }
 
+// Returns capabilities that enable an opcode.  An empty result is interpreted
+// as no prohibition of use of the opcode.  If the result is non-empty, then
+// the opcode may only be used if at least one of the capabilities is specified
+// by the module.
+CapabilitySet EnablingCapabilitiesForOp(const ValidationState_t& state,
+                                        SpvOp opcode) {
+  // Exceptions for SPV_AMD_shader_ballot
+  switch (opcode) {
+    // Normally these would require Group capability
+    case SpvOpGroupIAddNonUniformAMD:
+    case SpvOpGroupFAddNonUniformAMD:
+    case SpvOpGroupFMinNonUniformAMD:
+    case SpvOpGroupUMinNonUniformAMD:
+    case SpvOpGroupSMinNonUniformAMD:
+    case SpvOpGroupFMaxNonUniformAMD:
+    case SpvOpGroupUMaxNonUniformAMD:
+    case SpvOpGroupSMaxNonUniformAMD:
+      if (state.HasExtension(libspirv::kSPV_AMD_shader_ballot))
+        return CapabilitySet();
+      break;
+    default:
+      break;
+  }
+  // Look it up in the grammar
+  spv_opcode_desc opcode_desc = {};
+  if (SPV_SUCCESS == state.grammar().lookupOpcode(opcode, &opcode_desc)) {
+    CapabilitySet opcode_caps(opcode_desc->numCapabilities,
+                              opcode_desc->capabilities);
+    return opcode_caps;
+  }
+  return CapabilitySet();
+}
+
 // Returns an operand's required capabilities.
 CapabilitySet RequiredCapabilities(const ValidationState_t& state,
                                    spv_operand_type_t type, uint32_t operand) {
@@ -97,12 +130,18 @@ CapabilitySet RequiredCapabilities(const ValidationState_t& state,
     CapabilitySet result(operand_desc->numCapabilities,
                          operand_desc->capabilities);
 
-    // Allow FPRoundingMode decoration if requested
+    // Allow FPRoundingMode decoration if requested.
     if (state.features().free_fp_rounding_mode &&
         type == SPV_OPERAND_TYPE_DECORATION &&
         operand_desc->value == SpvDecorationFPRoundingMode) {
       return CapabilitySet();
     }
+    // Allow certain group operations if requested.
+    if (state.features().group_ops_reduce_and_scans &&
+        type == SPV_OPERAND_TYPE_GROUP_OPERATION &&
+        (operand <= uint32_t(SpvGroupOperationExclusiveScan))) {
+      return CapabilitySet();
+    }
     return result;
   }
 
@@ -128,16 +167,13 @@ namespace libspirv {
 
 spv_result_t CapabilityCheck(ValidationState_t& _,
                              const spv_parsed_instruction_t* inst) {
-  spv_opcode_desc opcode_desc = {};
   const SpvOp opcode = static_cast<SpvOp>(inst->opcode);
-  if (SPV_SUCCESS == _.grammar().lookupOpcode(opcode, &opcode_desc)) {
-    CapabilitySet opcode_caps(opcode_desc->numCapabilities,
-                              opcode_desc->capabilities);
-    if (!_.HasAnyOfCapabilities(opcode_caps))
-      return _.diag(SPV_ERROR_INVALID_CAPABILITY)
-             << "Opcode " << spvOpcodeString(opcode)
-             << " requires one of these capabilities: "
-             << ToString(opcode_caps, _.grammar());
+  CapabilitySet opcode_caps = EnablingCapabilitiesForOp(_, opcode);
+  if (!_.HasAnyOfCapabilities(opcode_caps)) {
+    return _.diag(SPV_ERROR_INVALID_CAPABILITY)
+           << "Opcode " << spvOpcodeString(opcode)
+           << " requires one of these capabilities: "
+           << ToString(opcode_caps, _.grammar());
   }
   for (int i = 0; i < inst->num_operands; ++i) {
     const auto& operand = inst->operands[i];
index e4993ca..bd3998c 100644 (file)
@@ -30,6 +30,7 @@ using ::libspirv::Extension;
 using ::testing::HasSubstr;
 using ::testing::Not;
 using ::testing::Values;
+using ::testing::ValuesIn;
 
 using std::string;
 
@@ -106,4 +107,106 @@ TEST_F(ValidateExtensionCapabilities, DeclCapabilityFailure) {
   EXPECT_THAT(getDiagnosticString(), HasSubstr("SPV_KHR_device_group"));
 }
 
+
+using ValidateAMDShaderBallotCapabilities = spvtest::ValidateBase<string>;
+
+// Returns a vector of strings for the prefix of a SPIR-V assembly shader
+// that can use the group instructions introduced by SPV_AMD_shader_ballot.
+std::vector<string> ShaderPartsForAMDShaderBallot() {
+  return std::vector<string>{R"(
+  OpCapability Shader
+  OpCapability Linkage
+  )",
+                             R"(
+  OpMemoryModel Logical GLSL450
+  %float = OpTypeFloat 32
+  %uint = OpTypeInt 32 0
+  %int = OpTypeInt 32 1
+  %scope = OpConstant %uint 3
+  %uint_const = OpConstant %uint 42
+  %int_const = OpConstant %uint 45
+  %float_const = OpConstant %float 3.5
+
+  %void = OpTypeVoid
+  %fn_ty = OpTypeFunction %void
+  %fn = OpFunction %void None %fn_ty
+  %entry = OpLabel
+  )"};
+}
+
+// Returns a list of SPIR-V assembly strings, where each uses only types
+// and IDs that can fit with a shader made from parts from the result
+// of ShaderPartsForAMDShaderBallot.
+std::vector<string> AMDShaderBallotGroupInstructions() {
+  return std::vector<string>{
+  "%iadd_reduce = OpGroupIAddNonUniformAMD %uint %scope Reduce %uint_const",
+  "%iadd_iscan = OpGroupIAddNonUniformAMD %uint %scope InclusiveScan %uint_const",
+  "%iadd_escan = OpGroupIAddNonUniformAMD %uint %scope ExclusiveScan %uint_const",
+
+  "%fadd_reduce = OpGroupFAddNonUniformAMD %float %scope Reduce %float_const",
+  "%fadd_iscan = OpGroupFAddNonUniformAMD %float %scope InclusiveScan %float_const",
+  "%fadd_escan = OpGroupFAddNonUniformAMD %float %scope ExclusiveScan %float_const",
+
+  "%fmin_reduce = OpGroupFMinNonUniformAMD %float %scope Reduce %float_const",
+  "%fmin_iscan = OpGroupFMinNonUniformAMD %float %scope InclusiveScan %float_const",
+  "%fmin_escan = OpGroupFMinNonUniformAMD %float %scope ExclusiveScan %float_const",
+
+  "%umin_reduce = OpGroupUMinNonUniformAMD %uint %scope Reduce %uint_const",
+  "%umin_iscan = OpGroupUMinNonUniformAMD %uint %scope InclusiveScan %uint_const",
+  "%umin_escan = OpGroupUMinNonUniformAMD %uint %scope ExclusiveScan %uint_const",
+
+  "%smin_reduce = OpGroupUMinNonUniformAMD %int %scope Reduce %int_const",
+  "%smin_iscan = OpGroupUMinNonUniformAMD %int %scope InclusiveScan %int_const",
+  "%smin_escan = OpGroupUMinNonUniformAMD %int %scope ExclusiveScan %int_const",
+
+  "%fmax_reduce = OpGroupFMaxNonUniformAMD %float %scope Reduce %float_const",
+  "%fmax_iscan = OpGroupFMaxNonUniformAMD %float %scope InclusiveScan %float_const",
+  "%fmax_escan = OpGroupFMaxNonUniformAMD %float %scope ExclusiveScan %float_const",
+
+  "%umax_reduce = OpGroupUMaxNonUniformAMD %uint %scope Reduce %uint_const",
+  "%umax_iscan = OpGroupUMaxNonUniformAMD %uint %scope InclusiveScan %uint_const",
+  "%umax_escan = OpGroupUMaxNonUniformAMD %uint %scope ExclusiveScan %uint_const",
+
+  "%smax_reduce = OpGroupUMaxNonUniformAMD %int %scope Reduce %int_const",
+  "%smax_iscan = OpGroupUMaxNonUniformAMD %int %scope InclusiveScan %int_const",
+  "%smax_escan = OpGroupUMaxNonUniformAMD %int %scope ExclusiveScan %int_const"
+  };
+}
+
+TEST_P(ValidateAMDShaderBallotCapabilities, ExpectSuccess) {
+  // Succeed because the module specifies the SPV_AMD_shader_ballot extension.
+  auto parts = ShaderPartsForAMDShaderBallot();
+
+  const string assembly = parts[0] + "OpExtension \"SPV_AMD_shader_ballot\"\n" +
+                          parts[1] + GetParam() + "\nOpReturn OpFunctionEnd";
+
+  CompileSuccessfully(assembly.c_str());
+  EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()) << getDiagnosticString();
+}
+
+INSTANTIATE_TEST_CASE_P(ExpectSuccess, ValidateAMDShaderBallotCapabilities,
+                        ValuesIn(AMDShaderBallotGroupInstructions()));
+
+TEST_P(ValidateAMDShaderBallotCapabilities, ExpectFailure) {
+  // Fail because the module does not specify the SPV_AMD_shader_ballot extension.
+  auto parts = ShaderPartsForAMDShaderBallot();
+
+  const string assembly =
+      parts[0] + parts[1] + GetParam() + "\nOpReturn OpFunctionEnd";
+
+  CompileSuccessfully(assembly.c_str());
+  EXPECT_EQ(SPV_ERROR_INVALID_CAPABILITY, ValidateInstructions());
+
+  // Make sure we get an appropriate error message.
+  // Find just the opcode name, skipping over the "Op" part.
+  auto prefix_with_opcode = GetParam().substr(GetParam().find("Group"));
+  auto opcode = prefix_with_opcode.substr(0, prefix_with_opcode.find(' '));
+  EXPECT_THAT(getDiagnosticString(),
+              HasSubstr(string("Opcode " + opcode +
+                               " requires one of these capabilities: Groups")));
+}
+
+INSTANTIATE_TEST_CASE_P(ExpectFailure, ValidateAMDShaderBallotCapabilities,
+                        ValuesIn(AMDShaderBallotGroupInstructions()));
+
 }  // anonymous namespace