From: Lei Zhang Date: Fri, 2 Sep 2016 22:06:18 +0000 (-0400) Subject: Add a callback mechanism for communicating messages to callers. X-Git-Tag: upstream/2018.6~1054 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=755f97f534153519e15636502f1e2d6f4f48cda6;p=platform%2Fupstream%2FSPIRV-Tools.git Add a callback mechanism for communicating messages to callers. Every time an event happens in the library that the user should be aware of, the callback will be invoked. The existing diagnostic mechanism is hijacked internally by a callback that creates an diagnostic object each time an event happens. --- diff --git a/include/spirv-tools/libspirv.h b/include/spirv-tools/libspirv.h index 42ccb2a..c9812ec 100644 --- a/include/spirv-tools/libspirv.h +++ b/include/spirv-tools/libspirv.h @@ -340,14 +340,14 @@ typedef enum { SPV_ENV_UNIVERSAL_1_0, // SPIR-V 1.0 latest revision, no other restrictions. SPV_ENV_VULKAN_1_0, // Vulkan 1.0 latest revision. SPV_ENV_UNIVERSAL_1_1, // SPIR-V 1.1 latest revision, no other restrictions. - SPV_ENV_OPENCL_2_1, // OpenCL 2.1 latest revision. - SPV_ENV_OPENCL_2_2, // OpenCL 2.2 latest revision. - SPV_ENV_OPENGL_4_0, // OpenGL 4.0 plus GL_ARB_gl_spirv, latest revisions. - SPV_ENV_OPENGL_4_1, // OpenGL 4.1 plus GL_ARB_gl_spirv, latest revisions. - SPV_ENV_OPENGL_4_2, // OpenGL 4.2 plus GL_ARB_gl_spirv, latest revisions. - SPV_ENV_OPENGL_4_3, // OpenGL 4.3 plus GL_ARB_gl_spirv, latest revisions. + SPV_ENV_OPENCL_2_1, // OpenCL 2.1 latest revision. + SPV_ENV_OPENCL_2_2, // OpenCL 2.2 latest revision. + SPV_ENV_OPENGL_4_0, // OpenGL 4.0 plus GL_ARB_gl_spirv, latest revisions. + SPV_ENV_OPENGL_4_1, // OpenGL 4.1 plus GL_ARB_gl_spirv, latest revisions. + SPV_ENV_OPENGL_4_2, // OpenGL 4.2 plus GL_ARB_gl_spirv, latest revisions. + SPV_ENV_OPENGL_4_3, // OpenGL 4.3 plus GL_ARB_gl_spirv, latest revisions. // There is no variant for OpenGL 4.4. - SPV_ENV_OPENGL_4_5, // OpenGL 4.5 plus GL_ARB_gl_spirv, latest revisions. + SPV_ENV_OPENGL_4_5, // OpenGL 4.5 plus GL_ARB_gl_spirv, latest revisions. } spv_target_env; // Returns a string describing the given SPIR-V target environment. @@ -361,8 +361,9 @@ void spvContextDestroy(spv_context context); // Encodes the given SPIR-V assembly text to its binary representation. The // length parameter specifies the number of bytes for text. Encoded binary will -// be stored into *binary. Any error will be written into *diagnostic. The -// generated binary is independent of the context and may outlive it. +// be stored into *binary. Any error will be written into *diagnostic if +// diagnostic is non-null. The generated binary is independent of the context +// and may outlive it. spv_result_t spvTextToBinary(const spv_const_context context, const char* text, const size_t length, spv_binary* binary, spv_diagnostic* diagnostic); @@ -374,7 +375,8 @@ void spvTextDestroy(spv_text text); // Decodes the given SPIR-V binary representation to its assembly text. The // word_count parameter specifies the number of words for binary. The options // parameter is a bit field of spv_binary_to_text_options_t. Decoded text will -// be stored into *text. Any error will be written into *diagnostic. +// be stored into *text. Any error will be written into *diagnostic if +// diagnostic is non-null. spv_result_t spvBinaryToText(const spv_const_context context, const uint32_t* binary, const size_t word_count, const uint32_t options, spv_text* text, @@ -385,7 +387,7 @@ spv_result_t spvBinaryToText(const spv_const_context context, void spvBinaryDestroy(spv_binary binary); // Validates a SPIR-V binary for correctness. Any errors will be written into -// *diagnostic. +// *diagnostic if diagnostic is non-null. spv_result_t spvValidate(const spv_const_context context, const spv_const_binary binary, spv_diagnostic* diagnostic); diff --git a/source/binary.cpp b/source/binary.cpp index a709c02..c8b7c38 100644 --- a/source/binary.cpp +++ b/source/binary.cpp @@ -61,6 +61,7 @@ class Parser { spv_parsed_header_fn_t parsed_header_fn, spv_parsed_instruction_fn_t parsed_instruction_fn) : grammar_(context), + consumer_(context->consumer), user_data_(user_data), parsed_header_fn_(parsed_header_fn), parsed_instruction_fn_(parsed_instruction_fn) {} @@ -120,8 +121,7 @@ class Parser { // returned object will be propagated to the current parse's diagnostic // object. libspirv::DiagnosticStream diagnostic(spv_result_t error) { - return libspirv::DiagnosticStream({0, 0, _.word_index}, _.diagnostic, - error); + return libspirv::DiagnosticStream({0, 0, _.word_index}, consumer_, error); } // Returns a diagnostic stream object with the default parse error code. @@ -156,6 +156,7 @@ class Parser { // Data members const libspirv::AssemblyGrammar grammar_; // SPIR-V syntax utility. + const spvtools::MessageConsumer& consumer_; // Message consumer callback. void* const user_data_; // Context for the callbacks const spv_parsed_header_fn_t parsed_header_fn_; // Parsed header callback const spv_parsed_instruction_fn_t @@ -752,7 +753,12 @@ spv_result_t spvBinaryParse(const spv_const_context context, void* user_data, spv_parsed_header_fn_t parsed_header, spv_parsed_instruction_fn_t parsed_instruction, spv_diagnostic* diagnostic) { - Parser parser(context, user_data, parsed_header, parsed_instruction); + spv_context_t hijack_context = *context; + if (diagnostic) { + *diagnostic = nullptr; + libspirv::UseDiagnosticAsMessageConsumer(&hijack_context, diagnostic); + } + Parser parser(&hijack_context, user_data, parsed_header, parsed_instruction); return parser.parse(code, num_words, diagnostic); } diff --git a/source/diagnostic.cpp b/source/diagnostic.cpp index b1a9cac..8f0e348 100644 --- a/source/diagnostic.cpp +++ b/source/diagnostic.cpp @@ -20,6 +20,7 @@ #include #include "spirv-tools/libspirv.h" +#include "table.h" // Diagnostic API @@ -68,12 +69,47 @@ spv_result_t spvDiagnosticPrint(const spv_diagnostic diagnostic) { namespace libspirv { DiagnosticStream::~DiagnosticStream() { - if (pDiagnostic_ && error_ != SPV_FAILED_MATCH) { - *pDiagnostic_ = spvDiagnosticCreate(&position_, stream_.str().c_str()); + using spvtools::MessageLevel; + if (error_ != SPV_FAILED_MATCH && consumer_ != nullptr) { + auto level = MessageLevel::Error; + switch (error_) { + case SPV_SUCCESS: + case SPV_REQUESTED_TERMINATION: // Essentially success. + level = MessageLevel::Info; + break; + case SPV_WARNING: + level = MessageLevel::Warning; + break; + case SPV_UNSUPPORTED: + case SPV_ERROR_INTERNAL: + case SPV_ERROR_INVALID_TABLE: + level = MessageLevel::InternalError; + break; + case SPV_ERROR_OUT_OF_MEMORY: + level = MessageLevel::Fatal; + break; + default: + break; + } + consumer_(level, "", position_, stream_.str().c_str()); } } -std::string -spvResultToString(spv_result_t res) { + +void UseDiagnosticAsMessageConsumer(spv_context context, + spv_diagnostic* diagnostic) { + assert(diagnostic && *diagnostic == nullptr); + + auto create_diagnostic = [diagnostic](spvtools::MessageLevel, const char*, + const spv_position_t& position, + const char* message) { + auto p = position; + spvDiagnosticDestroy(*diagnostic); // Avoid memory leak. + *diagnostic = spvDiagnosticCreate(&p, message); + }; + SetContextMessageConsumer(context, std::move(create_diagnostic)); +} + +std::string spvResultToString(spv_result_t res) { std::string out; switch (res) { case SPV_SUCCESS: diff --git a/source/diagnostic.h b/source/diagnostic.h index 840ec9f..9bf9ae2 100644 --- a/source/diagnostic.h +++ b/source/diagnostic.h @@ -19,6 +19,7 @@ #include #include +#include "message.h" #include "spirv-tools/libspirv.h" namespace libspirv { @@ -29,20 +30,16 @@ namespace libspirv { // emitted during the destructor. class DiagnosticStream { public: - DiagnosticStream(spv_position_t position, spv_diagnostic* pDiagnostic, + DiagnosticStream(spv_position_t position, + const spvtools::MessageConsumer& consumer, spv_result_t error) - : position_(position), pDiagnostic_(pDiagnostic), error_(error) {} + : position_(position), consumer_(consumer), error_(error) {} DiagnosticStream(DiagnosticStream&& other) : stream_(other.stream_.str()), position_(other.position_), - pDiagnostic_(other.pDiagnostic_), - error_(other.error_) { - // The other object's destructor will emit the text in its stream_ - // member if its pDiagnostic_ member is non-null. Prevent that, - // since emitting that text is now the responsibility of *this. - other.pDiagnostic_ = nullptr; - } + consumer_(other.consumer_), + error_(other.error_) {} ~DiagnosticStream(); @@ -59,10 +56,18 @@ class DiagnosticStream { private: std::stringstream stream_; spv_position_t position_; - spv_diagnostic* pDiagnostic_; + const spvtools::MessageConsumer& consumer_; // Message consumer callback. spv_result_t error_; }; +// Changes the MessageConsumer in |context| to one that updates |diagnostic| +// with the last message received. +// +// This function expects that |diagnostic| is not nullptr and its content is a +// nullptr. +void UseDiagnosticAsMessageConsumer(spv_context context, + spv_diagnostic* diagnostic); + std::string spvResultToString(spv_result_t res); } // namespace libspirv diff --git a/source/disassemble.cpp b/source/disassemble.cpp index e71581d..267ed17 100644 --- a/source/disassemble.cpp +++ b/source/disassemble.cpp @@ -393,10 +393,13 @@ spv_result_t spvBinaryToText(const spv_const_context context, const uint32_t* code, const size_t wordCount, const uint32_t options, spv_text* pText, spv_diagnostic* pDiagnostic) { - // Invalid arguments return error codes, but don't necessarily generate - // diagnostics. These are programmer errors, not user errors. - if (!pDiagnostic) return SPV_ERROR_INVALID_DIAGNOSTIC; - const libspirv::AssemblyGrammar grammar(context); + spv_context_t hijack_context = *context; + if (pDiagnostic) { + *pDiagnostic = nullptr; + libspirv::UseDiagnosticAsMessageConsumer(&hijack_context, pDiagnostic); + } + + const libspirv::AssemblyGrammar grammar(&hijack_context); if (!grammar.isValid()) return SPV_ERROR_INVALID_TABLE; // Generate friendly names for Ids if requested. @@ -404,15 +407,15 @@ spv_result_t spvBinaryToText(const spv_const_context context, libspirv::NameMapper name_mapper = libspirv::GetTrivialNameMapper(); if (options & SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES) { friendly_mapper.reset( - new libspirv::FriendlyNameMapper(context, code, wordCount)); + new libspirv::FriendlyNameMapper(&hijack_context, code, wordCount)); name_mapper = friendly_mapper->GetNameMapper(); } // Now disassemble! Disassembler disassembler(grammar, options, name_mapper); - if (auto error = spvBinaryParse(context, &disassembler, code, wordCount, - DisassembleHeader, DisassembleInstruction, - pDiagnostic)) { + if (auto error = spvBinaryParse(&hijack_context, &disassembler, code, + wordCount, DisassembleHeader, + DisassembleInstruction, pDiagnostic)) { return error; } diff --git a/source/message.h b/source/message.h new file mode 100644 index 0000000..68e4cd7 --- /dev/null +++ b/source/message.h @@ -0,0 +1,47 @@ +// Copyright (c) 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SPIRV_TOOLS_MESSAGE_H_ +#define SPIRV_TOOLS_MESSAGE_H_ + +#include + +#include "spirv-tools/libspirv.h" + +namespace spvtools { + +// TODO(antiagainst): This eventually should be in the C++ interface. + +// Severity levels of messages communicated to the consumer. +enum class MessageLevel { + Fatal, // Unrecoverable error due to environment. Will abort the program + // immediately. E.g., out of memory. + InternalError, // Unrecoverable error due to SPIRV-Tools internals. Will + // abort the program immediately. E.g., unimplemented feature. + Error, // Normal error due to user input. + Warning, // Warning information. + Info, // General information. + Debug, // Debug information. +}; + +// Message consumer. The C strings for source and message are only alive for the +// specific invocation. +using MessageConsumer = std::function; + +} // namespace spvtools + +#endif // SPIRV_TOOLS_MESSAGE_H_ diff --git a/source/opt/libspirv.cpp b/source/opt/libspirv.cpp index 0636c64..eabc260 100644 --- a/source/opt/libspirv.cpp +++ b/source/opt/libspirv.cpp @@ -16,6 +16,7 @@ #include "ir_loader.h" #include "make_unique.h" +#include "table.h" namespace spvtools { @@ -39,6 +40,10 @@ spv_result_t SetSpvInst(void* builder, const spv_parsed_instruction_t* inst) { } // annoymous namespace +void SpvTools::SetMessageConsumer(MessageConsumer consumer) { + SetContextMessageConsumer(context_, std::move(consumer)); +} + spv_result_t SpvTools::Assemble(const std::string& text, std::vector* binary) { spv_binary spvbinary = nullptr; diff --git a/source/opt/libspirv.hpp b/source/opt/libspirv.hpp index 46d7318..e645fcc 100644 --- a/source/opt/libspirv.hpp +++ b/source/opt/libspirv.hpp @@ -19,6 +19,7 @@ #include #include +#include "message.h" #include "module.h" #include "spirv-tools/libspirv.h" @@ -36,7 +37,8 @@ class SpvTools { ~SpvTools() { spvContextDestroy(context_); } - // TODO(antiagainst): handle error message in the following APIs. + // Sets the message consumer to the given |consumer|. + void SetMessageConsumer(MessageConsumer consumer); // Assembles the given assembly |text| and writes the result to |binary|. // Returns SPV_SUCCESS on successful assembling. diff --git a/source/table.cpp b/source/table.cpp index 6bdbd9b..24ab520 100644 --- a/source/table.cpp +++ b/source/table.cpp @@ -41,7 +41,13 @@ spv_context spvContextCreate(spv_target_env env) { spvOperandTableGet(&operand_table, env); spvExtInstTableGet(&ext_inst_table, env); - return new spv_context_t{env, opcode_table, operand_table, ext_inst_table}; + return new spv_context_t{env, opcode_table, operand_table, ext_inst_table, + nullptr /* a null default consumer */}; } void spvContextDestroy(spv_context context) { delete context; } + +void SetContextMessageConsumer(spv_context context, + spvtools::MessageConsumer consumer) { + context->consumer = std::move(consumer); +} diff --git a/source/table.h b/source/table.h index 8eedd0a..abce443 100644 --- a/source/table.h +++ b/source/table.h @@ -18,6 +18,7 @@ #include "spirv/1.1/spirv.h" #include "enum_set.h" +#include "message.h" #include "spirv-tools/libspirv.h" typedef struct spv_opcode_desc_t { @@ -87,8 +88,14 @@ struct spv_context_t { const spv_opcode_table opcode_table; const spv_operand_table operand_table; const spv_ext_inst_table ext_inst_table; + spvtools::MessageConsumer consumer; }; +// Sets the message consumer to |consumer| in the given |context|. The original +// message consumer will be overwritten. +void SetContextMessageConsumer(spv_context context, + spvtools::MessageConsumer consumer); + // Populates *table with entries for env. spv_result_t spvOpcodeTableGet(spv_opcode_table* table, spv_target_env env); diff --git a/source/text.cpp b/source/text.cpp index d317264..6e68ac2 100644 --- a/source/text.cpp +++ b/source/text.cpp @@ -31,6 +31,7 @@ #include "diagnostic.h" #include "ext_inst.h" #include "instruction.h" +#include "message.h" #include "opcode.h" #include "operand.h" #include "spirv-tools/libspirv.h" @@ -662,10 +663,9 @@ spv_result_t SetHeader(spv_target_env env, const uint32_t bound, // If a diagnostic is generated, it is not yet marked as being // for a text-based input. spv_result_t spvTextToBinaryInternal(const libspirv::AssemblyGrammar& grammar, - const spv_text text, spv_binary* pBinary, - spv_diagnostic* pDiagnostic) { - if (!pDiagnostic) return SPV_ERROR_INVALID_DIAGNOSTIC; - libspirv::AssemblyContext context(text, pDiagnostic); + const spvtools::MessageConsumer& consumer, + const spv_text text, spv_binary* pBinary) { + libspirv::AssemblyContext context(text, consumer); if (!text->str) return context.diagnostic() << "Missing assembly text."; if (!grammar.isValid()) { @@ -673,9 +673,6 @@ spv_result_t spvTextToBinaryInternal(const libspirv::AssemblyGrammar& grammar, } if (!pBinary) return SPV_ERROR_INVALID_POINTER; - // NOTE: Ensure diagnostic is zero initialised - *pDiagnostic = {}; - std::vector instructions; // Skip past whitespace and comments. @@ -728,11 +725,17 @@ spv_result_t spvTextToBinary(const spv_const_context context, const char* input_text, const size_t input_text_size, spv_binary* pBinary, spv_diagnostic* pDiagnostic) { + spv_context_t hijack_context = *context; + if (pDiagnostic) { + *pDiagnostic = nullptr; + libspirv::UseDiagnosticAsMessageConsumer(&hijack_context, pDiagnostic); + } + spv_text_t text = {input_text, input_text_size}; - libspirv::AssemblyGrammar grammar(context); + libspirv::AssemblyGrammar grammar(&hijack_context); spv_result_t result = - spvTextToBinaryInternal(grammar, &text, pBinary, pDiagnostic); + spvTextToBinaryInternal(grammar, hijack_context.consumer, &text, pBinary); if (pDiagnostic && *pDiagnostic) (*pDiagnostic)->isTextSource = true; return result; diff --git a/source/text_handler.h b/source/text_handler.h index 9951643..1bd004c 100644 --- a/source/text_handler.h +++ b/source/text_handler.h @@ -22,6 +22,7 @@ #include "diagnostic.h" #include "instruction.h" +#include "message.h" #include "spirv-tools/libspirv.h" #include "text.h" @@ -116,11 +117,8 @@ class ClampToZeroIfUnsignedType< // Encapsulates the data used during the assembly of a SPIR-V module. class AssemblyContext { public: - AssemblyContext(spv_text text, spv_diagnostic* diagnostic_arg) - : current_position_({}), - pDiagnostic_(diagnostic_arg), - text_(text), - bound_(1) {} + AssemblyContext(spv_text text, const spvtools::MessageConsumer& consumer) + : current_position_({}), consumer_(consumer), text_(text), bound_(1) {} // Assigns a new integer value to the given text ID, or returns the previously // assigned integer value if the ID has been seen before. @@ -148,7 +146,7 @@ class AssemblyContext { // stream, and for the given error code. Any data written to this object will // show up in pDiagnsotic on destruction. DiagnosticStream diagnostic(spv_result_t error) { - return DiagnosticStream(current_position_, pDiagnostic_, error); + return DiagnosticStream(current_position_, consumer_, error); } // Returns a diagnostic object with the default assembly error code. @@ -227,7 +225,6 @@ class AssemblyContext { spv_ext_inst_type_t getExtInstTypeForId(uint32_t id) const; private: - // Maps ID names to their corresponding numerical ids. using spv_named_id_table = std::unordered_map; // Maps type-defining IDs to their IdType. @@ -241,7 +238,7 @@ class AssemblyContext { // Maps an extended instruction import Id to the extended instruction type. std::unordered_map import_id_to_ext_inst_type_; spv_position_t current_position_; - spv_diagnostic* pDiagnostic_; + spvtools::MessageConsumer consumer_; spv_text text_; uint32_t bound_; }; diff --git a/source/val/ValidationState.cpp b/source/val/ValidationState.cpp index c8735dc..1368416 100644 --- a/source/val/ValidationState.cpp +++ b/source/val/ValidationState.cpp @@ -182,9 +182,8 @@ bool IsInstructionInLayoutSection(ModuleLayoutSection layout, SpvOp op) { } // anonymous namespace -ValidationState_t::ValidationState_t(spv_diagnostic* diagnostic, - const spv_const_context context) - : diagnostic_(diagnostic), +ValidationState_t::ValidationState_t(const spv_const_context ctx) + : context_(ctx), instruction_counter_(0), unresolved_forward_ids_{}, operand_names_{}, @@ -193,7 +192,7 @@ ValidationState_t::ValidationState_t(spv_diagnostic* diagnostic, module_capabilities_(), ordered_instructions_(), all_definitions_(), - grammar_(context), + grammar_(ctx), addressing_model_(SpvAddressingModelLogical), memory_model_(SpvMemoryModelSimple), in_function_(false) {} @@ -290,7 +289,7 @@ bool ValidationState_t::IsOpcodeInCurrentLayoutSection(SpvOp op) { DiagnosticStream ValidationState_t::diag(spv_result_t error_code) const { return libspirv::DiagnosticStream( - {0, 0, static_cast(instruction_counter_)}, diagnostic_, + {0, 0, static_cast(instruction_counter_)}, context_->consumer, error_code); } @@ -377,8 +376,8 @@ spv_result_t ValidationState_t::RegisterFunctionEnd() { void ValidationState_t::RegisterInstruction( const spv_parsed_instruction_t& inst) { if (in_function_body()) { - ordered_instructions_.emplace_back( - &inst, ¤t_function(), current_function().current_block()); + ordered_instructions_.emplace_back(&inst, ¤t_function(), + current_function().current_block()); } else { ordered_instructions_.emplace_back(&inst, nullptr, nullptr); } diff --git a/source/val/ValidationState.h b/source/val/ValidationState.h index 1f5c001..ff23d05 100644 --- a/source/val/ValidationState.h +++ b/source/val/ValidationState.h @@ -53,8 +53,10 @@ enum ModuleLayoutSection { /// This class manages the state of the SPIR-V validation as it is being parsed. class ValidationState_t { public: - ValidationState_t(spv_diagnostic* diagnostic, - const spv_const_context context); + ValidationState_t(const spv_const_context context); + + /// Returns the context + spv_const_context context() const { return context_; } /// Forward declares the id in the module spv_result_t ForwardDeclareId(uint32_t id); @@ -174,7 +176,8 @@ class ValidationState_t { private: ValidationState_t(const ValidationState_t&); - spv_diagnostic* diagnostic_; + const spv_const_context context_; + /// Tracks the number of instructions evaluated by the validator int instruction_counter_; @@ -191,7 +194,8 @@ class ValidationState_t { std::deque module_functions_; /// The capabilities available in the module - libspirv::CapabilitySet module_capabilities_; /// Module's declared capabilities. + libspirv::CapabilitySet + module_capabilities_; /// Module's declared capabilities. /// List of all instructions in the order they appear in the binary std::deque ordered_instructions_; diff --git a/source/validate.cpp b/source/validate.cpp index 1b50d9a..b0699cc 100644 --- a/source/validate.cpp +++ b/source/validate.cpp @@ -50,15 +50,17 @@ using libspirv::ModuleLayoutPass; using libspirv::IdPass; using libspirv::ValidationState_t; -spv_result_t spvValidateIDs( - const spv_instruction_t* pInsts, const uint64_t count, - const spv_opcode_table opcodeTable, const spv_operand_table operandTable, - const spv_ext_inst_table extInstTable, const ValidationState_t& state, - spv_position position, spv_diagnostic* pDiagnostic) { +spv_result_t spvValidateIDs(const spv_instruction_t* pInsts, + const uint64_t count, + const spv_opcode_table opcodeTable, + const spv_operand_table operandTable, + const spv_ext_inst_table extInstTable, + const ValidationState_t& state, + spv_position position) { position->index = SPV_INDEX_INSTRUCTION; if (auto error = spvValidateInstructionIDs(pInsts, count, opcodeTable, operandTable, - extInstTable, state, position, pDiagnostic)) + extInstTable, state, position)) return error; return SPV_SUCCESS; } @@ -175,29 +177,33 @@ UNUSED(void PrintDotGraph(ValidationState_t& _, libspirv::Function func)) { spv_result_t spvValidate(const spv_const_context context, const spv_const_binary binary, spv_diagnostic* pDiagnostic) { - if (!pDiagnostic) return SPV_ERROR_INVALID_DIAGNOSTIC; + spv_context_t hijack_context = *context; + if (pDiagnostic) { + *pDiagnostic = nullptr; + libspirv::UseDiagnosticAsMessageConsumer(&hijack_context, pDiagnostic); + } spv_endianness_t endian; spv_position_t position = {}; if (spvBinaryEndianness(binary, &endian)) { - return libspirv::DiagnosticStream(position, pDiagnostic, + return libspirv::DiagnosticStream(position, hijack_context.consumer, SPV_ERROR_INVALID_BINARY) << "Invalid SPIR-V magic number."; } spv_header_t header; if (spvBinaryHeaderGet(binary, endian, &header)) { - return libspirv::DiagnosticStream(position, pDiagnostic, + return libspirv::DiagnosticStream(position, hijack_context.consumer, SPV_ERROR_INVALID_BINARY) << "Invalid SPIR-V header."; } // NOTE: Parse the module and perform inline validation checks. These // checks do not require the the knowledge of the whole module. - ValidationState_t vstate(pDiagnostic, context); - if (auto error = - spvBinaryParse(context, &vstate, binary->code, binary->wordCount, - setHeader, ProcessInstruction, pDiagnostic)) + ValidationState_t vstate(&hijack_context); + if (auto error = spvBinaryParse(&hijack_context, &vstate, binary->code, + binary->wordCount, setHeader, + ProcessInstruction, pDiagnostic)) return error; if (vstate.in_function_body()) @@ -243,7 +249,7 @@ spv_result_t spvValidate(const spv_const_context context, position.index = SPV_INDEX_INSTRUCTION; return spvValidateIDs(instructions.data(), instructions.size(), - context->opcode_table, context->operand_table, - context->ext_inst_table, vstate, &position, - pDiagnostic); + hijack_context.opcode_table, + hijack_context.operand_table, + hijack_context.ext_inst_table, vstate, &position); } diff --git a/source/validate.h b/source/validate.h index 2ef341a..a19e911 100644 --- a/source/validate.h +++ b/source/validate.h @@ -20,6 +20,7 @@ #include #include "instruction.h" +#include "message.h" #include "spirv-tools/libspirv.h" #include "table.h" @@ -154,7 +155,6 @@ spv_result_t InstructionPass(ValidationState_t& _, /// @param[in] operandTable table of specified operands /// @param[in] usedefs use-def info from module parsing /// @param[in,out] position current position in the stream -/// @param[out] pDiag contains diagnostic on failure /// /// @return result code spv_result_t spvValidateInstructionIDs(const spv_instruction_t* pInsts, @@ -163,8 +163,7 @@ spv_result_t spvValidateInstructionIDs(const spv_instruction_t* pInsts, const spv_operand_table operandTable, const spv_ext_inst_table extInstTable, const libspirv::ValidationState_t& state, - spv_position position, - spv_diagnostic* pDiag); + spv_position position); /// @brief Validate the ID's within a SPIR-V binary /// @@ -174,7 +173,7 @@ spv_result_t spvValidateInstructionIDs(const spv_instruction_t* pInsts, /// @param[in] opcodeTable table of specified Opcodes /// @param[in] operandTable table of specified operands /// @param[in,out] position current word in the binary -/// @param[out] pDiagnostic contains diagnostic on failure +/// @param[in] consumer message consumer callback /// /// @return result code spv_result_t spvValidateIDs(const spv_instruction_t* pInstructions, @@ -182,6 +181,7 @@ spv_result_t spvValidateIDs(const spv_instruction_t* pInstructions, const spv_opcode_table opcodeTable, const spv_operand_table operandTable, const spv_ext_inst_table extInstTable, - spv_position position, spv_diagnostic* pDiagnostic); + spv_position position, + const spvtools::MessageConsumer& consumer); #endif // LIBSPIRV_VALIDATE_H_ diff --git a/source/validate_id.cpp b/source/validate_id.cpp index 1e10545..611d39f 100644 --- a/source/validate_id.cpp +++ b/source/validate_id.cpp @@ -24,6 +24,7 @@ #include "diagnostic.h" #include "instruction.h" +#include "message.h" #include "opcode.h" #include "spirv-tools/libspirv.h" #include "val/Function.h" @@ -48,7 +49,7 @@ class idUsage { const SpvMemoryModel memoryModelArg, const SpvAddressingModel addressingModelArg, const ValidationState_t& module, const vector& entry_points, - spv_position positionArg, spv_diagnostic* pDiagnosticArg) + spv_position positionArg, const spvtools::MessageConsumer& consumer) : opcodeTable(opcodeTableArg), operandTable(operandTableArg), extInstTable(extInstTableArg), @@ -57,7 +58,7 @@ class idUsage { memoryModel(memoryModelArg), addressingModel(addressingModelArg), position(positionArg), - pDiagnostic(pDiagnosticArg), + consumer_(consumer), module_(module), entry_points_(entry_points) {} @@ -75,14 +76,14 @@ class idUsage { const SpvMemoryModel memoryModel; const SpvAddressingModel addressingModel; spv_position position; - spv_diagnostic* pDiagnostic; + const spvtools::MessageConsumer& consumer_; const ValidationState_t& module_; vector entry_points_; }; #define DIAG(INDEX) \ position->index += INDEX; \ - libspirv::DiagnosticStream helper(*position, pDiagnostic, \ + libspirv::DiagnosticStream helper(*position, consumer_, \ SPV_ERROR_INVALID_DIAGNOSTIC); \ helper @@ -2553,11 +2554,10 @@ spv_result_t spvValidateInstructionIDs(const spv_instruction_t* pInsts, const spv_operand_table operandTable, const spv_ext_inst_table extInstTable, const libspirv::ValidationState_t& state, - spv_position position, - spv_diagnostic* pDiag) { + spv_position position) { idUsage idUsage(opcodeTable, operandTable, extInstTable, pInsts, instCount, state.memory_model(), state.addressing_model(), state, - state.entry_points(), position, pDiag); + state.entry_points(), position, state.context()->consumer); for (uint64_t instIndex = 0; instIndex < instCount; ++instIndex) { if (!idUsage.isValid(&pInsts[instIndex])) return SPV_ERROR_INVALID_ID; position->index += pInsts[instIndex].words.size(); diff --git a/test/BinaryParse.cpp b/test/BinaryParse.cpp index 97c2104..48e7edf 100644 --- a/test/BinaryParse.cpp +++ b/test/BinaryParse.cpp @@ -19,6 +19,8 @@ #include "TestFixture.h" #include "UnitSPIRV.h" #include "gmock/gmock.h" +#include "source/message.h" +#include "source/table.h" #include "spirv/1.0/OpenCL.std.h" // Returns true if two spv_parsed_operand_t values are equal. @@ -258,6 +260,112 @@ TEST_F(BinaryParseTest, NullDiagnosticsIsOkForBadParse) { words.size(), invoke_header, invoke_instruction, nullptr)); } +// Make sure that we don't blow up when both the consumer and the diagnostic are +// null. +TEST_F(BinaryParseTest, NullConsumerNullDiagnosticsForBadParse) { + auto words = CompileSuccessfully(""); + + auto ctx = spvContextCreate(SPV_ENV_UNIVERSAL_1_1); + SetContextMessageConsumer(ctx, nullptr); + + words.push_back(0xffffffff); // Certainly invalid instruction header. + EXPECT_HEADER(1).WillOnce(Return(SPV_SUCCESS)); + EXPECT_CALL(client_, Instruction(_)).Times(0); // No instruction callback. + EXPECT_EQ(SPV_ERROR_INVALID_BINARY, + spvBinaryParse(ctx, &client_, words.data(), words.size(), + invoke_header, invoke_instruction, nullptr)); + + spvContextDestroy(ctx); +} + +TEST_F(BinaryParseTest, SpecifyConsumerNullDiagnosticsForGoodParse) { + const auto words = CompileSuccessfully(""); + + auto ctx = spvContextCreate(SPV_ENV_UNIVERSAL_1_1); + int invocation = 0; + SetContextMessageConsumer( + ctx, [&invocation](spvtools::MessageLevel, const char*, + const spv_position_t&, const char*) { ++invocation; }); + + EXPECT_HEADER(1).WillOnce(Return(SPV_SUCCESS)); + EXPECT_CALL(client_, Instruction(_)).Times(0); // No instruction callback. + EXPECT_EQ(SPV_SUCCESS, + spvBinaryParse(ctx, &client_, words.data(), words.size(), + invoke_header, invoke_instruction, nullptr)); + EXPECT_EQ(0, invocation); + + spvContextDestroy(ctx); +} + +TEST_F(BinaryParseTest, SpecifyConsumerNullDiagnosticsForBadParse) { + auto words = CompileSuccessfully(""); + + auto ctx = spvContextCreate(SPV_ENV_UNIVERSAL_1_1); + int invocation = 0; + SetContextMessageConsumer( + ctx, [&invocation](spvtools::MessageLevel level, const char* source, + const spv_position_t& position, const char* message) { + ++invocation; + EXPECT_EQ(spvtools::MessageLevel::Error, level); + EXPECT_STREQ("", source); + EXPECT_EQ(0u, position.line); + EXPECT_EQ(0u, position.column); + EXPECT_EQ(5u, position.index); + EXPECT_STREQ("Invalid opcode: 65535", message); + }); + + words.push_back(0xffffffff); // Certainly invalid instruction header. + EXPECT_HEADER(1).WillOnce(Return(SPV_SUCCESS)); + EXPECT_CALL(client_, Instruction(_)).Times(0); // No instruction callback. + EXPECT_EQ(SPV_ERROR_INVALID_BINARY, + spvBinaryParse(ctx, &client_, words.data(), words.size(), + invoke_header, invoke_instruction, nullptr)); + EXPECT_EQ(1, invocation); + + spvContextDestroy(ctx); +} + +TEST_F(BinaryParseTest, SpecifyConsumerSpecifyDiagnosticsForGoodParse) { + const auto words = CompileSuccessfully(""); + + auto ctx = spvContextCreate(SPV_ENV_UNIVERSAL_1_1); + int invocation = 0; + SetContextMessageConsumer( + ctx, [&invocation](spvtools::MessageLevel, const char*, + const spv_position_t&, const char*) { ++invocation; }); + + EXPECT_HEADER(1).WillOnce(Return(SPV_SUCCESS)); + EXPECT_CALL(client_, Instruction(_)).Times(0); // No instruction callback. + EXPECT_EQ(SPV_SUCCESS, + spvBinaryParse(ctx, &client_, words.data(), words.size(), + invoke_header, invoke_instruction, &diagnostic_)); + EXPECT_EQ(0, invocation); + EXPECT_EQ(nullptr, diagnostic_); + + spvContextDestroy(ctx); +} + +TEST_F(BinaryParseTest, SpecifyConsumerSpecifyDiagnosticsForBadParse) { + auto words = CompileSuccessfully(""); + + auto ctx = spvContextCreate(SPV_ENV_UNIVERSAL_1_1); + int invocation = 0; + SetContextMessageConsumer( + ctx, [&invocation](spvtools::MessageLevel, const char*, + const spv_position_t&, const char*) { ++invocation; }); + + words.push_back(0xffffffff); // Certainly invalid instruction header. + EXPECT_HEADER(1).WillOnce(Return(SPV_SUCCESS)); + EXPECT_CALL(client_, Instruction(_)).Times(0); // No instruction callback. + EXPECT_EQ(SPV_ERROR_INVALID_BINARY, + spvBinaryParse(ctx, &client_, words.data(), words.size(), + invoke_header, invoke_instruction, &diagnostic_)); + EXPECT_EQ(0, invocation); + EXPECT_STREQ("Invalid opcode: 65535", diagnostic_->error); + + spvContextDestroy(ctx); +} + TEST_F(BinaryParseTest, ModuleWithSingleInstructionHasValidHeaderAndInstructionCallback) { for (bool endian_swap : kSwapEndians) { diff --git a/test/BinaryToText.cpp b/test/BinaryToText.cpp index ec4e411..1d28e9d 100644 --- a/test/BinaryToText.cpp +++ b/test/BinaryToText.cpp @@ -149,13 +149,6 @@ TEST_F(BinaryToText, InvalidMagicNumber) { spvDiagnosticDestroy(diagnostic); } -TEST_F(BinaryToText, InvalidDiagnostic) { - spv_text text; - ASSERT_EQ(SPV_ERROR_INVALID_DIAGNOSTIC, - spvBinaryToText(context, binary->code, binary->wordCount, - SPV_BINARY_TO_TEXT_OPTION_NONE, &text, nullptr)); -} - struct FailedDecodeCase { std::string source_text; std::vector appended_instruction; diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index b585a26..65fcfd3 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -130,6 +130,11 @@ add_spvtools_unittest( LIBS ${SPIRV_TOOLS}) add_spvtools_unittest( + TARGET c_interface + SRCS c_interface.cpp + LIBS ${SPIRV_TOOLS}) + +add_spvtools_unittest( TARGET cpp_interface SRCS cpp_interface.cpp LIBS SPIRV-Tools-opt ${SPIRV_TOOLS}) diff --git a/test/TextLiteral.cpp b/test/TextLiteral.cpp index 4870abf..c2f7704 100644 --- a/test/TextLiteral.cpp +++ b/test/TextLiteral.cpp @@ -14,15 +14,15 @@ #include "UnitSPIRV.h" -#include "gmock/gmock.h" #include "TestFixture.h" +#include "gmock/gmock.h" +#include "message.h" #include using ::testing::Eq; namespace { - TEST(TextLiteral, GoodI32) { spv_literal_t l; @@ -119,7 +119,7 @@ INSTANTIATE_TEST_CASE_P( {"\"\xE4\xBA\xB2\"", "\xE4\xBA\xB2"}, {"\"\\\xE4\xBA\xB2\"", "\xE4\xBA\xB2"}, {"\"this \\\" and this \\\\ and \\\xE4\xBA\xB2\"", - "this \" and this \\ and \xE4\xBA\xB2"}}),); + "this \" and this \\ and \xE4\xBA\xB2"}}), ); TEST(TextLiteral, StringTooLong) { spv_literal_t l; @@ -168,31 +168,32 @@ using IntegerTest = std::vector successfulEncode(const TextLiteralCase& test, libspirv::IdTypeClass type) { spv_instruction_t inst; - spv_diagnostic diagnostic; + std::string message; + auto capture_message = [&message](spvtools::MessageLevel, const char*, + const spv_position_t&, + const char* m) { message = m; }; libspirv::IdType expected_type{test.bitwidth, test.is_signed, type}; EXPECT_EQ(SPV_SUCCESS, - libspirv::AssemblyContext(nullptr, &diagnostic) + libspirv::AssemblyContext(nullptr, capture_message) .binaryEncodeNumericLiteral(test.text, SPV_ERROR_INVALID_TEXT, expected_type, &inst)) - << diagnostic->error; + << message; return inst.words; } std::string failedEncode(const TextLiteralCase& test, libspirv::IdTypeClass type) { spv_instruction_t inst; - spv_diagnostic diagnostic; + std::string message; + auto capture_message = [&message](spvtools::MessageLevel, const char*, + const spv_position_t&, + const char* m) { message = m; }; libspirv::IdType expected_type{test.bitwidth, test.is_signed, type}; EXPECT_EQ(SPV_ERROR_INVALID_TEXT, - libspirv::AssemblyContext(nullptr, &diagnostic) + libspirv::AssemblyContext(nullptr, capture_message) .binaryEncodeNumericLiteral(test.text, SPV_ERROR_INVALID_TEXT, expected_type, &inst)); - std::string ret_val; - if (diagnostic) { - ret_val = diagnostic->error; - spvDiagnosticDestroy(diagnostic); - } - return ret_val; + return message; } TEST_P(IntegerTest, IntegerBounds) { diff --git a/test/TextToBinary.cpp b/test/TextToBinary.cpp index 0a1439a..4db3ed2 100644 --- a/test/TextToBinary.cpp +++ b/test/TextToBinary.cpp @@ -126,14 +126,6 @@ TEST_F(TextToBinaryTest, InvalidPointer) { nullptr, &diagnostic)); } -TEST_F(TextToBinaryTest, InvalidDiagnostic) { - SetText( - "OpEntryPoint Kernel 0 \"\"\nOpExecutionMode 0 LocalSizeHint 1 1 1\n"); - ASSERT_EQ(SPV_ERROR_INVALID_DIAGNOSTIC, - spvTextToBinary(ScopedContext().context, text.str, text.length, - &binary, nullptr)); -} - TEST_F(TextToBinaryTest, InvalidPrefix) { EXPECT_EQ( "Expected or at the beginning of an instruction, " diff --git a/test/c_interface.cpp b/test/c_interface.cpp new file mode 100644 index 0000000..96fcb48 --- /dev/null +++ b/test/c_interface.cpp @@ -0,0 +1,278 @@ +// Copyright (c) 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "message.h" +#include "spirv-tools/libspirv.h" +#include "table.h" + +namespace { + +using namespace spvtools; + +// The default consumer is a null std::function. +TEST(CInterface, DefaultConsumerNullDiagnosticForValidInput) { + auto context = spvContextCreate(SPV_ENV_UNIVERSAL_1_1); + const char input_text[] = + "OpCapability Shader\nOpMemoryModel Logical GLSL450"; + + spv_binary binary = nullptr; + EXPECT_EQ(SPV_SUCCESS, spvTextToBinary(context, input_text, + sizeof(input_text), &binary, nullptr)); + + { + // Sadly the compiler don't allow me to feed binary directly to + // spvValidate(). + spv_const_binary_t b{binary->code, binary->wordCount}; + EXPECT_EQ(SPV_SUCCESS, spvValidate(context, &b, nullptr)); + } + + spv_text text = nullptr; + EXPECT_EQ(SPV_SUCCESS, spvBinaryToText(context, binary->code, + binary->wordCount, 0, &text, nullptr)); + + spvTextDestroy(text); + spvBinaryDestroy(binary); + spvContextDestroy(context); +} + +// The default consumer is a null std::function. +TEST(CInterface, DefaultConsumerNullDiagnosticForInvalidAssembling) { + auto context = spvContextCreate(SPV_ENV_UNIVERSAL_1_1); + const char input_text[] = "%1 = OpName"; + + spv_binary binary = nullptr; + EXPECT_EQ(SPV_ERROR_INVALID_TEXT, + spvTextToBinary(context, input_text, sizeof(input_text), &binary, + nullptr)); + spvBinaryDestroy(binary); + spvContextDestroy(context); +} + +// The default consumer is a null std::function. +TEST(CInterface, DefaultConsumerNullDiagnosticForInvalidDiassembling) { + auto context = spvContextCreate(SPV_ENV_UNIVERSAL_1_1); + const char input_text[] = "OpNop"; + + spv_binary binary = nullptr; + ASSERT_EQ(SPV_SUCCESS, spvTextToBinary(context, input_text, + sizeof(input_text), &binary, nullptr)); + // Change OpNop to an invalid (wordcount|opcode) word. + binary->code[binary->wordCount - 1] = 0xffffffff; + + spv_text text = nullptr; + EXPECT_EQ(SPV_ERROR_INVALID_BINARY, + spvBinaryToText(context, binary->code, binary->wordCount, 0, &text, + nullptr)); + + spvTextDestroy(text); + spvBinaryDestroy(binary); + spvContextDestroy(context); +} + +// The default consumer is a null std::function. +TEST(CInterface, DefaultConsumerNullDiagnosticForInvalidValidating) { + auto context = spvContextCreate(SPV_ENV_UNIVERSAL_1_1); + const char input_text[] = "OpNop"; + + spv_binary binary = nullptr; + ASSERT_EQ(SPV_SUCCESS, spvTextToBinary(context, input_text, + sizeof(input_text), &binary, nullptr)); + + spv_const_binary_t b{binary->code, binary->wordCount}; + EXPECT_EQ(SPV_ERROR_INVALID_LAYOUT, spvValidate(context, &b, nullptr)); + + spvBinaryDestroy(binary); + spvContextDestroy(context); +} + +TEST(CInterface, SpecifyConsumerNullDiagnosticForAssembling) { + const char input_text[] = "%1 = OpName\n"; + + auto context = spvContextCreate(SPV_ENV_UNIVERSAL_1_1); + int invocation = 0; + // TODO(antiagainst): Use public C API for setting the consumer once exists. + SetContextMessageConsumer( + context, + [&invocation](MessageLevel level, const char* source, + const spv_position_t& position, const char* message) { + ++invocation; + EXPECT_EQ(MessageLevel::Error, level); + // The error happens at scanning the begining of second line. + EXPECT_STREQ("", source); + EXPECT_EQ(1u, position.line); + EXPECT_EQ(0u, position.column); + EXPECT_EQ(12u, position.index); + EXPECT_STREQ("Expected operand, found end of stream.", message); + }); + + spv_binary binary = nullptr; + EXPECT_EQ(SPV_ERROR_INVALID_TEXT, + spvTextToBinary(context, input_text, sizeof(input_text), &binary, + nullptr)); + EXPECT_EQ(1, invocation); + spvBinaryDestroy(binary); + spvContextDestroy(context); +} + +TEST(CInterface, SpecifyConsumerNullDiagnosticForDisassembling) { + const char input_text[] = "OpNop"; + + auto context = spvContextCreate(SPV_ENV_UNIVERSAL_1_1); + int invocation = 0; + SetContextMessageConsumer( + context, + [&invocation](MessageLevel level, const char* source, + const spv_position_t& position, const char* message) { + ++invocation; + EXPECT_EQ(MessageLevel::Error, level); + EXPECT_STREQ("", source); + EXPECT_EQ(0u, position.line); + EXPECT_EQ(0u, position.column); + EXPECT_EQ(5u, position.index); + EXPECT_STREQ("Invalid opcode: 65535", message); + }); + + spv_binary binary = nullptr; + ASSERT_EQ(SPV_SUCCESS, spvTextToBinary(context, input_text, + sizeof(input_text), &binary, nullptr)); + // Change OpNop to an invalid (wordcount|opcode) word. + binary->code[binary->wordCount - 1] = 0xffffffff; + + spv_text text = nullptr; + EXPECT_EQ(SPV_ERROR_INVALID_BINARY, + spvBinaryToText(context, binary->code, binary->wordCount, 0, &text, + nullptr)); + EXPECT_EQ(1, invocation); + + spvTextDestroy(text); + spvBinaryDestroy(binary); + spvContextDestroy(context); +} + +TEST(CInterface, SpecifyConsumerNullDiagnosticForValidating) { + const char input_text[] = "OpNop"; + + auto context = spvContextCreate(SPV_ENV_UNIVERSAL_1_1); + int invocation = 0; + SetContextMessageConsumer( + context, + [&invocation](MessageLevel level, const char* source, + const spv_position_t& position, const char* message) { + ++invocation; + EXPECT_EQ(MessageLevel::Error, level); + EXPECT_STREQ("", source); + EXPECT_EQ(0u, position.line); + EXPECT_EQ(0u, position.column); + // TODO(antiagainst): what validation reports is not a word offset here. + // It is inconsistent with diassembler. Should be fixed. + EXPECT_EQ(1u, position.index); + EXPECT_STREQ("Nop cannot appear before the memory model instruction", + message); + }); + + spv_binary binary = nullptr; + ASSERT_EQ(SPV_SUCCESS, spvTextToBinary(context, input_text, + sizeof(input_text), &binary, nullptr)); + + spv_const_binary_t b{binary->code, binary->wordCount}; + EXPECT_EQ(SPV_ERROR_INVALID_LAYOUT, spvValidate(context, &b, nullptr)); + EXPECT_EQ(1, invocation); + + spvBinaryDestroy(binary); + spvContextDestroy(context); +} + +// When having both a consumer and an diagnostic object, the diagnostic object +// should take priority. +TEST(CInterface, SpecifyConsumerSpecifyDiagnosticForAssembling) { + const char input_text[] = "%1 = OpName"; + + auto context = spvContextCreate(SPV_ENV_UNIVERSAL_1_1); + int invocation = 0; + SetContextMessageConsumer( + context, [&invocation](MessageLevel, const char*, const spv_position_t&, + const char*) { ++invocation; }); + + spv_binary binary = nullptr; + spv_diagnostic diagnostic = nullptr; + EXPECT_EQ(SPV_ERROR_INVALID_TEXT, + spvTextToBinary(context, input_text, sizeof(input_text), &binary, + &diagnostic)); + EXPECT_EQ(0, invocation); // Consumer should not be invoked at all. + EXPECT_STREQ("Expected operand, found end of stream.", diagnostic->error); + + spvDiagnosticDestroy(diagnostic); + spvBinaryDestroy(binary); + spvContextDestroy(context); +} + +TEST(CInterface, SpecifyConsumerSpecifyDiagnosticForDisassembling) { + const char input_text[] = "OpNop"; + + auto context = spvContextCreate(SPV_ENV_UNIVERSAL_1_1); + int invocation = 0; + SetContextMessageConsumer( + context, [&invocation](MessageLevel, const char*, const spv_position_t&, + const char*) { ++invocation; }); + + spv_binary binary = nullptr; + ASSERT_EQ(SPV_SUCCESS, spvTextToBinary(context, input_text, + sizeof(input_text), &binary, nullptr)); + // Change OpNop to an invalid (wordcount|opcode) word. + binary->code[binary->wordCount - 1] = 0xffffffff; + + spv_diagnostic diagnostic = nullptr; + spv_text text = nullptr; + EXPECT_EQ(SPV_ERROR_INVALID_BINARY, + spvBinaryToText(context, binary->code, binary->wordCount, 0, &text, + &diagnostic)); + + EXPECT_EQ(0, invocation); // Consumer should not be invoked at all. + EXPECT_STREQ("Invalid opcode: 65535", diagnostic->error); + + spvTextDestroy(text); + spvDiagnosticDestroy(diagnostic); + spvBinaryDestroy(binary); + spvContextDestroy(context); +} + +TEST(CInterface, SpecifyConsumerSpecifyDiagnosticForValidating) { + const char input_text[] = "OpNop"; + + auto context = spvContextCreate(SPV_ENV_UNIVERSAL_1_1); + int invocation = 0; + SetContextMessageConsumer( + context, [&invocation](MessageLevel, const char*, const spv_position_t&, + const char*) { ++invocation; }); + + spv_binary binary = nullptr; + ASSERT_EQ(SPV_SUCCESS, spvTextToBinary(context, input_text, + sizeof(input_text), &binary, nullptr)); + + spv_diagnostic diagnostic = nullptr; + spv_const_binary_t b{binary->code, binary->wordCount}; + EXPECT_EQ(SPV_ERROR_INVALID_LAYOUT, spvValidate(context, &b, &diagnostic)); + + EXPECT_EQ(0, invocation); // Consumer should not be invoked at all. + EXPECT_STREQ("Nop cannot appear before the memory model instruction", + diagnostic->error); + + spvDiagnosticDestroy(diagnostic); + spvBinaryDestroy(binary); + spvContextDestroy(context); +} + +} // anonymous namespace diff --git a/test/diagnostic.cpp b/test/diagnostic.cpp index fa35502..feb5403 100644 --- a/test/diagnostic.cpp +++ b/test/diagnostic.cpp @@ -63,16 +63,16 @@ TEST(Diagnostic, PrintInvalidDiagnostic) { TEST(DiagnosticStream, ConversionToResultType) { // Check after the DiagnosticStream object is destroyed. spv_result_t value; - { value = DiagnosticStream({}, 0, SPV_ERROR_INVALID_TEXT); } + { value = DiagnosticStream({}, nullptr, SPV_ERROR_INVALID_TEXT); } EXPECT_EQ(SPV_ERROR_INVALID_TEXT, value); // Check implicit conversion via plain assignment. - value = DiagnosticStream({}, 0, SPV_SUCCESS); + value = DiagnosticStream({}, nullptr, SPV_SUCCESS); EXPECT_EQ(SPV_SUCCESS, value); // Check conversion via constructor. EXPECT_EQ(SPV_FAILED_MATCH, - spv_result_t(DiagnosticStream({}, 0, SPV_FAILED_MATCH))); + spv_result_t(DiagnosticStream({}, nullptr, SPV_FAILED_MATCH))); } } // anonymous namespace diff --git a/test/val/ValidationState.cpp b/test/val/ValidationState.cpp index e75f8cd..23df7cc 100644 --- a/test/val/ValidationState.cpp +++ b/test/val/ValidationState.cpp @@ -36,11 +36,9 @@ using std::vector; class ValidationStateTest : public testing::Test { public: ValidationStateTest() - : context_(spvContextCreate(SPV_ENV_UNIVERSAL_1_0)), - state_(&diag_, context_) {} + : context_(spvContextCreate(SPV_ENV_UNIVERSAL_1_0)), state_(context_) {} protected: - spv_diagnostic diag_; spv_context context_; ValidationState_t state_; };