//
#include "ie_onnx_reader.hpp"
+#include "onnx_model_validator.hpp"
#include <ie_api.h>
#include <onnx_import/onnx.hpp>
using namespace InferenceEngine;
-bool ONNXReader::supportModel(std::istream& model) const {
- model.seekg(0, model.beg);
- const int header_size = 128;
- std::string header(header_size, ' ');
- model.read(&header[0], header_size);
- model.seekg(0, model.beg);
- // find 'onnx' substring in the .onnx files
- // find 'ir_version' and 'graph' for prototxt
- // return (header.find("onnx") != std::string::npos) || (header.find("pytorch") != std::string::npos) ||
- // (header.find("ir_version") != std::string::npos && header.find("graph") != std::string::npos);
- return !((header.find("<net ") != std::string::npos) || (header.find("<Net ") != std::string::npos));
-}
-
namespace {
std::string readPathFromStream(std::istream& stream) {
if (stream.pword(0) == nullptr) {
}
}
+bool ONNXReader::supportModel(std::istream& model) const {
+ model.seekg(0, model.beg);
+
+ const auto model_path = readPathFromStream(model);
+
+ // this might mean that the model is loaded from a string in memory
+ // let's try to figure out if it's any of the supported formats
+ if (model_path.empty()) {
+ if (!is_valid_model(model, onnx_format{})) {
+ model.seekg(0, model.beg);
+ return is_valid_model(model, prototxt_format{});
+ } else {
+ return true;
+ }
+ }
+
+ if (model_path.find(".prototxt", 0) != std::string::npos) {
+ return is_valid_model(model, prototxt_format{});
+ } else {
+ return is_valid_model(model, onnx_format{});
+ }
+}
+
CNNNetwork ONNXReader::read(std::istream& model, const std::vector<IExtensionPtr>& exts) const {
return CNNNetwork(ngraph::onnx_import::import_onnx_model(model, readPathFromStream(model)));
}
--- /dev/null
+// Copyright (C) 2020 Intel Corporation
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#include "onnx_model_validator.hpp"
+
+#include <algorithm>
+#include <array>
+#include <exception>
+#include <map>
+#include <vector>
+#include <iostream>
+namespace detail {
+namespace onnx {
+ enum Field {
+ IR_VERSION = 1,
+ PRODUCER_NAME = 2,
+ PRODUCER_VERSION = 3,
+ DOMAIN_ = 4, // DOMAIN collides with some existing symbol in MSVC thus - underscore
+ MODEL_VERSION = 5,
+ DOC_STRING = 6,
+ GRAPH = 7,
+ OPSET_IMPORT = 8,
+ METADATA_PROPS = 14,
+ TRAINING_INFO = 20
+ };
+
+ enum WireType {
+ VARINT = 0,
+ BITS_64 = 1,
+ LENGTH_DELIMITED = 2,
+ START_GROUP = 3,
+ END_GROUP = 4,
+ BITS_32 = 5
+ };
+
+ // A PB key consists of a field number (defined in onnx.proto) and a type of data that follows this key
+ using PbKey = std::pair<char, char>;
+
+ // This pair represents a key found in the encoded model and optional size of the payload
+ // that follows the key (in bytes). They payload should be skipped for the fast check purposes.
+ using ONNXField = std::pair<Field, uint32_t>;
+
+ bool is_correct_onnx_field(const PbKey& decoded_key) {
+ static const std::map<Field, WireType> onnx_fields = {
+ {IR_VERSION, VARINT},
+ {PRODUCER_NAME, LENGTH_DELIMITED},
+ {PRODUCER_VERSION, LENGTH_DELIMITED},
+ {DOMAIN_, LENGTH_DELIMITED},
+ {MODEL_VERSION, VARINT},
+ {DOC_STRING, LENGTH_DELIMITED},
+ {GRAPH, LENGTH_DELIMITED},
+ {OPSET_IMPORT, LENGTH_DELIMITED},
+ {METADATA_PROPS, LENGTH_DELIMITED},
+ {TRAINING_INFO, LENGTH_DELIMITED},
+ };
+
+ if (!onnx_fields.count(static_cast<Field>(decoded_key.first))) {
+ return false;
+ }
+
+ return onnx_fields.at(static_cast<Field>(decoded_key.first)) == static_cast<WireType>(decoded_key.second);
+ }
+
+ /**
+ * Only 7 bits in each component of a varint count in this algorithm. The components form
+ * a decoded number when they are concatenated bitwise in a reverse order. For example:
+ * bytes = [b1, b2, b3, b4]
+ * varint = b4 ++ b3 ++ b2 ++ b1 <== only 7 bits of each byte should be extracted before concat
+ *
+ * b1 b2
+ * bytes = [00101100, 00000010]
+ * b2 b1
+ * varint = 0000010 ++ 0101100 = 100101100 => decimal: 300
+ * Each consecutive varint byte needs to be left shifted "7 x its position in the vector"
+ * and bitwise added to the accumulator afterwards.
+ */
+ uint32_t varint_bytes_to_number(const std::vector<char>& bytes) {
+ uint32_t accumulator = 0u;
+
+ for (size_t i = 0; i < bytes.size(); ++i) {
+ uint32_t b = bytes[i];
+ b <<= 7 * i;
+ accumulator |= b;
+ }
+
+ return accumulator;
+ }
+
+ uint32_t decode_varint(std::istream& model) {
+ std::vector<char> bytes;
+ bytes.reserve(4);
+
+ char key_component = 0;
+ model.get(key_component);
+
+ // keep reading all bytes from the stream which have the MSB on
+ while (key_component & 0x80) {
+ // drop the most significant bit
+ const char component = key_component & ~0x80;
+ bytes.push_back(component);
+ model.get(key_component);
+ }
+ // add the last byte - the one with MSB off
+ bytes.push_back(key_component);
+
+ return varint_bytes_to_number(bytes);
+ }
+
+ PbKey decode_key(const char key) {
+ // 3 least significant bits
+ const char wire_type = key & 0b111;
+ // remaining bits
+ const char field_number = key >> 3;
+ return {field_number, wire_type};
+ }
+
+ ONNXField decode_next_field(std::istream& model) {
+ char key = 0;
+ model.get(key);
+
+ const auto decoded_key = decode_key(key);
+
+ if (!is_correct_onnx_field(decoded_key)) {
+ throw std::runtime_error{"Incorrect field detected in the processed model"};
+ }
+
+ const auto onnx_field = static_cast<Field>(decoded_key.first);
+
+ switch (decoded_key.second) {
+ case VARINT: {
+ // the decoded varint is the payload in this case but its value doesnt matter
+ // in the fast check process so we just discard it
+ decode_varint(model);
+ return {onnx_field, 0};
+ }
+ case LENGTH_DELIMITED:
+ // the varint following the key determines the payload length
+ return {onnx_field, decode_varint(model)};
+ case BITS_64:
+ return {onnx_field, 8};
+ case BITS_32:
+ return {onnx_field, 4};
+ case START_GROUP:
+ case END_GROUP:
+ throw std::runtime_error{"StartGroup and EndGroup are not used in ONNX models"};
+ default:
+ throw std::runtime_error{"Unknown WireType encountered in the model"};
+ }
+ }
+
+ inline void skip_payload(std::istream& model, uint32_t payload_size) {
+ model.seekg(payload_size, std::ios::cur);
+ }
+} // namespace onnx
+
+namespace prototxt {
+ bool contains_onnx_model_keys(const std::string& model, const size_t expected_keys_num) {
+ size_t keys_found = 0;
+
+ const std::vector<std::string> onnx_keys = {
+ "ir_version", "producer_name", "producer_version", "domain", "model_version",
+ "doc_string", "graph", "opset_import", "metadata_props", "training_info"
+ };
+
+ size_t search_start_pos = 0;
+
+ while (keys_found < expected_keys_num) {
+ const auto key_finder = [&search_start_pos, &model](const std::string& key) {
+ const auto key_pos = model.find(key, search_start_pos);
+ if (key_pos != model.npos) {
+ // don't search from the beginning each time
+ search_start_pos = key_pos + key.size();
+ return true;
+ } else {
+ return false;
+ }
+ };
+
+ const auto found = std::any_of(std::begin(onnx_keys), std::end(onnx_keys), key_finder);
+ if (!found) {
+ break;
+ } else {
+ ++keys_found;
+ }
+ }
+
+ return keys_found == expected_keys_num;
+ }
+} // namespace prototxt
+} // namespace detail
+
+namespace InferenceEngine {
+ bool is_valid_model(std::istream& model, onnx_format) {
+ // the model usually starts with a 0x08 byte indicating the ir_version value
+ // so this checker expects at least 2 valid ONNX keys to be found in the validated model
+ const unsigned int EXPECTED_FIELDS_FOUND = 2u;
+ unsigned int valid_fields_found = 0u;
+ try {
+ while (!model.eof() && valid_fields_found < EXPECTED_FIELDS_FOUND) {
+ const auto field = detail::onnx::decode_next_field(model);
+
+ ++valid_fields_found;
+
+ if (field.second > 0) {
+ detail::onnx::skip_payload(model, field.second);
+ }
+ }
+
+ return valid_fields_found == EXPECTED_FIELDS_FOUND;
+ } catch (...) {
+ return false;
+ }
+ }
+
+ bool is_valid_model(std::istream& model, prototxt_format) {
+ std::array<char, 512> head_of_file;
+
+ model.seekg(0, model.beg);
+ model.read(head_of_file.data(), head_of_file.size());
+ model.clear();
+ model.seekg(0, model.beg);
+
+ return detail::prototxt::contains_onnx_model_keys(
+ std::string{std::begin(head_of_file), std::end(head_of_file)}, 2);
+ }
+} // namespace InferenceEngine
--- /dev/null
+// Copyright (C) 2018-2020 Intel Corporation
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#include <gtest/gtest.h>
+#include <fstream>
+
+#include <ie_core.hpp>
+
+namespace {
+ std::string model_path(const char* model) {
+ std::string path = ONNX_TEST_MODELS;
+ path += "support_test/";
+ path += model;
+ return path;
+ }
+}
+
+TEST(ONNXReader_ModelSupported, basic_model) {
+ // this model is a basic ONNX model taken from ngraph's unit test (add_abc.onnx)
+ // it contains the minimum number of fields required to accept this file as a valid model
+ EXPECT_NO_THROW(InferenceEngine::Core{}.ReadNetwork(model_path("supported/basic.onnx")));
+}
+
+TEST(ONNXReader_ModelSupported, basic_reverse_fields_order) {
+ // this model contains the same fields as basic.onnx but serialized in reverse order
+ EXPECT_NO_THROW(InferenceEngine::Core{}.ReadNetwork(model_path("supported/basic_reverse_fields_order.onnx")));
+}
+
+TEST(ONNXReader_ModelSupported, more_fields) {
+ // this model contains some optional fields (producer_name and doc_string) but 5 fields in total
+ EXPECT_NO_THROW(InferenceEngine::Core{}.ReadNetwork(model_path("supported/more_fields.onnx")));
+}
+
+TEST(ONNXReader_ModelSupported, varint_on_two_bytes) {
+ // the docstring's payload length is encoded as varint using 2 bytes which should be parsed correctly
+ EXPECT_NO_THROW(InferenceEngine::Core{}.ReadNetwork(model_path("supported/varint_on_two_bytes.onnx")));
+}
+
+TEST(ONNXReader_ModelSupported, prototxt_basic) {
+ EXPECT_NO_THROW(InferenceEngine::Core{}.ReadNetwork(model_path("supported/basic.prototxt")));
+}
+
+TEST(ONNXReader_ModelSupported, scrambled_keys) {
+ // same as the prototxt_basic but with a different order of keys
+ EXPECT_NO_THROW(InferenceEngine::Core{}.ReadNetwork(model_path("supported/scrambled_keys.prototxt")));
+}
+
+TEST(ONNXReader_ModelUnsupported, no_graph_field) {
+ // this model contains only 2 fields (it doesn't contain a graph in particular)
+ EXPECT_THROW(InferenceEngine::Core{}.ReadNetwork(model_path("unsupported/no_graph_field.onnx")),
+ InferenceEngine::details::InferenceEngineException);
+}
+
+TEST(ONNXReader_ModelUnsupported, incorrect_onnx_field) {
+ // in this model the second field's key is F8 (field number 31) which is doesn't exist in ONNX
+ // this test will have to be changed if the number of fields in onnx.proto
+ // (ModelProto message definition) ever reaches 31 or more
+ EXPECT_THROW(InferenceEngine::Core{}.ReadNetwork(model_path("unsupported/incorrect_onnx_field.onnx")),
+ InferenceEngine::details::InferenceEngineException);
+}
+
+TEST(ONNXReader_ModelUnsupported, unknown_wire_type) {
+ // in this model the graph key contains wire type 7 encoded in it - this value is incorrect
+ EXPECT_THROW(InferenceEngine::Core{}.ReadNetwork(model_path("unsupported/unknown_wire_type.onnx")),
+ InferenceEngine::details::InferenceEngineException);
+}
+
+TEST(ONNXReader_ModelUnsupported, no_valid_keys) {
+ EXPECT_THROW(InferenceEngine::Core{}.ReadNetwork(model_path("unsupported/no_valid_keys.prototxt")),
+ InferenceEngine::details::InferenceEngineException);
+}