From: Tomasz Dołbniak Date: Fri, 23 Oct 2020 10:13:04 +0000 (+0200) Subject: ONNX Reader supportModel() implementation (#2744) X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=f1444b33e7ef671ee154a8201093219404266f96;p=platform%2Fupstream%2Fdldt.git ONNX Reader supportModel() implementation (#2744) --- diff --git a/inference-engine/src/inference_engine/ie_network_reader.cpp b/inference-engine/src/inference_engine/ie_network_reader.cpp index 4135930..b406c22 100644 --- a/inference-engine/src/inference_engine/ie_network_reader.cpp +++ b/inference-engine/src/inference_engine/ie_network_reader.cpp @@ -183,6 +183,7 @@ CNNNetwork details::ReadNetwork(const std::string& modelPath, const std::string& auto reader = it->second; // Check that reader supports the model if (reader->supportModel(modelStream)) { + modelStream.seekg(0, modelStream.beg); // Find weights std::string bPath = binPath; if (bPath.empty()) { @@ -235,6 +236,7 @@ CNNNetwork details::ReadNetwork(const std::string& model, const Blob::CPtr& weig for (auto it = readers.begin(); it != readers.end(); it++) { auto reader = it->second; if (reader->supportModel(modelStream)) { + modelStream.seekg(0, modelStream.beg); if (weights) return reader->read(modelStream, binStream, exts); return reader->read(modelStream, exts); diff --git a/inference-engine/src/readers/onnx_reader/ie_onnx_reader.cpp b/inference-engine/src/readers/onnx_reader/ie_onnx_reader.cpp index 25993eb..0892239 100644 --- a/inference-engine/src/readers/onnx_reader/ie_onnx_reader.cpp +++ b/inference-engine/src/readers/onnx_reader/ie_onnx_reader.cpp @@ -3,24 +3,12 @@ // #include "ie_onnx_reader.hpp" +#include "onnx_model_validator.hpp" #include #include using namespace InferenceEngine; -bool ONNXReader::supportModel(std::istream& model) const { - model.seekg(0, model.beg); - const int header_size = 128; - std::string header(header_size, ' '); - model.read(&header[0], header_size); - model.seekg(0, model.beg); - // find 'onnx' substring in the .onnx files - // find 'ir_version' and 'graph' for prototxt - // return (header.find("onnx") != std::string::npos) || (header.find("pytorch") != std::string::npos) || - // (header.find("ir_version") != std::string::npos && header.find("graph") != std::string::npos); - return !((header.find("& exts) const { return CNNNetwork(ngraph::onnx_import::import_onnx_model(model, readPathFromStream(model))); } diff --git a/inference-engine/src/readers/onnx_reader/onnx_model_validator.cpp b/inference-engine/src/readers/onnx_reader/onnx_model_validator.cpp new file mode 100644 index 0000000..4f8d8b9 --- /dev/null +++ b/inference-engine/src/readers/onnx_reader/onnx_model_validator.cpp @@ -0,0 +1,227 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "onnx_model_validator.hpp" + +#include +#include +#include +#include +#include +#include +namespace detail { +namespace onnx { + enum Field { + IR_VERSION = 1, + PRODUCER_NAME = 2, + PRODUCER_VERSION = 3, + DOMAIN_ = 4, // DOMAIN collides with some existing symbol in MSVC thus - underscore + MODEL_VERSION = 5, + DOC_STRING = 6, + GRAPH = 7, + OPSET_IMPORT = 8, + METADATA_PROPS = 14, + TRAINING_INFO = 20 + }; + + enum WireType { + VARINT = 0, + BITS_64 = 1, + LENGTH_DELIMITED = 2, + START_GROUP = 3, + END_GROUP = 4, + BITS_32 = 5 + }; + + // A PB key consists of a field number (defined in onnx.proto) and a type of data that follows this key + using PbKey = std::pair; + + // This pair represents a key found in the encoded model and optional size of the payload + // that follows the key (in bytes). They payload should be skipped for the fast check purposes. + using ONNXField = std::pair; + + bool is_correct_onnx_field(const PbKey& decoded_key) { + static const std::map onnx_fields = { + {IR_VERSION, VARINT}, + {PRODUCER_NAME, LENGTH_DELIMITED}, + {PRODUCER_VERSION, LENGTH_DELIMITED}, + {DOMAIN_, LENGTH_DELIMITED}, + {MODEL_VERSION, VARINT}, + {DOC_STRING, LENGTH_DELIMITED}, + {GRAPH, LENGTH_DELIMITED}, + {OPSET_IMPORT, LENGTH_DELIMITED}, + {METADATA_PROPS, LENGTH_DELIMITED}, + {TRAINING_INFO, LENGTH_DELIMITED}, + }; + + if (!onnx_fields.count(static_cast(decoded_key.first))) { + return false; + } + + return onnx_fields.at(static_cast(decoded_key.first)) == static_cast(decoded_key.second); + } + + /** + * Only 7 bits in each component of a varint count in this algorithm. The components form + * a decoded number when they are concatenated bitwise in a reverse order. For example: + * bytes = [b1, b2, b3, b4] + * varint = b4 ++ b3 ++ b2 ++ b1 <== only 7 bits of each byte should be extracted before concat + * + * b1 b2 + * bytes = [00101100, 00000010] + * b2 b1 + * varint = 0000010 ++ 0101100 = 100101100 => decimal: 300 + * Each consecutive varint byte needs to be left shifted "7 x its position in the vector" + * and bitwise added to the accumulator afterwards. + */ + uint32_t varint_bytes_to_number(const std::vector& bytes) { + uint32_t accumulator = 0u; + + for (size_t i = 0; i < bytes.size(); ++i) { + uint32_t b = bytes[i]; + b <<= 7 * i; + accumulator |= b; + } + + return accumulator; + } + + uint32_t decode_varint(std::istream& model) { + std::vector bytes; + bytes.reserve(4); + + char key_component = 0; + model.get(key_component); + + // keep reading all bytes from the stream which have the MSB on + while (key_component & 0x80) { + // drop the most significant bit + const char component = key_component & ~0x80; + bytes.push_back(component); + model.get(key_component); + } + // add the last byte - the one with MSB off + bytes.push_back(key_component); + + return varint_bytes_to_number(bytes); + } + + PbKey decode_key(const char key) { + // 3 least significant bits + const char wire_type = key & 0b111; + // remaining bits + const char field_number = key >> 3; + return {field_number, wire_type}; + } + + ONNXField decode_next_field(std::istream& model) { + char key = 0; + model.get(key); + + const auto decoded_key = decode_key(key); + + if (!is_correct_onnx_field(decoded_key)) { + throw std::runtime_error{"Incorrect field detected in the processed model"}; + } + + const auto onnx_field = static_cast(decoded_key.first); + + switch (decoded_key.second) { + case VARINT: { + // the decoded varint is the payload in this case but its value doesnt matter + // in the fast check process so we just discard it + decode_varint(model); + return {onnx_field, 0}; + } + case LENGTH_DELIMITED: + // the varint following the key determines the payload length + return {onnx_field, decode_varint(model)}; + case BITS_64: + return {onnx_field, 8}; + case BITS_32: + return {onnx_field, 4}; + case START_GROUP: + case END_GROUP: + throw std::runtime_error{"StartGroup and EndGroup are not used in ONNX models"}; + default: + throw std::runtime_error{"Unknown WireType encountered in the model"}; + } + } + + inline void skip_payload(std::istream& model, uint32_t payload_size) { + model.seekg(payload_size, std::ios::cur); + } +} // namespace onnx + +namespace prototxt { + bool contains_onnx_model_keys(const std::string& model, const size_t expected_keys_num) { + size_t keys_found = 0; + + const std::vector onnx_keys = { + "ir_version", "producer_name", "producer_version", "domain", "model_version", + "doc_string", "graph", "opset_import", "metadata_props", "training_info" + }; + + size_t search_start_pos = 0; + + while (keys_found < expected_keys_num) { + const auto key_finder = [&search_start_pos, &model](const std::string& key) { + const auto key_pos = model.find(key, search_start_pos); + if (key_pos != model.npos) { + // don't search from the beginning each time + search_start_pos = key_pos + key.size(); + return true; + } else { + return false; + } + }; + + const auto found = std::any_of(std::begin(onnx_keys), std::end(onnx_keys), key_finder); + if (!found) { + break; + } else { + ++keys_found; + } + } + + return keys_found == expected_keys_num; + } +} // namespace prototxt +} // namespace detail + +namespace InferenceEngine { + bool is_valid_model(std::istream& model, onnx_format) { + // the model usually starts with a 0x08 byte indicating the ir_version value + // so this checker expects at least 2 valid ONNX keys to be found in the validated model + const unsigned int EXPECTED_FIELDS_FOUND = 2u; + unsigned int valid_fields_found = 0u; + try { + while (!model.eof() && valid_fields_found < EXPECTED_FIELDS_FOUND) { + const auto field = detail::onnx::decode_next_field(model); + + ++valid_fields_found; + + if (field.second > 0) { + detail::onnx::skip_payload(model, field.second); + } + } + + return valid_fields_found == EXPECTED_FIELDS_FOUND; + } catch (...) { + return false; + } + } + + bool is_valid_model(std::istream& model, prototxt_format) { + std::array head_of_file; + + model.seekg(0, model.beg); + model.read(head_of_file.data(), head_of_file.size()); + model.clear(); + model.seekg(0, model.beg); + + return detail::prototxt::contains_onnx_model_keys( + std::string{std::begin(head_of_file), std::end(head_of_file)}, 2); + } +} // namespace InferenceEngine diff --git a/inference-engine/src/readers/onnx_reader/onnx_model_validator.hpp b/inference-engine/src/readers/onnx_reader/onnx_model_validator.hpp new file mode 100644 index 0000000..337a693 --- /dev/null +++ b/inference-engine/src/readers/onnx_reader/onnx_model_validator.hpp @@ -0,0 +1,17 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +namespace InferenceEngine { + // 2 empty structs used for tag dispatch below + struct onnx_format {}; + struct prototxt_format {}; + + bool is_valid_model(std::istream& model, onnx_format); + + bool is_valid_model(std::istream& model, prototxt_format); +} // namespace InferenceEngine diff --git a/inference-engine/tests/functional/inference_engine/onnx_reader/model_support_tests.cpp b/inference-engine/tests/functional/inference_engine/onnx_reader/model_support_tests.cpp new file mode 100644 index 0000000..a460eb5 --- /dev/null +++ b/inference-engine/tests/functional/inference_engine/onnx_reader/model_support_tests.cpp @@ -0,0 +1,72 @@ +// Copyright (C) 2018-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include + +#include + +namespace { + std::string model_path(const char* model) { + std::string path = ONNX_TEST_MODELS; + path += "support_test/"; + path += model; + return path; + } +} + +TEST(ONNXReader_ModelSupported, basic_model) { + // this model is a basic ONNX model taken from ngraph's unit test (add_abc.onnx) + // it contains the minimum number of fields required to accept this file as a valid model + EXPECT_NO_THROW(InferenceEngine::Core{}.ReadNetwork(model_path("supported/basic.onnx"))); +} + +TEST(ONNXReader_ModelSupported, basic_reverse_fields_order) { + // this model contains the same fields as basic.onnx but serialized in reverse order + EXPECT_NO_THROW(InferenceEngine::Core{}.ReadNetwork(model_path("supported/basic_reverse_fields_order.onnx"))); +} + +TEST(ONNXReader_ModelSupported, more_fields) { + // this model contains some optional fields (producer_name and doc_string) but 5 fields in total + EXPECT_NO_THROW(InferenceEngine::Core{}.ReadNetwork(model_path("supported/more_fields.onnx"))); +} + +TEST(ONNXReader_ModelSupported, varint_on_two_bytes) { + // the docstring's payload length is encoded as varint using 2 bytes which should be parsed correctly + EXPECT_NO_THROW(InferenceEngine::Core{}.ReadNetwork(model_path("supported/varint_on_two_bytes.onnx"))); +} + +TEST(ONNXReader_ModelSupported, prototxt_basic) { + EXPECT_NO_THROW(InferenceEngine::Core{}.ReadNetwork(model_path("supported/basic.prototxt"))); +} + +TEST(ONNXReader_ModelSupported, scrambled_keys) { + // same as the prototxt_basic but with a different order of keys + EXPECT_NO_THROW(InferenceEngine::Core{}.ReadNetwork(model_path("supported/scrambled_keys.prototxt"))); +} + +TEST(ONNXReader_ModelUnsupported, no_graph_field) { + // this model contains only 2 fields (it doesn't contain a graph in particular) + EXPECT_THROW(InferenceEngine::Core{}.ReadNetwork(model_path("unsupported/no_graph_field.onnx")), + InferenceEngine::details::InferenceEngineException); +} + +TEST(ONNXReader_ModelUnsupported, incorrect_onnx_field) { + // in this model the second field's key is F8 (field number 31) which is doesn't exist in ONNX + // this test will have to be changed if the number of fields in onnx.proto + // (ModelProto message definition) ever reaches 31 or more + EXPECT_THROW(InferenceEngine::Core{}.ReadNetwork(model_path("unsupported/incorrect_onnx_field.onnx")), + InferenceEngine::details::InferenceEngineException); +} + +TEST(ONNXReader_ModelUnsupported, unknown_wire_type) { + // in this model the graph key contains wire type 7 encoded in it - this value is incorrect + EXPECT_THROW(InferenceEngine::Core{}.ReadNetwork(model_path("unsupported/unknown_wire_type.onnx")), + InferenceEngine::details::InferenceEngineException); +} + +TEST(ONNXReader_ModelUnsupported, no_valid_keys) { + EXPECT_THROW(InferenceEngine::Core{}.ReadNetwork(model_path("unsupported/no_valid_keys.prototxt")), + InferenceEngine::details::InferenceEngineException); +} diff --git a/inference-engine/tests/functional/inference_engine/onnx_reader/models/support_test/supported/basic.onnx b/inference-engine/tests/functional/inference_engine/onnx_reader/models/support_test/supported/basic.onnx new file mode 100644 index 0000000..2469457 --- /dev/null +++ b/inference-engine/tests/functional/inference_engine/onnx_reader/models/support_test/supported/basic.onnx @@ -0,0 +1,12 @@ +:D + +xy"Cosh +cosh_graphZ +x +  + +b +y +  + +B \ No newline at end of file diff --git a/inference-engine/tests/functional/inference_engine/onnx_reader/models/support_test/supported/basic.prototxt b/inference-engine/tests/functional/inference_engine/onnx_reader/models/support_test/supported/basic.prototxt new file mode 100644 index 0000000..7f63b68 --- /dev/null +++ b/inference-engine/tests/functional/inference_engine/onnx_reader/models/support_test/supported/basic.prototxt @@ -0,0 +1,74 @@ +ir_version: 3 +producer_name: "nGraph ONNX Importer" +graph { + node { + input: "A" + input: "B" + output: "X" + name: "add_node1" + op_type: "Add" + } + node { + input: "X" + input: "C" + output: "Y" + name: "add_node2" + op_type: "Add" + } + name: "test_graph" + input { + name: "A" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + } + } + } + } + input { + name: "B" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + } + } + } + } + input { + name: "C" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + } + } + } + } + output { + name: "Y" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + } + } + } + } +} +opset_import { + version: 4 +} diff --git a/inference-engine/tests/functional/inference_engine/onnx_reader/models/support_test/supported/basic_reverse_fields_order.onnx b/inference-engine/tests/functional/inference_engine/onnx_reader/models/support_test/supported/basic_reverse_fields_order.onnx new file mode 100644 index 0000000..c7c9efd --- /dev/null +++ b/inference-engine/tests/functional/inference_engine/onnx_reader/models/support_test/supported/basic_reverse_fields_order.onnx @@ -0,0 +1,12 @@ +B :D + +xy"Cosh +cosh_graphZ +x +  + +b +y +  + + \ No newline at end of file diff --git a/inference-engine/tests/functional/inference_engine/onnx_reader/models/support_test/supported/more_fields.onnx b/inference-engine/tests/functional/inference_engine/onnx_reader/models/support_test/supported/more_fields.onnx new file mode 100644 index 0000000..bb3fa98 --- /dev/null +++ b/inference-engine/tests/functional/inference_engine/onnx_reader/models/support_test/supported/more_fields.onnx @@ -0,0 +1,12 @@ +ONNX Reader test2Doc string for this model:D + +xy"Cosh +cosh_graphZ +x +  + +b +y +  + +B \ No newline at end of file diff --git a/inference-engine/tests/functional/inference_engine/onnx_reader/models/support_test/supported/scrambled_keys.prototxt b/inference-engine/tests/functional/inference_engine/onnx_reader/models/support_test/supported/scrambled_keys.prototxt new file mode 100644 index 0000000..52676bd --- /dev/null +++ b/inference-engine/tests/functional/inference_engine/onnx_reader/models/support_test/supported/scrambled_keys.prototxt @@ -0,0 +1,73 @@ +graph { + node { + input: "A" + input: "B" + output: "X" + name: "add_node1" + op_type: "Add" + } + node { + input: "X" + input: "C" + output: "Y" + name: "add_node2" + op_type: "Add" + } + name: "test_graph" + input { + name: "A" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + } + } + } + } + input { + name: "B" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + } + } + } + } + input { + name: "C" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + } + } + } + } + output { + name: "Y" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + } + } + } + } +} +opset_import { + version: 4 +} +ir_version: 3 \ No newline at end of file diff --git a/inference-engine/tests/functional/inference_engine/onnx_reader/models/support_test/supported/varint_on_two_bytes.onnx b/inference-engine/tests/functional/inference_engine/onnx_reader/models/support_test/supported/varint_on_two_bytes.onnx new file mode 100644 index 0000000..e824946 --- /dev/null +++ b/inference-engine/tests/functional/inference_engine/onnx_reader/models/support_test/supported/varint_on_two_bytes.onnx @@ -0,0 +1,12 @@ +2¶over_128_chars_in_this_string_to_make_sure_its_length_is_encoded_as_varint_using_two_bytes__over_128_chars_in_this_string_to_make_sure_its_length_is_encoded_as_varint_using_two_bytes:D + +xy"Cosh +cosh_graphZ +x +  + +b +y +  + +B \ No newline at end of file diff --git a/inference-engine/tests/functional/inference_engine/onnx_reader/models/support_test/unsupported/incorrect_onnx_field.onnx b/inference-engine/tests/functional/inference_engine/onnx_reader/models/support_test/unsupported/incorrect_onnx_field.onnx new file mode 100644 index 0000000..d19de1c --- /dev/null +++ b/inference-engine/tests/functional/inference_engine/onnx_reader/models/support_test/unsupported/incorrect_onnx_field.onnx @@ -0,0 +1,12 @@ +øD + +xy"Cosh +cosh_graphZ +x +  + +b +y +  + +B \ No newline at end of file diff --git a/inference-engine/tests/functional/inference_engine/onnx_reader/models/support_test/unsupported/no_graph_field.onnx b/inference-engine/tests/functional/inference_engine/onnx_reader/models/support_test/unsupported/no_graph_field.onnx new file mode 100644 index 0000000..8552a0d Binary files /dev/null and b/inference-engine/tests/functional/inference_engine/onnx_reader/models/support_test/unsupported/no_graph_field.onnx differ diff --git a/inference-engine/tests/functional/inference_engine/onnx_reader/models/support_test/unsupported/no_valid_keys.prototxt b/inference-engine/tests/functional/inference_engine/onnx_reader/models/support_test/unsupported/no_valid_keys.prototxt new file mode 100644 index 0000000..28a5e83 --- /dev/null +++ b/inference-engine/tests/functional/inference_engine/onnx_reader/models/support_test/unsupported/no_valid_keys.prototxt @@ -0,0 +1,5 @@ +james_bond: 007 +Shakira: "Waka Waka" +blip { + bloop: 21,37 +} diff --git a/inference-engine/tests/functional/inference_engine/onnx_reader/models/support_test/unsupported/unknown_wire_type.onnx b/inference-engine/tests/functional/inference_engine/onnx_reader/models/support_test/unsupported/unknown_wire_type.onnx new file mode 100644 index 0000000..9e1ce57 --- /dev/null +++ b/inference-engine/tests/functional/inference_engine/onnx_reader/models/support_test/unsupported/unknown_wire_type.onnx @@ -0,0 +1,12 @@ +?D + +xy"Cosh +cosh_graphZ +x +  + +b +y +  + +B \ No newline at end of file