From b4cf3719368491e5aab662eeedb28b5b60045c87 Mon Sep 17 00:00:00 2001 From: Andrey Tuganov Date: Wed, 3 May 2017 15:26:39 -0400 Subject: [PATCH] Stats analyzer uses validator Stats analyzer calls validator to check the instruction and update validator state. Fixed unit tests (validator was failing). --- source/spirv_stats.cpp | 163 ++++++++++++++++++++++-------------- source/val/instruction.h | 5 ++ source/validate.cpp | 10 ++- source/validate.h | 6 ++ test/stats/stats_aggregate_test.cpp | 31 ++++++- 5 files changed, 146 insertions(+), 69 deletions(-) diff --git a/source/spirv_stats.cpp b/source/spirv_stats.cpp index 19da719..ef80229 100644 --- a/source/spirv_stats.cpp +++ b/source/spirv_stats.cpp @@ -17,6 +17,7 @@ #include #include +#include #include #include @@ -29,79 +30,114 @@ #include "operand.h" #include "spirv-tools/libspirv.h" #include "spirv_endian.h" +#include "spirv_validator_options.h" +#include "validate.h" +#include "val/instruction.h" +#include "val/validation_state.h" +using libspirv::Instruction; using libspirv::SpirvStats; +using libspirv::ValidationState_t; namespace { -struct StatsContext { - SpirvStats* stats; +// Helper class for stats aggregation. Receives as in/out parameter. +// Constructs ValidationState and updates it by running validator for each +// instruction. +class StatsAggregator { + public: + StatsAggregator(SpirvStats* in_out_stats, const spv_const_context context) { + stats_ = in_out_stats; + vstate_.reset(new ValidationState_t(context, &validator_options_)); + } - // Opcodes of already processed instructions in the order as they appear in - // the module. - std::vector opcodes; -}; + // Collects header statistics and sets correct id_bound. + spv_result_t ProcessHeader( + spv_endianness_t /* endian */, uint32_t /* magic */, + uint32_t version, uint32_t generator, uint32_t id_bound, + uint32_t /* schema */) { + vstate_->setIdBound(id_bound); + ++stats_->version_hist[version]; + ++stats_->generator_hist[generator]; + return SPV_SUCCESS; + } -// Collects statistics from SPIR-V header (version, generator). -spv_result_t ProcessHeader( - void* user_data, spv_endianness_t /* endian */, uint32_t /* magic */, - uint32_t version, uint32_t generator, uint32_t /* id_bound */, - uint32_t /* schema */) { - StatsContext* ctx = reinterpret_cast(user_data); - ++ctx->stats->version_hist[version]; - ++ctx->stats->generator_hist[generator]; - return SPV_SUCCESS; -} + // Runs validator to validate the instruction and update vstate_, + // then procession the instruction to collect stats. + spv_result_t ProcessInstruction(const spv_parsed_instruction_t* inst) { + const spv_result_t validation_result = + spvtools::ValidateInstructionAndUpdateValidationState(vstate_.get(), inst); + if (validation_result != SPV_SUCCESS) + return validation_result; -// Collects OpCapability statistics. -void ProcessCapability(StatsContext* ctx, - const spv_parsed_instruction_t* inst) { - if (static_cast(inst->opcode) != SpvOpCapability) return; - assert(inst->num_operands == 1); - const spv_parsed_operand_t& operand = inst->operands[0]; - assert(operand.num_words == 1); - assert(operand.offset < inst->num_words); - const uint32_t capability = inst->words[operand.offset]; - ++ctx->stats->capability_hist[capability]; -} + ProcessOpcode(); + ProcessCapability(); + ProcessExtension(); -// Collects OpExtension statistics. -void ProcessExtension(StatsContext* ctx, - const spv_parsed_instruction_t* inst) { - if (static_cast(inst->opcode) != SpvOpExtension) return; - const std::string extension = libspirv::GetExtensionString(inst); - ++ctx->stats->extension_hist[extension]; -} + return SPV_SUCCESS; + } -// Collects OpCode statistics. -void ProcessOpcode(StatsContext* ctx, - const spv_parsed_instruction_t* inst) { - const SpvOp opcode = static_cast(inst->opcode); - ++ctx->stats->opcode_hist[opcode]; - - auto opcode_it = ctx->opcodes.rbegin(); - auto step_it = ctx->stats->opcode_markov_hist.begin(); - for (; opcode_it != ctx->opcodes.rend() && - step_it != ctx->stats->opcode_markov_hist.end(); - ++opcode_it, ++step_it) { - auto& hist = (*step_it)[*opcode_it]; - ++hist[opcode]; + // Collects OpCapability statistics. + void ProcessCapability() { + const Instruction& inst = GetCurrentInstruction(); + if (inst.opcode() != SpvOpCapability) return; + const uint32_t capability = inst.word(inst.operands()[0].offset); + ++stats_->capability_hist[capability]; } -} -// Collects opcode usage statistics and calls other collectors. -spv_result_t ProcessInstruction( - void* user_data, const spv_parsed_instruction_t* inst) { - StatsContext* ctx = reinterpret_cast(user_data); + // Collects OpExtension statistics. + void ProcessExtension() { + const Instruction& inst = GetCurrentInstruction(); + if (inst.opcode() != SpvOpExtension) return; + const std::string extension = libspirv::GetExtensionString(&inst.c_inst()); + ++stats_->extension_hist[extension]; + } - ProcessOpcode(ctx, inst); - ProcessCapability(ctx, inst); - ProcessExtension(ctx, inst); + // Collects OpCode statistics. + void ProcessOpcode() { + auto inst_it = vstate_->ordered_instructions().rbegin(); + const SpvOp opcode = inst_it->opcode(); + ++stats_->opcode_hist[opcode]; + + ++inst_it; + auto step_it = stats_->opcode_markov_hist.begin(); + for (; inst_it != vstate_->ordered_instructions().rend() && + step_it != stats_->opcode_markov_hist.end(); ++inst_it, ++step_it) { + auto& hist = (*step_it)[inst_it->opcode()]; + ++hist[opcode]; + } + } + + SpirvStats* stats() { + return stats_; + } - const SpvOp opcode = static_cast(inst->opcode); - ctx->opcodes.push_back(opcode); + private: + // Returns the current instruction (the one last processed by the validator). + const Instruction& GetCurrentInstruction() const { + return vstate_->ordered_instructions().back(); + } + + SpirvStats* stats_; + spv_validator_options_t validator_options_; + std::unique_ptr vstate_; +}; - return SPV_SUCCESS; +spv_result_t ProcessHeader( + void* user_data, spv_endianness_t endian, uint32_t magic, + uint32_t version, uint32_t generator, uint32_t id_bound, + uint32_t schema) { + StatsAggregator* stats_aggregator = + reinterpret_cast(user_data); + return stats_aggregator->ProcessHeader( + endian, magic, version, generator, id_bound, schema); +} + +spv_result_t ProcessInstruction( + void* user_data, const spv_parsed_instruction_t* inst) { + StatsAggregator* stats_aggregator = + reinterpret_cast(user_data); + return stats_aggregator->ProcessInstruction(inst); } } // namespace @@ -109,29 +145,28 @@ spv_result_t ProcessInstruction( namespace libspirv { spv_result_t AggregateStats( - const spv_context_t& spv_context, const uint32_t* words, const size_t num_words, + const spv_context_t& context, const uint32_t* words, const size_t num_words, spv_diagnostic* pDiagnostic, SpirvStats* stats) { spv_const_binary_t binary = {words, num_words}; spv_endianness_t endian; spv_position_t position = {}; if (spvBinaryEndianness(&binary, &endian)) { - return libspirv::DiagnosticStream(position, spv_context.consumer, + return libspirv::DiagnosticStream(position, context.consumer, SPV_ERROR_INVALID_BINARY) << "Invalid SPIR-V magic number."; } spv_header_t header; if (spvBinaryHeaderGet(&binary, endian, &header)) { - return libspirv::DiagnosticStream(position, spv_context.consumer, + return libspirv::DiagnosticStream(position, context.consumer, SPV_ERROR_INVALID_BINARY) << "Invalid SPIR-V header."; } - StatsContext stats_context; - stats_context.stats = stats; + StatsAggregator stats_aggregator(stats, &context); - return spvBinaryParse(&spv_context, &stats_context, words, num_words, + return spvBinaryParse(&context, &stats_aggregator, words, num_words, ProcessHeader, ProcessInstruction, pDiagnostic); } diff --git a/source/val/instruction.h b/source/val/instruction.h index 1d8fe91..7cef7e3 100644 --- a/source/val/instruction.h +++ b/source/val/instruction.h @@ -71,6 +71,11 @@ class Instruction { return operands_; } + /// Provides direct access to the stored C instruction object. + const spv_parsed_instruction_t& c_inst() const { + return inst_; + } + private: const std::vector words_; const std::vector operands_; diff --git a/source/validate.cpp b/source/validate.cpp index ccfce05..a24c60e 100644 --- a/source/validate.cpp +++ b/source/validate.cpp @@ -374,7 +374,9 @@ spv_result_t spvValidateWithOptions(const spv_const_context context, hijack_context, binary->code, binary->wordCount, pDiagnostic, &vstate); } -spv_result_t spvtools::ValidateBinaryAndKeepValidationState( +namespace spvtools { + +spv_result_t ValidateBinaryAndKeepValidationState( const spv_const_context context, spv_const_validator_options options, const uint32_t* words, const size_t num_words, spv_diagnostic* pDiagnostic, std::unique_ptr* vstate) { @@ -390,3 +392,9 @@ spv_result_t spvtools::ValidateBinaryAndKeepValidationState( hijack_context, words, num_words, pDiagnostic, vstate->get()); } +spv_result_t ValidateInstructionAndUpdateValidationState( + ValidationState_t* vstate, const spv_parsed_instruction_t* inst) { + return ProcessInstruction(vstate, inst); +} + +} // namespace spvtools diff --git a/source/validate.h b/source/validate.h index a3976fb..f24e75d 100644 --- a/source/validate.h +++ b/source/validate.h @@ -188,6 +188,12 @@ spv_result_t ValidateBinaryAndKeepValidationState( const spv_const_context context, spv_const_validator_options options, const uint32_t* words, const size_t num_words, spv_diagnostic* pDiagnostic, std::unique_ptr* vstate); + +// Performs validation for a single instruction and updates given validation +// state. +spv_result_t ValidateInstructionAndUpdateValidationState( + libspirv::ValidationState_t* vstate, const spv_parsed_instruction_t* inst); + } // namespace spvtools #endif // LIBSPIRV_VALIDATE_H_ diff --git a/test/stats/stats_aggregate_test.cpp b/test/stats/stats_aggregate_test.cpp index 463b782..8b23504 100644 --- a/test/stats/stats_aggregate_test.cpp +++ b/test/stats/stats_aggregate_test.cpp @@ -25,10 +25,33 @@ namespace { using libspirv::SpirvStats; using spvtest::ScopedContext; +void DiagnosticsMessageHandler(spv_message_level_t level, const char*, + const spv_position_t& position, + const char* message) { + switch (level) { + case SPV_MSG_FATAL: + case SPV_MSG_INTERNAL_ERROR: + case SPV_MSG_ERROR: + std::cerr << "error: " << position.index << ": " << message + << std::endl; + break; + case SPV_MSG_WARNING: + std::cout << "warning: " << position.index << ": " << message + << std::endl; + break; + case SPV_MSG_INFO: + std::cout << "info: " << position.index << ": " << message << std::endl; + break; + default: + break; + } +} + // Calls libspirv::AggregateStats for binary compiled from |code|. void CompileAndAggregateStats(const std::string& code, SpirvStats* stats, spv_target_env env = SPV_ENV_UNIVERSAL_1_1) { ScopedContext ctx(env); + SetContextMessageConsumer(ctx.context, DiagnosticsMessageHandler); spv_binary binary; ASSERT_EQ(SPV_SUCCESS, spvTextToBinary( ctx.context, code.c_str(), code.size(), &binary, nullptr)); @@ -186,10 +209,10 @@ TEST(AggregateStats, OpcodeHistogram) { const std::string code1 = R"( OpCapability Addresses OpCapability Kernel -OpCapability GenericPointer +OpCapability Int64 OpCapability Linkage OpMemoryModel Physical32 OpenCL -%i32 = OpTypeInt 32 1 +%u64 = OpTypeInt 64 0 %u32 = OpTypeInt 32 0 %f32 = OpTypeFloat 32 )"; @@ -246,10 +269,10 @@ OpMemoryModel Logical GLSL450 const std::string code2 = R"( OpCapability Addresses OpCapability Kernel -OpCapability GenericPointer +OpCapability Int64 OpCapability Linkage OpMemoryModel Physical32 OpenCL -%i32 = OpTypeInt 32 1 +%u64 = OpTypeInt 64 0 %u32 = OpTypeInt 32 0 %f32 = OpTypeFloat 32 )"; -- 2.7.4