From c0d71900fd23456f1653a0d55fd27b07198f6a1d Mon Sep 17 00:00:00 2001 From: Mateusz Bencer Date: Wed, 14 Oct 2020 11:30:53 +0200 Subject: [PATCH] Provide ONNX external data mechanism to ReadNetwork (#2588) MIME-Version: 1.0 Content-Type: text/plain; charset=utf8 Content-Transfer-Encoding: 8bit * added unit test * added python test * using pword approach * Added passing path to onnx reader * support for wstring * Added more tests * Apply suggestions from code review Co-authored-by: Michał Karzyński <4430709+postrational@users.noreply.github.com> * fix build for Windows * styles applied * Fixed Windows tests * styles applied * fixed styles in tests * review remarks * cmake order * Used target_compile_definitions instead of add_definitions * Move ONNX_TEST_MODELS to other scope Co-authored-by: Michał Karzyński <4430709+postrational@users.noreply.github.com> --- inference-engine/include/ie_core.hpp | 11 +- .../src/inference_engine/ie_network_reader.cpp | 4 + .../src/inference_engine/ie_network_reader.hpp | 3 + .../src/readers/onnx_reader/ie_onnx_reader.cpp | 12 ++- .../functional/inference_engine/CMakeLists.txt | 2 + .../onnx_reader/models/data/tensor.data | Bin 0 -> 16 bytes .../onnx_reader/models/onnx_external_data.prototxt | 97 ++++++++++++++++++ ...2\346\227\245\346\234\254\350\252\236.prototxt" | 97 ++++++++++++++++++ .../onnx_reader/onnx_reader_external_data.cpp | 112 +++++++++++++++++++++ ngraph/core/include/ngraph/file_util.hpp | 14 +++ ngraph/core/include/ngraph/ngraph_visibility.hpp | 11 ++ ngraph/core/src/file_util.cpp | 53 ++++++++++ .../onnx_import/utils/tensor_external_data.hpp | 3 +- ngraph/frontend/onnx_import/src/onnx.cpp | 6 +- .../onnx_import/src/utils/tensor_external_data.cpp | 8 +- .../python/tests/test_onnx/models/data/tensor.data | Bin 0 -> 12 bytes .../tests/test_onnx/models/external_data.prototxt | 77 ++++++++++++++ .../tests/test_onnx/test_onnx_external_data.py | 41 ++++++++ 18 files changed, 544 insertions(+), 7 deletions(-) create mode 100644 inference-engine/tests/functional/inference_engine/onnx_reader/models/data/tensor.data create mode 100644 inference-engine/tests/functional/inference_engine/onnx_reader/models/onnx_external_data.prototxt create mode 100644 "inference-engine/tests/functional/inference_engine/onnx_reader/models/\320\220\320\221\320\222\320\223\320\224\320\225\320\201\320\226\320\227\320\230\320\231/\343\201\262\343\202\211\343\201\214\343\201\252\346\227\245\346\234\254\350\252\236.prototxt" create mode 100644 inference-engine/tests/functional/inference_engine/onnx_reader/onnx_reader_external_data.cpp create mode 100644 ngraph/python/tests/test_onnx/models/data/tensor.data create mode 100644 ngraph/python/tests/test_onnx/models/external_data.prototxt create mode 100644 ngraph/python/tests/test_onnx/test_onnx_external_data.py diff --git a/inference-engine/include/ie_core.hpp b/inference-engine/include/ie_core.hpp index 32c9125..7f2b9ef 100644 --- a/inference-engine/include/ie_core.hpp +++ b/inference-engine/include/ie_core.hpp @@ -57,7 +57,8 @@ public: * For IR format (*.bin): * * if path is empty, will try to read bin file with the same name as xml and * * if bin file with the same name was not found, will load IR without weights. - * ONNX models with data files are not supported + * For ONNX format (*.onnx or *.prototxt): + * * binPath parameter is not used. * @return CNNNetwork */ CNNNetwork ReadNetwork(const std::wstring& modelPath, const std::wstring& binPath = {}) const; @@ -70,7 +71,8 @@ public: * For IR format (*.bin): * * if path is empty, will try to read bin file with the same name as xml and * * if bin file with the same name was not found, will load IR without weights. - * ONNX models with data files are not supported + * For ONNX format (*.onnx or *.prototxt): + * * binPath parameter is not used. * @return CNNNetwork */ CNNNetwork ReadNetwork(const std::string& modelPath, const std::string& binPath = {}) const; @@ -78,7 +80,10 @@ public: * @brief Reads models from IR and ONNX formats * @param model string with model in IR or ONNX format * @param weights shared pointer to constant blob with weights - * ONNX models doesn't support models with data blobs. + * Reading ONNX models doesn't support loading weights from data blobs. + * If you are using an ONNX model with external data files, please use the + * `InferenceEngine::Core::ReadNetwork(const std::string& model, const Blob::CPtr& weights) const` + * function overload which takes a filesystem path to the model. * For ONNX case the second parameter should contain empty blob. * @return CNNNetwork */ diff --git a/inference-engine/src/inference_engine/ie_network_reader.cpp b/inference-engine/src/inference_engine/ie_network_reader.cpp index 0f9033c..4135930 100644 --- a/inference-engine/src/inference_engine/ie_network_reader.cpp +++ b/inference-engine/src/inference_engine/ie_network_reader.cpp @@ -168,6 +168,10 @@ CNNNetwork details::ReadNetwork(const std::string& modelPath, const std::string& #endif // Try to open model file std::ifstream modelStream(model_path, std::ios::binary); + // save path in extensible array of stream + // notice: lifetime of path pointed by pword(0) is limited by current scope + const std::string path_to_save_in_stream = modelPath; + modelStream.pword(0) = const_cast(path_to_save_in_stream.c_str()); if (!modelStream.is_open()) THROW_IE_EXCEPTION << "Model file " << modelPath << " cannot be opened!"; diff --git a/inference-engine/src/inference_engine/ie_network_reader.hpp b/inference-engine/src/inference_engine/ie_network_reader.hpp index b98c7be..cae90e3 100644 --- a/inference-engine/src/inference_engine/ie_network_reader.hpp +++ b/inference-engine/src/inference_engine/ie_network_reader.hpp @@ -26,6 +26,9 @@ CNNNetwork ReadNetwork(const std::string& modelPath, const std::string& binPath, * @param model string with IR * @param weights shared pointer to constant blob with weights * @param exts vector with extensions + * @note Reading ONNX models doesn't support loading weights from data blobs. + If you are using an ONNX model with external data files, please use the + ReadNetwork function overload which takes a filesystem path to the model. * @return CNNNetwork */ CNNNetwork ReadNetwork(const std::string& model, const Blob::CPtr& weights, const std::vector& exts); diff --git a/inference-engine/src/readers/onnx_reader/ie_onnx_reader.cpp b/inference-engine/src/readers/onnx_reader/ie_onnx_reader.cpp index a99fd71..25993eb 100644 --- a/inference-engine/src/readers/onnx_reader/ie_onnx_reader.cpp +++ b/inference-engine/src/readers/onnx_reader/ie_onnx_reader.cpp @@ -21,8 +21,18 @@ bool ONNXReader::supportModel(std::istream& model) const { return !((header.find("(stream.pword(0))}; + } +} + CNNNetwork ONNXReader::read(std::istream& model, const std::vector& exts) const { - return CNNNetwork(ngraph::onnx_import::import_onnx_model(model)); + return CNNNetwork(ngraph::onnx_import::import_onnx_model(model, readPathFromStream(model))); } INFERENCE_PLUGIN_API(StatusCode) InferenceEngine::CreateReader(IReader*& reader, ResponseDesc *resp) noexcept { diff --git a/inference-engine/tests/functional/inference_engine/CMakeLists.txt b/inference-engine/tests/functional/inference_engine/CMakeLists.txt index d4c8ff9..f9b94b5 100644 --- a/inference-engine/tests/functional/inference_engine/CMakeLists.txt +++ b/inference-engine/tests/functional/inference_engine/CMakeLists.txt @@ -52,6 +52,8 @@ if(TARGET inference_engine_onnx_reader) add_dependencies(${TARGET_NAME} inference_engine_onnx_reader) endif() +target_compile_definitions(${TARGET_NAME} PRIVATE ONNX_TEST_MODELS="${CMAKE_CURRENT_SOURCE_DIR}/onnx_reader/models/") + include(CMakeParseArguments) # diff --git a/inference-engine/tests/functional/inference_engine/onnx_reader/models/data/tensor.data b/inference-engine/tests/functional/inference_engine/onnx_reader/models/data/tensor.data new file mode 100644 index 0000000000000000000000000000000000000000..8db91db9674aa2d6dfc82f483ab345d9ad0f29f7 GIT binary patch literal 16 UcmZQzXs~BsU~m8;AZ~B~01$%$KmY&$ literal 0 HcmV?d00001 diff --git a/inference-engine/tests/functional/inference_engine/onnx_reader/models/onnx_external_data.prototxt b/inference-engine/tests/functional/inference_engine/onnx_reader/models/onnx_external_data.prototxt new file mode 100644 index 0000000..ac9bb40 --- /dev/null +++ b/inference-engine/tests/functional/inference_engine/onnx_reader/models/onnx_external_data.prototxt @@ -0,0 +1,97 @@ +ir_version: 3 +producer_name: "nGraph ONNX Importer" +graph { + node { + input: "A" + input: "B" + output: "X" + name: "add_node1" + op_type: "Add" + } + node { + input: "X" + input: "C" + output: "Y" + name: "add_node2" + op_type: "Add" + } + name: "test_graph" + initializer { + dims: 2 + dims: 2 + data_type: 1 + name: "A" + external_data { + key: "location", + value: "data/tensor.data" + } + data_location: 1 + } + input { + name: "A" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 2 + } + dim { + dim_value: 2 + } + } + } + } + } + input { + name: "B" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 2 + } + dim { + dim_value: 2 + } + } + } + } + } + input { + name: "C" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 2 + } + dim { + dim_value: 2 + } + } + } + } + } + output { + name: "Y" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 2 + } + dim { + dim_value: 2 + } + } + } + } + } +} +opset_import { + version: 4 +} diff --git "a/inference-engine/tests/functional/inference_engine/onnx_reader/models/\320\220\320\221\320\222\320\223\320\224\320\225\320\201\320\226\320\227\320\230\320\231/\343\201\262\343\202\211\343\201\214\343\201\252\346\227\245\346\234\254\350\252\236.prototxt" "b/inference-engine/tests/functional/inference_engine/onnx_reader/models/\320\220\320\221\320\222\320\223\320\224\320\225\320\201\320\226\320\227\320\230\320\231/\343\201\262\343\202\211\343\201\214\343\201\252\346\227\245\346\234\254\350\252\236.prototxt" new file mode 100644 index 0000000..ab03f7c --- /dev/null +++ "b/inference-engine/tests/functional/inference_engine/onnx_reader/models/\320\220\320\221\320\222\320\223\320\224\320\225\320\201\320\226\320\227\320\230\320\231/\343\201\262\343\202\211\343\201\214\343\201\252\346\227\245\346\234\254\350\252\236.prototxt" @@ -0,0 +1,97 @@ +ir_version: 3 +producer_name: "nGraph ONNX Importer" +graph { + node { + input: "A" + input: "B" + output: "X" + name: "multiply_node_1" + op_type: "Mul" + } + node { + input: "X" + input: "C" + output: "Y" + name: "multiply_node_2" + op_type: "Mul" + } + name: "test_graph" + initializer { + dims: 2 + dims: 2 + data_type: 1 + name: "A" + external_data { + key: "location", + value: "../data/tensor.data" + } + data_location: 1 + } + input { + name: "A" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 2 + } + dim { + dim_value: 2 + } + } + } + } + } + input { + name: "B" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 2 + } + dim { + dim_value: 2 + } + } + } + } + } + input { + name: "C" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 2 + } + dim { + dim_value: 2 + } + } + } + } + } + output { + name: "Y" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 2 + } + dim { + dim_value: 2 + } + } + } + } + } +} +opset_import { + version: 4 +} diff --git a/inference-engine/tests/functional/inference_engine/onnx_reader/onnx_reader_external_data.cpp b/inference-engine/tests/functional/inference_engine/onnx_reader/onnx_reader_external_data.cpp new file mode 100644 index 0000000..0d4c073 --- /dev/null +++ b/inference-engine/tests/functional/inference_engine/onnx_reader/onnx_reader_external_data.cpp @@ -0,0 +1,112 @@ +// Copyright (C) 2018-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +TEST(ONNX_Reader_Tests, ImportModelWithExternalDataFromFile) { + InferenceEngine::Core ie; + auto cnnNetwork = ie.ReadNetwork(std::string(ONNX_TEST_MODELS) + "onnx_external_data.prototxt", ""); + auto function = cnnNetwork.getFunction(); + + int count_additions = 0; + int count_constants = 0; + int count_parameters = 0; + + std::shared_ptr external_data_node; + for (auto op : function->get_ops()) { + const auto op_type = std::string(op->get_type_name()); + count_additions += (op_type == "Add" ? 1 : 0); + count_parameters += (op_type == "Parameter" ? 1 : 0); + if (op_type == "Constant") { + count_constants += 1; + external_data_node = op; + } + } + + ASSERT_EQ(function->get_output_size(), 1); + ASSERT_EQ(std::string(function->get_output_op(0)->get_type_name()), "Result"); + ASSERT_EQ(function->get_output_element_type(0), ngraph::element::f32); + ASSERT_EQ(function->get_output_shape(0), ngraph::Shape({2, 2})); + ASSERT_EQ(count_additions, 2); + ASSERT_EQ(count_constants, 1); + ASSERT_EQ(count_parameters, 2); + + const auto external_data_node_const = ngraph::as_type_ptr(external_data_node); + ASSERT_TRUE(external_data_node_const->get_vector() == (std::vector{1, 2, 3, 4})); +} + +TEST(ONNX_Reader_Tests, ImportModelWithExternalDataFromStringException) { + InferenceEngine::Core ie; + const auto path = std::string(ONNX_TEST_MODELS) + "onnx_external_data.prototxt"; + InferenceEngine::Blob::CPtr weights; //not used + std::ifstream stream(path, std::ios::binary); + std::string modelAsString((std::istreambuf_iterator(stream)), std::istreambuf_iterator()); + stream.close(); + try { + auto cnnNetwork = ie.ReadNetwork(modelAsString, weights); + } + catch(const ngraph::ngraph_error& e) { + EXPECT_PRED_FORMAT2( + testing::IsSubstring, + std::string("invalid external data:"), + e.what()); + + EXPECT_PRED_FORMAT2( + testing::IsSubstring, + std::string("data/tensor.data, offset: 0, data_lenght: 0, sha1_digest: 0)"), + e.what()); + } + catch(...) { + FAIL() << "Reading network failed for unexpected reason"; + } +} + +#if defined(ENABLE_UNICODE_PATH_SUPPORT) && defined(_WIN32) +TEST(ONNX_Reader_Tests, ImportModelWithExternalDataFromWstringNamedFile) { + InferenceEngine::Core ie; + std::string win_dir_path = ONNX_TEST_MODELS; + std::replace(win_dir_path.begin(), win_dir_path.end(), '/', '\\'); + const std::wstring unicode_win_dir_path = FileUtils::multiByteCharToWString(win_dir_path.c_str()); + const std::wstring path = unicode_win_dir_path + L"АБВГДЕЁЖЗИЙ\\ひらがな日本語.prototxt"; + + auto cnnNetwork = ie.ReadNetwork(path, L""); + auto function = cnnNetwork.getFunction(); + + int count_multiply = 0; + int count_constants = 0; + int count_parameters = 0; + + std::shared_ptr external_data_node; + for (auto op : function->get_ops()) { + const auto op_type = std::string(op->get_type_name()); + count_multiply += (op_type == "Multiply" ? 1 : 0); + count_parameters += (op_type == "Parameter" ? 1 : 0); + if (op_type == "Constant") { + count_constants += 1; + external_data_node = op; + } + } + + ASSERT_EQ(function->get_output_size(), 1); + ASSERT_EQ(std::string(function->get_output_op(0)->get_type_name()), "Result"); + ASSERT_EQ(function->get_output_element_type(0), ngraph::element::f32); + ASSERT_EQ(function->get_output_shape(0), ngraph::Shape({2, 2})); + ASSERT_EQ(count_multiply, 2); + ASSERT_EQ(count_constants, 1); + ASSERT_EQ(count_parameters, 2); + + const auto external_data_node_const = ngraph::as_type_ptr(external_data_node); + ASSERT_TRUE(external_data_node_const->get_vector() == (std::vector{1, 2, 3, 4})); +} +#endif diff --git a/ngraph/core/include/ngraph/file_util.hpp b/ngraph/core/include/ngraph/file_util.hpp index b589d62..8f9bf9c 100644 --- a/ngraph/core/include/ngraph/file_util.hpp +++ b/ngraph/core/include/ngraph/file_util.hpp @@ -63,5 +63,19 @@ namespace ngraph std::function func, bool recurse = false, bool include_links = false); + + /// \brief Change Linux-style path ('/') to Windows-style ('\\') + /// \param path The path to change file separator + NGRAPH_API void convert_path_win_style(std::string& path); + + /// \brief Conversion from wide character string to a single-byte chain. + /// \param wstr A wide-char string + /// \return A multi-byte string + NGRAPH_API std::string wstring_to_string(const std::wstring& wstr); + + /// \brief Conversion from single-byte chain to wide character string. + /// \param str A null-terminated string + /// \return A wide-char string + NGRAPH_API std::wstring multi_byte_char_to_wstring(const char* str); } } diff --git a/ngraph/core/include/ngraph/ngraph_visibility.hpp b/ngraph/core/include/ngraph/ngraph_visibility.hpp index 8798375..6fa16c0 100644 --- a/ngraph/core/include/ngraph/ngraph_visibility.hpp +++ b/ngraph/core/include/ngraph/ngraph_visibility.hpp @@ -30,3 +30,14 @@ #else #define NGRAPH_API NGRAPH_HELPER_DLL_IMPORT #endif // ngraph_EXPORTS + +#ifndef ENABLE_UNICODE_PATH_SUPPORT +#ifdef _WIN32 +#if defined __INTEL_COMPILER || defined _MSC_VER +#define ENABLE_UNICODE_PATH_SUPPORT +#endif +#elif defined(__GNUC__) && (__GNUC__ > 5 || (__GNUC__ == 5 && __GNUC_MINOR__ > 2)) || \ + defined(__clang__) +#define ENABLE_UNICODE_PATH_SUPPORT +#endif +#endif diff --git a/ngraph/core/src/file_util.cpp b/ngraph/core/src/file_util.cpp index cd22252..1d18396 100644 --- a/ngraph/core/src/file_util.cpp +++ b/ngraph/core/src/file_util.cpp @@ -23,6 +23,7 @@ #include #include #endif +#include #include #include #include @@ -43,6 +44,10 @@ #else #define RMDIR(a) rmdir(a) #define RMFILE(a) remove(a) +#ifdef ENABLE_UNICODE_PATH_SUPPORT +#include +#include +#endif #endif using namespace std; @@ -77,10 +82,19 @@ string file_util::get_file_ext(const string& s) string file_util::get_directory(const string& s) { string rc = s; + // Linux-style separator auto pos = s.find_last_of('/'); if (pos != string::npos) { rc = s.substr(0, pos); + return rc; + } + // Windows-style separator + pos = s.find_last_of('\\'); + if (pos != string::npos) + { + rc = s.substr(0, pos); + return rc; } return rc; } @@ -240,3 +254,42 @@ void file_util::iterate_files(const string& path, func(f, true); } } + +NGRAPH_API void file_util::convert_path_win_style(std::string& path) +{ + std::replace(path.begin(), path.end(), '/', '\\'); +} + +#ifdef ENABLE_UNICODE_PATH_SUPPORT + +std::string file_util::wstring_to_string(const std::wstring& wstr) +{ +#ifdef _WIN32 + int size_needed = + WideCharToMultiByte(CP_UTF8, 0, &wstr[0], (int)wstr.size(), NULL, 0, NULL, NULL); // NOLINT + std::string strTo(size_needed, 0); + WideCharToMultiByte( + CP_UTF8, 0, &wstr[0], (int)wstr.size(), &strTo[0], size_needed, NULL, NULL); // NOLINT + return strTo; +#else + std::wstring_convert> wstring_decoder; + return wstring_decoder.to_bytes(wstr); +#endif +} + +std::wstring file_util::multi_byte_char_to_wstring(const char* str) +{ +#ifdef _WIN32 + int strSize = static_cast(std::strlen(str)); + int size_needed = MultiByteToWideChar(CP_UTF8, 0, str, strSize, NULL, 0); + std::wstring wstrTo(size_needed, 0); + MultiByteToWideChar(CP_UTF8, 0, str, strSize, &wstrTo[0], size_needed); + return wstrTo; +#else + std::wstring_convert> wstring_encoder; + std::wstring result = wstring_encoder.from_bytes(str); + return result; +#endif +} + +#endif // ENABLE_UNICODE_PATH_SUPPORT diff --git a/ngraph/frontend/onnx_import/include/onnx_import/utils/tensor_external_data.hpp b/ngraph/frontend/onnx_import/include/onnx_import/utils/tensor_external_data.hpp index 0baaef5..9e5f4e3 100644 --- a/ngraph/frontend/onnx_import/include/onnx_import/utils/tensor_external_data.hpp +++ b/ngraph/frontend/onnx_import/include/onnx_import/utils/tensor_external_data.hpp @@ -33,7 +33,8 @@ namespace ngraph /// \brief Load external data from tensor passed to constructor /// /// \note If read data from external file fails, - /// the invalid_external_data is thrown + /// \note If reading data from external files fails, + /// the invalid_external_data exception is thrown. /// /// \return External binary data loaded into a std::string std::string load_external_data() const; diff --git a/ngraph/frontend/onnx_import/src/onnx.cpp b/ngraph/frontend/onnx_import/src/onnx.cpp index 28333fe..c2679f7 100644 --- a/ngraph/frontend/onnx_import/src/onnx.cpp +++ b/ngraph/frontend/onnx_import/src/onnx.cpp @@ -119,9 +119,13 @@ namespace ngraph { const auto external_data_relative_path = initializer_tensor.external_data(location_key_value_index).value(); - const auto external_data_full_path = + auto external_data_full_path = file_util::path_join(model_dir_path, external_data_relative_path); +#if defined(ENABLE_UNICODE_PATH_SUPPORT) && defined(_WIN32) + file_util::convert_path_win_style(external_data_full_path); +#endif + // Set full paths to the external file initializer_tensor.mutable_external_data(location_key_value_index) ->set_value(external_data_full_path); diff --git a/ngraph/frontend/onnx_import/src/utils/tensor_external_data.cpp b/ngraph/frontend/onnx_import/src/utils/tensor_external_data.cpp index c7cfd5e..1d47211 100644 --- a/ngraph/frontend/onnx_import/src/utils/tensor_external_data.cpp +++ b/ngraph/frontend/onnx_import/src/utils/tensor_external_data.cpp @@ -17,6 +17,7 @@ #include #include +#include "ngraph/file_util.hpp" #include "ngraph/log.hpp" #include "onnx_import/exceptions.hpp" #include "tensor_external_data.hpp" @@ -44,7 +45,12 @@ namespace ngraph std::string TensorExternalData::load_external_data() const { - std::ifstream external_data_stream(m_data_location, +#if defined(ENABLE_UNICODE_PATH_SUPPORT) && defined(_WIN32) + std::wstring path = file_util::multi_byte_char_to_wstring(m_data_location.c_str()); +#else + std::string path = m_data_location; +#endif + std::ifstream external_data_stream(path, std::ios::binary | std::ios::in | std::ios::ate); if (external_data_stream.fail()) throw error::invalid_external_data{*this}; diff --git a/ngraph/python/tests/test_onnx/models/data/tensor.data b/ngraph/python/tests/test_onnx/models/data/tensor.data new file mode 100644 index 0000000000000000000000000000000000000000..5116510eebcfbd3a254c8e0e661dbb88acd086a7 GIT binary patch literal 12 TcmZQzSm40G&|uHN;NSoN4Tl0C literal 0 HcmV?d00001 diff --git a/ngraph/python/tests/test_onnx/models/external_data.prototxt b/ngraph/python/tests/test_onnx/models/external_data.prototxt new file mode 100644 index 0000000..e5ab8d3 --- /dev/null +++ b/ngraph/python/tests/test_onnx/models/external_data.prototxt @@ -0,0 +1,77 @@ +ir_version: 3 +producer_name: "nGraph ONNX Importer" +graph { + node { + input: "data_a" + input: "data_b" + input: "data_c" + output: "result" + op_type: "Mean" + } + name: "test_mean_example" + initializer { + dims: 3 + data_type: 1 + name: "data_c" + external_data { + key: "location", + value: "data/tensor.data" + } + data_location: 1 + } + input { + name: "data_a" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 3 + } + } + } + } + } + input { + name: "data_b" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 3 + } + } + } + } + } + input { + name: "data_c" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 3 + } + } + } + } + } + output { + name: "result" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 3 + } + } + } + } + } +} +opset_import { + version: 8 +} diff --git a/ngraph/python/tests/test_onnx/test_onnx_external_data.py b/ngraph/python/tests/test_onnx/test_onnx_external_data.py new file mode 100644 index 0000000..28077f0 --- /dev/null +++ b/ngraph/python/tests/test_onnx/test_onnx_external_data.py @@ -0,0 +1,41 @@ +# ****************************************************************************** +# Copyright 2017-2020 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ****************************************************************************** + +import os + +import numpy as np +import ngraph as ng +from openvino.inference_engine import IECore + +from tests.runtime import get_runtime + + +def test_import_onnx_with_external_data(): + model_path = os.path.join(os.path.dirname(__file__), "models/external_data.prototxt") + ie = IECore() + ie_network = ie.read_network(model=model_path) + + ng_function = ng.function_from_cnn(ie_network) + + dtype = np.float32 + value_a = np.array([1.0, 3.0, 5.0], dtype=dtype) + value_b = np.array([3.0, 5.0, 1.0], dtype=dtype) + # third input [5.0, 1.0, 3.0] read from external file + + runtime = get_runtime() + computation = runtime.computation(ng_function) + result = computation(value_a, value_b) + assert np.allclose(result, np.array([3.0, 3.0, 3.0], dtype=dtype)) -- 2.7.4