Add a callback mechanism for communicating messages to callers.
authorLei Zhang <antiagainst@google.com>
Fri, 2 Sep 2016 22:06:18 +0000 (18:06 -0400)
committerLei Zhang <antiagainst@google.com>
Thu, 15 Sep 2016 16:35:48 +0000 (12:35 -0400)
Every time an event happens in the library that the user should be
aware of, the callback will be invoked.

The existing diagnostic mechanism is hijacked internally by a
callback that creates an diagnostic object each time an event
happens.

25 files changed:
include/spirv-tools/libspirv.h
source/binary.cpp
source/diagnostic.cpp
source/diagnostic.h
source/disassemble.cpp
source/message.h [new file with mode: 0644]
source/opt/libspirv.cpp
source/opt/libspirv.hpp
source/table.cpp
source/table.h
source/text.cpp
source/text_handler.h
source/val/ValidationState.cpp
source/val/ValidationState.h
source/validate.cpp
source/validate.h
source/validate_id.cpp
test/BinaryParse.cpp
test/BinaryToText.cpp
test/CMakeLists.txt
test/TextLiteral.cpp
test/TextToBinary.cpp
test/c_interface.cpp [new file with mode: 0644]
test/diagnostic.cpp
test/val/ValidationState.cpp

index 42ccb2a..c9812ec 100644 (file)
@@ -340,14 +340,14 @@ typedef enum {
   SPV_ENV_UNIVERSAL_1_0,  // SPIR-V 1.0 latest revision, no other restrictions.
   SPV_ENV_VULKAN_1_0,     // Vulkan 1.0 latest revision.
   SPV_ENV_UNIVERSAL_1_1,  // SPIR-V 1.1 latest revision, no other restrictions.
-  SPV_ENV_OPENCL_2_1, // OpenCL 2.1 latest revision.
-  SPV_ENV_OPENCL_2_2, // OpenCL 2.2 latest revision.
-  SPV_ENV_OPENGL_4_0, // OpenGL 4.0 plus GL_ARB_gl_spirv, latest revisions.
-  SPV_ENV_OPENGL_4_1, // OpenGL 4.1 plus GL_ARB_gl_spirv, latest revisions.
-  SPV_ENV_OPENGL_4_2, // OpenGL 4.2 plus GL_ARB_gl_spirv, latest revisions.
-  SPV_ENV_OPENGL_4_3, // OpenGL 4.3 plus GL_ARB_gl_spirv, latest revisions.
+  SPV_ENV_OPENCL_2_1,     // OpenCL 2.1 latest revision.
+  SPV_ENV_OPENCL_2_2,     // OpenCL 2.2 latest revision.
+  SPV_ENV_OPENGL_4_0,     // OpenGL 4.0 plus GL_ARB_gl_spirv, latest revisions.
+  SPV_ENV_OPENGL_4_1,     // OpenGL 4.1 plus GL_ARB_gl_spirv, latest revisions.
+  SPV_ENV_OPENGL_4_2,     // OpenGL 4.2 plus GL_ARB_gl_spirv, latest revisions.
+  SPV_ENV_OPENGL_4_3,     // OpenGL 4.3 plus GL_ARB_gl_spirv, latest revisions.
   // There is no variant for OpenGL 4.4.
-  SPV_ENV_OPENGL_4_5, // OpenGL 4.5 plus GL_ARB_gl_spirv, latest revisions.
+  SPV_ENV_OPENGL_4_5,  // OpenGL 4.5 plus GL_ARB_gl_spirv, latest revisions.
 } spv_target_env;
 
 // Returns a string describing the given SPIR-V target environment.
@@ -361,8 +361,9 @@ void spvContextDestroy(spv_context context);
 
 // Encodes the given SPIR-V assembly text to its binary representation. The
 // length parameter specifies the number of bytes for text. Encoded binary will
-// be stored into *binary. Any error will be written into *diagnostic. The
-// generated binary is independent of the context and may outlive it.
+// be stored into *binary. Any error will be written into *diagnostic if
+// diagnostic is non-null. The generated binary is independent of the context
+// and may outlive it.
 spv_result_t spvTextToBinary(const spv_const_context context, const char* text,
                              const size_t length, spv_binary* binary,
                              spv_diagnostic* diagnostic);
@@ -374,7 +375,8 @@ void spvTextDestroy(spv_text text);
 // Decodes the given SPIR-V binary representation to its assembly text. The
 // word_count parameter specifies the number of words for binary. The options
 // parameter is a bit field of spv_binary_to_text_options_t. Decoded text will
-// be stored into *text. Any error will be written into *diagnostic.
+// be stored into *text. Any error will be written into *diagnostic if
+// diagnostic is non-null.
 spv_result_t spvBinaryToText(const spv_const_context context,
                              const uint32_t* binary, const size_t word_count,
                              const uint32_t options, spv_text* text,
@@ -385,7 +387,7 @@ spv_result_t spvBinaryToText(const spv_const_context context,
 void spvBinaryDestroy(spv_binary binary);
 
 // Validates a SPIR-V binary for correctness. Any errors will be written into
-// *diagnostic.
+// *diagnostic if diagnostic is non-null.
 spv_result_t spvValidate(const spv_const_context context,
                          const spv_const_binary binary,
                          spv_diagnostic* diagnostic);
index a709c02..c8b7c38 100644 (file)
@@ -61,6 +61,7 @@ class Parser {
          spv_parsed_header_fn_t parsed_header_fn,
          spv_parsed_instruction_fn_t parsed_instruction_fn)
       : grammar_(context),
+        consumer_(context->consumer),
         user_data_(user_data),
         parsed_header_fn_(parsed_header_fn),
         parsed_instruction_fn_(parsed_instruction_fn) {}
@@ -120,8 +121,7 @@ class Parser {
   // returned object will be propagated to the current parse's diagnostic
   // object.
   libspirv::DiagnosticStream diagnostic(spv_result_t error) {
-    return libspirv::DiagnosticStream({0, 0, _.word_index}, _.diagnostic,
-                                      error);
+    return libspirv::DiagnosticStream({0, 0, _.word_index}, consumer_, error);
   }
 
   // Returns a diagnostic stream object with the default parse error code.
@@ -156,6 +156,7 @@ class Parser {
   // Data members
 
   const libspirv::AssemblyGrammar grammar_;        // SPIR-V syntax utility.
+  const spvtools::MessageConsumer& consumer_;      // Message consumer callback.
   void* const user_data_;                          // Context for the callbacks
   const spv_parsed_header_fn_t parsed_header_fn_;  // Parsed header callback
   const spv_parsed_instruction_fn_t
@@ -752,7 +753,12 @@ spv_result_t spvBinaryParse(const spv_const_context context, void* user_data,
                             spv_parsed_header_fn_t parsed_header,
                             spv_parsed_instruction_fn_t parsed_instruction,
                             spv_diagnostic* diagnostic) {
-  Parser parser(context, user_data, parsed_header, parsed_instruction);
+  spv_context_t hijack_context = *context;
+  if (diagnostic) {
+    *diagnostic = nullptr;
+    libspirv::UseDiagnosticAsMessageConsumer(&hijack_context, diagnostic);
+  }
+  Parser parser(&hijack_context, user_data, parsed_header, parsed_instruction);
   return parser.parse(code, num_words, diagnostic);
 }
 
index b1a9cac..8f0e348 100644 (file)
@@ -20,6 +20,7 @@
 #include <iostream>
 
 #include "spirv-tools/libspirv.h"
+#include "table.h"
 
 // Diagnostic API
 
@@ -68,12 +69,47 @@ spv_result_t spvDiagnosticPrint(const spv_diagnostic diagnostic) {
 namespace libspirv {
 
 DiagnosticStream::~DiagnosticStream() {
-  if (pDiagnostic_ && error_ != SPV_FAILED_MATCH) {
-    *pDiagnostic_ = spvDiagnosticCreate(&position_, stream_.str().c_str());
+  using spvtools::MessageLevel;
+  if (error_ != SPV_FAILED_MATCH && consumer_ != nullptr) {
+    auto level = MessageLevel::Error;
+    switch (error_) {
+      case SPV_SUCCESS:
+      case SPV_REQUESTED_TERMINATION:  // Essentially success.
+        level = MessageLevel::Info;
+        break;
+      case SPV_WARNING:
+        level = MessageLevel::Warning;
+        break;
+      case SPV_UNSUPPORTED:
+      case SPV_ERROR_INTERNAL:
+      case SPV_ERROR_INVALID_TABLE:
+        level = MessageLevel::InternalError;
+        break;
+      case SPV_ERROR_OUT_OF_MEMORY:
+        level = MessageLevel::Fatal;
+        break;
+      default:
+        break;
+    }
+    consumer_(level, "", position_, stream_.str().c_str());
   }
 }
-std::string
-spvResultToString(spv_result_t res) {
+
+void UseDiagnosticAsMessageConsumer(spv_context context,
+                                    spv_diagnostic* diagnostic) {
+  assert(diagnostic && *diagnostic == nullptr);
+
+  auto create_diagnostic = [diagnostic](spvtools::MessageLevel, const char*,
+                                        const spv_position_t& position,
+                                        const char* message) {
+    auto p = position;
+    spvDiagnosticDestroy(*diagnostic);  // Avoid memory leak.
+    *diagnostic = spvDiagnosticCreate(&p, message);
+  };
+  SetContextMessageConsumer(context, std::move(create_diagnostic));
+}
+
+std::string spvResultToString(spv_result_t res) {
   std::string out;
   switch (res) {
     case SPV_SUCCESS:
index 840ec9f..9bf9ae2 100644 (file)
@@ -19,6 +19,7 @@
 #include <sstream>
 #include <utility>
 
+#include "message.h"
 #include "spirv-tools/libspirv.h"
 
 namespace libspirv {
@@ -29,20 +30,16 @@ namespace libspirv {
 // emitted during the destructor.
 class DiagnosticStream {
  public:
-  DiagnosticStream(spv_position_t position, spv_diagnostic* pDiagnostic,
+  DiagnosticStream(spv_position_t position,
+                   const spvtools::MessageConsumer& consumer,
                    spv_result_t error)
-      : position_(position), pDiagnostic_(pDiagnostic), error_(error) {}
+      : position_(position), consumer_(consumer), error_(error) {}
 
   DiagnosticStream(DiagnosticStream&& other)
       : stream_(other.stream_.str()),
         position_(other.position_),
-        pDiagnostic_(other.pDiagnostic_),
-        error_(other.error_) {
-    // The other object's destructor will emit the text in its stream_
-    // member if its pDiagnostic_ member is non-null.  Prevent that,
-    // since emitting that text is now the responsibility of *this.
-    other.pDiagnostic_ = nullptr;
-  }
+        consumer_(other.consumer_),
+        error_(other.error_) {}
 
   ~DiagnosticStream();
 
@@ -59,10 +56,18 @@ class DiagnosticStream {
  private:
   std::stringstream stream_;
   spv_position_t position_;
-  spv_diagnostic* pDiagnostic_;
+  const spvtools::MessageConsumer& consumer_;  // Message consumer callback.
   spv_result_t error_;
 };
 
+// Changes the MessageConsumer in |context| to one that updates |diagnostic|
+// with the last message received.
+//
+// This function expects that |diagnostic| is not nullptr and its content is a
+// nullptr.
+void UseDiagnosticAsMessageConsumer(spv_context context,
+                                    spv_diagnostic* diagnostic);
+
 std::string spvResultToString(spv_result_t res);
 
 }  // namespace libspirv
index e71581d..267ed17 100644 (file)
@@ -393,10 +393,13 @@ spv_result_t spvBinaryToText(const spv_const_context context,
                              const uint32_t* code, const size_t wordCount,
                              const uint32_t options, spv_text* pText,
                              spv_diagnostic* pDiagnostic) {
-  // Invalid arguments return error codes, but don't necessarily generate
-  // diagnostics.  These are programmer errors, not user errors.
-  if (!pDiagnostic) return SPV_ERROR_INVALID_DIAGNOSTIC;
-  const libspirv::AssemblyGrammar grammar(context);
+  spv_context_t hijack_context = *context;
+  if (pDiagnostic) {
+    *pDiagnostic = nullptr;
+    libspirv::UseDiagnosticAsMessageConsumer(&hijack_context, pDiagnostic);
+  }
+
+  const libspirv::AssemblyGrammar grammar(&hijack_context);
   if (!grammar.isValid()) return SPV_ERROR_INVALID_TABLE;
 
   // Generate friendly names for Ids if requested.
@@ -404,15 +407,15 @@ spv_result_t spvBinaryToText(const spv_const_context context,
   libspirv::NameMapper name_mapper = libspirv::GetTrivialNameMapper();
   if (options & SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES) {
     friendly_mapper.reset(
-        new libspirv::FriendlyNameMapper(context, code, wordCount));
+        new libspirv::FriendlyNameMapper(&hijack_context, code, wordCount));
     name_mapper = friendly_mapper->GetNameMapper();
   }
 
   // Now disassemble!
   Disassembler disassembler(grammar, options, name_mapper);
-  if (auto error = spvBinaryParse(context, &disassembler, code, wordCount,
-                                  DisassembleHeader, DisassembleInstruction,
-                                  pDiagnostic)) {
+  if (auto error = spvBinaryParse(&hijack_context, &disassembler, code,
+                                  wordCount, DisassembleHeader,
+                                  DisassembleInstruction, pDiagnostic)) {
     return error;
   }
 
diff --git a/source/message.h b/source/message.h
new file mode 100644 (file)
index 0000000..68e4cd7
--- /dev/null
@@ -0,0 +1,47 @@
+// Copyright (c) 2016 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SPIRV_TOOLS_MESSAGE_H_
+#define SPIRV_TOOLS_MESSAGE_H_
+
+#include <functional>
+
+#include "spirv-tools/libspirv.h"
+
+namespace spvtools {
+
+// TODO(antiagainst): This eventually should be in the C++ interface.
+
+// Severity levels of messages communicated to the consumer.
+enum class MessageLevel {
+  Fatal,  // Unrecoverable error due to environment. Will abort the program
+          // immediately. E.g., out of memory.
+  InternalError,  // Unrecoverable error due to SPIRV-Tools internals. Will
+                  // abort the program immediately. E.g., unimplemented feature.
+  Error,          // Normal error due to user input.
+  Warning,        // Warning information.
+  Info,           // General information.
+  Debug,          // Debug information.
+};
+
+// Message consumer. The C strings for source and message are only alive for the
+// specific invocation.
+using MessageConsumer = std::function<void(
+    MessageLevel /* level */, const char* /* source */,
+    const spv_position_t& /* position */, const char* /* message */
+    )>;
+
+}  // namespace spvtools
+
+#endif  // SPIRV_TOOLS_MESSAGE_H_
index 0636c64..eabc260 100644 (file)
@@ -16,6 +16,7 @@
 
 #include "ir_loader.h"
 #include "make_unique.h"
+#include "table.h"
 
 namespace spvtools {
 
@@ -39,6 +40,10 @@ spv_result_t SetSpvInst(void* builder, const spv_parsed_instruction_t* inst) {
 
 }  // annoymous namespace
 
+void SpvTools::SetMessageConsumer(MessageConsumer consumer) {
+  SetContextMessageConsumer(context_, std::move(consumer));
+}
+
 spv_result_t SpvTools::Assemble(const std::string& text,
                                 std::vector<uint32_t>* binary) {
   spv_binary spvbinary = nullptr;
index 46d7318..e645fcc 100644 (file)
@@ -19,6 +19,7 @@
 #include <string>
 #include <vector>
 
+#include "message.h"
 #include "module.h"
 #include "spirv-tools/libspirv.h"
 
@@ -36,7 +37,8 @@ class SpvTools {
 
   ~SpvTools() { spvContextDestroy(context_); }
 
-  // TODO(antiagainst): handle error message in the following APIs.
+  // Sets the message consumer to the given |consumer|.
+  void SetMessageConsumer(MessageConsumer consumer);
 
   // Assembles the given assembly |text| and writes the result to |binary|.
   // Returns SPV_SUCCESS on successful assembling.
index 6bdbd9b..24ab520 100644 (file)
@@ -41,7 +41,13 @@ spv_context spvContextCreate(spv_target_env env) {
   spvOperandTableGet(&operand_table, env);
   spvExtInstTableGet(&ext_inst_table, env);
 
-  return new spv_context_t{env, opcode_table, operand_table, ext_inst_table};
+  return new spv_context_t{env, opcode_table, operand_table, ext_inst_table,
+                           nullptr /* a null default consumer */};
 }
 
 void spvContextDestroy(spv_context context) { delete context; }
+
+void SetContextMessageConsumer(spv_context context,
+                               spvtools::MessageConsumer consumer) {
+  context->consumer = std::move(consumer);
+}
index 8eedd0a..abce443 100644 (file)
@@ -18,6 +18,7 @@
 #include "spirv/1.1/spirv.h"
 
 #include "enum_set.h"
+#include "message.h"
 #include "spirv-tools/libspirv.h"
 
 typedef struct spv_opcode_desc_t {
@@ -87,8 +88,14 @@ struct spv_context_t {
   const spv_opcode_table opcode_table;
   const spv_operand_table operand_table;
   const spv_ext_inst_table ext_inst_table;
+  spvtools::MessageConsumer consumer;
 };
 
+// Sets the message consumer to |consumer| in the given |context|. The original
+// message consumer will be overwritten.
+void SetContextMessageConsumer(spv_context context,
+                               spvtools::MessageConsumer consumer);
+
 // Populates *table with entries for env.
 spv_result_t spvOpcodeTableGet(spv_opcode_table* table, spv_target_env env);
 
index d317264..6e68ac2 100644 (file)
@@ -31,6 +31,7 @@
 #include "diagnostic.h"
 #include "ext_inst.h"
 #include "instruction.h"
+#include "message.h"
 #include "opcode.h"
 #include "operand.h"
 #include "spirv-tools/libspirv.h"
@@ -662,10 +663,9 @@ spv_result_t SetHeader(spv_target_env env, const uint32_t bound,
 // If a diagnostic is generated, it is not yet marked as being
 // for a text-based input.
 spv_result_t spvTextToBinaryInternal(const libspirv::AssemblyGrammar& grammar,
-                                     const spv_text text, spv_binary* pBinary,
-                                     spv_diagnostic* pDiagnostic) {
-  if (!pDiagnostic) return SPV_ERROR_INVALID_DIAGNOSTIC;
-  libspirv::AssemblyContext context(text, pDiagnostic);
+                                     const spvtools::MessageConsumer& consumer,
+                                     const spv_text text, spv_binary* pBinary) {
+  libspirv::AssemblyContext context(text, consumer);
   if (!text->str) return context.diagnostic() << "Missing assembly text.";
 
   if (!grammar.isValid()) {
@@ -673,9 +673,6 @@ spv_result_t spvTextToBinaryInternal(const libspirv::AssemblyGrammar& grammar,
   }
   if (!pBinary) return SPV_ERROR_INVALID_POINTER;
 
-  // NOTE: Ensure diagnostic is zero initialised
-  *pDiagnostic = {};
-
   std::vector<spv_instruction_t> instructions;
 
   // Skip past whitespace and comments.
@@ -728,11 +725,17 @@ spv_result_t spvTextToBinary(const spv_const_context context,
                              const char* input_text,
                              const size_t input_text_size, spv_binary* pBinary,
                              spv_diagnostic* pDiagnostic) {
+  spv_context_t hijack_context = *context;
+  if (pDiagnostic) {
+    *pDiagnostic = nullptr;
+    libspirv::UseDiagnosticAsMessageConsumer(&hijack_context, pDiagnostic);
+  }
+
   spv_text_t text = {input_text, input_text_size};
-  libspirv::AssemblyGrammar grammar(context);
+  libspirv::AssemblyGrammar grammar(&hijack_context);
 
   spv_result_t result =
-      spvTextToBinaryInternal(grammar, &text, pBinary, pDiagnostic);
+      spvTextToBinaryInternal(grammar, hijack_context.consumer, &text, pBinary);
   if (pDiagnostic && *pDiagnostic) (*pDiagnostic)->isTextSource = true;
 
   return result;
index 9951643..1bd004c 100644 (file)
@@ -22,6 +22,7 @@
 
 #include "diagnostic.h"
 #include "instruction.h"
+#include "message.h"
 #include "spirv-tools/libspirv.h"
 #include "text.h"
 
@@ -116,11 +117,8 @@ class ClampToZeroIfUnsignedType<
 // Encapsulates the data used during the assembly of a SPIR-V module.
 class AssemblyContext {
  public:
-  AssemblyContext(spv_text text, spv_diagnostic* diagnostic_arg)
-      : current_position_({}),
-        pDiagnostic_(diagnostic_arg),
-        text_(text),
-        bound_(1) {}
+  AssemblyContext(spv_text text, const spvtools::MessageConsumer& consumer)
+      : current_position_({}), consumer_(consumer), text_(text), bound_(1) {}
 
   // Assigns a new integer value to the given text ID, or returns the previously
   // assigned integer value if the ID has been seen before.
@@ -148,7 +146,7 @@ class AssemblyContext {
   // stream, and for the given error code. Any data written to this object will
   // show up in pDiagnsotic on destruction.
   DiagnosticStream diagnostic(spv_result_t error) {
-    return DiagnosticStream(current_position_, pDiagnostic_, error);
+    return DiagnosticStream(current_position_, consumer_, error);
   }
 
   // Returns a diagnostic object with the default assembly error code.
@@ -227,7 +225,6 @@ class AssemblyContext {
   spv_ext_inst_type_t getExtInstTypeForId(uint32_t id) const;
 
  private:
-
   // Maps ID names to their corresponding numerical ids.
   using spv_named_id_table = std::unordered_map<std::string, uint32_t>;
   // Maps type-defining IDs to their IdType.
@@ -241,7 +238,7 @@ class AssemblyContext {
   // Maps an extended instruction import Id to the extended instruction type.
   std::unordered_map<uint32_t, spv_ext_inst_type_t> import_id_to_ext_inst_type_;
   spv_position_t current_position_;
-  spv_diagnostic* pDiagnostic_;
+  spvtools::MessageConsumer consumer_;
   spv_text text_;
   uint32_t bound_;
 };
index c8735dc..1368416 100644 (file)
@@ -182,9 +182,8 @@ bool IsInstructionInLayoutSection(ModuleLayoutSection layout, SpvOp op) {
 
 }  // anonymous namespace
 
-ValidationState_t::ValidationState_t(spv_diagnostic* diagnostic,
-                                     const spv_const_context context)
-    : diagnostic_(diagnostic),
+ValidationState_t::ValidationState_t(const spv_const_context ctx)
+    : context_(ctx),
       instruction_counter_(0),
       unresolved_forward_ids_{},
       operand_names_{},
@@ -193,7 +192,7 @@ ValidationState_t::ValidationState_t(spv_diagnostic* diagnostic,
       module_capabilities_(),
       ordered_instructions_(),
       all_definitions_(),
-      grammar_(context),
+      grammar_(ctx),
       addressing_model_(SpvAddressingModelLogical),
       memory_model_(SpvMemoryModelSimple),
       in_function_(false) {}
@@ -290,7 +289,7 @@ bool ValidationState_t::IsOpcodeInCurrentLayoutSection(SpvOp op) {
 
 DiagnosticStream ValidationState_t::diag(spv_result_t error_code) const {
   return libspirv::DiagnosticStream(
-      {0, 0, static_cast<size_t>(instruction_counter_)}, diagnostic_,
+      {0, 0, static_cast<size_t>(instruction_counter_)}, context_->consumer,
       error_code);
 }
 
@@ -377,8 +376,8 @@ spv_result_t ValidationState_t::RegisterFunctionEnd() {
 void ValidationState_t::RegisterInstruction(
     const spv_parsed_instruction_t& inst) {
   if (in_function_body()) {
-    ordered_instructions_.emplace_back(
-        &inst, &current_function(), current_function().current_block());
+    ordered_instructions_.emplace_back(&inst, &current_function(),
+                                       current_function().current_block());
   } else {
     ordered_instructions_.emplace_back(&inst, nullptr, nullptr);
   }
index 1f5c001..ff23d05 100644 (file)
@@ -53,8 +53,10 @@ enum ModuleLayoutSection {
 /// This class manages the state of the SPIR-V validation as it is being parsed.
 class ValidationState_t {
  public:
-  ValidationState_t(spv_diagnostic* diagnostic,
-                    const spv_const_context context);
+  ValidationState_t(const spv_const_context context);
+
+  /// Returns the context
+  spv_const_context context() const { return context_; }
 
   /// Forward declares the id in the module
   spv_result_t ForwardDeclareId(uint32_t id);
@@ -174,7 +176,8 @@ class ValidationState_t {
  private:
   ValidationState_t(const ValidationState_t&);
 
-  spv_diagnostic* diagnostic_;
+  const spv_const_context context_;
+
   /// Tracks the number of instructions evaluated by the validator
   int instruction_counter_;
 
@@ -191,7 +194,8 @@ class ValidationState_t {
   std::deque<Function> module_functions_;
 
   /// The capabilities available in the module
-  libspirv::CapabilitySet module_capabilities_;  /// Module's declared capabilities.
+  libspirv::CapabilitySet
+      module_capabilities_;  /// Module's declared capabilities.
 
   /// List of all instructions in the order they appear in the binary
   std::deque<Instruction> ordered_instructions_;
index 1b50d9a..b0699cc 100644 (file)
@@ -50,15 +50,17 @@ using libspirv::ModuleLayoutPass;
 using libspirv::IdPass;
 using libspirv::ValidationState_t;
 
-spv_result_t spvValidateIDs(
-    const spv_instruction_t* pInsts, const uint64_t count,
-    const spv_opcode_table opcodeTable, const spv_operand_table operandTable,
-    const spv_ext_inst_table extInstTable, const ValidationState_t& state,
-    spv_position position, spv_diagnostic* pDiagnostic) {
+spv_result_t spvValidateIDs(const spv_instruction_t* pInsts,
+                            const uint64_t count,
+                            const spv_opcode_table opcodeTable,
+                            const spv_operand_table operandTable,
+                            const spv_ext_inst_table extInstTable,
+                            const ValidationState_t& state,
+                            spv_position position) {
   position->index = SPV_INDEX_INSTRUCTION;
   if (auto error =
           spvValidateInstructionIDs(pInsts, count, opcodeTable, operandTable,
-                                    extInstTable, state, position, pDiagnostic))
+                                    extInstTable, state, position))
     return error;
   return SPV_SUCCESS;
 }
@@ -175,29 +177,33 @@ UNUSED(void PrintDotGraph(ValidationState_t& _, libspirv::Function func)) {
 spv_result_t spvValidate(const spv_const_context context,
                          const spv_const_binary binary,
                          spv_diagnostic* pDiagnostic) {
-  if (!pDiagnostic) return SPV_ERROR_INVALID_DIAGNOSTIC;
+  spv_context_t hijack_context = *context;
+  if (pDiagnostic) {
+    *pDiagnostic = nullptr;
+    libspirv::UseDiagnosticAsMessageConsumer(&hijack_context, pDiagnostic);
+  }
 
   spv_endianness_t endian;
   spv_position_t position = {};
   if (spvBinaryEndianness(binary, &endian)) {
-    return libspirv::DiagnosticStream(position, pDiagnostic,
+    return libspirv::DiagnosticStream(position, hijack_context.consumer,
                                       SPV_ERROR_INVALID_BINARY)
            << "Invalid SPIR-V magic number.";
   }
 
   spv_header_t header;
   if (spvBinaryHeaderGet(binary, endian, &header)) {
-    return libspirv::DiagnosticStream(position, pDiagnostic,
+    return libspirv::DiagnosticStream(position, hijack_context.consumer,
                                       SPV_ERROR_INVALID_BINARY)
            << "Invalid SPIR-V header.";
   }
 
   // NOTE: Parse the module and perform inline validation checks. These
   // checks do not require the the knowledge of the whole module.
-  ValidationState_t vstate(pDiagnostic, context);
-  if (auto error =
-          spvBinaryParse(context, &vstate, binary->code, binary->wordCount,
-                         setHeader, ProcessInstruction, pDiagnostic))
+  ValidationState_t vstate(&hijack_context);
+  if (auto error = spvBinaryParse(&hijack_context, &vstate, binary->code,
+                                  binary->wordCount, setHeader,
+                                  ProcessInstruction, pDiagnostic))
     return error;
 
   if (vstate.in_function_body())
@@ -243,7 +249,7 @@ spv_result_t spvValidate(const spv_const_context context,
 
   position.index = SPV_INDEX_INSTRUCTION;
   return spvValidateIDs(instructions.data(), instructions.size(),
-                        context->opcode_table, context->operand_table,
-                        context->ext_inst_table, vstate, &position,
-                        pDiagnostic);
+                        hijack_context.opcode_table,
+                        hijack_context.operand_table,
+                        hijack_context.ext_inst_table, vstate, &position);
 }
index 2ef341a..a19e911 100644 (file)
@@ -20,6 +20,7 @@
 #include <vector>
 
 #include "instruction.h"
+#include "message.h"
 #include "spirv-tools/libspirv.h"
 #include "table.h"
 
@@ -154,7 +155,6 @@ spv_result_t InstructionPass(ValidationState_t& _,
 /// @param[in] operandTable table of specified operands
 /// @param[in] usedefs use-def info from module parsing
 /// @param[in,out] position current position in the stream
-/// @param[out] pDiag contains diagnostic on failure
 ///
 /// @return result code
 spv_result_t spvValidateInstructionIDs(const spv_instruction_t* pInsts,
@@ -163,8 +163,7 @@ spv_result_t spvValidateInstructionIDs(const spv_instruction_t* pInsts,
                                        const spv_operand_table operandTable,
                                        const spv_ext_inst_table extInstTable,
                                        const libspirv::ValidationState_t& state,
-                                       spv_position position,
-                                       spv_diagnostic* pDiag);
+                                       spv_position position);
 
 /// @brief Validate the ID's within a SPIR-V binary
 ///
@@ -174,7 +173,7 @@ spv_result_t spvValidateInstructionIDs(const spv_instruction_t* pInsts,
 /// @param[in] opcodeTable table of specified Opcodes
 /// @param[in] operandTable table of specified operands
 /// @param[in,out] position current word in the binary
-/// @param[out] pDiagnostic contains diagnostic on failure
+/// @param[in] consumer message consumer callback
 ///
 /// @return result code
 spv_result_t spvValidateIDs(const spv_instruction_t* pInstructions,
@@ -182,6 +181,7 @@ spv_result_t spvValidateIDs(const spv_instruction_t* pInstructions,
                             const spv_opcode_table opcodeTable,
                             const spv_operand_table operandTable,
                             const spv_ext_inst_table extInstTable,
-                            spv_position position, spv_diagnostic* pDiagnostic);
+                            spv_position position,
+                            const spvtools::MessageConsumer& consumer);
 
 #endif  // LIBSPIRV_VALIDATE_H_
index 1e10545..611d39f 100644 (file)
@@ -24,6 +24,7 @@
 
 #include "diagnostic.h"
 #include "instruction.h"
+#include "message.h"
 #include "opcode.h"
 #include "spirv-tools/libspirv.h"
 #include "val/Function.h"
@@ -48,7 +49,7 @@ class idUsage {
           const SpvMemoryModel memoryModelArg,
           const SpvAddressingModel addressingModelArg,
           const ValidationState_t& module, const vector<uint32_t>& entry_points,
-          spv_position positionArg, spv_diagnostic* pDiagnosticArg)
+          spv_position positionArg, const spvtools::MessageConsumer& consumer)
       : opcodeTable(opcodeTableArg),
         operandTable(operandTableArg),
         extInstTable(extInstTableArg),
@@ -57,7 +58,7 @@ class idUsage {
         memoryModel(memoryModelArg),
         addressingModel(addressingModelArg),
         position(positionArg),
-        pDiagnostic(pDiagnosticArg),
+        consumer_(consumer),
         module_(module),
         entry_points_(entry_points) {}
 
@@ -75,14 +76,14 @@ class idUsage {
   const SpvMemoryModel memoryModel;
   const SpvAddressingModel addressingModel;
   spv_position position;
-  spv_diagnostic* pDiagnostic;
+  const spvtools::MessageConsumer& consumer_;
   const ValidationState_t& module_;
   vector<uint32_t> entry_points_;
 };
 
 #define DIAG(INDEX)                                                \
   position->index += INDEX;                                        \
-  libspirv::DiagnosticStream helper(*position, pDiagnostic,        \
+  libspirv::DiagnosticStream helper(*position, consumer_,          \
                                     SPV_ERROR_INVALID_DIAGNOSTIC); \
   helper
 
@@ -2553,11 +2554,10 @@ spv_result_t spvValidateInstructionIDs(const spv_instruction_t* pInsts,
                                        const spv_operand_table operandTable,
                                        const spv_ext_inst_table extInstTable,
                                        const libspirv::ValidationState_t& state,
-                                       spv_position position,
-                                       spv_diagnostic* pDiag) {
+                                       spv_position position) {
   idUsage idUsage(opcodeTable, operandTable, extInstTable, pInsts, instCount,
                   state.memory_model(), state.addressing_model(), state,
-                  state.entry_points(), position, pDiag);
+                  state.entry_points(), position, state.context()->consumer);
   for (uint64_t instIndex = 0; instIndex < instCount; ++instIndex) {
     if (!idUsage.isValid(&pInsts[instIndex])) return SPV_ERROR_INVALID_ID;
     position->index += pInsts[instIndex].words.size();
index 97c2104..48e7edf 100644 (file)
@@ -19,6 +19,8 @@
 #include "TestFixture.h"
 #include "UnitSPIRV.h"
 #include "gmock/gmock.h"
+#include "source/message.h"
+#include "source/table.h"
 #include "spirv/1.0/OpenCL.std.h"
 
 // Returns true if two spv_parsed_operand_t values are equal.
@@ -258,6 +260,112 @@ TEST_F(BinaryParseTest, NullDiagnosticsIsOkForBadParse) {
                      words.size(), invoke_header, invoke_instruction, nullptr));
 }
 
+// Make sure that we don't blow up when both the consumer and the diagnostic are
+// null.
+TEST_F(BinaryParseTest, NullConsumerNullDiagnosticsForBadParse) {
+  auto words = CompileSuccessfully("");
+
+  auto ctx = spvContextCreate(SPV_ENV_UNIVERSAL_1_1);
+  SetContextMessageConsumer(ctx, nullptr);
+
+  words.push_back(0xffffffff);  // Certainly invalid instruction header.
+  EXPECT_HEADER(1).WillOnce(Return(SPV_SUCCESS));
+  EXPECT_CALL(client_, Instruction(_)).Times(0);  // No instruction callback.
+  EXPECT_EQ(SPV_ERROR_INVALID_BINARY,
+            spvBinaryParse(ctx, &client_, words.data(), words.size(),
+                           invoke_header, invoke_instruction, nullptr));
+
+  spvContextDestroy(ctx);
+}
+
+TEST_F(BinaryParseTest, SpecifyConsumerNullDiagnosticsForGoodParse) {
+  const auto words = CompileSuccessfully("");
+
+  auto ctx = spvContextCreate(SPV_ENV_UNIVERSAL_1_1);
+  int invocation = 0;
+  SetContextMessageConsumer(
+      ctx, [&invocation](spvtools::MessageLevel, const char*,
+                         const spv_position_t&, const char*) { ++invocation; });
+
+  EXPECT_HEADER(1).WillOnce(Return(SPV_SUCCESS));
+  EXPECT_CALL(client_, Instruction(_)).Times(0);  // No instruction callback.
+  EXPECT_EQ(SPV_SUCCESS,
+            spvBinaryParse(ctx, &client_, words.data(), words.size(),
+                           invoke_header, invoke_instruction, nullptr));
+  EXPECT_EQ(0, invocation);
+
+  spvContextDestroy(ctx);
+}
+
+TEST_F(BinaryParseTest, SpecifyConsumerNullDiagnosticsForBadParse) {
+  auto words = CompileSuccessfully("");
+
+  auto ctx = spvContextCreate(SPV_ENV_UNIVERSAL_1_1);
+  int invocation = 0;
+  SetContextMessageConsumer(
+      ctx, [&invocation](spvtools::MessageLevel level, const char* source,
+                         const spv_position_t& position, const char* message) {
+        ++invocation;
+        EXPECT_EQ(spvtools::MessageLevel::Error, level);
+        EXPECT_STREQ("", source);
+        EXPECT_EQ(0u, position.line);
+        EXPECT_EQ(0u, position.column);
+        EXPECT_EQ(5u, position.index);
+        EXPECT_STREQ("Invalid opcode: 65535", message);
+      });
+
+  words.push_back(0xffffffff);  // Certainly invalid instruction header.
+  EXPECT_HEADER(1).WillOnce(Return(SPV_SUCCESS));
+  EXPECT_CALL(client_, Instruction(_)).Times(0);  // No instruction callback.
+  EXPECT_EQ(SPV_ERROR_INVALID_BINARY,
+            spvBinaryParse(ctx, &client_, words.data(), words.size(),
+                           invoke_header, invoke_instruction, nullptr));
+  EXPECT_EQ(1, invocation);
+
+  spvContextDestroy(ctx);
+}
+
+TEST_F(BinaryParseTest, SpecifyConsumerSpecifyDiagnosticsForGoodParse) {
+  const auto words = CompileSuccessfully("");
+
+  auto ctx = spvContextCreate(SPV_ENV_UNIVERSAL_1_1);
+  int invocation = 0;
+  SetContextMessageConsumer(
+      ctx, [&invocation](spvtools::MessageLevel, const char*,
+                         const spv_position_t&, const char*) { ++invocation; });
+
+  EXPECT_HEADER(1).WillOnce(Return(SPV_SUCCESS));
+  EXPECT_CALL(client_, Instruction(_)).Times(0);  // No instruction callback.
+  EXPECT_EQ(SPV_SUCCESS,
+            spvBinaryParse(ctx, &client_, words.data(), words.size(),
+                           invoke_header, invoke_instruction, &diagnostic_));
+  EXPECT_EQ(0, invocation);
+  EXPECT_EQ(nullptr, diagnostic_);
+
+  spvContextDestroy(ctx);
+}
+
+TEST_F(BinaryParseTest, SpecifyConsumerSpecifyDiagnosticsForBadParse) {
+  auto words = CompileSuccessfully("");
+
+  auto ctx = spvContextCreate(SPV_ENV_UNIVERSAL_1_1);
+  int invocation = 0;
+  SetContextMessageConsumer(
+      ctx, [&invocation](spvtools::MessageLevel, const char*,
+                         const spv_position_t&, const char*) { ++invocation; });
+
+  words.push_back(0xffffffff);  // Certainly invalid instruction header.
+  EXPECT_HEADER(1).WillOnce(Return(SPV_SUCCESS));
+  EXPECT_CALL(client_, Instruction(_)).Times(0);  // No instruction callback.
+  EXPECT_EQ(SPV_ERROR_INVALID_BINARY,
+            spvBinaryParse(ctx, &client_, words.data(), words.size(),
+                           invoke_header, invoke_instruction, &diagnostic_));
+  EXPECT_EQ(0, invocation);
+  EXPECT_STREQ("Invalid opcode: 65535", diagnostic_->error);
+
+  spvContextDestroy(ctx);
+}
+
 TEST_F(BinaryParseTest,
        ModuleWithSingleInstructionHasValidHeaderAndInstructionCallback) {
   for (bool endian_swap : kSwapEndians) {
index ec4e411..1d28e9d 100644 (file)
@@ -149,13 +149,6 @@ TEST_F(BinaryToText, InvalidMagicNumber) {
   spvDiagnosticDestroy(diagnostic);
 }
 
-TEST_F(BinaryToText, InvalidDiagnostic) {
-  spv_text text;
-  ASSERT_EQ(SPV_ERROR_INVALID_DIAGNOSTIC,
-            spvBinaryToText(context, binary->code, binary->wordCount,
-                            SPV_BINARY_TO_TEXT_OPTION_NONE, &text, nullptr));
-}
-
 struct FailedDecodeCase {
   std::string source_text;
   std::vector<uint32_t> appended_instruction;
index b585a26..65fcfd3 100644 (file)
@@ -130,6 +130,11 @@ add_spvtools_unittest(
   LIBS ${SPIRV_TOOLS})
 
 add_spvtools_unittest(
+  TARGET c_interface
+  SRCS c_interface.cpp
+  LIBS ${SPIRV_TOOLS})
+
+add_spvtools_unittest(
   TARGET cpp_interface
   SRCS cpp_interface.cpp
   LIBS SPIRV-Tools-opt ${SPIRV_TOOLS})
index 4870abf..c2f7704 100644 (file)
 
 #include "UnitSPIRV.h"
 
-#include "gmock/gmock.h"
 #include "TestFixture.h"
+#include "gmock/gmock.h"
+#include "message.h"
 
 #include <string>
 
 using ::testing::Eq;
 namespace {
 
-
 TEST(TextLiteral, GoodI32) {
   spv_literal_t l;
 
@@ -119,7 +119,7 @@ INSTANTIATE_TEST_CASE_P(
         {"\"\xE4\xBA\xB2\"", "\xE4\xBA\xB2"},
         {"\"\\\xE4\xBA\xB2\"", "\xE4\xBA\xB2"},
         {"\"this \\\" and this \\\\ and \\\xE4\xBA\xB2\"",
-         "this \" and this \\ and \xE4\xBA\xB2"}}),);
+         "this \" and this \\ and \xE4\xBA\xB2"}}), );
 
 TEST(TextLiteral, StringTooLong) {
   spv_literal_t l;
@@ -168,31 +168,32 @@ using IntegerTest =
 std::vector<uint32_t> successfulEncode(const TextLiteralCase& test,
                                        libspirv::IdTypeClass type) {
   spv_instruction_t inst;
-  spv_diagnostic diagnostic;
+  std::string message;
+  auto capture_message = [&message](spvtools::MessageLevel, const char*,
+                                    const spv_position_t&,
+                                    const char* m) { message = m; };
   libspirv::IdType expected_type{test.bitwidth, test.is_signed, type};
   EXPECT_EQ(SPV_SUCCESS,
-            libspirv::AssemblyContext(nullptr, &diagnostic)
+            libspirv::AssemblyContext(nullptr, capture_message)
                 .binaryEncodeNumericLiteral(test.text, SPV_ERROR_INVALID_TEXT,
                                             expected_type, &inst))
-      << diagnostic->error;
+      << message;
   return inst.words;
 }
 
 std::string failedEncode(const TextLiteralCase& test,
                          libspirv::IdTypeClass type) {
   spv_instruction_t inst;
-  spv_diagnostic diagnostic;
+  std::string message;
+  auto capture_message = [&message](spvtools::MessageLevel, const char*,
+                                    const spv_position_t&,
+                                    const char* m) { message = m; };
   libspirv::IdType expected_type{test.bitwidth, test.is_signed, type};
   EXPECT_EQ(SPV_ERROR_INVALID_TEXT,
-            libspirv::AssemblyContext(nullptr, &diagnostic)
+            libspirv::AssemblyContext(nullptr, capture_message)
                 .binaryEncodeNumericLiteral(test.text, SPV_ERROR_INVALID_TEXT,
                                             expected_type, &inst));
-  std::string ret_val;
-  if (diagnostic) {
-    ret_val = diagnostic->error;
-    spvDiagnosticDestroy(diagnostic);
-  }
-  return ret_val;
+  return message;
 }
 
 TEST_P(IntegerTest, IntegerBounds) {
index 0a1439a..4db3ed2 100644 (file)
@@ -126,14 +126,6 @@ TEST_F(TextToBinaryTest, InvalidPointer) {
                             nullptr, &diagnostic));
 }
 
-TEST_F(TextToBinaryTest, InvalidDiagnostic) {
-  SetText(
-      "OpEntryPoint Kernel 0 \"\"\nOpExecutionMode 0 LocalSizeHint 1 1 1\n");
-  ASSERT_EQ(SPV_ERROR_INVALID_DIAGNOSTIC,
-            spvTextToBinary(ScopedContext().context, text.str, text.length,
-                            &binary, nullptr));
-}
-
 TEST_F(TextToBinaryTest, InvalidPrefix) {
   EXPECT_EQ(
       "Expected <opcode> or <result-id> at the beginning of an instruction, "
diff --git a/test/c_interface.cpp b/test/c_interface.cpp
new file mode 100644 (file)
index 0000000..96fcb48
--- /dev/null
@@ -0,0 +1,278 @@
+// Copyright (c) 2016 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <gtest/gtest.h>
+
+#include "message.h"
+#include "spirv-tools/libspirv.h"
+#include "table.h"
+
+namespace {
+
+using namespace spvtools;
+
+// The default consumer is a null std::function.
+TEST(CInterface, DefaultConsumerNullDiagnosticForValidInput) {
+  auto context = spvContextCreate(SPV_ENV_UNIVERSAL_1_1);
+  const char input_text[] =
+      "OpCapability Shader\nOpMemoryModel Logical GLSL450";
+
+  spv_binary binary = nullptr;
+  EXPECT_EQ(SPV_SUCCESS, spvTextToBinary(context, input_text,
+                                         sizeof(input_text), &binary, nullptr));
+
+  {
+    // Sadly the compiler don't allow me to feed binary directly to
+    // spvValidate().
+    spv_const_binary_t b{binary->code, binary->wordCount};
+    EXPECT_EQ(SPV_SUCCESS, spvValidate(context, &b, nullptr));
+  }
+
+  spv_text text = nullptr;
+  EXPECT_EQ(SPV_SUCCESS, spvBinaryToText(context, binary->code,
+                                         binary->wordCount, 0, &text, nullptr));
+
+  spvTextDestroy(text);
+  spvBinaryDestroy(binary);
+  spvContextDestroy(context);
+}
+
+// The default consumer is a null std::function.
+TEST(CInterface, DefaultConsumerNullDiagnosticForInvalidAssembling) {
+  auto context = spvContextCreate(SPV_ENV_UNIVERSAL_1_1);
+  const char input_text[] = "%1 = OpName";
+
+  spv_binary binary = nullptr;
+  EXPECT_EQ(SPV_ERROR_INVALID_TEXT,
+            spvTextToBinary(context, input_text, sizeof(input_text), &binary,
+                            nullptr));
+  spvBinaryDestroy(binary);
+  spvContextDestroy(context);
+}
+
+// The default consumer is a null std::function.
+TEST(CInterface, DefaultConsumerNullDiagnosticForInvalidDiassembling) {
+  auto context = spvContextCreate(SPV_ENV_UNIVERSAL_1_1);
+  const char input_text[] = "OpNop";
+
+  spv_binary binary = nullptr;
+  ASSERT_EQ(SPV_SUCCESS, spvTextToBinary(context, input_text,
+                                         sizeof(input_text), &binary, nullptr));
+  // Change OpNop to an invalid (wordcount|opcode) word.
+  binary->code[binary->wordCount - 1] = 0xffffffff;
+
+  spv_text text = nullptr;
+  EXPECT_EQ(SPV_ERROR_INVALID_BINARY,
+            spvBinaryToText(context, binary->code, binary->wordCount, 0, &text,
+                            nullptr));
+
+  spvTextDestroy(text);
+  spvBinaryDestroy(binary);
+  spvContextDestroy(context);
+}
+
+// The default consumer is a null std::function.
+TEST(CInterface, DefaultConsumerNullDiagnosticForInvalidValidating) {
+  auto context = spvContextCreate(SPV_ENV_UNIVERSAL_1_1);
+  const char input_text[] = "OpNop";
+
+  spv_binary binary = nullptr;
+  ASSERT_EQ(SPV_SUCCESS, spvTextToBinary(context, input_text,
+                                         sizeof(input_text), &binary, nullptr));
+
+  spv_const_binary_t b{binary->code, binary->wordCount};
+  EXPECT_EQ(SPV_ERROR_INVALID_LAYOUT, spvValidate(context, &b, nullptr));
+
+  spvBinaryDestroy(binary);
+  spvContextDestroy(context);
+}
+
+TEST(CInterface, SpecifyConsumerNullDiagnosticForAssembling) {
+  const char input_text[] = "%1 = OpName\n";
+
+  auto context = spvContextCreate(SPV_ENV_UNIVERSAL_1_1);
+  int invocation = 0;
+  // TODO(antiagainst): Use public C API for setting the consumer once exists.
+  SetContextMessageConsumer(
+      context,
+      [&invocation](MessageLevel level, const char* source,
+                    const spv_position_t& position, const char* message) {
+        ++invocation;
+        EXPECT_EQ(MessageLevel::Error, level);
+        // The error happens at scanning the begining of second line.
+        EXPECT_STREQ("", source);
+        EXPECT_EQ(1u, position.line);
+        EXPECT_EQ(0u, position.column);
+        EXPECT_EQ(12u, position.index);
+        EXPECT_STREQ("Expected operand, found end of stream.", message);
+      });
+
+  spv_binary binary = nullptr;
+  EXPECT_EQ(SPV_ERROR_INVALID_TEXT,
+            spvTextToBinary(context, input_text, sizeof(input_text), &binary,
+                            nullptr));
+  EXPECT_EQ(1, invocation);
+  spvBinaryDestroy(binary);
+  spvContextDestroy(context);
+}
+
+TEST(CInterface, SpecifyConsumerNullDiagnosticForDisassembling) {
+  const char input_text[] = "OpNop";
+
+  auto context = spvContextCreate(SPV_ENV_UNIVERSAL_1_1);
+  int invocation = 0;
+  SetContextMessageConsumer(
+      context,
+      [&invocation](MessageLevel level, const char* source,
+                    const spv_position_t& position, const char* message) {
+        ++invocation;
+        EXPECT_EQ(MessageLevel::Error, level);
+        EXPECT_STREQ("", source);
+        EXPECT_EQ(0u, position.line);
+        EXPECT_EQ(0u, position.column);
+        EXPECT_EQ(5u, position.index);
+        EXPECT_STREQ("Invalid opcode: 65535", message);
+      });
+
+  spv_binary binary = nullptr;
+  ASSERT_EQ(SPV_SUCCESS, spvTextToBinary(context, input_text,
+                                         sizeof(input_text), &binary, nullptr));
+  // Change OpNop to an invalid (wordcount|opcode) word.
+  binary->code[binary->wordCount - 1] = 0xffffffff;
+
+  spv_text text = nullptr;
+  EXPECT_EQ(SPV_ERROR_INVALID_BINARY,
+            spvBinaryToText(context, binary->code, binary->wordCount, 0, &text,
+                            nullptr));
+  EXPECT_EQ(1, invocation);
+
+  spvTextDestroy(text);
+  spvBinaryDestroy(binary);
+  spvContextDestroy(context);
+}
+
+TEST(CInterface, SpecifyConsumerNullDiagnosticForValidating) {
+  const char input_text[] = "OpNop";
+
+  auto context = spvContextCreate(SPV_ENV_UNIVERSAL_1_1);
+  int invocation = 0;
+  SetContextMessageConsumer(
+      context,
+      [&invocation](MessageLevel level, const char* source,
+                    const spv_position_t& position, const char* message) {
+        ++invocation;
+        EXPECT_EQ(MessageLevel::Error, level);
+        EXPECT_STREQ("", source);
+        EXPECT_EQ(0u, position.line);
+        EXPECT_EQ(0u, position.column);
+        // TODO(antiagainst): what validation reports is not a word offset here.
+        // It is inconsistent with diassembler. Should be fixed.
+        EXPECT_EQ(1u, position.index);
+        EXPECT_STREQ("Nop cannot appear before the memory model instruction",
+                     message);
+      });
+
+  spv_binary binary = nullptr;
+  ASSERT_EQ(SPV_SUCCESS, spvTextToBinary(context, input_text,
+                                         sizeof(input_text), &binary, nullptr));
+
+  spv_const_binary_t b{binary->code, binary->wordCount};
+  EXPECT_EQ(SPV_ERROR_INVALID_LAYOUT, spvValidate(context, &b, nullptr));
+  EXPECT_EQ(1, invocation);
+
+  spvBinaryDestroy(binary);
+  spvContextDestroy(context);
+}
+
+// When having both a consumer and an diagnostic object, the diagnostic object
+// should take priority.
+TEST(CInterface, SpecifyConsumerSpecifyDiagnosticForAssembling) {
+  const char input_text[] = "%1 = OpName";
+
+  auto context = spvContextCreate(SPV_ENV_UNIVERSAL_1_1);
+  int invocation = 0;
+  SetContextMessageConsumer(
+      context, [&invocation](MessageLevel, const char*, const spv_position_t&,
+                             const char*) { ++invocation; });
+
+  spv_binary binary = nullptr;
+  spv_diagnostic diagnostic = nullptr;
+  EXPECT_EQ(SPV_ERROR_INVALID_TEXT,
+            spvTextToBinary(context, input_text, sizeof(input_text), &binary,
+                            &diagnostic));
+  EXPECT_EQ(0, invocation);  // Consumer should not be invoked at all.
+  EXPECT_STREQ("Expected operand, found end of stream.", diagnostic->error);
+
+  spvDiagnosticDestroy(diagnostic);
+  spvBinaryDestroy(binary);
+  spvContextDestroy(context);
+}
+
+TEST(CInterface, SpecifyConsumerSpecifyDiagnosticForDisassembling) {
+  const char input_text[] = "OpNop";
+
+  auto context = spvContextCreate(SPV_ENV_UNIVERSAL_1_1);
+  int invocation = 0;
+  SetContextMessageConsumer(
+      context, [&invocation](MessageLevel, const char*, const spv_position_t&,
+                             const char*) { ++invocation; });
+
+  spv_binary binary = nullptr;
+  ASSERT_EQ(SPV_SUCCESS, spvTextToBinary(context, input_text,
+                                         sizeof(input_text), &binary, nullptr));
+  // Change OpNop to an invalid (wordcount|opcode) word.
+  binary->code[binary->wordCount - 1] = 0xffffffff;
+
+  spv_diagnostic diagnostic = nullptr;
+  spv_text text = nullptr;
+  EXPECT_EQ(SPV_ERROR_INVALID_BINARY,
+            spvBinaryToText(context, binary->code, binary->wordCount, 0, &text,
+                            &diagnostic));
+
+  EXPECT_EQ(0, invocation);  // Consumer should not be invoked at all.
+  EXPECT_STREQ("Invalid opcode: 65535", diagnostic->error);
+
+  spvTextDestroy(text);
+  spvDiagnosticDestroy(diagnostic);
+  spvBinaryDestroy(binary);
+  spvContextDestroy(context);
+}
+
+TEST(CInterface, SpecifyConsumerSpecifyDiagnosticForValidating) {
+  const char input_text[] = "OpNop";
+
+  auto context = spvContextCreate(SPV_ENV_UNIVERSAL_1_1);
+  int invocation = 0;
+  SetContextMessageConsumer(
+      context, [&invocation](MessageLevel, const char*, const spv_position_t&,
+                             const char*) { ++invocation; });
+
+  spv_binary binary = nullptr;
+  ASSERT_EQ(SPV_SUCCESS, spvTextToBinary(context, input_text,
+                                         sizeof(input_text), &binary, nullptr));
+
+  spv_diagnostic diagnostic = nullptr;
+  spv_const_binary_t b{binary->code, binary->wordCount};
+  EXPECT_EQ(SPV_ERROR_INVALID_LAYOUT, spvValidate(context, &b, &diagnostic));
+
+  EXPECT_EQ(0, invocation);  // Consumer should not be invoked at all.
+  EXPECT_STREQ("Nop cannot appear before the memory model instruction",
+               diagnostic->error);
+
+  spvDiagnosticDestroy(diagnostic);
+  spvBinaryDestroy(binary);
+  spvContextDestroy(context);
+}
+
+}  // anonymous namespace
index fa35502..feb5403 100644 (file)
@@ -63,16 +63,16 @@ TEST(Diagnostic, PrintInvalidDiagnostic) {
 TEST(DiagnosticStream, ConversionToResultType) {
   // Check after the DiagnosticStream object is destroyed.
   spv_result_t value;
-  { value = DiagnosticStream({}, 0, SPV_ERROR_INVALID_TEXT); }
+  { value = DiagnosticStream({}, nullptr, SPV_ERROR_INVALID_TEXT); }
   EXPECT_EQ(SPV_ERROR_INVALID_TEXT, value);
 
   // Check implicit conversion via plain assignment.
-  value = DiagnosticStream({}, 0, SPV_SUCCESS);
+  value = DiagnosticStream({}, nullptr, SPV_SUCCESS);
   EXPECT_EQ(SPV_SUCCESS, value);
 
   // Check conversion via constructor.
   EXPECT_EQ(SPV_FAILED_MATCH,
-            spv_result_t(DiagnosticStream({}, 0, SPV_FAILED_MATCH)));
+            spv_result_t(DiagnosticStream({}, nullptr, SPV_FAILED_MATCH)));
 }
 
 }  // anonymous namespace
index e75f8cd..23df7cc 100644 (file)
@@ -36,11 +36,9 @@ using std::vector;
 class ValidationStateTest : public testing::Test {
  public:
   ValidationStateTest()
-      : context_(spvContextCreate(SPV_ENV_UNIVERSAL_1_0)),
-        state_(&diag_, context_) {}
+      : context_(spvContextCreate(SPV_ENV_UNIVERSAL_1_0)), state_(context_) {}
 
  protected:
-  spv_diagnostic diag_;
   spv_context context_;
   ValidationState_t state_;
 };