From 2f5a28d44f1fc8e866e587121ea88dd3f711dc8e Mon Sep 17 00:00:00 2001 From: Ivan Tikhonov Date: Fri, 4 Sep 2020 09:04:36 +0300 Subject: [PATCH] LSTMCell/Sequence v1, reference implementations and decompose transformations for LSTM/GRU/RNN Cells (#2000) * validate_and_infer_types() implementation * input parameter validation for LSTM, GRU and RNN * style-check applied * Add LSTMSequence dynamic shape validation and test props for RNNCell, GRUCell, LSTMCell and LSTMSequence. * recurrent_sequence.hpp moved to ngraph/core/include/ngraph/op/util/ * style check applied * removed unused variable from LSTMSequence::validate_and_infer_types * Add missing newline mark at the end of file. * Add supression macro for FusedOp deprecation. * Add element type initialization * transpose,rnn cell reference implementations * Apply PR review remarks * reference implementations for cells op, single layer tests, align lstm cell/sequence according to the spec * lstm/gru/rnn cell decompostion transformations * ngraph codestyle * clean up * ngraph code style * change inheritance of Cells, fix build * fix build * fix build again * remove Peepholes from LSTMSeq, fix copy_runtime_info in transformations * Rewrite tests to use gtest exception assertions. * resolve tests issues * ngraph codestyle * add missed files * fix typeprop tests * fix lstm sequence checks * fix arm build * fix arm again * delete unnecessary file * add convert weghts format function, enable lstm test, resolve review comments * add ngraph builders * ngraph codestyle * fix unit tests * revert transpose reference implementation * revert LSTM Cell v0, add LSTMCell v1, update transformation lstm_cell_to_cell_ie * v1 version of LSTMCell op * LSTMSequence v1 operation, exclude LSTMSeq from opset4 * fix python api tests * resolve review comments, tests for decomposition transformations, switch lstm cell to opset4 in mo Co-authored-by: Szymon Durawa --- .../src/convert_function_to_cnn_network.cpp | 24 +- .../legacy_api/src/ie_cnn_layer_builder_ngraph.cpp | 48 - .../src/readers/ir_reader/ie_ir_parser.cpp | 6 +- .../include/ngraph_ops/lstm_cell_ie.hpp | 7 +- .../transformations/gru_cell_decomposition.hpp | 41 + .../transformations/lstm_cell_decomposition.hpp | 42 + .../transformations/rnn_cell_decomposition.hpp | 36 + .../include/transformations/utils/utils.hpp | 3 + .../src/ngraph_ops/lstm_cell_ie.cpp | 9 + .../convert_cells_to_cells_ie.cpp | 13 +- .../src/transformations/gru_cell_decomposition.cpp | 104 ++ .../transformations/lstm_cell_decomposition.cpp | 85 ++ .../src/transformations/rnn_cell_decomposition.cpp | 52 + .../src/transformations/utils/utils.cpp | 12 + .../inference_engine/ngraph_reader/ti.cpp | 1403 +++++++++++++++++++- .../convert_cells_to_cells_ie_test.cpp | 77 +- .../single_layer_tests/gru_cell.cpp | 37 + .../single_layer_tests/lstm_cell.cpp | 36 + .../single_layer_tests/rnn_cell.cpp | 34 + .../shared/include/single_layer_tests/gru_cell.hpp | 38 + .../include/single_layer_tests/lstm_cell.hpp | 37 + .../shared/include/single_layer_tests/rnn_cell.hpp | 37 + .../shared/src/single_layer_tests/gru_cell.cpp | 90 ++ .../shared/src/single_layer_tests/lstm_cell.cpp | 89 ++ .../shared/src/single_layer_tests/rnn_cell.cpp | 82 ++ .../shared/src/subgraph_tests/basic_lstm.cpp | 6 +- .../include/ngraph_functions/builders.hpp | 26 + .../include/ngraph_functions/subgraph_builders.hpp | 2 +- .../tests/ngraph_functions/src/gru_cell.cpp | 30 + .../tests/ngraph_functions/src/lstm_cell.cpp | 29 + .../tests/ngraph_functions/src/rnn_cell.cpp | 29 + model-optimizer/extensions/ops/lstm_cell.py | 2 +- ngraph/core/include/ngraph/op/gru_cell.hpp | 9 +- ngraph/core/include/ngraph/op/lstm_cell.hpp | 163 ++- ngraph/core/include/ngraph/op/lstm_sequence.hpp | 66 +- ngraph/core/include/ngraph/op/rnn_cell.hpp | 13 +- .../core/include/ngraph/op/util/rnn_cell_base.hpp | 33 +- ngraph/core/include/ngraph/opsets/opset4_tbl.hpp | 3 +- .../include/ngraph/runtime/reference/gru_cell.hpp | 316 +++++ .../include/ngraph/runtime/reference/lstm_cell.hpp | 217 +++ .../include/ngraph/runtime/reference/rnn_cell.hpp | 132 ++ .../include/ngraph/runtime/reference/split.hpp | 37 + .../core/reference/src/runtime/reference/split.cpp | 54 + ngraph/core/src/op/gru_cell.cpp | 180 +-- ngraph/core/src/op/lstm_cell.cpp | 613 ++++----- ngraph/core/src/op/lstm_sequence.cpp | 165 +++ ngraph/core/src/op/rnn_cell.cpp | 216 +-- ngraph/core/src/op/split.cpp | 56 +- ngraph/core/src/op/util/rnn_cell_base.cpp | 33 +- ngraph/frontend/onnx_import/src/op/lstm.cpp | 6 +- ngraph/python/src/ngraph/opset4/__init__.py | 2 +- ngraph/python/src/ngraph/opset4/ops.py | 51 + ngraph/python/tests/test_ngraph/test_create_op.py | 69 +- ngraph/test/attributes.cpp | 32 +- ngraph/test/backend/fused_op.in.cpp | 213 ++- ngraph/test/op_is.cpp | 8 +- ngraph/test/runtime/ie/unit_test.manifest | 16 +- ngraph/test/runtime/interpreter/int_executable.hpp | 67 +- ngraph/test/runtime/interpreter/opset_int_tbl.hpp | 3 + ngraph/test/runtime/interpreter/unit_test.manifest | 17 + ngraph/test/runtime/opset0_tbl.hpp | 5 +- ngraph/test/type_prop/gru_cell.cpp | 45 +- ngraph/test/type_prop/lstm_cell.cpp | 238 ++-- ngraph/test/type_prop/lstm_sequence.cpp | 141 +- ngraph/test/type_prop/rnn_cell.cpp | 154 ++- 65 files changed, 4693 insertions(+), 1246 deletions(-) create mode 100644 inference-engine/src/transformations/include/transformations/gru_cell_decomposition.hpp create mode 100644 inference-engine/src/transformations/include/transformations/lstm_cell_decomposition.hpp create mode 100644 inference-engine/src/transformations/include/transformations/rnn_cell_decomposition.hpp create mode 100644 inference-engine/src/transformations/src/transformations/gru_cell_decomposition.cpp create mode 100644 inference-engine/src/transformations/src/transformations/lstm_cell_decomposition.cpp create mode 100644 inference-engine/src/transformations/src/transformations/rnn_cell_decomposition.cpp create mode 100644 inference-engine/tests/functional/plugin/cpu/shared_tests_instances/single_layer_tests/gru_cell.cpp create mode 100644 inference-engine/tests/functional/plugin/cpu/shared_tests_instances/single_layer_tests/lstm_cell.cpp create mode 100644 inference-engine/tests/functional/plugin/cpu/shared_tests_instances/single_layer_tests/rnn_cell.cpp create mode 100644 inference-engine/tests/functional/plugin/shared/include/single_layer_tests/gru_cell.hpp create mode 100644 inference-engine/tests/functional/plugin/shared/include/single_layer_tests/lstm_cell.hpp create mode 100644 inference-engine/tests/functional/plugin/shared/include/single_layer_tests/rnn_cell.hpp create mode 100644 inference-engine/tests/functional/plugin/shared/src/single_layer_tests/gru_cell.cpp create mode 100644 inference-engine/tests/functional/plugin/shared/src/single_layer_tests/lstm_cell.cpp create mode 100644 inference-engine/tests/functional/plugin/shared/src/single_layer_tests/rnn_cell.cpp create mode 100644 inference-engine/tests/ngraph_functions/src/gru_cell.cpp create mode 100644 inference-engine/tests/ngraph_functions/src/lstm_cell.cpp create mode 100644 inference-engine/tests/ngraph_functions/src/rnn_cell.cpp create mode 100644 ngraph/core/reference/include/ngraph/runtime/reference/gru_cell.hpp create mode 100644 ngraph/core/reference/include/ngraph/runtime/reference/lstm_cell.hpp create mode 100644 ngraph/core/reference/include/ngraph/runtime/reference/rnn_cell.hpp create mode 100644 ngraph/core/reference/include/ngraph/runtime/reference/split.hpp create mode 100644 ngraph/core/reference/src/runtime/reference/split.cpp diff --git a/inference-engine/src/legacy_api/src/convert_function_to_cnn_network.cpp b/inference-engine/src/legacy_api/src/convert_function_to_cnn_network.cpp index 498b73c..fe86357 100644 --- a/inference-engine/src/legacy_api/src/convert_function_to_cnn_network.cpp +++ b/inference-engine/src/legacy_api/src/convert_function_to_cnn_network.cpp @@ -410,6 +410,29 @@ InferenceEngine::details::CNNLayerCreator::CNNLayerCreator(const std::shared_ptr return res; }); + addSpecificCreator({"LSTMCellIE"}, [](const std::shared_ptr<::ngraph::Node>& node, + const std::map params) -> CNNLayerPtr { + LayerParams attrs = {node->get_friendly_name(), "LSTMCell", + details::convertPrecision(node->get_output_element_type(0))}; + auto res = std::make_shared(attrs); + res->params = params; + Builder::NodeConverter converter; + const auto weightsNode = node->input_value(3).get_node_shared_ptr(); + if (converter.canCreate(weightsNode)) { + const auto& weights = converter.createLayer(weightsNode); + res->blobs["weights"] = weights->blobs["custom"]; + res->_weights = weights->blobs["custom"]; + } + + const auto biasNode = node->input_value(4).get_node_shared_ptr(); + if (converter.canCreate(biasNode)) { + const auto& bias = converter.createLayer(biasNode); + res->blobs["biases"] = bias->blobs["custom"]; + res->_biases = bias->blobs["custom"]; + } + return res; + }); + addSpecificCreator({"RNNCellIE"}, [](const std::shared_ptr<::ngraph::Node>& node, const std::map& params) -> CNNLayerPtr { LayerParams attrs = {node->get_friendly_name(), "RNNCell", @@ -672,7 +695,6 @@ void convertFunctionToICNNNetwork(const std::shared_ptr>(), std::make_shared>(), std::make_shared>(), - std::make_shared>(), std::make_shared>(), std::make_shared>(), std::make_shared>(), diff --git a/inference-engine/src/legacy_api/src/ie_cnn_layer_builder_ngraph.cpp b/inference-engine/src/legacy_api/src/ie_cnn_layer_builder_ngraph.cpp index 03ac65d..16bbebf 100644 --- a/inference-engine/src/legacy_api/src/ie_cnn_layer_builder_ngraph.cpp +++ b/inference-engine/src/legacy_api/src/ie_cnn_layer_builder_ngraph.cpp @@ -1867,54 +1867,6 @@ CNNLayer::Ptr NodeConverter::createLayer(const std:: } template <> -CNNLayer::Ptr NodeConverter::createLayer(const std::shared_ptr& layer) const { - LayerParams params = {layer->get_friendly_name(), "LSTMCell", - details::convertPrecision(layer->get_output_element_type(0))}; - auto castedLayer = ngraph::as_type_ptr(layer); - if (castedLayer == nullptr) THROW_IE_EXCEPTION << "Cannot get " << params.type << " layer " << params.name; - - auto res = std::make_shared(params); - res->params["hidden_size"] = asString(castedLayer->get_hidden_size()); - std::string value; - for (const auto& val : castedLayer->get_activations()) { - if (!value.empty()) value += ","; - value += val; - } - res->params["activations"] = value; - - value.clear(); - for (const auto& val : castedLayer->get_activations_alpha()) { - if (!value.empty()) value += ","; - value += val; - } - res->params["activations_alpha"] = value; - - value.clear(); - for (const auto& val : castedLayer->get_activations_beta()) { - if (!value.empty()) value += ","; - value += val; - } - res->params["activations_beta"] = value; - res->params["clip"] = asString(castedLayer->get_clip()); - - NodeConverter converter; - const auto weightsNode = layer->input_value(3).get_node_shared_ptr(); - if (converter.canCreate(weightsNode)) { - const auto& weights = converter.createLayer(weightsNode); - res->blobs["weights"] = weights->blobs["custom"]; - res->_weights = weights->blobs["custom"]; - } - - const auto biasNode = layer->input_value(4).get_node_shared_ptr(); - if (converter.canCreate(biasNode)) { - const auto& bias = converter.createLayer(biasNode); - res->blobs["biases"] = bias->blobs["custom"]; - res->_biases = bias->blobs["custom"]; - } - return res; -} - -template <> CNNLayer::Ptr NodeConverter::createLayer(const std::shared_ptr& layer) const { LayerParams params = {layer->get_friendly_name(), "Gemm", details::convertPrecision(layer->get_output_element_type(0))}; diff --git a/inference-engine/src/readers/ir_reader/ie_ir_parser.cpp b/inference-engine/src/readers/ir_reader/ie_ir_parser.cpp index e291ba0..d34891e 100644 --- a/inference-engine/src/readers/ir_reader/ie_ir_parser.cpp +++ b/inference-engine/src/readers/ir_reader/ie_ir_parser.cpp @@ -439,7 +439,7 @@ std::shared_ptr V10Parser::createNode(const std::vector>("Select"), std::make_shared>("LRN"), std::make_shared>("MVN"), - std::make_shared>("LSTMCell"), + std::make_shared>("LSTMCell"), std::make_shared>("MaxPool"), std::make_shared>("Maximum"), std::make_shared>("Minimum"), @@ -910,7 +910,7 @@ std::shared_ptr V10Parser::LayerCreator::crea // LSTMCell layer template <> -std::shared_ptr V10Parser::LayerCreator::createLayer( +std::shared_ptr V10Parser::LayerCreator::createLayer( const ngraph::OutputVector& inputs, const pugi::xml_node& node, std::istream& binStream, const GenericLayerParams& layerParsePrms) { checkParameters(inputs, layerParsePrms, 6); @@ -922,7 +922,7 @@ std::shared_ptr V10Parser::LayerCreator::cre std::vector activations_alpha = getParameters(dn, "activations_alpha", {}); std::vector activations_beta = getParameters(dn, "activations_beta", {}); float clip = GetFloatAttr(dn, "clip", 0.f); - return std::make_shared(inputs[0], inputs[1], inputs[2], inputs[3], inputs[4], inputs[5], + return std::make_shared(inputs[0], inputs[1], inputs[2], inputs[3], inputs[4], inputs[5], GetUInt64Attr(dn, "hidden_size"), ngraph::op::LSTMWeightsFormat::IFCO, activations, activations_alpha, activations_beta, clip); } diff --git a/inference-engine/src/transformations/include/ngraph_ops/lstm_cell_ie.hpp b/inference-engine/src/transformations/include/ngraph_ops/lstm_cell_ie.hpp index 733630b..7b5b9b5 100644 --- a/inference-engine/src/transformations/include/ngraph_ops/lstm_cell_ie.hpp +++ b/inference-engine/src/transformations/include/ngraph_ops/lstm_cell_ie.hpp @@ -41,13 +41,14 @@ public: const std::vector& get_activations_alpha() { return m_activations_alpha; } const std::vector& get_activations_beta() { return m_activations_beta; } float get_clip() {return m_clip;} + bool visit_attributes(AttributeVisitor& visitor) override; protected: int64_t m_hidden_size{}; - const std::vector m_activations; - const std::vector m_activations_alpha; - const std::vector m_activations_beta; + std::vector m_activations; + std::vector m_activations_alpha; + std::vector m_activations_beta; float m_clip; }; diff --git a/inference-engine/src/transformations/include/transformations/gru_cell_decomposition.hpp b/inference-engine/src/transformations/include/transformations/gru_cell_decomposition.hpp new file mode 100644 index 0000000..784df58 --- /dev/null +++ b/inference-engine/src/transformations/include/transformations/gru_cell_decomposition.hpp @@ -0,0 +1,41 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include + +#include + +#include + +namespace ngraph { +namespace pass { + +class TRANSFORMATIONS_API GRUCellDecomposition; + +} // namespace pass +} // namespace ngraph + +/** + * @ingroup ie_transformation_common_api + * @brief GRUCellDecomposition transformation decomposes GRUCell layer with inputs X, H, W, R, B + * to Add, Split, MatMul, Multiply and Subtract ops according to the formula: + (.) - Denotes element-wise multiplication. + * - Denotes dot product. + f, g - are activation functions + + zt = f(Xt*(Wz^T) + Ht-1*(Rz^T) + Wbz + Rbz) + rt = f(Xt*(Wr^T) + Ht-1*(Rr^T) + Wbr + Rbr) + ht = g(Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh) # when linear_before_reset := false # (default) + ht = g(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh) # when linear_before_reset:= true + Ht = (1 - zt) (.) ht + zt (.) Ht-1 + * * + */ + +class ngraph::pass::GRUCellDecomposition: public ngraph::pass::MatcherPass { +public: + GRUCellDecomposition(); +}; diff --git a/inference-engine/src/transformations/include/transformations/lstm_cell_decomposition.hpp b/inference-engine/src/transformations/include/transformations/lstm_cell_decomposition.hpp new file mode 100644 index 0000000..dd381a7 --- /dev/null +++ b/inference-engine/src/transformations/include/transformations/lstm_cell_decomposition.hpp @@ -0,0 +1,42 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include + +#include + +#include + +namespace ngraph { +namespace pass { + +class TRANSFORMATIONS_API LSTMCellDecomposition; + +} // namespace pass +} // namespace ngraph + +/** + * @ingroup ie_transformation_common_api + * @brief LSTMCellDecomposition transformation decomposes LSTMCell layer with inputs X, H, C, W, R, B + * to Add, Split, MatMul, Multiply ops according to the formula: + * (.) - Denotes element-wise multiplication. + * - Denotes dot product. + f, g, h - are activation functions. + + * it = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Wbi + Rbi) + ft = f(Xt*(Wf^T) + Ht-1*(Rf^T) + Wbf + Rbf) + ct = g(Xt*(Wc^T) + Ht-1*(Rc^T) + Wbc + Rbc) + ot = f(Xt*(Wo^T) + Ht-1*(Ro^T) + Wbo + Rbo) + Ct = ft (.) Ct-1 + it (.) ct + Ht = ot (.) h(Ct) + * * + */ + +class ngraph::pass::LSTMCellDecomposition: public ngraph::pass::MatcherPass { +public: + LSTMCellDecomposition(); +}; diff --git a/inference-engine/src/transformations/include/transformations/rnn_cell_decomposition.hpp b/inference-engine/src/transformations/include/transformations/rnn_cell_decomposition.hpp new file mode 100644 index 0000000..bf25e35 --- /dev/null +++ b/inference-engine/src/transformations/include/transformations/rnn_cell_decomposition.hpp @@ -0,0 +1,36 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include + +#include + +#include + +namespace ngraph { +namespace pass { + +class TRANSFORMATIONS_API RNNCellDecomposition; + +} // namespace pass +} // namespace ngraph + +/** + * @ingroup ie_transformation_common_api + * @brief RNNCellDecomposition transformation decomposes RNNCell layer with inputs X, H, W, R, B + * to Add, MatMul ops according to the formula: + * - Denotes dot product. + f - is an activation functions. + + * Ht = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Wbi + Rbi) + * * + */ + +class ngraph::pass::RNNCellDecomposition: public ngraph::pass::MatcherPass { +public: + RNNCellDecomposition(); +}; diff --git a/inference-engine/src/transformations/include/transformations/utils/utils.hpp b/inference-engine/src/transformations/include/transformations/utils/utils.hpp index 2783bab..4bd8586 100644 --- a/inference-engine/src/transformations/include/transformations/utils/utils.hpp +++ b/inference-engine/src/transformations/include/transformations/utils/utils.hpp @@ -101,6 +101,9 @@ TRANSFORMATIONS_API bool has_f16_constants(const std::shared_ptr activation(const std::string& activation_name, + const ngraph::Output& apply_to); + } // namespace util } // namespace op } // namespace ngraph diff --git a/inference-engine/src/transformations/src/ngraph_ops/lstm_cell_ie.cpp b/inference-engine/src/transformations/src/ngraph_ops/lstm_cell_ie.cpp index 196ad4c..58d0fb4 100644 --- a/inference-engine/src/transformations/src/ngraph_ops/lstm_cell_ie.cpp +++ b/inference-engine/src/transformations/src/ngraph_ops/lstm_cell_ie.cpp @@ -37,6 +37,15 @@ void op::LSTMCellIE::validate_and_infer_types() { set_output_type(1, arg_type, output_shape); } +bool ngraph::op::LSTMCellIE::visit_attributes(AttributeVisitor& visitor) { + visitor.on_attribute("hidden_size", m_hidden_size); + visitor.on_attribute("activations", m_activations); + visitor.on_attribute("activations_alpha", m_activations_alpha); + visitor.on_attribute("activations_beta", m_activations_beta); + visitor.on_attribute("clip", m_clip); + return true; +} + shared_ptr op::LSTMCellIE::clone_with_new_inputs(const OutputVector& new_args) const { check_new_args_count(this, new_args); return make_shared(new_args.at(0), new_args.at(1), new_args.at(2), new_args.at(3), new_args.at(4), diff --git a/inference-engine/src/transformations/src/transformations/convert_opset1_to_legacy/convert_cells_to_cells_ie.cpp b/inference-engine/src/transformations/src/transformations/convert_opset1_to_legacy/convert_cells_to_cells_ie.cpp index 91960c9..4bd931d 100644 --- a/inference-engine/src/transformations/src/transformations/convert_opset1_to_legacy/convert_cells_to_cells_ie.cpp +++ b/inference-engine/src/transformations/src/transformations/convert_opset1_to_legacy/convert_cells_to_cells_ie.cpp @@ -9,22 +9,25 @@ #include #include +#include #include #include +#include #include #include #include ngraph::pass::ConvertLSTMCellMatcher::ConvertLSTMCellMatcher() { - auto lstm_cell_ngraph = ngraph::pattern::wrap_type(); - + auto is_supported_lstm_cell = [](const std::shared_ptr& n) { + return pattern::has_class()(n) || pattern::has_class()(n); + }; + auto any_lstm = std::make_shared(element::f32, Shape{}, is_supported_lstm_cell); ngraph::matcher_pass_callback callback = [](pattern::Matcher& m) { - auto lstm_cell = std::dynamic_pointer_cast (m.get_match_root()); + auto lstm_cell = std::dynamic_pointer_cast(m.get_match_root()); if (!lstm_cell) { return false; } - auto W = std::dynamic_pointer_cast (lstm_cell->input_value(3).get_node_shared_ptr()); if (!W) { return false; @@ -53,7 +56,7 @@ ngraph::pass::ConvertLSTMCellMatcher::ConvertLSTMCellMatcher() { return true; }; - auto m = std::make_shared(lstm_cell_ngraph, "ConvertLSTMCellToLSTMCellIE"); + auto m = std::make_shared(any_lstm, "ConvertLSTMCellToLSTMCellIE"); this->register_matcher(m, callback); } diff --git a/inference-engine/src/transformations/src/transformations/gru_cell_decomposition.cpp b/inference-engine/src/transformations/src/transformations/gru_cell_decomposition.cpp new file mode 100644 index 0000000..7489da1 --- /dev/null +++ b/inference-engine/src/transformations/src/transformations/gru_cell_decomposition.cpp @@ -0,0 +1,104 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "transformations/gru_cell_decomposition.hpp" + +#include +#include + +#include +#include +#include +#include + +ngraph::pass::GRUCellDecomposition::GRUCellDecomposition() { + auto gru_cell = ngraph::pattern::wrap_type(); + ngraph::matcher_pass_callback callback = [](ngraph::pattern::Matcher& m) { + auto gru_cell = std::dynamic_pointer_cast (m.get_match_root()); + if (!gru_cell) { + return false; + } + + const Output& X = gru_cell->input_value(0); + const Output& H_t = gru_cell->input_value(1); + const Output& W = gru_cell->input_value(2); + const Output& R = gru_cell->input_value(3); + const Output& B = gru_cell->input_value(4); + + // Xt*(W^T) + auto Xt_W = std::make_shared(X, W, false, true); + // Ht-1*(R^T) + auto Ht_R = std::make_shared(H_t, R, false, true); + + // split to gates: + auto axis_0 = ngraph::opset4::Constant::create(element::i64, Shape{}, {0}); + auto axis_1 = ngraph::opset4::Constant::create(element::i64, Shape{}, {1}); + auto Xt_W_zrh = std::make_shared(Xt_W, axis_1, 3); + auto R_zrh = std::make_shared(R, axis_0, 3); + auto Ht_R_zrh = std::make_shared(Ht_R, axis_1, 3); + auto biases_zrh = std::make_shared(B, axis_0, gru_cell->get_linear_before_reset() ? 4 : 3); + + // Xt*(Wz^T) + Ht-1*(Rz^T) + Wbz + Rbz + auto add_z_1 = std::make_shared(Ht_R_zrh->output(0), biases_zrh->output(0)); + auto add_z_2 = std::make_shared(Xt_W_zrh->output(0), add_z_1); + + // Xt*(Wr^T) + Ht-1*(Rr^T) + Wbr + Rbr + auto add_r_1 = std::make_shared(Ht_R_zrh->output(1), biases_zrh->output(1)); + auto add_r_2 = std::make_shared(Xt_W_zrh->output(1), add_r_1); + + auto clip = gru_cell->get_clip(); + std::shared_ptr clamp_z = add_z_2; + std::shared_ptr clamp_r = add_r_2; + if (clip > 0.f) { + clamp_z = std::make_shared(add_z_2, -clip, clip); + clamp_r = std::make_shared(add_r_2, -clip, clip); + ngraph::copy_runtime_info(gru_cell, {clamp_z, clamp_r}); + } + + // zt = f(Xt*(Wz^T) + Ht-1*(Rz^T) + Wbz + Rbz) + auto z_t = ngraph::op::util::activation(gru_cell->get_activations()[0], clamp_z); + // rt = f(Xt*(Wr^T) + Ht-1*(Rr^T) + Wbr + Rbr) + auto r_t = ngraph::op::util::activation(gru_cell->get_activations()[0], clamp_r); + + std::shared_ptr _h; + if (gru_cell->get_linear_before_reset()) { + // _h = Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh + auto Ht_Rh_Rbh = std::make_shared(Ht_R_zrh->output(2), biases_zrh->output(3)); + auto mul_h_1 = std::make_shared(r_t, Ht_Rh_Rbh); + auto add_h_1 = std::make_shared(mul_h_1, biases_zrh->output(2)); + _h = std::make_shared(Xt_W_zrh->output(2), add_h_1); + ngraph::copy_runtime_info(gru_cell, {Ht_Rh_Rbh, mul_h_1, add_h_1, _h}); + } else { + // _h = Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh + auto rt_Ht = std::make_shared(r_t, H_t); + auto mul_h_1 = std::make_shared(rt_Ht, R_zrh->output(2), false, true); + auto add_h_1 = std::make_shared(mul_h_1, biases_zrh->output(2)); + _h = std::make_shared(Xt_W_zrh->output(2), add_h_1); + ngraph::copy_runtime_info(gru_cell, {rt_Ht, mul_h_1, add_h_1, _h}); + } + // ht = g(_h) + std::shared_ptr clamp_h = _h; + if (clip > 0.f) { + clamp_h = std::make_shared(_h, -clip, clip); + ngraph::copy_runtime_info(gru_cell, clamp_h); + } + auto h_t = ngraph::op::util::activation(gru_cell->get_activations()[1], clamp_h); + + // Ht = (1 - zt) (.) ht + zt (.) Ht-1 + auto one = opset4::Constant::create(z_t->get_element_type(), Shape{1}, {1.f}); + auto sub = std::make_shared(one, z_t); + auto mul_1 = std::make_shared(sub, h_t); + auto mul_2 = std::make_shared(z_t, H_t); + auto out_H = std::make_shared(mul_1, mul_2); + + out_H->set_friendly_name(gru_cell->get_friendly_name()); + ngraph::copy_runtime_info(gru_cell, {Xt_W, Ht_R, axis_0, Xt_W_zrh, R_zrh, Ht_R_zrh, biases_zrh, + add_z_1, add_z_2, add_r_1, add_r_2, h_t, one, sub, mul_1, mul_2, out_H}); + ngraph::replace_node(gru_cell, out_H); + return true; + }; + + auto m = std::make_shared(gru_cell, "GRUCellDecomposition"); + register_matcher(m, callback); +} diff --git a/inference-engine/src/transformations/src/transformations/lstm_cell_decomposition.cpp b/inference-engine/src/transformations/src/transformations/lstm_cell_decomposition.cpp new file mode 100644 index 0000000..3cf6b5b --- /dev/null +++ b/inference-engine/src/transformations/src/transformations/lstm_cell_decomposition.cpp @@ -0,0 +1,85 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "transformations/lstm_cell_decomposition.hpp" + +#include +#include + +#include +#include +#include +#include + +ngraph::pass::LSTMCellDecomposition::LSTMCellDecomposition() { + auto lstm_cell = ngraph::pattern::wrap_type(); + ngraph::matcher_pass_callback callback = [](ngraph::pattern::Matcher& m) { + auto lstm_cell = std::dynamic_pointer_cast (m.get_match_root()); + if (!lstm_cell) { + return false; + } + const Output& X = lstm_cell->input_value(0); + const Output& H_t = lstm_cell->input_value(1); + const Output& C_t = lstm_cell->input_value(2); + const Output& W = lstm_cell->input_value(3); + const Output& R = lstm_cell->input_value(4); + const Output& bias = lstm_cell->input_value(5); + + // Xt*(W^T) + auto Xt_W = std::make_shared(X, W, false, true); + // Ht-1*(R^T) + auto Ht_R = std::make_shared(H_t, R, false, true); + // Xt*(W^T) + Ht-1*(R^T) + Wb + Rb + auto add = std::make_shared(Ht_R, bias); + auto XHB = std::make_shared(Xt_W, add); + + auto axis_node = ngraph::opset4::Constant::create(element::u64, Shape{}, {1}); + auto split = std::make_shared(XHB, axis_node, 4); + Output f = split->output(0); + Output i = split->output(1); + Output c = split->output(2); + Output o = split->output(3); + + auto clip = lstm_cell->get_clip(); + if (clip > 0.f) { + auto clamp_f = std::make_shared(f, -clip, clip); + auto clamp_i = std::make_shared(i, -clip, clip); + auto clamp_c = std::make_shared(c, -clip, clip); + auto clamp_o = std::make_shared(o, -clip, clip); + f = clamp_f; + i = clamp_i; + c = clamp_c; + o = clamp_o; + ngraph::copy_runtime_info(lstm_cell, {clamp_f, clamp_i, clamp_c, clamp_o}); + } + + // ft = f(Xt*(Wf^T) + Ht-1*(Rf^T) + Wbf + Rbf) + // it = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Wbi + Rbi) + // ct = g(Xt*(Wc^T) + Ht-1*(Rc^T) + Wbc + Rbc) + // ot = f(Xt*(Wo^T) + Ht-1*(Ro^T) + Wbo + Rbo) + auto f_t = ngraph::op::util::activation(lstm_cell->get_activations()[0], f); + auto i_t = ngraph::op::util::activation(lstm_cell->get_activations()[0], i); + auto c_t = ngraph::op::util::activation(lstm_cell->get_activations()[1], c); + auto o_t = ngraph::op::util::activation(lstm_cell->get_activations()[0], o); + + // Ct = ft (.) Ct-1 + it (.) ct + auto mul1 = std::make_shared(f_t, C_t); + auto mul2 = std::make_shared(i_t, c_t); + auto out_C = std::make_shared(mul1, mul2); + + // H = ot (.) h(Ct) + auto hC = ngraph::op::util::activation(lstm_cell->get_activations()[2], out_C); + auto out_H = std::make_shared(o_t, hC); + + out_H->set_friendly_name(lstm_cell->get_friendly_name()+".0"); + out_C->set_friendly_name(lstm_cell->get_friendly_name()+".1"); + ngraph::copy_runtime_info(lstm_cell, {Xt_W, Ht_R, add, split, mul1, mul2, out_H, hC, out_C, axis_node, XHB, + f_t, i_t, c_t, o_t}); + ngraph::replace_node(lstm_cell, {out_H->output(0), out_C->output(0)}); + return true; + }; + + auto m = std::make_shared(lstm_cell, "LSTMCellDecomposition"); + register_matcher(m, callback); +} diff --git a/inference-engine/src/transformations/src/transformations/rnn_cell_decomposition.cpp b/inference-engine/src/transformations/src/transformations/rnn_cell_decomposition.cpp new file mode 100644 index 0000000..d02938f --- /dev/null +++ b/inference-engine/src/transformations/src/transformations/rnn_cell_decomposition.cpp @@ -0,0 +1,52 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "transformations/rnn_cell_decomposition.hpp" + +#include +#include + +#include +#include +#include +#include + +ngraph::pass::RNNCellDecomposition::RNNCellDecomposition() { + auto rnn_cell = ngraph::pattern::wrap_type(); + ngraph::matcher_pass_callback callback = [](ngraph::pattern::Matcher& m) { + auto rnn_cell = std::dynamic_pointer_cast (m.get_match_root()); + if (!rnn_cell) { + return false; + } + const Output& X = rnn_cell->input_value(0); + const Output& H_t = rnn_cell->input_value(1); + const Output& W = rnn_cell->input_value(2); + const Output& R = rnn_cell->input_value(3); + const Output& bias = rnn_cell->input_value(4); + + // Xt*(W^T) + auto Xt_W = std::make_shared(X, W, false, true); + // Ht-1*(R^T) + auto Ht_R = std::make_shared(H_t, R, false, true); + // Xt*(W^T) + Ht-1*(R^T) + Wb + Rb + auto add = std::make_shared(Ht_R, bias); + auto i_t = std::make_shared(Xt_W, add); + + // f(Xt*(Wi^T) + Ht-1*(Ri^T) + Wbi + Rbi) + auto clip = rnn_cell->get_clip(); + std::shared_ptr clamp = i_t; + if (clip > 0.f) { + clamp = std::make_shared(i_t, -clip, clip); + ngraph::copy_runtime_info(rnn_cell, clamp); + } + auto out = ngraph::op::util::activation(rnn_cell->get_activations()[0], clamp); + out->set_friendly_name(rnn_cell->get_friendly_name()); + ngraph::copy_runtime_info(rnn_cell, {Xt_W, Ht_R, add, i_t, out}); + ngraph::replace_node(rnn_cell, out); + return true; + }; + + auto m = std::make_shared(rnn_cell, "RNNCellDecomposition"); + register_matcher(m, callback); +} diff --git a/inference-engine/src/transformations/src/transformations/utils/utils.cpp b/inference-engine/src/transformations/src/transformations/utils/utils.cpp index dbd2e21..8f94b57 100644 --- a/inference-engine/src/transformations/src/transformations/utils/utils.cpp +++ b/inference-engine/src/transformations/src/transformations/utils/utils.cpp @@ -108,6 +108,18 @@ bool check_for_broadcast(const ngraph::Shape &ref_shape, const ngraph::Shape &ot return false; } +std::shared_ptr activation(const std::string& activation_name, const ngraph::Output& apply_to) { + if (activation_name == "relu") { + return std::make_shared(apply_to); + } else if (activation_name == "sigmoid") { + return std::make_shared(apply_to); + } else if (activation_name == "tanh") { + return std::make_shared(apply_to); + } else { + throw ngraph_error("Unsupported activation function"); + } +} + } // namespace util } // namespace op } // namespace ngraph diff --git a/inference-engine/tests/functional/inference_engine/ngraph_reader/ti.cpp b/inference-engine/tests/functional/inference_engine/ngraph_reader/ti.cpp index 1c97a0b..f2e4193 100644 --- a/inference-engine/tests/functional/inference_engine/ngraph_reader/ti.cpp +++ b/inference-engine/tests/functional/inference_engine/ngraph_reader/ti.cpp @@ -4,7 +4,7 @@ #include #include "ngraph_reader_tests.hpp" -TEST_F(NGraphReaderTests, ReadTensorIteratorNetwork) { +TEST_F(NGraphReaderTests, ReadTensorIteratorNetwork_opset1) { std::string model_v10 = R"V0G0N( @@ -457,7 +457,7 @@ TEST_F(NGraphReaderTests, ReadTensorIteratorNetwork) { }); } -TEST_F(NGraphReaderTests, ReadTensorIteratorNetwork_resnet) { +TEST_F(NGraphReaderTests, ReadTensorIteratorNetwork_resnet_opset1) { std::string model_v10 = R"V0G0N( @@ -948,7 +948,7 @@ TEST_F(NGraphReaderTests, ReadTensorIteratorNetwork_resnet) { }); } -TEST_F(NGraphReaderTests, ReadTensorIteratorNetwork_negative_stride) { +TEST_F(NGraphReaderTests, ReadTensorIteratorNetwork_negative_stride_opset1) { std::string model_v10 = R"V0G0N( @@ -1400,3 +1400,1400 @@ TEST_F(NGraphReaderTests, ReadTensorIteratorNetwork_negative_stride) { data[393732] = 256; }); } + +TEST_F(NGraphReaderTests, ReadTensorIteratorNetwork_opset4) { + std::string model_v10 = R"V0G0N( + + + + + + + 1 + 25 + 512 + + + + + + + + 1 + 256 + + + + + + + + 1 + 256 + + + + + + + 1 + 25 + 512 + + + 1 + 256 + + + 1 + 256 + + + + + 1 + 25 + 256 + + + + + + + + + + + + + + + + + + + 1 + 1 + 512 + + + + + + + + 2 + + + + + + + + 1 + 1 + 512 + + + 2 + + + + + 1 + 512 + + + + + + + + 1 + 256 + + + + + + + + 1 + 256 + + + + + + + + 1024 + 512 + + + + + + + + 1024 + 256 + + + + + + + + 1024 + + + + + + + + 1 + 512 + + + 1 + 256 + + + 1 + 256 + + + 1024 + 512 + + + 1024 + 256 + + + 1024 + + + + + 1 + 256 + + + 1 + 256 + + + + + + + 1 + 256 + + + + + + + 1 + 256 + + + + + + + + 3 + + + + + + + + 1 + 256 + + + 3 + + + + + 1 + 1 + 256 + + + + + + + 1 + 1 + 256 + + + + + + + + + + + + + + + + + + + + + + + + + 1 + 25 + 256 + + + + + + + + + + + + )V0G0N"; + std::string model_v6 = R"VOGON( + + + + + + 1 + 25 + 512 + + + + + + + 1 + 256 + + + + + + + 1 + 256 + + + + + + + 1 + 25 + 512 + + + 1 + 256 + + + 1 + 256 + + + + + 1 + 25 + 256 + + + + + + + + + + + + + + + + + + 2 + + + + + + + + + + 1 + 1 + 512 + + + 2 + + + + + 1 + 512 + + + + + + + + 1 + 512 + + + 1 + 256 + + + 1 + 256 + + + + + 1 + 256 + + + 1 + 256 + + + + + + + + + + + 3 + + + + + + + + + + 1 + 256 + + + 3 + + + + + 1 + 1 + 256 + + + + + + + + + + + + + + + + + + + + )VOGON"; + + compareIRs(model_v10, model_v6, 3149864, [](Blob::Ptr& weights) { + auto *data = weights->buffer().as(); + data[0] = 1; + data[1] = 512; + + data[393730] = 1; + data[393731] = 1; + data[393732] = 256; + }); +} + +TEST_F(NGraphReaderTests, ReadTensorIteratorNetwork_resnet_opset4) { + std::string model_v10 = R"V0G0N( + + + + + + + 16 + 1 + 512 + + + + + + + + 1 + 512 + + + + + + + + 1 + 512 + + + + + + + 16 + 1 + 512 + + + 1 + 512 + + + 1 + 512 + + + + + 16 + 1 + 512 + + + 1 + 512 + + + 1 + 512 + + + + + + + + + + + + + + + + + + + + + 1 + 1 + 512 + + + + + + + + 2 + + + + + + + + 1 + 1 + 512 + + + 2 + + + + + 1 + 512 + + + + + + + + 1 + 512 + + + + + + + + 1 + 512 + + + + + + + + 2048 + 512 + + + + + + + + 2048 + 512 + + + + + + + + 2048 + + + + + + + + 1 + 512 + + + 1 + 512 + + + 1 + 512 + + + 2048 + 512 + + + 2048 + 512 + + + 2048 + + + + + 1 + 512 + + + 1 + 512 + + + + + + + 1 + 512 + + + + + + + 1 + 512 + + + + + + + + 3 + + + + + + + + 1 + 512 + + + 3 + + + + + 1 + 1 + 512 + + + + + + + 1 + 1 + 512 + + + + + + + + + + + + + + + + + + + + + + + + + 16 + 1 + 512 + + + + + + + 1 + 512 + + + + + + + 1 + 512 + + + + + + + + + + + + + + )V0G0N"; + std::string model_v6 = R"V0G0N( + + + + + + 16 + 1 + 512 + + + + + + + 1 + 512 + + + + + + + 1 + 512 + + + + + + + 16 + 1 + 512 + + + 1 + 512 + + + 1 + 512 + + + + + 16 + 1 + 512 + + + 1 + 512 + + + 1 + 512 + + + + + + + + + + + + + + + + + + + + 2 + + + + + + + + + + 1 + 1 + 512 + + + 2 + + + + + 1 + 512 + + + + + + + + 1 + 512 + + + 1 + 512 + + + 1 + 512 + + + + + 1 + 512 + + + 1 + 512 + + + + + + + + + + + 3 + + + + + + + + + + 1 + 512 + + + 3 + + + + + 1 + 1 + 512 + + + + + + + + + + + + + + + + + + + + )V0G0N"; + + compareIRs(model_v10, model_v6, 8396840, [](Blob::Ptr& weights) { + auto *data = weights->buffer().as(); + data[0] = 1; + data[1] = 512; + + data[1049602] = 1; + data[1049603] = 1; + data[1049604] = 512; + }); +} + +TEST_F(NGraphReaderTests, ReadTensorIteratorNetwork_negative_stride_opset4) { + std::string model_v10 = R"V0G0N( + + + + + + + 1 + 25 + 512 + + + + + + + + 1 + 256 + + + + + + + + 1 + 256 + + + + + + + 1 + 25 + 512 + + + 1 + 256 + + + 1 + 256 + + + + + 1 + 25 + 256 + + + + + + + + + + + + + + + + + + + 1 + 1 + 512 + + + + + + + + 2 + + + + + + + + 1 + 1 + 512 + + + 2 + + + + + 1 + 512 + + + + + + + + 1 + 256 + + + + + + + + 1 + 256 + + + + + + + + 1024 + 512 + + + + + + + + 1024 + 256 + + + + + + + + 1024 + + + + + + + + 1 + 512 + + + 1 + 256 + + + 1 + 256 + + + 1024 + 512 + + + 1024 + 256 + + + 1024 + + + + + 1 + 256 + + + 1 + 256 + + + + + + + 1 + 256 + + + + + + + 1 + 256 + + + + + + + + 3 + + + + + + + + 1 + 256 + + + 3 + + + + + 1 + 1 + 256 + + + + + + + 1 + 1 + 256 + + + + + + + + + + + + + + + + + + + + + + + + + 1 + 25 + 256 + + + + + + + + + + + + )V0G0N"; + std::string model_v6 = R"VOGON( + + + + + + 1 + 25 + 512 + + + + + + + 1 + 256 + + + + + + + 1 + 256 + + + + + + + 1 + 25 + 512 + + + 1 + 256 + + + 1 + 256 + + + + + 1 + 25 + 256 + + + + + + + + + + + + + + + + + + 2 + + + + + + + + + + 1 + 1 + 512 + + + 2 + + + + + 1 + 512 + + + + + + + + 1 + 512 + + + 1 + 256 + + + 1 + 256 + + + + + 1 + 256 + + + 1 + 256 + + + + + + + + + + + 3 + + + + + + + + + + 1 + 256 + + + 3 + + + + + 1 + 1 + 256 + + + + + + + + + + + + + + + + + + + + )VOGON"; + + compareIRs(model_v10, model_v6, 3149864, [](Blob::Ptr& weights) { + auto *data = weights->buffer().as(); + data[0] = 1; + data[1] = 512; + + data[393730] = 1; + data[393731] = 1; + data[393732] = 256; + }); +} diff --git a/inference-engine/tests/functional/inference_engine/transformations/convert_cells_to_cells_ie_test.cpp b/inference-engine/tests/functional/inference_engine/transformations/convert_cells_to_cells_ie_test.cpp index 516cf85..3f6bff5 100644 --- a/inference-engine/tests/functional/inference_engine/transformations/convert_cells_to_cells_ie_test.cpp +++ b/inference-engine/tests/functional/inference_engine/transformations/convert_cells_to_cells_ie_test.cpp @@ -14,6 +14,7 @@ #include #include #include +#include #include #include #include @@ -129,7 +130,7 @@ TEST(TransformationTests, RNNCellConversionTest) { ASSERT_TRUE(cell_node->get_friendly_name() == "test_cell") << "Transformation ConvertRNNCellToRNNCellIE should keep output names.\n"; } -TEST(TransformationTests, LSTMCellConversionTest) { +TEST(TransformationTests, LSTMCellConversionTest_opset3) { const size_t batch_size = 2; const size_t input_size = 3; const size_t hidden_size = 3; @@ -186,4 +187,76 @@ TEST(TransformationTests, LSTMCellConversionTest) { auto result_node_of_converted_f = f->get_output_op(0); auto cell_node = result_node_of_converted_f->input(0).get_source_output().get_node_shared_ptr(); ASSERT_TRUE(cell_node->get_friendly_name() == "test_cell") << "Transformation ConvertLSTMCellToLSTMCellIE should keep output names.\n"; -} \ No newline at end of file +} + +TEST(TransformationTests, LSTMCellConversionTest_opset4) { + const size_t batch_size = 2; + const size_t input_size = 3; + const size_t hidden_size = 3; + const size_t gates_count = 4; + + std::shared_ptr f(nullptr), f_ref(nullptr); + std::shared_ptr cell; + { + const auto X = std::make_shared(ngraph::element::f32, + ngraph::Shape{batch_size, input_size}); + const auto W = + std::make_shared(ngraph::element::f32, + ngraph::Shape{gates_count * hidden_size, input_size}); + const auto R = + std::make_shared(ngraph::element::f32, + ngraph::Shape{gates_count * hidden_size, hidden_size}); + const auto H_t = std::make_shared(ngraph::element::f32, + ngraph::Shape{batch_size, hidden_size}); + const auto C_t = std::make_shared(ngraph::element::f32, + ngraph::Shape{batch_size, hidden_size}); + const auto B = std::make_shared(ngraph::element::f32, + ngraph::Shape{gates_count * hidden_size}); + + cell = std::make_shared(X, H_t, C_t, W, R, B, hidden_size); + cell->set_friendly_name("test_cell"); + + f = std::make_shared(ngraph::NodeVector{cell}, ngraph::ParameterVector{X, H_t, C_t}); + ngraph::pass::Manager manager; + manager.register_pass(); + manager.register_pass(); + manager.run_passes(f); + ASSERT_NO_THROW(check_rt_info(f)); + } + + { + const auto X = std::make_shared(ngraph::element::f32, + ngraph::Shape{batch_size, input_size}); + const auto W = + std::make_shared(ngraph::element::f32, + ngraph::Shape{gates_count * hidden_size, input_size}); + const auto R = + std::make_shared(ngraph::element::f32, + ngraph::Shape{gates_count * hidden_size, hidden_size}); + const auto H_t = std::make_shared(ngraph::element::f32, + ngraph::Shape{batch_size, hidden_size}); + const auto C_t = std::make_shared(ngraph::element::f32, + ngraph::Shape{batch_size, hidden_size}); + const auto B = std::make_shared(ngraph::element::f32, + ngraph::Shape{gates_count * hidden_size}); + + auto concat = std::make_shared(ngraph::NodeVector({W, R}), 1); + auto cell_ie = std::make_shared(X, H_t, C_t, concat, B, + cell->get_hidden_size(), + cell->get_activations(), + cell->get_activations_alpha(), + cell->get_activations_beta(), + cell->get_clip()); + cell_ie->set_friendly_name("test_cell"); + + f_ref = std::make_shared(ngraph::NodeVector{cell_ie}, ngraph::ParameterVector{X, H_t, C_t}); + } + + auto res = compare_functions(f, f_ref); + ASSERT_TRUE(res.first) << res.second; + + auto result_node_of_converted_f = f->get_output_op(0); + auto cell_node = result_node_of_converted_f->input(0).get_source_output().get_node_shared_ptr(); + ASSERT_TRUE(cell_node->get_friendly_name() == "test_cell") + << "Transformation ConvertLSTMCellToLSTMCellIE should keep output names.\n"; +} diff --git a/inference-engine/tests/functional/plugin/cpu/shared_tests_instances/single_layer_tests/gru_cell.cpp b/inference-engine/tests/functional/plugin/cpu/shared_tests_instances/single_layer_tests/gru_cell.cpp new file mode 100644 index 0000000..4d015df --- /dev/null +++ b/inference-engine/tests/functional/plugin/cpu/shared_tests_instances/single_layer_tests/gru_cell.cpp @@ -0,0 +1,37 @@ +// Copyright (C) 2019 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include + +#include "single_layer_tests/gru_cell.hpp" +#include "common_test_utils/test_constants.hpp" + +using namespace LayerTestsDefinitions; + +namespace { + std::vector should_decompose{false, true}; + std::vector batch{5}; + std::vector hidden_size{1, 10}; + std::vector input_size{1, 30}; + std::vector> activations = {{"relu", "tanh"}, {"tanh", "sigmoid"}, {"sigmoid", "tanh"}, + {"tanh", "relu"}}; + std::vector clip = {0.0f, 0.7f}; + std::vector linear_before_reset = {true, false}; + std::vector netPrecisions = {InferenceEngine::Precision::FP32, + InferenceEngine::Precision::FP16}; + + INSTANTIATE_TEST_CASE_P(GRUCellCommon, GRUCellTest, + ::testing::Combine( + ::testing::ValuesIn(should_decompose), + ::testing::ValuesIn(batch), + ::testing::ValuesIn(hidden_size), + ::testing::ValuesIn(input_size), + ::testing::ValuesIn(activations), + ::testing::ValuesIn(clip), + ::testing::ValuesIn(linear_before_reset), + ::testing::ValuesIn(netPrecisions), + ::testing::Values(CommonTestUtils::DEVICE_CPU)), + GRUCellTest::getTestCaseName); + +} // namespace diff --git a/inference-engine/tests/functional/plugin/cpu/shared_tests_instances/single_layer_tests/lstm_cell.cpp b/inference-engine/tests/functional/plugin/cpu/shared_tests_instances/single_layer_tests/lstm_cell.cpp new file mode 100644 index 0000000..abf5114 --- /dev/null +++ b/inference-engine/tests/functional/plugin/cpu/shared_tests_instances/single_layer_tests/lstm_cell.cpp @@ -0,0 +1,36 @@ +// Copyright (C) 2019 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include + +#include "single_layer_tests/lstm_cell.hpp" +#include "common_test_utils/test_constants.hpp" + +using namespace LayerTestsDefinitions; + +namespace { + std::vector should_decompose{false, true}; + std::vector batch{5}; + std::vector hidden_size{1, 10}; + std::vector input_size{1, 30}; + std::vector> activations = {{"relu", "sigmoid", "tanh"}, {"sigmoid", "tanh", "tanh"}, + {"tanh", "relu", "sigmoid"}, {"sigmoid", "sigmoid", "sigmoid"}, + {"tanh", "tanh", "tanh"}, {"relu", "relu", "relu"}}; + std::vector clip{0.f, 0.7f}; + std::vector netPrecisions = {InferenceEngine::Precision::FP32, + InferenceEngine::Precision::FP16}; + + INSTANTIATE_TEST_CASE_P(LSTMCellCommon, LSTMCellTest, + ::testing::Combine( + ::testing::ValuesIn(should_decompose), + ::testing::ValuesIn(batch), + ::testing::ValuesIn(hidden_size), + ::testing::ValuesIn(input_size), + ::testing::ValuesIn(activations), + ::testing::ValuesIn(clip), + ::testing::ValuesIn(netPrecisions), + ::testing::Values(CommonTestUtils::DEVICE_CPU)), + LSTMCellTest::getTestCaseName); + +} // namespace diff --git a/inference-engine/tests/functional/plugin/cpu/shared_tests_instances/single_layer_tests/rnn_cell.cpp b/inference-engine/tests/functional/plugin/cpu/shared_tests_instances/single_layer_tests/rnn_cell.cpp new file mode 100644 index 0000000..cf9f572 --- /dev/null +++ b/inference-engine/tests/functional/plugin/cpu/shared_tests_instances/single_layer_tests/rnn_cell.cpp @@ -0,0 +1,34 @@ +// Copyright (C) 2019 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include + +#include "single_layer_tests/rnn_cell.hpp" +#include "common_test_utils/test_constants.hpp" + +using namespace LayerTestsDefinitions; + +namespace { + std::vector should_decompose{false, true}; + std::vector batch{1, 5}; + std::vector hidden_size{1, 10}; + std::vector input_size{1, 30}; + std::vector> activations = {{"relu"}, {"sigmoid"}, {"tanh"}}; + std::vector clip = {0.f, 0.7f}; + std::vector netPrecisions = {InferenceEngine::Precision::FP32, + InferenceEngine::Precision::FP16}; + + INSTANTIATE_TEST_CASE_P(RNNCellCommon, RNNCellTest, + ::testing::Combine( + ::testing::ValuesIn(should_decompose), + ::testing::ValuesIn(batch), + ::testing::ValuesIn(hidden_size), + ::testing::ValuesIn(input_size), + ::testing::ValuesIn(activations), + ::testing::ValuesIn(clip), + ::testing::ValuesIn(netPrecisions), + ::testing::Values(CommonTestUtils::DEVICE_CPU)), + RNNCellTest::getTestCaseName); + +} // namespace diff --git a/inference-engine/tests/functional/plugin/shared/include/single_layer_tests/gru_cell.hpp b/inference-engine/tests/functional/plugin/shared/include/single_layer_tests/gru_cell.hpp new file mode 100644 index 0000000..72f7a4f --- /dev/null +++ b/inference-engine/tests/functional/plugin/shared/include/single_layer_tests/gru_cell.hpp @@ -0,0 +1,38 @@ +// Copyright (C) 2019 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include +#include + +#include "functional_test_utils/layer_test_utils.hpp" +#include "ngraph_functions/builders.hpp" +#include "ngraph_functions/utils/ngraph_helpers.hpp" + +namespace LayerTestsDefinitions { + +using GRUCellParams = typename std::tuple< + bool, // using decompose to sub-ops transformation + size_t, // batch + size_t, // hidden size + size_t, // input size + std::vector, // activations + float, // clip + bool, // linear_before_reset + InferenceEngine::Precision, // Network precision + std::string>; // Device name + +class GRUCellTest : public testing::WithParamInterface, + virtual public LayerTestsUtils::LayerTestsCommon { +public: + static std::string getTestCaseName(const testing::TestParamInfo &obj); + +protected: + void SetUp() override; +}; + +} // namespace LayerTestsDefinitions diff --git a/inference-engine/tests/functional/plugin/shared/include/single_layer_tests/lstm_cell.hpp b/inference-engine/tests/functional/plugin/shared/include/single_layer_tests/lstm_cell.hpp new file mode 100644 index 0000000..c43a8a9 --- /dev/null +++ b/inference-engine/tests/functional/plugin/shared/include/single_layer_tests/lstm_cell.hpp @@ -0,0 +1,37 @@ +// Copyright (C) 2019 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include +#include + +#include "functional_test_utils/layer_test_utils.hpp" +#include "ngraph_functions/builders.hpp" +#include "ngraph_functions/utils/ngraph_helpers.hpp" + +namespace LayerTestsDefinitions { + +using LSTMCellParams = typename std::tuple< + bool, // using decompose to sub-ops transformation + size_t, // batch + size_t, // hidden size + size_t, // input size + std::vector, // activations + float, // clip + InferenceEngine::Precision, // Network precision + std::string>; // Device name + +class LSTMCellTest : public testing::WithParamInterface, + virtual public LayerTestsUtils::LayerTestsCommon { +public: + static std::string getTestCaseName(const testing::TestParamInfo &obj); + +protected: + void SetUp() override; +}; + +} // namespace LayerTestsDefinitions diff --git a/inference-engine/tests/functional/plugin/shared/include/single_layer_tests/rnn_cell.hpp b/inference-engine/tests/functional/plugin/shared/include/single_layer_tests/rnn_cell.hpp new file mode 100644 index 0000000..8e6a961 --- /dev/null +++ b/inference-engine/tests/functional/plugin/shared/include/single_layer_tests/rnn_cell.hpp @@ -0,0 +1,37 @@ +// Copyright (C) 2019 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include +#include + +#include "functional_test_utils/layer_test_utils.hpp" +#include "ngraph_functions/builders.hpp" +#include "ngraph_functions/utils/ngraph_helpers.hpp" + +namespace LayerTestsDefinitions { + +using RNNCellParams = typename std::tuple< + bool, // using decompose to sub-ops transformation + size_t, // batch + size_t, // hidden size + size_t, // input size + std::vector, // activations + float, // clip + InferenceEngine::Precision, // Network precision + std::string>; // Device name + +class RNNCellTest : public testing::WithParamInterface, + virtual public LayerTestsUtils::LayerTestsCommon { +public: + static std::string getTestCaseName(const testing::TestParamInfo &obj); + +protected: + void SetUp() override; +}; + +} // namespace LayerTestsDefinitions diff --git a/inference-engine/tests/functional/plugin/shared/src/single_layer_tests/gru_cell.cpp b/inference-engine/tests/functional/plugin/shared/src/single_layer_tests/gru_cell.cpp new file mode 100644 index 0000000..0750819 --- /dev/null +++ b/inference-engine/tests/functional/plugin/shared/src/single_layer_tests/gru_cell.cpp @@ -0,0 +1,90 @@ +// Copyright (C) 2019 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include +#include +#include +#include + +#include "ie_core.hpp" + +#include "common_test_utils/common_utils.hpp" +#include "functional_test_utils/blob_utils.hpp" +#include "functional_test_utils/precision_utils.hpp" +#include "functional_test_utils/plugin_cache.hpp" +#include "functional_test_utils/skip_tests_config.hpp" + +#include +#include "single_layer_tests/gru_cell.hpp" + +namespace LayerTestsDefinitions { + +std::string GRUCellTest::getTestCaseName(const testing::TestParamInfo &obj) { + bool should_decompose; + size_t batch; + size_t hidden_size; + size_t input_size; + std::vector activations; + std::vector activations_alpha; + std::vector activations_beta; + float clip; + bool linear_before_reset; + std::vector> inputShapes; + InferenceEngine::Precision netPrecision; + std::string targetDevice; + std::tie(should_decompose, batch, hidden_size, input_size, activations, clip, + linear_before_reset, netPrecision, targetDevice) = obj.param; + std::ostringstream result; + result << "decomposition" << should_decompose << "_"; + result << "batch=" << batch << "_"; + result << "hidden_size=" << hidden_size << "_"; + result << "input_size=" << input_size << "_"; + result << "IS=" << CommonTestUtils::vec2str(inputShapes) << "_"; + result << "activations=" << CommonTestUtils::vec2str(activations) << "_"; + result << "clip=" << clip << "_"; + result << "linear_before_reset=" << linear_before_reset << "_"; + result << "netPRC=" << netPrecision.name() << "_"; + result << "targetDevice=" << targetDevice << "_"; + return result.str(); +} + +void GRUCellTest::SetUp() { + bool should_decompose; + size_t batch; + size_t hidden_size; + size_t input_size; + std::vector activations; + std::vector activations_alpha; + std::vector activations_beta; + float clip; + bool linear_before_reset; + InferenceEngine::Precision netPrecision; + std::tie(should_decompose, batch, hidden_size, input_size, activations, clip, linear_before_reset, + netPrecision, targetDevice) = this->GetParam(); + + std::vector> inputShapes = { + {{batch, input_size}, {batch, hidden_size}, {3 * hidden_size, input_size}, + {3 * hidden_size, hidden_size}, {(linear_before_reset? 4 : 3) * hidden_size}}, + }; + + auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(netPrecision); + auto params = ngraph::builder::makeParams(ngPrc, {inputShapes[0], inputShapes[1]}); + std::vector WRB = {inputShapes[2], inputShapes[3], inputShapes[4]}; + auto gru_cell = ngraph::builder::makeGRUCell(ngraph::helpers::convert2OutputVector(ngraph::helpers::castOps2Nodes(params)), + WRB, hidden_size, activations, {}, {}, clip, linear_before_reset); + ngraph::ResultVector results{std::make_shared(gru_cell->output(0))}; + function = std::make_shared(results, params, "gru_cell"); + if (should_decompose) { + ngraph::pass::Manager m; + m.register_pass(); + m.run_passes(function); + } +} + + +TEST_P(GRUCellTest, CompareWithRefs) { + Run(); +}; +} // namespace LayerTestsDefinitions diff --git a/inference-engine/tests/functional/plugin/shared/src/single_layer_tests/lstm_cell.cpp b/inference-engine/tests/functional/plugin/shared/src/single_layer_tests/lstm_cell.cpp new file mode 100644 index 0000000..2c8c9c7 --- /dev/null +++ b/inference-engine/tests/functional/plugin/shared/src/single_layer_tests/lstm_cell.cpp @@ -0,0 +1,89 @@ +// Copyright (C) 2019 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include +#include +#include +#include + +#include "ie_core.hpp" + +#include "common_test_utils/common_utils.hpp" +#include "functional_test_utils/blob_utils.hpp" +#include "functional_test_utils/precision_utils.hpp" +#include "functional_test_utils/plugin_cache.hpp" +#include "functional_test_utils/skip_tests_config.hpp" + +#include +#include "single_layer_tests/lstm_cell.hpp" + +namespace LayerTestsDefinitions { + +std::string LSTMCellTest::getTestCaseName(const testing::TestParamInfo &obj) { + bool should_decompose; + size_t batch; + size_t hidden_size; + size_t input_size; + std::vector activations; + std::vector activations_alpha; + std::vector activations_beta; + float clip; + InferenceEngine::Precision netPrecision; + std::string targetDevice; + std::tie(should_decompose, batch, hidden_size, input_size, activations, clip, netPrecision, + targetDevice) = obj.param; + std::vector> inputShapes = { + {{batch, input_size}, {batch, hidden_size}, {batch, hidden_size}, {4 * hidden_size, input_size}, + {4 * hidden_size, hidden_size}, {4 * hidden_size}}, + }; + std::ostringstream result; + result << "decomposition" << should_decompose << "_"; + result << "batch=" << batch << "_"; + result << "hidden_size=" << hidden_size << "_"; + result << "input_size=" << input_size << "_"; + result << "IS=" << CommonTestUtils::vec2str(inputShapes) << "_"; + result << "activations=" << CommonTestUtils::vec2str(activations) << "_"; + result << "clip=" << clip << "_"; + result << "netPRC=" << netPrecision.name() << "_"; + result << "targetDevice=" << targetDevice << "_"; + return result.str(); +} + +void LSTMCellTest::SetUp() { + bool should_decompose; + size_t batch; + size_t hidden_size; + size_t input_size; + std::vector activations; + std::vector activations_alpha; + std::vector activations_beta; + float clip; + InferenceEngine::Precision netPrecision; + std::tie(should_decompose, batch, hidden_size, input_size, activations, clip, netPrecision, + targetDevice) = this->GetParam(); + std::vector> inputShapes = { + {{batch, input_size}, {batch, hidden_size}, {batch, hidden_size}, {4 * hidden_size, input_size}, + {4 * hidden_size, hidden_size}, {4 * hidden_size}}, + }; + auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(netPrecision); + auto params = ngraph::builder::makeParams(ngPrc, {inputShapes[0], inputShapes[1], inputShapes[2]}); + std::vector WRB = {inputShapes[3], inputShapes[4], inputShapes[5]}; + auto lstm_cell = ngraph::builder::makeLSTMCell(ngraph::helpers::convert2OutputVector(ngraph::helpers::castOps2Nodes(params)), + WRB, hidden_size, activations, {}, {}, clip); + ngraph::ResultVector results{std::make_shared(lstm_cell->output(0)), + std::make_shared(lstm_cell->output(1))}; + function = std::make_shared(results, params, "lstm_cell"); + if (should_decompose) { + ngraph::pass::Manager m; + m.register_pass(); + m.run_passes(function); + } +} + + +TEST_P(LSTMCellTest, CompareWithRefs) { + Run(); +}; +} // namespace LayerTestsDefinitions diff --git a/inference-engine/tests/functional/plugin/shared/src/single_layer_tests/rnn_cell.cpp b/inference-engine/tests/functional/plugin/shared/src/single_layer_tests/rnn_cell.cpp new file mode 100644 index 0000000..97c1c08 --- /dev/null +++ b/inference-engine/tests/functional/plugin/shared/src/single_layer_tests/rnn_cell.cpp @@ -0,0 +1,82 @@ +// Copyright (C) 2019 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include +#include +#include +#include + +#include "ie_core.hpp" + +#include "common_test_utils/common_utils.hpp" +#include "functional_test_utils/blob_utils.hpp" +#include "functional_test_utils/precision_utils.hpp" +#include "functional_test_utils/plugin_cache.hpp" +#include "functional_test_utils/skip_tests_config.hpp" + +#include +#include "single_layer_tests/rnn_cell.hpp" + +namespace LayerTestsDefinitions { + +std::string RNNCellTest::getTestCaseName(const testing::TestParamInfo &obj) { + bool should_decompose; + size_t batch; + size_t hidden_size; + size_t input_size; + std::vector activations; + float clip; + InferenceEngine::Precision netPrecision; + std::string targetDevice; + std::tie(should_decompose, batch, hidden_size, input_size, activations, clip, + netPrecision, targetDevice) = obj.param; + std::vector> inputShapes = {{batch, input_size}, {batch, hidden_size}, + {hidden_size, input_size}, {hidden_size, hidden_size}, {hidden_size}}; + std::ostringstream result; + result << "decomposition" << should_decompose << "_"; + result << "batch=" << batch << "_"; + result << "hidden_size=" << hidden_size << "_"; + result << "input_size=" << input_size << "_"; + result << "IS=" << CommonTestUtils::vec2str(inputShapes) << "_"; + result << "activations=" << CommonTestUtils::vec2str(activations) << "_"; + result << "clip=" << clip << "_"; + result << "netPRC=" << netPrecision.name() << "_"; + result << "targetDevice=" << targetDevice << "_"; + return result.str(); +} + +void RNNCellTest::SetUp() { + bool should_decompose; + size_t batch; + size_t hidden_size; + size_t input_size; + std::vector activations; + std::vector activations_alpha; + std::vector activations_beta; + float clip; + InferenceEngine::Precision netPrecision; + std::tie(should_decompose, batch, hidden_size, input_size, activations, clip, + netPrecision, targetDevice) = this->GetParam(); + std::vector> inputShapes = {{batch, input_size}, {batch, hidden_size}, + {hidden_size, input_size}, {hidden_size, hidden_size}, {hidden_size}}; + auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(netPrecision); + auto params = ngraph::builder::makeParams(ngPrc, {inputShapes[0], inputShapes[1]}); + std::vector WRB = {inputShapes[2], inputShapes[3], inputShapes[4]}; + auto rnn_cell = ngraph::builder::makeRNNCell(ngraph::helpers::convert2OutputVector(ngraph::helpers::castOps2Nodes(params)), + WRB, hidden_size, activations, {}, {}, clip); + ngraph::ResultVector results{std::make_shared(rnn_cell)}; + function = std::make_shared(results, params, "rnn_cell"); + if (should_decompose) { + ngraph::pass::Manager m; + m.register_pass(); + m.run_passes(function); + } +} + + +TEST_P(RNNCellTest, CompareWithRefs) { + Run(); +}; +} // namespace LayerTestsDefinitions diff --git a/inference-engine/tests/functional/plugin/shared/src/subgraph_tests/basic_lstm.cpp b/inference-engine/tests/functional/plugin/shared/src/subgraph_tests/basic_lstm.cpp index f98f86b..f2e03f7 100644 --- a/inference-engine/tests/functional/plugin/shared/src/subgraph_tests/basic_lstm.cpp +++ b/inference-engine/tests/functional/plugin/shared/src/subgraph_tests/basic_lstm.cpp @@ -72,7 +72,7 @@ void Basic_LSTM_S::SetUp() { //lstm [1, 10], [1, 118], [1, 118] -> [1, 118], [1, 118] outFormShapes1 = { batch_size, reshape1_shape[2] }; auto constantX = std::make_shared(ngraph::element::i64, ngraph::Shape{2}, outFormShapes1); - auto lstm1 = std::make_shared(std::make_shared(X, constantX, false), + auto lstm1 = std::make_shared(std::make_shared(X, constantX, false), H_t, C_t, weightsNode, reccurrenceWeightsNode, hidden_size); @@ -137,7 +137,7 @@ std::shared_ptr Basic_LSTM_S::CreateGraphWithUnrolledTI() { ngraph::Output H[iterations + 1]; ngraph::Output C[iterations + 1]; - std::shared_ptr lstm[iterations]; + std::shared_ptr lstm[iterations]; H[0] = ngraph::builder::makeConstant(ngPrc, { batch_size, hidden_size }, {}, true); C[0] = ngraph::builder::makeConstant(ngPrc, { batch_size, hidden_size }, {}, true); auto reshape1_shape = reshape1->output(0).get_shape(); @@ -149,7 +149,7 @@ std::shared_ptr Basic_LSTM_S::CreateGraphWithUnrolledTI() { for (size_t i = 0; i < iterations; ++i) { auto X = split1->output(i); - lstm[i] = std::make_shared(std::make_shared(X, constantX, false), + lstm[i] = std::make_shared(std::make_shared(X, constantX, false), H[i], C[i], weightsNode, reccurrenceWeightsNode, hidden_size); diff --git a/inference-engine/tests/ngraph_functions/include/ngraph_functions/builders.hpp b/inference-engine/tests/ngraph_functions/include/ngraph_functions/builders.hpp index 0d1c088..4285467 100644 --- a/inference-engine/tests/ngraph_functions/include/ngraph_functions/builders.hpp +++ b/inference-engine/tests/ngraph_functions/include/ngraph_functions/builders.hpp @@ -389,5 +389,31 @@ std::shared_ptr makePad(const ngraph::Output& data, std::shared_ptr makeBatchNormInference(const ngraph::Output& data, double epsilon); +std::shared_ptr makeLSTMCell(const OutputVector& in, + const std::vector& WRB, + std::size_t hidden_size, + const std::vector& activations = + std::vector{"sigmoid", "tanh", "tanh"}, + const std::vector& activations_alpha = {}, + const std::vector& activations_beta = {}, + float clip = 0.f); + +std::shared_ptr makeGRUCell(const OutputVector& in, + const std::vector& WRB, + std::size_t hidden_size, + const std::vector& activations = + std::vector{"sigmoid", "tanh"}, + const std::vector& activations_alpha = {}, + const std::vector& activations_beta = {}, + float clip = 0.f, + bool linear_before_reset = false); + +std::shared_ptr makeRNNCell(const OutputVector& in, + const std::vector& WRB, + std::size_t hidden_size, + const std::vector& activations = std::vector{"tanh"}, + const std::vector& activations_alpha = {}, + const std::vector& activations_beta = {}, + float clip = 0.f); } // namespace builder } // namespace ngraph diff --git a/inference-engine/tests/ngraph_functions/include/ngraph_functions/subgraph_builders.hpp b/inference-engine/tests/ngraph_functions/include/ngraph_functions/subgraph_builders.hpp index 57b90b3..d6f002f 100644 --- a/inference-engine/tests/ngraph_functions/include/ngraph_functions/subgraph_builders.hpp +++ b/inference-engine/tests/ngraph_functions/include/ngraph_functions/subgraph_builders.hpp @@ -130,7 +130,7 @@ static std::shared_ptr makeTIwithLSTMcell(InferenceEngine::Pre inShape = {N, I}; auto constantX = std::make_shared(ngraph::element::i64, ngraph::Shape{2}, inShape); auto LSTM_cell = - std::make_shared(std::make_shared(X, constantX, false), + std::make_shared(std::make_shared(X, constantX, false), std::make_shared(H_t, constantH, false), std::make_shared(C_t, constantH, false), W_body, diff --git a/inference-engine/tests/ngraph_functions/src/gru_cell.cpp b/inference-engine/tests/ngraph_functions/src/gru_cell.cpp new file mode 100644 index 0000000..487959f --- /dev/null +++ b/inference-engine/tests/ngraph_functions/src/gru_cell.cpp @@ -0,0 +1,30 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include + +#include "ngraph_functions/builders.hpp" + +namespace ngraph { +namespace builder { + +std::shared_ptr makeGRUCell(const OutputVector& in, + const std::vector& WRB, + std::size_t hidden_size, + const std::vector& activations, + const std::vector& activations_alpha, + const std::vector& activations_beta, + float clip, + bool linear_before_reset) { + std::vector empty; + auto W = ngraph::builder::makeConstant(in[0].get_element_type(), WRB[0], empty, true); + auto R = ngraph::builder::makeConstant(in[0].get_element_type(), WRB[1], empty, true); + auto B = ngraph::builder::makeConstant(in[0].get_element_type(), WRB[2], empty, true); + return std::make_shared(in[0], in[1], W, R, B, hidden_size, activations, + activations_alpha, activations_beta, clip, linear_before_reset); +} + +} // namespace builder +} // namespace ngraph \ No newline at end of file diff --git a/inference-engine/tests/ngraph_functions/src/lstm_cell.cpp b/inference-engine/tests/ngraph_functions/src/lstm_cell.cpp new file mode 100644 index 0000000..38f39f7 --- /dev/null +++ b/inference-engine/tests/ngraph_functions/src/lstm_cell.cpp @@ -0,0 +1,29 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include + +#include "ngraph_functions/builders.hpp" + +namespace ngraph { +namespace builder { + +std::shared_ptr makeLSTMCell(const std::vector>& in, + const std::vector& WRB, + std::size_t hidden_size, + const std::vector& activations, + const std::vector& activations_alpha, + const std::vector& activations_beta, + float clip) { + std::vector empty; + auto W = ngraph::builder::makeConstant(in[0].get_element_type(), WRB[0], empty, true); + auto R = ngraph::builder::makeConstant(in[0].get_element_type(), WRB[1], empty, true); + auto B = ngraph::builder::makeConstant(in[0].get_element_type(), WRB[2], empty, true); + return std::make_shared(in[0], in[1], in[2], W, R, B, hidden_size, activations, + activations_alpha, activations_beta, clip); +} + +} // namespace builder +} // namespace ngraph \ No newline at end of file diff --git a/inference-engine/tests/ngraph_functions/src/rnn_cell.cpp b/inference-engine/tests/ngraph_functions/src/rnn_cell.cpp new file mode 100644 index 0000000..824c4a8 --- /dev/null +++ b/inference-engine/tests/ngraph_functions/src/rnn_cell.cpp @@ -0,0 +1,29 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include + +#include "ngraph_functions/builders.hpp" + +namespace ngraph { +namespace builder { + +std::shared_ptr makeRNNCell(const OutputVector& in, + const std::vector& WRB, + std::size_t hidden_size, + const std::vector& activations, + const std::vector& activations_alpha, + const std::vector& activations_beta, + float clip) { + std::vector empty; + auto W = ngraph::builder::makeConstant(in[0].get_element_type(), WRB[0], empty, true); + auto R = ngraph::builder::makeConstant(in[0].get_element_type(), WRB[1], empty, true); + auto B = ngraph::builder::makeConstant(in[0].get_element_type(), WRB[2], empty, true); + return std::make_shared(in[0], in[1], W, R, B, hidden_size, activations, + activations_alpha, activations_beta, clip); +} + +} // namespace builder +} // namespace ngraph \ No newline at end of file diff --git a/model-optimizer/extensions/ops/lstm_cell.py b/model-optimizer/extensions/ops/lstm_cell.py index 7c82f54..bdd8c05 100644 --- a/model-optimizer/extensions/ops/lstm_cell.py +++ b/model-optimizer/extensions/ops/lstm_cell.py @@ -42,7 +42,7 @@ class LSTMCell(Op): mandatory_props = { 'type': __class__.op, 'op': __class__.op, - 'version': 'opset1', + 'version': 'opset4', 'infer': __class__.infer, 'in_ports_count': 5, 'out_ports_count': 2, diff --git a/ngraph/core/include/ngraph/op/gru_cell.hpp b/ngraph/core/include/ngraph/op/gru_cell.hpp index e7a608e..64f8b62 100644 --- a/ngraph/core/include/ngraph/op/gru_cell.hpp +++ b/ngraph/core/include/ngraph/op/gru_cell.hpp @@ -26,8 +26,6 @@ #include "ngraph/op/util/fused_op.hpp" #include "ngraph/op/util/rnn_cell_base.hpp" -NGRAPH_SUPPRESS_DEPRECATED_START - namespace ngraph { namespace op @@ -42,7 +40,7 @@ namespace ngraph /// /// Note this class represents only single *cell* and not whole GRU *layer*. /// - class NGRAPH_API GRUCell : public util::FusedOp, public util::RNNCellBase + class NGRAPH_API GRUCell : public util::RNNCellBase { public: static constexpr NodeTypeInfo type_info{"GRUCell", 3}; @@ -151,8 +149,6 @@ namespace ngraph virtual void validate_and_infer_types() override; bool visit_attributes(AttributeVisitor& visitor) override; - virtual void pre_validate_and_infer_types() override; - virtual OutputVector decompose_op() const override; virtual std::shared_ptr clone_with_new_inputs(const OutputVector& new_args) const override; @@ -180,8 +176,5 @@ namespace ngraph bool m_linear_before_reset; }; } - using v3::GRUCell; } } - -NGRAPH_SUPPRESS_DEPRECATED_END diff --git a/ngraph/core/include/ngraph/op/lstm_cell.hpp b/ngraph/core/include/ngraph/op/lstm_cell.hpp index dcd7d94..c830cae 100644 --- a/ngraph/core/include/ngraph/op/lstm_cell.hpp +++ b/ngraph/core/include/ngraph/op/lstm_cell.hpp @@ -69,7 +69,7 @@ namespace ngraph /// /// \sa LSTMSequence, RNNCell, GRUCell /// - class NGRAPH_API LSTMCell : public util::FusedOp, public util::RNNCellBase + class NGRAPH_API LSTMCell : public util::RNNCellBase { public: static constexpr NodeTypeInfo type_info{"LSTMCell", 0}; @@ -216,24 +216,11 @@ namespace ngraph virtual void validate_and_infer_types() override; bool visit_attributes(AttributeVisitor& visitor) override; - virtual void pre_validate_and_infer_types() override; - virtual OutputVector decompose_op() const override; virtual std::shared_ptr clone_with_new_inputs(const OutputVector& new_args) const override; bool get_input_forget() const { return m_input_forget; } LSTMWeightsFormat get_weights_format() const { return m_weights_format; } - /// - /// \brief Change data format of provided node into IFCO. - /// - /// \node The IFCO format was chosen because it's default DNNL format. - /// - /// \param[in] node The input node to be permuted. - /// - /// \return Node representing reshaped tensor according to IFCO weights format. - /// - std::shared_ptr convert_node_format(const Output& node) const; - private: /// /// \brief Creates the default bias input initialized with zeros. @@ -273,9 +260,149 @@ namespace ngraph static constexpr std::size_t s_gates_count{4}; static constexpr std::size_t s_peepholes_count{3}; }; - } - using v0::LSTMCell; - } // namespace op + } // v0 + + namespace v4 + { + /// + /// \brief Class for single lstm cell node. + /// + /// \note Following implementation supports: + /// \li \c peepholes Gers & Schmidhuber (2000) + /// https://ieeexplore.ieee.org/document/861302 + /// \li Coupling input and forget gates. + /// + /// \note It calculates following equations: + /// + /// it = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Wbi + Rbi) + /// ft = f(Xt*(Wf^T) + Ht-1*(Rf^T) + Wbf + Rbf) + /// ct = g(Xt*(Wc^T) + Ht-1*(Rc^T) + Wbc + Rbc) + /// Ct = ft (.) Ct-1 + it (.) ct + /// ot = f(Xt*(Wo^T) + Ht-1*(Ro^T) + Wbo + Rbo) + /// Ht = ot (.) h(Ct) + /// + /// * - Is a dot product, + /// (.) - is a Hadamard product (element-wise), + /// f, g, h - are activation functions. + /// + /// \note This class represents only single *cell* (for current time step) and not + /// the whole LSTM Sequence layer + /// + /// \sa LSTMSequence, RNNCell, GRUCell + /// + class NGRAPH_API LSTMCell : public util::RNNCellBase + { + public: + static constexpr NodeTypeInfo type_info{"LSTMCell", 1}; + const NodeTypeInfo& get_type_info() const override { return type_info; } + LSTMCell(); + /// + /// \brief Constructs LSTMCell node. + /// + /// \param[in] X The input tensor with shape: [batch_size, + /// input_size]. + /// \param[in] initial_hidden_state The hidden state tensor at current time step + /// with shape: [batch_size, hidden_size]. + /// \param[in] initial_cell_state The cell state tensor at current time step + /// with shape: [batch_size, hidden_size]. + /// \param[in] W The gate weights tensor with shape: + /// [4*hidden_size, input_size]. + /// \param[in] R The recurrence weights tensor with shape: + /// [4*hidden_size, hidden_size]. + /// \param[in] hidden_size The number of hidden units for recurrent cell. + /// \param[in] activations The vector of activation functions used inside + /// recurrent cell. + /// \param[in] activations_alpha The vector of alpha parameters for activation + /// functions in order respective to activation + /// list. + /// \param[in] activations_beta The vector of beta parameters for activation + /// functions in order respective to activation + /// list. + /// \param[in] clip The value defining clipping range [-clip, + /// clip] on input of activation functions. + LSTMCell(const Output& X, + const Output& initial_hidden_state, + const Output& initial_cell_state, + const Output& W, + const Output& R, + std::size_t hidden_size, + const std::vector& activations = + std::vector{"sigmoid", "tanh", "tanh"}, + const std::vector& activations_alpha = {}, + const std::vector& activations_beta = {}, + float clip = 0.f); + + /// + /// \brief Constructs LSTMCell node. + /// + /// \param[in] X The input tensor with shape: [batch_size, + /// input_size]. + /// \param[in] initial_hidden_state The hidden state tensor at current time step + /// with shape: [batch_size, hidden_size]. + /// \param[in] initial_cell_state The cell state tensor at current time step + /// with shape: [batch_size, hidden_size]. + /// \param[in] W The weight tensor with shape: [4*hidden_size, + /// input_size]. + /// \param[in] R The recurrence weight tensor with shape: + /// [4*hidden_size, hidden_size]. + /// \param[in] B The bias tensor for gates with shape: + /// [4*hidden_size]. + /// \param[in] hidden_size The number of hidden units for recurrent cell. + /// \param[in] activations The vector of activation functions used inside + /// recurrent cell. + /// \param[in] activations_alpha The vector of alpha parameters for activation + /// functions in order respective to activation + /// list. + /// \param[in] activations_beta The vector of beta parameters for activation + /// functions in order respective to activation + /// list. + /// \param[in] clip The value defining clipping range [-clip, + /// clip] on input of activation functions. + /// + LSTMCell(const Output& X, + const Output& initial_hidden_state, + const Output& initial_cell_state, + const Output& W, + const Output& R, + const Output& B, + std::size_t hidden_size, + const std::vector& activations = + std::vector{"sigmoid", "tanh", "tanh"}, + const std::vector& activations_alpha = {}, + const std::vector& activations_beta = {}, + float clip = 0.f); + + void validate_and_infer_types() override; + + bool visit_attributes(AttributeVisitor& visitor) override; + std::shared_ptr + clone_with_new_inputs(const OutputVector& new_args) const override; + + private: + /// + /// \brief Creates the default bias input initialized with zeros. + /// + /// \return The object of Output class. + /// + Output get_default_bias_input() const; + + /// + /// \brief The Activation function f. + /// + util::ActivationFunction m_activation_f; + /// + /// \brief The Activation function g. + /// + util::ActivationFunction m_activation_g; + /// + /// \brief The Activation function h. + /// + util::ActivationFunction m_activation_h; + + static constexpr std::size_t s_gates_count{4}; + }; + } // v1 + } // namespace op NGRAPH_API std::ostream& operator<<(std::ostream& s, const op::LSTMWeightsFormat& type); @@ -294,5 +421,3 @@ namespace ngraph const DiscreteTypeInfo& get_type_info() const override { return type_info; } }; } // namespace ngraph - -NGRAPH_SUPPRESS_DEPRECATED_END diff --git a/ngraph/core/include/ngraph/op/lstm_sequence.hpp b/ngraph/core/include/ngraph/op/lstm_sequence.hpp index 4309f16..7fbe1f3 100644 --- a/ngraph/core/include/ngraph/op/lstm_sequence.hpp +++ b/ngraph/core/include/ngraph/op/lstm_sequence.hpp @@ -27,8 +27,7 @@ #include "ngraph/op/lstm_cell.hpp" #include "ngraph/op/util/attr_types.hpp" #include "ngraph/op/util/fused_op.hpp" - -NGRAPH_SUPPRESS_DEPRECATED_START +#include "ngraph/op/util/rnn_cell_base.hpp" namespace ngraph { @@ -186,9 +185,66 @@ namespace ngraph LSTMWeightsFormat m_weights_format; }; } - using v0::LSTMSequence; + + namespace v1 + { + /// + /// \brief Class for lstm sequence node. + /// + /// \note It follows notation and equations defined as in ONNX standard: + /// https://github.com/onnx/onnx/blob/master/docs/Operators.md#LSTM + /// + /// \sa LSTMCell, RNNCell, GRUCell + /// + /// + class NGRAPH_API LSTMSequence : public util::RNNCellBase + { + public: + static constexpr NodeTypeInfo type_info{"LSTMSequence", 1}; + const NodeTypeInfo& get_type_info() const override { return type_info; } + LSTMSequence() = default; + + using direction = RecurrentSequenceDirection; + + size_t get_default_output_index() const override { return no_default_index(); } + explicit LSTMSequence(const Output& X, + const Output& initial_hidden_state, + const Output& initial_cell_state, + const Output& sequence_lengths, + const Output& W, + const Output& R, + const Output& B, + const std::int64_t hidden_size, + const direction lstm_direction, + const std::vector activations_alpha = {}, + const std::vector activations_beta = {}, + const std::vector activations = {"sigmoid", + "tanh", + "tanh"}, + const float clip = 0.f) + : RNNCellBase( + {X, initial_hidden_state, initial_cell_state, sequence_lengths, W, R, B}, + hidden_size, + clip, + activations, + activations_alpha, + activations_beta) + , m_direction(lstm_direction) + { + constructor_validate_and_infer_types(); + } + + void validate_and_infer_types() override; + bool visit_attributes(AttributeVisitor& visitor) override; + + virtual std::shared_ptr + clone_with_new_inputs(const OutputVector& new_args) const override; + + direction get_direction() const { return m_direction; } + private: + direction m_direction; + }; + } } // namespace op } // namespace ngraph - -NGRAPH_SUPPRESS_DEPRECATED_END diff --git a/ngraph/core/include/ngraph/op/rnn_cell.hpp b/ngraph/core/include/ngraph/op/rnn_cell.hpp index 6b7055c..42d36a4 100644 --- a/ngraph/core/include/ngraph/op/rnn_cell.hpp +++ b/ngraph/core/include/ngraph/op/rnn_cell.hpp @@ -26,8 +26,6 @@ #include "ngraph/op/util/fused_op.hpp" #include "ngraph/op/util/rnn_cell_base.hpp" -NGRAPH_SUPPRESS_DEPRECATED_START - namespace ngraph { namespace op @@ -52,7 +50,7 @@ namespace ngraph /// /// \sa LSTMSequence, LSTMCell, GRUCell /// - class NGRAPH_API RNNCell : public util::FusedOp, public util::RNNCellBase + class NGRAPH_API RNNCell : public util::RNNCellBase { public: static constexpr NodeTypeInfo type_info{"RNNCell", 0}; @@ -129,11 +127,9 @@ namespace ngraph const std::vector& activations_beta = {}, float clip = 0.f); - virtual void validate_and_infer_types() override; + void validate_and_infer_types() override; bool visit_attributes(AttributeVisitor& visitor) override; - virtual void pre_validate_and_infer_types() override; - virtual OutputVector decompose_op() const override; - virtual std::shared_ptr + std::shared_ptr clone_with_new_inputs(const OutputVector& new_args) const override; private: @@ -152,8 +148,5 @@ namespace ngraph static constexpr std::size_t s_gates_count{1}; }; } - using v0::RNNCell; } // namespace op } // namespace ngraph - -NGRAPH_SUPPRESS_DEPRECATED_END diff --git a/ngraph/core/include/ngraph/op/util/rnn_cell_base.hpp b/ngraph/core/include/ngraph/op/util/rnn_cell_base.hpp index f24cd52..103d915 100644 --- a/ngraph/core/include/ngraph/op/util/rnn_cell_base.hpp +++ b/ngraph/core/include/ngraph/op/util/rnn_cell_base.hpp @@ -30,11 +30,39 @@ namespace ngraph { namespace util { + enum class LSTMWeightsFormat + { + FICO, // IE + ICOF, // PyTorch + IFCO, // DNNL, TF, MxNet + IFOC, // Caffe + IOFC, // ONNX + }; + + /// + /// \brief Change data format of provided node. + /// + /// \param[in] node The input node to be permuted. + /// + /// + /// \param[in] from_format Original node weights format. + /// + /// + /// \param[in] to_format Weights format to convert to. + /// + /// \return Node representing reshaped tensor according to `to_format` weights + /// format. + /// + std::shared_ptr NGRAPH_API + convert_lstm_node_format(const Output& node, + LSTMWeightsFormat from_format, + LSTMWeightsFormat to_format = LSTMWeightsFormat::FICO); + /// \brief Base class for all recurrent network cells. /// /// \note It holds all common attributes. /// - class NGRAPH_API RNNCellBase + class NGRAPH_API RNNCellBase : public Op { public: /// @@ -50,7 +78,8 @@ namespace ngraph /// \param[in] activations_beta The vector of beta parameters for activation /// functions in order respective to activation list. /// - RNNCellBase(std::size_t hidden_size, + RNNCellBase(const OutputVector& args, + std::size_t hidden_size, float clip, const std::vector& activations, const std::vector& activations_alpha, diff --git a/ngraph/core/include/ngraph/opsets/opset4_tbl.hpp b/ngraph/core/include/ngraph/opsets/opset4_tbl.hpp index 980ea91..001af3f 100644 --- a/ngraph/core/include/ngraph/opsets/opset4_tbl.hpp +++ b/ngraph/core/include/ngraph/opsets/opset4_tbl.hpp @@ -70,8 +70,7 @@ NGRAPH_OP(LogicalNot, ngraph::op::v1) NGRAPH_OP(LogicalOr, ngraph::op::v1) NGRAPH_OP(LogicalXor, ngraph::op::v1) NGRAPH_OP(LRN, ngraph::op::v0) -NGRAPH_OP(LSTMCell, ngraph::op::v0) -NGRAPH_OP(LSTMSequence, ngraph::op::v0) +NGRAPH_OP(LSTMCell, ngraph::op::v4) NGRAPH_OP(MatMul, ngraph::op::v0) NGRAPH_OP(MaxPool, ngraph::op::v1) NGRAPH_OP(Maximum, ngraph::op::v1) diff --git a/ngraph/core/reference/include/ngraph/runtime/reference/gru_cell.hpp b/ngraph/core/reference/include/ngraph/runtime/reference/gru_cell.hpp new file mode 100644 index 0000000..9f89fec --- /dev/null +++ b/ngraph/core/reference/include/ngraph/runtime/reference/gru_cell.hpp @@ -0,0 +1,316 @@ +//***************************************************************************** +// Copyright 2020 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//***************************************************************************** + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace ngraph +{ + namespace runtime + { + namespace reference + { + template + void gru_cell(const T* X, + const Shape& X_shape, + const T* H, + const Shape& H_shape, + const T* W, + const Shape& W_shape, + const T* R, + const Shape& R_shape, + const T* B, + const Shape& B_shape, + T* dst_data, + const std::string& activation_f, + const std::string& activation_g, + float clip, + bool linear_before_reset) + { + // ------ VARIABLE'S NAMES AND ACRONYM DEFINITIONS ------ + // The names used below are analogous to the one used in ONNX documentation. + // + // ------ ACRONYMS ------ + // z_t - update gate at current time step + // r_t - reset gate at current time step + // h_t - hidden gate at current time step + // t - time step (t-1 means previous time step) + // X The input data tensor. Shape: [batch_size, input_size]. + // W[zrh] - The weight tensor for update, reset and hidden gates. + // Shape: [gates_count * hidden_size, input_size]. + // R[zrh] - The recurrence weight tensor for update, reset and hidden gates. + // Shape: [gates_count * hidden_size, hidden_size]. + // H_t - The hidden state tensor at current time step. Shape: [batch_size, + // hidden_size]. + // B - The sum of biases (weight and recurrence) for update, reset and hidden + // gates. + // If linear_before_reset := true then biases for hidden gates are placed + // separately + // (weight and recurrence). + // Shape: [gates_count * hidden_size] when linear_before_reset := false + // Shape: [(gates_count + 1) * hidden_size] when linear_before_reset := + // true + // Wb[zrh] - W bias vectors for update, reset and hidden gates. + // Rb[zrh] - R bias vectors for update, reset and hidden gates. + + // (.) - Denotes element-wise multiplication. + // * - Denotes dot product. + + // ---- Equations ---- + // f, g - are activation functions + // zt = f(Xt*(Wz^T) + Ht-1*(Rz^T) + Wbz + Rbz) + // rt = f(Xt*(Wr^T) + Ht-1*(Rr^T) + Wbr + Rbr) + // ht = g(Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh) # when linear_before_reset + // := false + // # (default) + // ht = g(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh) # when linear_before_reset + // := true + // Ht = (1 - zt) (.) ht + zt (.) Ht-1 + // ------------------- + + Shape gate_shape{X_shape[0], H_shape[1]}; + Shape all_gates_shape{X_shape[0], 3 * H_shape[1]}; + Shape bias_shape{H_shape[1], H_shape[1]}; + auto gate_shape_size = X_shape[0] * H_shape[1]; + auto all_gates_shape_size = gate_shape_size * 3; + auto bias_shape_size = H_shape[1] * H_shape[1]; + + // Xt*(W^T) + std::vector Xt_W(all_gates_shape_size); + reference::matmul( + X, W, Xt_W.data(), X_shape, W_shape, all_gates_shape, false, true); + + // Ht-1*(R^T) + std::vector Ht_R(all_gates_shape_size); + reference::matmul( + H, R, Ht_R.data(), H_shape, R_shape, all_gates_shape, false, true); + + std::vector> X_W_zrh(3, std::vector(gate_shape_size)); + std::vector pointers_XW = {reinterpret_cast(X_W_zrh[0].data()), + reinterpret_cast(X_W_zrh[1].data()), + reinterpret_cast(X_W_zrh[2].data())}; + std::vector> R_zrh(3, std::vector(bias_shape_size)); + std::vector pointers_R = {reinterpret_cast(R_zrh[0].data()), + reinterpret_cast(R_zrh[1].data()), + reinterpret_cast(R_zrh[2].data())}; + std::vector> Ht_R_zrh(3, std::vector(gate_shape_size)); + std::vector pointers_H_R = {reinterpret_cast(Ht_R_zrh[0].data()), + reinterpret_cast(Ht_R_zrh[1].data()), + reinterpret_cast(Ht_R_zrh[2].data())}; + + size_t num_b_splits = linear_before_reset ? 4 : 3; + std::vector> biases_zrh(num_b_splits, + std::vector(B_shape[0] / num_b_splits)); + std::vector pointers_biases = { + reinterpret_cast(biases_zrh[0].data()), + reinterpret_cast(biases_zrh[1].data()), + reinterpret_cast(biases_zrh[2].data())}; + if (linear_before_reset) + { + pointers_biases.push_back(reinterpret_cast(biases_zrh[3].data())); + } + + // split on gates + reference::split(reinterpret_cast(Xt_W.data()), + all_gates_shape, + sizeof(T), + 1, + 3, + pointers_XW.data()); + reference::split( + reinterpret_cast(R), R_shape, sizeof(T), 0, 3, pointers_R.data()); + reference::split(reinterpret_cast(Ht_R.data()), + all_gates_shape, + sizeof(T), + 1, + 3, + pointers_H_R.data()); + reference::split(reinterpret_cast(B), + B_shape, + sizeof(T), + 0, + num_b_splits, + pointers_biases.data()); + + auto clip_activation = [&clip](std::vector& gate, + const std::string& activation) { + if (clip > 0.f) + { + reference::clamp(gate.data(), + gate.data(), + static_cast(-clip), + static_cast(clip), + gate.size()); + } + if (activation == "relu") + { + reference::relu(gate.data(), gate.data(), gate.size()); + } + else if (activation == "sigmoid") + { + reference::sigmoid(gate.data(), gate.data(), gate.size()); + } + else if (activation == "tanh") + { + reference::tanh(gate.data(), gate.data(), gate.size()); + } + else + { + throw ngraph_error("Activation function " + activation + + " is not supported."); + } + }; + + // calculate z_t + // steps: + // Ht-1*(Rz^T) + Wbz + Rbz + // Xt*(Wz^T) + Ht-1*(Rz^T) + Wbz + Rbz + // zt = f(Xt*(Wz^T) + Ht-1*(Rz^T) + Wbz + Rbz) + std::vector z_t(gate_shape_size); + reference::add(Ht_R_zrh[0].data(), + biases_zrh[0].data(), + z_t.data(), + gate_shape, + {B_shape[0] / num_b_splits}, + op::AutoBroadcastSpec::NUMPY); // + reference::add(X_W_zrh[0].data(), + z_t.data(), + z_t.data(), + gate_shape, + gate_shape, + op::AutoBroadcastSpec::NUMPY); // + clip_activation(z_t, activation_f); + + // calculate r_t + // steps: + // Ht-1*(Rr^T) + Wbr + Rbr + // Xt*(Wr^T) + Ht-1*(Rr^T) + Wbr + Rbr + // rt = f(Xt*(Wr^T) + Ht-1*(Rr^T) + Wbr + Rbr) + std::vector r_t(gate_shape_size); + reference::add(Ht_R_zrh[1].data(), + biases_zrh[1].data(), + r_t.data(), + gate_shape, + {B_shape[0] / num_b_splits}, + op::AutoBroadcastSpec::NUMPY); + reference::add(X_W_zrh[1].data(), + r_t.data(), + r_t.data(), + gate_shape, + gate_shape, + op::AutoBroadcastSpec::NUMPY); + clip_activation(r_t, activation_f); + + // calculate h_t + vector h_t(gate_shape_size); + if (linear_before_reset) + { + // ht = g(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh) + reference::add(Ht_R_zrh[2].data(), + biases_zrh[3].data(), + h_t.data(), + gate_shape, + {B_shape[0] / num_b_splits}, + op::AutoBroadcastSpec::NUMPY); + reference::multiply(r_t.data(), + h_t.data(), + h_t.data(), + gate_shape, + gate_shape, + op::AutoBroadcastSpec::NUMPY); + reference::add(h_t.data(), + biases_zrh[2].data(), + h_t.data(), + gate_shape, + {B_shape[0] / num_b_splits}, + op::AutoBroadcastSpec::NUMPY); + reference::add(X_W_zrh[2].data(), + h_t.data(), + h_t.data(), + gate_shape, + gate_shape, + op::AutoBroadcastSpec::NUMPY); + } + else + { + // ht = g(Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh) + reference::multiply(r_t.data(), + H, + h_t.data(), + gate_shape, + H_shape, + op::AutoBroadcastSpec::NUMPY); + std::vector matmul(gate_shape_size); + reference::matmul(h_t.data(), + R_zrh[2].data(), + matmul.data(), + gate_shape, + bias_shape, + gate_shape, + false, + true); + reference::add(matmul.data(), + biases_zrh[2].data(), + h_t.data(), + gate_shape, + {B_shape[0] / num_b_splits}, + op::AutoBroadcastSpec::NUMPY); + reference::add(X_W_zrh[2].data(), + h_t.data(), + h_t.data(), + gate_shape, + gate_shape, + op::AutoBroadcastSpec::NUMPY); + } + clip_activation(h_t, activation_g); + // Ht = (1 - zt) (.) ht + zt (.) Ht-1 + vector mul1(gate_shape_size); + vector mul2(gate_shape_size); + T one[] = {1}; + reference::subtract( + one, z_t.data(), mul1.data(), {1}, gate_shape, op::AutoBroadcastSpec::NUMPY); + reference::multiply(mul1.data(), + h_t.data(), + mul1.data(), + gate_shape, + gate_shape, + op::AutoBroadcastSpec::NUMPY); + reference::multiply(z_t.data(), + H, + mul2.data(), + gate_shape, + gate_shape, + op::AutoBroadcastSpec::NUMPY); + reference::add(mul1.data(), + mul2.data(), + dst_data, + gate_shape, + gate_shape, + op::AutoBroadcastSpec::NUMPY); + } + } + } +} diff --git a/ngraph/core/reference/include/ngraph/runtime/reference/lstm_cell.hpp b/ngraph/core/reference/include/ngraph/runtime/reference/lstm_cell.hpp new file mode 100644 index 0000000..583332d --- /dev/null +++ b/ngraph/core/reference/include/ngraph/runtime/reference/lstm_cell.hpp @@ -0,0 +1,217 @@ +//***************************************************************************** +// Copyright 2020 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//***************************************************************************** + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace ngraph +{ + namespace runtime + { + namespace reference + { + template + void lstm_cell(const T* X, + const Shape& X_shape, + const T* H, + const Shape& H_shape, + const T* C, + const Shape& C_shape, + const T* W, + const Shape& W_shape, + const T* R, + const Shape& R_shape, + const T* B, + const Shape& B_shape, + T* out_Ht, + T* out_Ct, + const std::string& activation_f, + const std::string& activation_g, + const std::string& activation_h, + float clip) + { + // ------ VARIABLE'S NAMES AND ACRONYM DEFINITIONS ------ + // The names used below are analogous to the one used in ONNX documentation. + // + // ------ ACRONYMS ------ + // i - input gate + // o - output gate + // f - forget gate + // c - cell gate + // t - time step (t-1 means previous time step) + // Wb - W bias vectors for input, output, forget, and cell gates. + // Rb - R bias vectors for input, output, forget, and cell gates. + // P - The peephole weights for input, output and forget gates. + // ------ VARIABLE NAMES ------ + // X - The input data tensor. Shape: [batch_size, input_size]. + // W - The weight matrix for input, forget, cell and output gates + // Shape: [4*hidden_size, input_size] + // R - The recurrence weight matrix for input, forget, cell and output gates. + // Shape: [4*hidden_size, hidden_size]. + // H_t - The hidden state tensor at current time step. Shape: [batch_size, + // hidden_size]. + // C_t - The cell state tensor at current time step. Shape: [batch_size, + // hidden_size]. + // bias - The sum of biases (weight and recurrence) for input, forget, cell and + // output gates. + // Shape: [4 * hidden_size] + // p_[iof] - The peephole weight vector for respectively: input, output, and forget + // gates. + // Each peephole has shape [hidden_size]. + // + // (.) - Denotes element-wise multiplication. + // * - Denotes dot product. + // + // ---- Equations ---- + // f, g, h - are activation functions. + // it = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Wbi + Rbi) + // ft = f(Xt*(Wf^T) + Ht-1*(Rf^T) + Wbf + Rbf) + // ct = g(Xt*(Wc^T) + Ht-1*(Rc^T) + Wbc + Rbc) + // ot = f(Xt*(Wo^T) + Ht-1*(Ro^T) + Wbo + Rbo) + // Ct = ft (.) Ct-1 + it (.) ct + // Ht = ot (.) h(Ct) + // -------------------- + Shape gate_shape{X_shape[0], H_shape[1]}; + Shape all_gates_shape{X_shape[0], 4 * H_shape[1]}; + auto gate_shape_size = X_shape[0] * H_shape[1]; + auto all_gates_shape_size = gate_shape_size * 4; + // Xt*(W^T) + std::vector Xt_W(all_gates_shape_size); + reference::matmul( + X, W, Xt_W.data(), X_shape, W_shape, all_gates_shape, false, true); + + // Ht-1*(R^T) + std::vector Ht_R(all_gates_shape_size); + reference::matmul( + H, R, Ht_R.data(), H_shape, R_shape, all_gates_shape, false, true); + + // Ht-1*(R^T) + Wb + Rb + std::vector Ht_R_B(all_gates_shape_size); + reference::add(Ht_R.data(), + B, + Ht_R_B.data(), + all_gates_shape, + B_shape, + op::AutoBroadcastSpec::NUMPY); + + // Xt*(W^T) + Ht-1*(R^T) + Wb + Rb + std::vector XHB(all_gates_shape_size); + reference::add(Xt_W.data(), + Ht_R_B.data(), + XHB.data(), + all_gates_shape, + all_gates_shape, + op::AutoBroadcastSpec::NUMPY); + + std::vector> X_W_fico(4, std::vector(all_gates_shape_size / 4)); + std::vector pointers = {reinterpret_cast(X_W_fico[0].data()), + reinterpret_cast(X_W_fico[1].data()), + reinterpret_cast(X_W_fico[2].data()), + reinterpret_cast(X_W_fico[3].data())}; + // split on gates + reference::split(reinterpret_cast(XHB.data()), + all_gates_shape, + sizeof(T), + 1, + 4, + pointers.data()); + + auto clip_activation = [&clip]( + std::vector& gate, const std::string& activation, bool enable_clip = true) { + if (clip > 0.f && enable_clip) + { + reference::clamp(gate.data(), + gate.data(), + static_cast(-clip), + static_cast(clip), + gate.size()); + } + if (activation == "relu") + { + reference::relu(gate.data(), gate.data(), gate.size()); + } + else if (activation == "sigmoid") + { + reference::sigmoid(gate.data(), gate.data(), gate.size()); + } + else if (activation == "tanh") + { + reference::tanh(gate.data(), gate.data(), gate.size()); + } + else + { + throw ngraph_error("Activation function " + activation + + " is not supported."); + } + }; + + // ft = f(Xt*(Wf^T) + Ht-1*(Rf^T) + Wbf + Rbf) + clip_activation(X_W_fico[0], activation_f); + // it = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Wbi + Rbi) + clip_activation(X_W_fico[1], activation_f); + // ct = g(Xt*(Wc^T) + Ht-1*(Rc^T) + Wbc + Rbc) + clip_activation(X_W_fico[2], activation_g); + // ot = f(Xt*(Wo^T) + Ht-1*(Ro^T) + Wbo + Rbo) + clip_activation(X_W_fico[3], activation_f); + + vector mul1(gate_shape_size); + vector mul2(gate_shape_size); + vector Ct(gate_shape_size); + // ft (.) Ct-1 + reference::multiply(X_W_fico[0].data(), + C, + mul1.data(), + gate_shape, + C_shape, + op::AutoBroadcastSpec::NUMPY); + // it (.) ct + reference::multiply(X_W_fico[1].data(), + X_W_fico[2].data(), + mul2.data(), + gate_shape, + gate_shape, + op::AutoBroadcastSpec::NUMPY); + // Ct = ft (.) Ct-1 + it (.) ct + reference::add(mul1.data(), + mul2.data(), + Ct.data(), + gate_shape, + gate_shape, + op::AutoBroadcastSpec::NUMPY); + std::memcpy(out_Ct, Ct.data(), Ct.size() * sizeof(T)); + clip_activation(Ct, activation_h, false); + + // Ht = ot (.) h(Ct) + reference::multiply(X_W_fico[3].data(), + Ct.data(), + out_Ht, + gate_shape, + gate_shape, + op::AutoBroadcastSpec::NUMPY); + } + } + } +} diff --git a/ngraph/core/reference/include/ngraph/runtime/reference/rnn_cell.hpp b/ngraph/core/reference/include/ngraph/runtime/reference/rnn_cell.hpp new file mode 100644 index 0000000..b54045d --- /dev/null +++ b/ngraph/core/reference/include/ngraph/runtime/reference/rnn_cell.hpp @@ -0,0 +1,132 @@ +//***************************************************************************** +// Copyright 2020 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//***************************************************************************** + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +namespace ngraph +{ + namespace runtime + { + namespace reference + { + template + void rnn_cell(const T* X, + const Shape& X_shape, + const T* H, + const Shape& H_shape, + const T* W, + const Shape& W_shape, + const T* R, + const Shape& R_shape, + const T* B, + const Shape& B_shape, + T* dst_data, + const std::string& activation_f, + float clip) + { + // ------ VARIABLE'S NAMES AND ACRONYM DEFINITIONS ------ + // The names used below are analogous to the one used in ONNX documentation. + // + // ------ ACRONYMS ------ + // i_t - input gate at current time step + // t - time step (t-1 means previous time step) + // X - The input data tensor. Shape: [batch_size, input_size]. + // W - The weight tensor for input gate. Shape: [hidden_size, input_size]. + // R - The recurrence weight tensor for input gate. Shape: [hidden_size, + // hidden_size]. + // H_t - The hidden state tensor at current time step. Shape: [batch_size, + // hidden_size]. + // B - The bias tensor for the input gate. Shape: [hidden_size]. + // Wb - W bias vectors for input gate. + // Rb - R bias vectors for input gate. + // ------ VARIABLE NAMES ------ + // Xt_W - Input sequence multiplied by weights tensor at current time step. + // Ht_R - Hidden state multiplied by weights tensor at current time step. + + // (.) - Denotes element-wise multiplication. + // * - Denotes dot product. + + // ---- Equations ---- + // f - is activation functions. + // Ht = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Wbi + Rbi) + // -------------------- + + // Xt*(W^T) + std::vector Xt_W(X_shape[0] * W_shape[0]); + reference::matmul( + X, W, Xt_W.data(), X_shape, W_shape, {X_shape[0], W_shape[0]}, false, true); + + // Ht-1*(R^T) + std::vector Ht_R(H_shape[0] * R_shape[0]); + reference::matmul( + H, R, Ht_R.data(), H_shape, R_shape, {H_shape[0], R_shape[0]}, false, true); + + // Ht-1*(R^T) + Wb + Rb + std::vector Ht_R_B(H_shape[0] * R_shape[0]); + reference::add(Ht_R.data(), + B, + Ht_R_B.data(), + {H_shape[0], R_shape[0]}, + B_shape, + op::AutoBroadcastSpec::NUMPY); + + // Xt*(W^T) + Ht-1*(R^T) + Wb + Rb + std::vector i_t(H_shape[0] * R_shape[0]); + reference::add(Xt_W.data(), + Ht_R_B.data(), + i_t.data(), + {X_shape[0], W_shape[0]}, + {H_shape[0], R_shape[0]}, + op::AutoBroadcastSpec::NUMPY); + + // f(Xt*(Wi^T) + Ht-1*(Ri^T) + Wbi + Rbi) + if (clip != 0.f) + { + reference::clamp(i_t.data(), + i_t.data(), + static_cast(-clip), + static_cast(clip), + i_t.size()); + } + if (activation_f == "relu") + { + reference::relu(i_t.data(), dst_data, i_t.size()); + } + else if (activation_f == "sigmoid") + { + reference::sigmoid(i_t.data(), dst_data, i_t.size()); + } + else if (activation_f == "tanh") + { + reference::tanh(i_t.data(), dst_data, i_t.size()); + } + else + { + throw ngraph_error("Activation function " + activation_f + + " is not supported."); + } + } + } + } +} diff --git a/ngraph/core/reference/include/ngraph/runtime/reference/split.hpp b/ngraph/core/reference/include/ngraph/runtime/reference/split.hpp new file mode 100644 index 0000000..517c8f4 --- /dev/null +++ b/ngraph/core/reference/include/ngraph/runtime/reference/split.hpp @@ -0,0 +1,37 @@ +//***************************************************************************** +// Copyright 2020 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//***************************************************************************** + +#pragma once + +#include + +#include "ngraph/runtime/reference/slice.hpp" + +namespace ngraph +{ + namespace runtime + { + namespace reference + { + void split(const char* data, + const Shape& data_shape, + size_t elem_size, + int64_t axis, + size_t num_splits, + char** out_data); + } + } +} diff --git a/ngraph/core/reference/src/runtime/reference/split.cpp b/ngraph/core/reference/src/runtime/reference/split.cpp new file mode 100644 index 0000000..6cd11cc --- /dev/null +++ b/ngraph/core/reference/src/runtime/reference/split.cpp @@ -0,0 +1,54 @@ +//***************************************************************************** +// Copyright 2017-2020 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//***************************************************************************** + +#include +#include + +#include "ngraph/check.hpp" +#include "ngraph/runtime/reference/split.hpp" + +using namespace ngraph; + +void runtime::reference::split(const char* data, + const Shape& data_shape, + size_t elem_size, + int64_t axis, + size_t num_splits, + char** out_data) +{ + const size_t part_length = data_shape.at(axis) / num_splits; + + Shape output_shape = data_shape; + output_shape.at(axis) = part_length; + + std::vector lower_bounds(data_shape.size(), 0); + std::vector upper_bounds = data_shape; + upper_bounds.at(axis) = part_length; + + for (size_t i = 0; i < num_splits; ++i) + { + runtime::reference::slice(data, + out_data[i], + data_shape, + lower_bounds, + upper_bounds, + Strides(lower_bounds.size(), 1), + output_shape, + elem_size); + lower_bounds.at(axis) += part_length; + upper_bounds.at(axis) += part_length; + } +} diff --git a/ngraph/core/src/op/gru_cell.cpp b/ngraph/core/src/op/gru_cell.cpp index ba6d4ca..fff0fdd 100644 --- a/ngraph/core/src/op/gru_cell.cpp +++ b/ngraph/core/src/op/gru_cell.cpp @@ -15,12 +15,9 @@ //***************************************************************************** #include -#include -#include "ngraph/builder/reshape.hpp" -#include "ngraph/builder/split.hpp" +#include "itt.hpp" #include "ngraph/op/constant.hpp" -#include "ngraph/op/dot.hpp" #include "ngraph/op/gru_cell.hpp" #include "ngraph/shape.hpp" #include "ngraph/type/element_type.hpp" @@ -28,8 +25,6 @@ using namespace std; using namespace ngraph; -NGRAPH_SUPPRESS_DEPRECATED_START - constexpr NodeTypeInfo op::v3::GRUCell::type_info; op::v3::GRUCell::GRUCell() @@ -68,8 +63,12 @@ op::v3::GRUCell::GRUCell(const Output& X, const vector& activations_beta, float clip, bool linear_before_reset) - : FusedOp({X, initial_hidden_state, W, R}) - , RNNCellBase(hidden_size, clip, activations, activations_alpha, activations_beta) + : RNNCellBase({X, initial_hidden_state, W, R}, + hidden_size, + clip, + activations, + activations_alpha, + activations_beta) , m_activation_f{get_activation_function(0)} , m_activation_g{get_activation_function(1)} , m_linear_before_reset{linear_before_reset} @@ -89,8 +88,12 @@ op::v3::GRUCell::GRUCell(const Output& X, const vector& activations_beta, float clip, bool linear_before_reset) - : FusedOp({X, initial_hidden_state, W, R, B}) - , RNNCellBase(hidden_size, clip, activations, activations_alpha, activations_beta) + : RNNCellBase({X, initial_hidden_state, W, R, B}, + hidden_size, + clip, + activations, + activations_alpha, + activations_beta) , m_activation_f{get_activation_function(0)} , m_activation_g{get_activation_function(1)} , m_linear_before_reset{linear_before_reset} @@ -104,83 +107,12 @@ bool op::v3::GRUCell::visit_attributes(AttributeVisitor& visitor) return op::util::RNNCellBase::visit_attributes(visitor); } -void op::v3::GRUCell::pre_validate_and_infer_types() -{ - set_output_type(0, get_input_element_type(0), PartialShape::dynamic()); - - if (is_dynamic()) - { - return; - } - - const auto& x_pshape = get_input_partial_shape(0); - const auto& ht_pshape = get_input_partial_shape(1); - const auto& w_pshape = get_input_partial_shape(2); - const auto& r_pshape = get_input_partial_shape(3); - const auto& b_pshape = get_input_partial_shape(4); - - const Shape& x_shape{x_pshape.to_shape()}; - - const size_t batch_size = x_shape.at(0); - const size_t input_size = x_shape.at(1); - - const Shape& w_shape{w_pshape.to_shape()}; - const Shape& r_shape{r_pshape.to_shape()}; - const Shape& ht_shape{ht_pshape.to_shape()}; - - NODE_VALIDATION_CHECK(this, - (w_shape == Shape{s_gates_count * get_hidden_size(), input_size}), - "Input tensor W must have shape (", - s_gates_count * get_hidden_size(), - ", ", - input_size, - "). Actual shape is:", - w_shape, - "."); - NODE_VALIDATION_CHECK(this, - (r_shape == Shape{s_gates_count * get_hidden_size(), get_hidden_size()}), - "Input tensor R must have shape (", - s_gates_count * get_hidden_size(), - ", ", - get_hidden_size(), - "). Actual shape is:", - w_shape, - "."); - NODE_VALIDATION_CHECK(this, - (ht_shape == Shape{batch_size, get_hidden_size()}), - "Input tensor initial_hidden_state must have shape (", - batch_size, - ", ", - get_hidden_size(), - "). Actual shape is:", - w_shape, - "."); - - const Shape& b_shape{b_pshape.to_shape()}; - NODE_VALIDATION_CHECK( - this, - (b_shape == Shape{(s_gates_count + m_linear_before_reset) * get_hidden_size()}), - "Input tensor B must have shape (", - (s_gates_count + m_linear_before_reset) * get_hidden_size(), - "). Actual shape is:", - b_shape, - "."); -} - void op::v3::GRUCell::validate_and_infer_types() { - std::vector input_param{}; - auto merged_batch_size = Dimension::dynamic(); auto merged_hidden_size = Dimension::dynamic(); auto result_et = element::dynamic; - // Copy all inputs for further validation - for (size_t i = 0; i < get_input_size(); i++) - { - input_param.push_back(get_input_partial_shape(i)); - } - // Get input partial shape for all inputs const auto& x_pshape = get_input_partial_shape(0); const auto& ht_pshape = get_input_partial_shape(1); @@ -188,7 +120,7 @@ void op::v3::GRUCell::validate_and_infer_types() const auto& r_pshape = get_input_partial_shape(3); const auto& b_pshape = get_input_partial_shape(4); - validate_input_rank_dimension(input_param); + validate_input_rank_dimension({x_pshape, ht_pshape, w_pshape, r_pshape, b_pshape}); // Validate input types and save result for output type NODE_VALIDATION_CHECK( @@ -265,90 +197,6 @@ void op::v3::GRUCell::validate_and_infer_types() set_output_type(0, result_et, {merged_batch_size, merged_hidden_size}); } -OutputVector op::v3::GRUCell::decompose_op() const -{ - // ------ VARIABLE'S NAMES AND ACRONYM DEFINITIONS ------ - // The names used below are analogous to the one used in ONNX documentation. - // - // ------ ACRONYMS ------ - // z_t - update gate at current time step - // r_t - reset gate at current time step - // h_t - hidden gate at current time step - // t - time step (t-1 means previous time step) - // X The input data tensor. Shape: [batch_size, input_size]. - // W[zrh] - The weight tensor for update, reset and hidden gates. - // Shape: [gates_count * hidden_size, input_size]. - // R[zrh] - The recurrence weight tensor for update, reset and hidden gates. - // Shape: [gates_count * hidden_size, hidden_size]. - // H_t - The hidden state tensor at current time step. Shape: [batch_size, hidden_size]. - // B - The sum of biases (weight and recurrence) for update, reset and hidden gates. - // If linear_before_reset := true then biases for hidden gates are placed separately - // (weight and recurrence). - // Shape: [gates_count * hidden_size] when linear_before_reset := false - // Shape: [(gates_count + 1) * hidden_size] when linear_before_reset := true - // Wb[zrh] - W bias vectors for update, reset and hidden gates. - // Rb[zrh] - R bias vectors for update, reset and hidden gates. - - // (.) - Denotes element-wise multiplication. - // * - Denotes dot product. - - // ---- Equations ---- - // f, g - are activation functions - // zt = f(Xt*(Wz^T) + Ht-1*(Rz^T) + Wbz + Rbz) - // rt = f(Xt*(Wr^T) + Ht-1*(Rr^T) + Wbr + Rbr) - // ht = g(Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh) # when linear_before_reset := false - // # (default) - // ht = g(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh) # when linear_before_reset := true - // Ht = (1 - zt) (.) ht + zt (.) Ht-1 - // ------------------- - - Output X = input_value(0); - Output H_t = input_value(1); - Output W = input_value(2); - Output R = input_value(3); - Output B = input_value(4); - - // Xt*(W^T) - auto Xt_W = make_shared(X, builder::opset1::transpose(W)); - auto R_transpose = builder::opset1::transpose(R); - // Ht-1*(R^T) - auto Ht_R = make_shared(H_t, R_transpose); - - // split to gates: - OutputVector Xt_W_zrh = builder::split(Xt_W, 3, 1); - OutputVector R_zrh = builder::split(R_transpose, 3, 1); - OutputVector Ht_R_zrh = builder::split(Ht_R, 3, 1); - OutputVector biases_zrh = m_linear_before_reset ? builder::split(B, 4) : builder::split(B, 3); - - // zt = f(Xt*(Wz^T) + Ht-1*(Rz^T) + Wbz + Rbz) - auto z_t = m_activation_f(clip(add(Xt_W_zrh[0], add(Ht_R_zrh[0], biases_zrh[0])))); - // rt = f(Xt*(Wr^T) + Ht-1*(Rr^T) + Wbr + Rbr) - auto r_t = m_activation_f(clip(add(Xt_W_zrh[1], add(Ht_R_zrh[1], biases_zrh[1])))); - - Output h_t; - if (m_linear_before_reset) - { - // ht = g(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh) - auto Ht_Rh_Rbh = add(Ht_R_zrh[2], biases_zrh[3]); - h_t = m_activation_g(clip(add(Xt_W_zrh[2], add(mul(r_t, Ht_Rh_Rbh), biases_zrh[2])))); - } - else - { - // ht = g(Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh) - auto rt_Ht = mul(r_t, H_t); - auto rt_Ht_Rh = make_shared(rt_Ht, R_zrh[2]); - // Tensor shape: [batch_size, hidden_size] - h_t = m_activation_g(clip(add(Xt_W_zrh[2], add(rt_Ht_Rh, biases_zrh[2])))); - } - - auto one = op::Constant::create(z_t->get_element_type(), - z_t->get_shape(), - vector(shape_size(z_t->get_shape()), 1.f)); - // Ht = (1 - zt) (.) ht + zt (.) Ht-1 - H_t = add(mul(sub(one, z_t), h_t), mul(z_t, H_t)); - return {H_t.get_node_shared_ptr()}; -} - void op::v3::GRUCell::add_default_bias_input() { Output B = op::Constant::create( diff --git a/ngraph/core/src/op/lstm_cell.cpp b/ngraph/core/src/op/lstm_cell.cpp index 354afa5..6f72cc0 100644 --- a/ngraph/core/src/op/lstm_cell.cpp +++ b/ngraph/core/src/op/lstm_cell.cpp @@ -18,12 +18,8 @@ #include #include "ngraph/attribute_visitor.hpp" -#include "ngraph/builder/reshape.hpp" -#include "ngraph/builder/split.hpp" -#include "ngraph/op/add.hpp" #include "ngraph/op/concat.hpp" #include "ngraph/op/constant.hpp" -#include "ngraph/op/dot.hpp" #include "ngraph/op/lstm_cell.hpp" #include "ngraph/shape.hpp" #include "ngraph/type/element_type.hpp" @@ -31,11 +27,10 @@ using namespace std; using namespace ngraph; -NGRAPH_SUPPRESS_DEPRECATED_START +constexpr NodeTypeInfo op::v4::LSTMCell::type_info; +constexpr NodeTypeInfo op::v0::LSTMCell::type_info; -constexpr NodeTypeInfo op::LSTMCell::type_info; - -op::LSTMCell::LSTMCell() +op::v0::LSTMCell::LSTMCell() : m_input_forget(false) , m_weights_format(LSTMWeightsFormat::IFCO) { @@ -45,20 +40,24 @@ op::LSTMCell::LSTMCell() m_activation_h = get_activation_function(2); } -op::LSTMCell::LSTMCell(const Output& X, - const Output& initial_hidden_state, - const Output& initial_cell_state, - const Output& W, - const Output& R, - size_t hidden_size, - op::LSTMWeightsFormat weights_format, - const vector& activations, - const vector& activations_alpha, - const vector& activations_beta, - float clip, - bool input_forget) - : FusedOp({X, initial_hidden_state, initial_cell_state, W, R}) - , RNNCellBase(hidden_size, clip, activations, activations_alpha, activations_beta) +op::v0::LSTMCell::LSTMCell(const Output& X, + const Output& initial_hidden_state, + const Output& initial_cell_state, + const Output& W, + const Output& R, + size_t hidden_size, + op::LSTMWeightsFormat weights_format, + const vector& activations, + const vector& activations_alpha, + const vector& activations_beta, + float clip, + bool input_forget) + : RNNCellBase({X, initial_hidden_state, initial_cell_state, W, R}, + hidden_size, + clip, + activations, + activations_alpha, + activations_beta) , m_activation_f{get_activation_function(0)} , m_activation_g{get_activation_function(1)} , m_activation_h{get_activation_function(2)} @@ -70,21 +69,25 @@ op::LSTMCell::LSTMCell(const Output& X, constructor_validate_and_infer_types(); } -op::LSTMCell::LSTMCell(const Output& X, - const Output& initial_hidden_state, - const Output& initial_cell_state, - const Output& W, - const Output& R, - const Output& B, - size_t hidden_size, - op::LSTMWeightsFormat weights_format, - const vector& activations, - const vector& activations_alpha, - const vector& activations_beta, - float clip, - bool input_forget) - : FusedOp({X, initial_hidden_state, initial_cell_state, W, R, B}) - , RNNCellBase(hidden_size, clip, activations, activations_alpha, activations_beta) +op::v0::LSTMCell::LSTMCell(const Output& X, + const Output& initial_hidden_state, + const Output& initial_cell_state, + const Output& W, + const Output& R, + const Output& B, + size_t hidden_size, + op::LSTMWeightsFormat weights_format, + const vector& activations, + const vector& activations_alpha, + const vector& activations_beta, + float clip, + bool input_forget) + : RNNCellBase({X, initial_hidden_state, initial_cell_state, W, R, B}, + hidden_size, + clip, + activations, + activations_alpha, + activations_beta) , m_activation_f{get_activation_function(0)} , m_activation_g{get_activation_function(1)} , m_activation_h{get_activation_function(2)} @@ -95,22 +98,26 @@ op::LSTMCell::LSTMCell(const Output& X, constructor_validate_and_infer_types(); } -op::LSTMCell::LSTMCell(const Output& X, - const Output& initial_hidden_state, - const Output& initial_cell_state, - const Output& W, - const Output& R, - const Output& B, - const Output& P, - size_t hidden_size, - op::LSTMWeightsFormat weights_format, - const vector& activations, - const vector& activations_alpha, - const vector& activations_beta, - float clip, - bool input_forget) - : FusedOp({X, initial_hidden_state, initial_cell_state, W, R, B, P}) - , RNNCellBase(hidden_size, clip, activations, activations_alpha, activations_beta) +op::v0::LSTMCell::LSTMCell(const Output& X, + const Output& initial_hidden_state, + const Output& initial_cell_state, + const Output& W, + const Output& R, + const Output& B, + const Output& P, + size_t hidden_size, + op::LSTMWeightsFormat weights_format, + const vector& activations, + const vector& activations_alpha, + const vector& activations_beta, + float clip, + bool input_forget) + : RNNCellBase({X, initial_hidden_state, initial_cell_state, W, R, B, P}, + hidden_size, + clip, + activations, + activations_alpha, + activations_beta) , m_activation_f{get_activation_function(0)} , m_activation_g{get_activation_function(1)} , m_activation_h{get_activation_function(2)} @@ -133,101 +140,7 @@ bool ngraph::op::v0::LSTMCell::visit_attributes(AttributeVisitor& visitor) return true; } -void op::LSTMCell::pre_validate_and_infer_types() -{ - set_output_type(0, get_input_element_type(0), PartialShape::dynamic()); - set_output_type(1, get_input_element_type(0), PartialShape::dynamic()); - if (is_dynamic()) - { - return; - } - - const auto& x_pshape = get_input_partial_shape(0); - const auto& ht_pshape = get_input_partial_shape(1); - const auto& ct_pshape = get_input_partial_shape(2); - const auto& w_pshape = get_input_partial_shape(3); - const auto& r_pshape = get_input_partial_shape(4); - - NODE_VALIDATION_CHECK(this, - (x_pshape.is_static() || w_pshape.is_static() || r_pshape.is_static() || - ht_pshape.is_static() || ct_pshape.is_static()), - "LSTMCell supports only static input tensors."); - - const Shape& x_shape{x_pshape.to_shape()}; - - const size_t batch_size = x_shape.at(0); - const size_t input_size = x_shape.at(1); - - const Shape& w_shape{w_pshape.to_shape()}; - const Shape& r_shape{r_pshape.to_shape()}; - const Shape& ht_shape{ht_pshape.to_shape()}; - const Shape& ct_shape{ct_pshape.to_shape()}; - - NODE_VALIDATION_CHECK(this, - (w_shape == Shape{s_gates_count * get_hidden_size(), input_size}), - "Input tensor W must have shape (", - s_gates_count * get_hidden_size(), - ", ", - input_size, - "). Actual shape is:", - w_shape, - "."); - NODE_VALIDATION_CHECK(this, - (r_shape == Shape{s_gates_count * get_hidden_size(), get_hidden_size()}), - "Input tensor R must have shape (", - s_gates_count * get_hidden_size(), - ", ", - get_hidden_size(), - "). Actual shape is:", - r_shape, - "."); - NODE_VALIDATION_CHECK(this, - (ht_shape == Shape{batch_size, get_hidden_size()}), - "Input tensor initial_hidden_state must have shape (", - batch_size, - ", ", - get_hidden_size(), - "). Actual shape is:", - ht_shape, - "."); - NODE_VALIDATION_CHECK(this, - (ct_shape == Shape{batch_size, get_hidden_size()}), - "Input tensor initial_cell_state must have shape (", - batch_size, - ", ", - get_hidden_size(), - "). Actual shape is:", - ct_shape, - "."); - - const auto& b_pshape = get_input_partial_shape(5); - const auto& p_pshape = get_input_partial_shape(6); - - NODE_VALIDATION_CHECK(this, - (b_pshape.is_static() || p_pshape.is_static()), - "LSTMCell supports only static input tensors."); - - const Shape& b_shape{b_pshape.to_shape()}; - const Shape& p_shape{p_pshape.to_shape()}; - - NODE_VALIDATION_CHECK(this, - (b_shape == Shape{s_gates_count * get_hidden_size()}), - "Input tensor B must have shape (", - s_gates_count * get_hidden_size(), - "). Actual shape is:", - b_shape, - "."); - - NODE_VALIDATION_CHECK(this, - (p_shape == Shape{s_peepholes_count * get_hidden_size()}), - "Input tensor P must have shape (", - s_peepholes_count * get_hidden_size(), - "). Actual shape is:", - p_shape, - "."); -} - -void op::LSTMCell::validate_and_infer_types() +void op::v0::LSTMCell::validate_and_infer_types() { std::vector input_param{}; @@ -367,137 +280,273 @@ void op::LSTMCell::validate_and_infer_types() set_output_type(1, result_et, {merged_batch_size, merged_hidden_size}); } -OutputVector op::LSTMCell::decompose_op() const +Output op::v0::LSTMCell::get_default_bias_input() const +{ + return Output{op::Constant::create( + get_input_element_type(0), Shape{s_gates_count * get_hidden_size()}, vector{0.f})}; +} + +Output op::v0::LSTMCell::get_default_peepholes_input() const { - // ------ VARIABLE'S NAMES AND ACRONYM DEFINITIONS ------ - // The names used below are analogous to the one used in ONNX documentation. - // - // ------ ACRONYMS ------ - // i - input gate - // o - output gate - // f - forget gate - // c - cell gate - // t - time step (t-1 means previous time step) - // Wb - W bias vectors for input, output, forget, and cell gates. - // Rb - R bias vectors for input, output, forget, and cell gates. - // P - The peephole weights for input, output and forget gates. - // ------ VARIABLE NAMES ------ - // X - The input data tensor. Shape: [batch_size, input_size]. - // W - The weight matrix for input, forget, cell and output gates - // Shape: [4*hidden_size, input_size] - // R - The recurrence weight matrix for input, forget, cell and output gates. - // Shape: [4*hidden_size, hidden_size]. - // H_t - The hidden state tensor at current time step. Shape: [batch_size, hidden_size]. - // C_t - The cell state tensor at current time step. Shape: [batch_size, hidden_size]. - // bias - The sum of biases (weight and recurrence) for input, forget, cell and output gates. - // Shape: [4 * hidden_size] - // p_[iof] - The peephole weight vector for respectively: input, output, and forget gates. - // Each peephole has shape [hidden_size]. - // - // (.) - Denotes element-wise multiplication. - // * - Denotes dot product. - // - // ---- Equations ---- - // f, g, h - are activation functions. - // it = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Pi (.) Ct-1 + Wbi + Rbi) - // ft = f(Xt*(Wf^T) + Ht-1*(Rf^T) + Pf (.) Ct-1 + Wbf + Rbf) - // ct = g(Xt*(Wc^T) + Ht-1*(Rc^T) + Wbc + Rbc) - // Ct = ft (.) Ct-1 + it (.) ct - // ot = f(Xt*(Wo^T) + Ht-1*(Ro^T) + Po (.) Ct + Wbo + Rbo) - // Ht = ot (.) h(Ct) - // -------------------- - - Output X = input_value(0); - Output H_t = input_value(1); - Output C_t = input_value(2); - Output W = input_value(3); - Output R = input_value(4); - Output bias = input_value(5); - OutputVector p_iof = builder::split(input_value(6), s_peepholes_count); - - // Converting to IFCO format since it's DNNL default. - if (m_weights_format != op::LSTMWeightsFormat::IFCO) + return Output{op::Constant::create(get_input_element_type(0), + Shape{s_peepholes_count * get_hidden_size()}, + vector{0.f})}; +} + +shared_ptr op::v0::LSTMCell::clone_with_new_inputs(const OutputVector& new_args) const +{ + check_new_args_count(this, new_args); + if (new_args.size() == 5) { - W = convert_node_format(W); - R = convert_node_format(R); - bias = convert_node_format(bias); + return make_shared(new_args.at(0), + new_args.at(1), + new_args.at(2), + new_args.at(3), + new_args.at(4), + get_hidden_size(), + get_weights_format(), + get_activations(), + get_activations_alpha(), + get_activations_beta(), + get_clip(), + m_input_forget); } - - const auto& p_i = p_iof.at(0); - const auto& p_o = p_iof.at(1); - const auto& p_f = p_iof.at(2); - - // Xt*(W^T) -- for [iofc] gates. - auto Xt_W = make_shared(X, builder::opset1::transpose(W)); - // Ht-1*(R^T) -- for [iofc] gates. - auto Ht_R = make_shared(H_t, builder::opset1::transpose(R)); - // Xt*(W^T) + Ht-1*(R^T) + Wb + Rb -- for [iofc] gates. - auto gates = add(Xt_W, add(Ht_R, bias)); - - OutputVector split_gates = builder::split(gates, 4, -1); - auto i_t = split_gates.at(0); - auto f_t = split_gates.at(1); - auto c_t = split_gates.at(2); - auto o_t = split_gates.at(3); - - // f(Xt*(Wi^T) + Ht-1*(Ri^T) + Pi (.) Ct-1 + Wbi + Rbi) - i_t = m_activation_f(clip(add(i_t, mul(p_i, C_t)))); - if (m_input_forget) + else if (new_args.size() == 6) + { + return make_shared(new_args.at(0), + new_args.at(1), + new_args.at(2), + new_args.at(3), + new_args.at(4), + new_args.at(5), + get_hidden_size(), + get_weights_format(), + get_activations(), + get_activations_alpha(), + get_activations_beta(), + get_clip(), + m_input_forget); + } + else if (new_args.size() == 7) { - // Couple input with forget gate: 1 - i_t - f_t = sub(op::Constant::create(i_t.get_element_type(), - i_t.get_shape(), - vector(shape_size(i_t.get_shape()), 1.f)), - i_t); + return make_shared(new_args.at(0), + new_args.at(1), + new_args.at(2), + new_args.at(3), + new_args.at(4), + new_args.at(5), + new_args.at(6), + get_hidden_size(), + get_weights_format(), + get_activations(), + get_activations_alpha(), + get_activations_beta(), + get_clip(), + m_input_forget); } else { - // f(Xt*(Wf^T) + Ht-1*(Rf^T) + Pf (.) Ct-1 + Wbf + Rbf) - f_t = m_activation_f(clip(add(f_t, mul(p_f, C_t)))); + throw ngraph_error("Incorrect number of new arguments"); } - // ft (.) Ct-1 + it (.) ct - auto C = add(mul(f_t, C_t), mul(i_t, m_activation_g(clip(c_t)))); - // f(Xt*(Wo^T) + Ht-1*(Ro^T) + Po (.) Ct + Wbo + Rbo) - o_t = m_activation_f(clip(add(o_t, mul(p_o, C)))); - // ot (.) h(Ct) - auto H = mul(o_t, m_activation_h(clip(C))); - - return {H, C}; } -Output op::LSTMCell::get_default_bias_input() const +namespace ngraph { - return Output{op::Constant::create( - get_input_element_type(0), Shape{s_gates_count * get_hidden_size()}, vector{0.f})}; + template <> + EnumNames& EnumNames::get() + { + static auto enum_names = + EnumNames("op::LSTMWeightsFormat", + {{"fico", op::LSTMWeightsFormat::FICO}, + {"icof", op::LSTMWeightsFormat::ICOF}, + {"ifco", op::LSTMWeightsFormat::IFCO}, + {"ifoc", op::LSTMWeightsFormat::IFOC}, + {"iofc", op::LSTMWeightsFormat::IOFC}}); + return enum_names; + } + + constexpr DiscreteTypeInfo AttributeAdapter::type_info; + + std::ostream& operator<<(std::ostream& s, const op::LSTMWeightsFormat& type) + { + return s << as_string(type); + } +} // namespace ngraph + +op::v4::LSTMCell::LSTMCell() +{ + m_activations = {"sigmoid", "tanh", "tanh"}; + m_activation_f = get_activation_function(0); + m_activation_g = get_activation_function(1); + m_activation_h = get_activation_function(2); } -Output op::LSTMCell::get_default_peepholes_input() const +op::v4::LSTMCell::LSTMCell(const Output& X, + const Output& initial_hidden_state, + const Output& initial_cell_state, + const Output& W, + const Output& R, + size_t hidden_size, + const vector& activations, + const vector& activations_alpha, + const vector& activations_beta, + float clip) + : RNNCellBase({X, initial_hidden_state, initial_cell_state, W, R}, + hidden_size, + clip, + activations, + activations_alpha, + activations_beta) + , m_activation_f{get_activation_function(0)} + , m_activation_g{get_activation_function(1)} + , m_activation_h{get_activation_function(2)} { - return Output{op::Constant::create(get_input_element_type(0), - Shape{s_peepholes_count * get_hidden_size()}, - vector{0.f})}; + set_argument(5, get_default_bias_input()); + constructor_validate_and_infer_types(); } -shared_ptr op::LSTMCell::convert_node_format(const Output& node) const +op::v4::LSTMCell::LSTMCell(const Output& X, + const Output& initial_hidden_state, + const Output& initial_cell_state, + const Output& W, + const Output& R, + const Output& B, + size_t hidden_size, + const vector& activations, + const vector& activations_alpha, + const vector& activations_beta, + float clip) + : RNNCellBase({X, initial_hidden_state, initial_cell_state, W, R, B}, + hidden_size, + clip, + activations, + activations_alpha, + activations_beta) + , m_activation_f{get_activation_function(0)} + , m_activation_g{get_activation_function(1)} + , m_activation_h{get_activation_function(2)} +{ + constructor_validate_and_infer_types(); +} + +bool ngraph::op::v4::LSTMCell::visit_attributes(AttributeVisitor& visitor) +{ + return op::util::RNNCellBase::visit_attributes(visitor); +} + +void op::v4::LSTMCell::validate_and_infer_types() { - static const std::map> gate_order_conversion_map{ - {op::LSTMWeightsFormat::FICO, {1, 0, 2, 3}}, - {op::LSTMWeightsFormat::ICOF, {0, 3, 1, 2}}, - {op::LSTMWeightsFormat::IFOC, {0, 1, 3, 2}}, - {op::LSTMWeightsFormat::IOFC, {0, 2, 3, 1}}, - }; - - OutputVector splitted_node = builder::split(node, s_gates_count); - OutputVector nodes_in_new_format; - nodes_in_new_format.reserve(s_gates_count); - for (const auto& axis : gate_order_conversion_map.at(m_weights_format)) + auto merged_batch_size = Dimension::dynamic(); + auto merged_hidden_size = Dimension::dynamic(); + auto result_et = element::dynamic; + + // Get input partial shape for all inputs + const auto& x_pshape = get_input_partial_shape(0); + const auto& ht_pshape = get_input_partial_shape(1); + const auto& ct_pshape = get_input_partial_shape(2); + const auto& w_pshape = get_input_partial_shape(3); + const auto& r_pshape = get_input_partial_shape(4); + const auto& b_pshape = get_input_partial_shape(5); + + // Validate rank and dimension for initial_cell_state input + NODE_VALIDATION_CHECK(this, + (ct_pshape.rank().is_static()), + "LSTMCell input tensor initial_cell_state shall have static rank."); + + NODE_VALIDATION_CHECK(this, + (ct_pshape.rank().get_length() == 2), + "LSTMCell input tensor initial_cell_state shall have dimension 2D."); + + validate_input_rank_dimension({x_pshape, ht_pshape, w_pshape, r_pshape, b_pshape}); + + // Validate input element types and save result for output type + NODE_VALIDATION_CHECK( + this, + element::Type::merge(result_et, result_et, get_input_element_type(0)) && + element::Type::merge(result_et, result_et, get_input_element_type(1)) && + element::Type::merge(result_et, result_et, get_input_element_type(2)) && + element::Type::merge(result_et, result_et, get_input_element_type(3)) && + element::Type::merge(result_et, result_et, get_input_element_type(4)) && + element::Type::merge(result_et, result_et, get_input_element_type(5)), + "Element types for X, initial_hidden_state, initial_cell_state, W, R and B do not match."); + + // Merge batch_size dimension across all inputs to evaluate output[0] dimension + NODE_VALIDATION_CHECK( + this, + Dimension::merge(merged_batch_size, merged_batch_size, ht_pshape[0]) && + Dimension::merge(merged_batch_size, merged_batch_size, ct_pshape[0]) && + Dimension::merge(merged_batch_size, merged_batch_size, x_pshape[0]), + "Parameter batch_size not matched for X, initial_hidden_state or initial_cell_state " + "inputs."); + + // Merge hidden_size dimension across all inputs to evaluate output[1] dimension + NODE_VALIDATION_CHECK( + this, + Dimension::merge(merged_hidden_size, merged_hidden_size, ht_pshape[1]) && + Dimension::merge(merged_hidden_size, merged_hidden_size, ct_pshape[1]) && + Dimension::merge(merged_hidden_size, merged_hidden_size, r_pshape[1]), + "Parameter hidden_size not matched for R, initial_hidden_state and initial_cell_state " + "inputs."); + + // Validate hidden_size value for W, R and P inputs + if (merged_hidden_size.is_static()) { - nodes_in_new_format.push_back(splitted_node.at(axis)); + if (w_pshape[0].is_static()) + { + NODE_VALIDATION_CHECK( + this, + w_pshape[0].compatible(merged_hidden_size * s_gates_count), + "Parameter hidden_size mistmatched in W input. Current value is: ", + w_pshape[0].get_length(), + ", expected: ", + merged_hidden_size.get_length() * s_gates_count, + "."); + } + + if (r_pshape[0].is_static()) + { + NODE_VALIDATION_CHECK( + this, + r_pshape[0].compatible(merged_hidden_size * s_gates_count), + "Parameter hidden_size mistmatched in R input. Current value is: ", + r_pshape[0].get_length(), + ", expected: ", + merged_hidden_size.get_length() * s_gates_count, + "."); + } + + if (b_pshape[0].is_static()) + { + NODE_VALIDATION_CHECK( + this, + b_pshape[0].compatible(merged_hidden_size * s_gates_count), + "Parameter hidden_size mistmatched in B input. Current value is: ", + b_pshape[0].get_length(), + ", expected: ", + merged_hidden_size.get_length() * s_gates_count, + "."); + } } - return make_shared(nodes_in_new_format, 0); + + // Mark inputs which are relevant to output parameters + set_input_is_relevant_to_shape(0); + set_input_is_relevant_to_shape(1); + set_input_is_relevant_to_shape(2); + set_input_is_relevant_to_shape(4); + + // Set output size, type and shape + set_output_size(2); + set_output_type(0, result_et, {merged_batch_size, merged_hidden_size}); + set_output_type(1, result_et, {merged_batch_size, merged_hidden_size}); +} + +Output op::v4::LSTMCell::get_default_bias_input() const +{ + return Output{op::Constant::create( + get_input_element_type(0), Shape{s_gates_count * get_hidden_size()}, vector{0.f})}; } -shared_ptr op::LSTMCell::clone_with_new_inputs(const OutputVector& new_args) const +shared_ptr op::v4::LSTMCell::clone_with_new_inputs(const OutputVector& new_args) const { check_new_args_count(this, new_args); if (new_args.size() == 5) @@ -508,12 +557,10 @@ shared_ptr op::LSTMCell::clone_with_new_inputs(const OutputVector& new_arg new_args.at(3), new_args.at(4), get_hidden_size(), - get_weights_format(), get_activations(), get_activations_alpha(), get_activations_beta(), - get_clip(), - m_input_forget); + get_clip()); } else if (new_args.size() == 6) { @@ -524,55 +571,13 @@ shared_ptr op::LSTMCell::clone_with_new_inputs(const OutputVector& new_arg new_args.at(4), new_args.at(5), get_hidden_size(), - get_weights_format(), get_activations(), get_activations_alpha(), get_activations_beta(), - get_clip(), - m_input_forget); - } - else if (new_args.size() == 7) - { - return make_shared(new_args.at(0), - new_args.at(1), - new_args.at(2), - new_args.at(3), - new_args.at(4), - new_args.at(5), - new_args.at(6), - get_hidden_size(), - get_weights_format(), - get_activations(), - get_activations_alpha(), - get_activations_beta(), - get_clip(), - m_input_forget); + get_clip()); } else { throw ngraph_error("Incorrect number of new arguments"); } } - -namespace ngraph -{ - template <> - EnumNames& EnumNames::get() - { - static auto enum_names = - EnumNames("op::LSTMWeightsFormat", - {{"fico", op::LSTMWeightsFormat::FICO}, - {"icof", op::LSTMWeightsFormat::ICOF}, - {"ifco", op::LSTMWeightsFormat::IFCO}, - {"ifoc", op::LSTMWeightsFormat::IFOC}, - {"iofc", op::LSTMWeightsFormat::IOFC}}); - return enum_names; - } - - constexpr DiscreteTypeInfo AttributeAdapter::type_info; - - std::ostream& operator<<(std::ostream& s, const op::LSTMWeightsFormat& type) - { - return s << as_string(type); - } -} // namespace ngraph diff --git a/ngraph/core/src/op/lstm_sequence.cpp b/ngraph/core/src/op/lstm_sequence.cpp index ec11ef4..10a5b75 100644 --- a/ngraph/core/src/op/lstm_sequence.cpp +++ b/ngraph/core/src/op/lstm_sequence.cpp @@ -22,13 +22,16 @@ #include "ngraph/builder/split.hpp" #include "ngraph/opsets/opset1.hpp" +#include "ngraph/opsets/opset4.hpp" #include "ngraph/op/util/recurrent_sequence.hpp" using namespace ngraph; using namespace std; +constexpr NodeTypeInfo op::v1::LSTMSequence::type_info; constexpr NodeTypeInfo op::v0::LSTMSequence::type_info; + bool ngraph::op::v0::LSTMSequence::visit_attributes(AttributeVisitor& visitor) { visitor.on_attribute("hidden_size", m_hidden_size); @@ -415,3 +418,165 @@ void op::v0::LSTMSequence::validate_and_infer_types() set_output_type(1, result_et, {merged_batch_size, merged_num_directions, merged_hidden_size}); set_output_type(2, result_et, {merged_batch_size, merged_num_directions, merged_hidden_size}); } + +bool ngraph::op::v1::LSTMSequence::visit_attributes(AttributeVisitor& visitor) +{ + visitor.on_attribute("direction", m_direction); + return op::util::RNNCellBase::visit_attributes(visitor); +} + +shared_ptr op::v1::LSTMSequence::clone_with_new_inputs(const OutputVector& new_args) const +{ + check_new_args_count(this, new_args); + if (new_args.size() == 7) + { + return make_shared(new_args.at(0), // X + new_args.at(1), // initial_hidden_state + new_args.at(2), // initial_cell_state + new_args.at(3), // sequence_lengths + new_args.at(4), // W + new_args.at(5), // R + new_args.at(6), // B + m_hidden_size, + m_direction, + m_activations_alpha, + m_activations_beta, + m_activations, + m_clip); + } + else + { + throw ngraph_error("Incorrect number of new arguments"); + } +} + +void op::v1::LSTMSequence::validate_and_infer_types() +{ + std::vector input_param{}; + + auto lstm_seq_gates_count = 4; + auto merged_batch_size = Dimension::dynamic(); + auto merged_hidden_size = Dimension::dynamic(); + auto merged_num_directions = Dimension::dynamic(); + auto result_et = element::dynamic; + + // Copy all inputs without initial_cell_state information for further validation + for (size_t i = 0; i < get_input_size(); i++) + { + // exclude initial_cell_state from the loop + if (i != 2) + { + input_param.push_back(get_input_partial_shape(i)); + } + } + + // Get input partial shape for all inputs + const auto& x_pshape = get_input_partial_shape(0); + const auto& ht_pshape = get_input_partial_shape(1); + const auto& ct_pshape = get_input_partial_shape(2); + const auto& sl_pshape = get_input_partial_shape(3); + const auto& w_pshape = get_input_partial_shape(4); + const auto& r_pshape = get_input_partial_shape(5); + const auto& b_pshape = get_input_partial_shape(6); + + ngraph::op::util::validate_seq_input_rank_dimension(input_param); + + // Validate rank and dimension for initial_cell_state input + NODE_VALIDATION_CHECK(this, + (ct_pshape.rank().is_static()), + "LSTMSequence input tensor initial_cell_state shall have static rank."); + + NODE_VALIDATION_CHECK(this, + (ct_pshape.rank().get_length() == 3), + "LSTMSequence input tensor initial_cell_state shall have dimension 3D."); + + // Validate input types and save result for output type + NODE_VALIDATION_CHECK( + this, + element::Type::merge(result_et, result_et, get_input_element_type(0)) && + element::Type::merge(result_et, result_et, get_input_element_type(1)) && + element::Type::merge(result_et, result_et, get_input_element_type(2)) && + element::Type::merge(result_et, result_et, get_input_element_type(4)) && + element::Type::merge(result_et, result_et, get_input_element_type(5)) && + element::Type::merge(result_et, result_et, get_input_element_type(6)), + "Element types for X, initial_hidden_state, initial_cell_state, W, R and B inputs do not " + "match."); + + // Merge batch_size dimension across all inputs to evaluate output[0] dimension + NODE_VALIDATION_CHECK( + this, + Dimension::merge(merged_batch_size, merged_batch_size, ht_pshape[0]) && + Dimension::merge(merged_batch_size, merged_batch_size, ct_pshape[0]) && + Dimension::merge(merged_batch_size, merged_batch_size, x_pshape[0]) && + Dimension::merge(merged_batch_size, merged_batch_size, sl_pshape[0]), + "Parameter batch_size not matched in LSTMSequence."); + + // Merge hidden_size dimension across all inputs to evaluate output dimension + NODE_VALIDATION_CHECK( + this, + Dimension::merge(merged_hidden_size, merged_hidden_size, ht_pshape[2]) && + Dimension::merge(merged_hidden_size, merged_hidden_size, ct_pshape[2]) && + Dimension::merge(merged_hidden_size, merged_hidden_size, r_pshape[2]), + "Parameter hidden_size not matched LSTMSequence."); + + // Merge num_directions dimension across all inputs to evaluate output dimension + NODE_VALIDATION_CHECK( + this, + Dimension::merge(merged_num_directions, merged_num_directions, ht_pshape[1]) && + Dimension::merge(merged_num_directions, merged_num_directions, ct_pshape[1]) && + Dimension::merge(merged_num_directions, merged_num_directions, w_pshape[0]) && + Dimension::merge(merged_num_directions, merged_num_directions, r_pshape[0]) && + Dimension::merge(merged_num_directions, merged_num_directions, b_pshape[0]), + "Parameter num_directions not matched in LSTMSequence."); + + // Validate hidden_size value for W, R, B inputs + if (merged_hidden_size.is_static()) + { + if (w_pshape[0].is_static()) + { + NODE_VALIDATION_CHECK( + this, + w_pshape[1].compatible(merged_hidden_size * lstm_seq_gates_count), + "Parameter hidden_size mistmatched in W input. Current value is: ", + w_pshape[1].get_length(), + ", expected: ", + merged_hidden_size.get_length() * lstm_seq_gates_count, + "."); + } + + if (r_pshape[0].is_static()) + { + NODE_VALIDATION_CHECK( + this, + r_pshape[1].compatible(merged_hidden_size * lstm_seq_gates_count), + "Parameter hidden_size mistmatched in R input. Current value is: ", + r_pshape[1].get_length(), + ", expected: ", + merged_hidden_size.get_length() * lstm_seq_gates_count, + "."); + } + + if (b_pshape[0].is_static()) + { + NODE_VALIDATION_CHECK( + this, + b_pshape[1].compatible(merged_hidden_size * lstm_seq_gates_count), + "Parameter hidden_size mistmatched in B input. Current value is: ", + b_pshape[1].get_length(), + ", expected: ", + merged_hidden_size.get_length() * lstm_seq_gates_count, + "."); + } + } + + // Mark inputs which are relevant to output parameters + for (size_t i = 0; i <= 6; ++i) + set_input_is_relevant_to_shape(i); + + // Set output size, type and shape + set_output_size(3); + set_output_type( + 0, result_et, {merged_batch_size, merged_num_directions, x_pshape[1], merged_hidden_size}); + set_output_type(1, result_et, {merged_batch_size, merged_num_directions, merged_hidden_size}); + set_output_type(2, result_et, {merged_batch_size, merged_num_directions, merged_hidden_size}); +} diff --git a/ngraph/core/src/op/rnn_cell.cpp b/ngraph/core/src/op/rnn_cell.cpp index e34425d..6310b23 100644 --- a/ngraph/core/src/op/rnn_cell.cpp +++ b/ngraph/core/src/op/rnn_cell.cpp @@ -14,156 +14,79 @@ // limitations under the License. //***************************************************************************** +#include "ngraph/op/rnn_cell.hpp" #include -#include - +#include "itt.hpp" #include "ngraph/builder/reshape.hpp" -#include "ngraph/builder/split.hpp" -#include "ngraph/op/add.hpp" #include "ngraph/op/constant.hpp" #include "ngraph/op/dot.hpp" -#include "ngraph/op/rnn_cell.hpp" #include "ngraph/shape.hpp" #include "ngraph/type/element_type.hpp" using namespace std; using namespace ngraph; -NGRAPH_SUPPRESS_DEPRECATED_START +constexpr NodeTypeInfo op::v0::RNNCell::type_info; -constexpr NodeTypeInfo op::RNNCell::type_info; - -op::RNNCell::RNNCell() +op::v0::RNNCell::RNNCell() { m_activations = {"tanh"}; m_activation_f = get_activation_function(0); } -op::RNNCell::RNNCell(const Output& X, - const Output& initial_hidden_state, - const Output& W, - const Output& R, - size_t hidden_size, - const vector& activations, - const vector& activations_alpha, - const vector& activations_beta, - float clip) - : FusedOp({X, initial_hidden_state, W, R}) - , RNNCellBase(hidden_size, clip, activations, activations_alpha, activations_beta) +op::v0::RNNCell::RNNCell(const Output& X, + const Output& initial_hidden_state, + const Output& W, + const Output& R, + size_t hidden_size, + const vector& activations, + const vector& activations_alpha, + const vector& activations_beta, + float clip) + : RNNCellBase({X, initial_hidden_state, W, R}, + hidden_size, + clip, + activations, + activations_alpha, + activations_beta) , m_activation_f{get_activation_function(0)} { set_argument(4, get_default_bias_input()); constructor_validate_and_infer_types(); } -op::RNNCell::RNNCell(const Output& X, - const Output& initial_hidden_state, - const Output& W, - const Output& R, - const Output& B, - size_t hidden_size, - const vector& activations, - const vector& activations_alpha, - const vector& activations_beta, - float clip) - : FusedOp({X, initial_hidden_state, W, R, B}) - , RNNCellBase(hidden_size, clip, activations, activations_alpha, activations_beta) +op::v0::RNNCell::RNNCell(const Output& X, + const Output& initial_hidden_state, + const Output& W, + const Output& R, + const Output& B, + size_t hidden_size, + const vector& activations, + const vector& activations_alpha, + const vector& activations_beta, + float clip) + : RNNCellBase({X, initial_hidden_state, W, R, B}, + hidden_size, + clip, + activations, + activations_alpha, + activations_beta) , m_activation_f{get_activation_function(0)} { constructor_validate_and_infer_types(); } -bool op::RNNCell::visit_attributes(AttributeVisitor& visitor) +bool op::v0::RNNCell::visit_attributes(AttributeVisitor& visitor) { return op::util::RNNCellBase::visit_attributes(visitor); } -void op::RNNCell::pre_validate_and_infer_types() +void op::v0::RNNCell::validate_and_infer_types() { - set_output_type(0, get_input_element_type(0), PartialShape::dynamic()); - - if (is_dynamic()) - { - return; - } - - const auto& x_pshape = get_input_partial_shape(0); - const auto& ht_pshape = get_input_partial_shape(1); - const auto& w_pshape = get_input_partial_shape(2); - const auto& r_pshape = get_input_partial_shape(3); - - NODE_VALIDATION_CHECK(this, - (x_pshape.is_static() || w_pshape.is_static() || r_pshape.is_static() || - ht_pshape.is_static()), - "RNNCell supports only static input tensors."); - - const Shape& x_shape{x_pshape.to_shape()}; - - const size_t batch_size = x_shape.at(0); - const size_t input_size = x_shape.at(1); - - const Shape& w_shape{w_pshape.to_shape()}; - const Shape& r_shape{r_pshape.to_shape()}; - const Shape& ht_shape{ht_pshape.to_shape()}; - - NODE_VALIDATION_CHECK(this, - (w_shape == Shape{get_hidden_size(), input_size}), - "Input tensor W must have shape (", - get_hidden_size(), - ", ", - input_size, - "). Actual shape is:", - w_shape, - "."); - NODE_VALIDATION_CHECK(this, - (r_shape == Shape{get_hidden_size(), get_hidden_size()}), - "Input tensor R must have shape (", - get_hidden_size(), - ", ", - get_hidden_size(), - "). Actual shape is:", - w_shape, - "."); - NODE_VALIDATION_CHECK(this, - (ht_shape == Shape{batch_size, get_hidden_size()}), - "Input tensor initial_hidden_state must have shape (", - batch_size, - ", ", - get_hidden_size(), - "). Actual shape is:", - w_shape, - "."); - - const auto& b_pshape = get_input_partial_shape(4); - - NODE_VALIDATION_CHECK( - this, b_pshape.is_static(), "RNNCell supports only static input tensors."); - - const Shape& b_shape{b_pshape.to_shape()}; - - NODE_VALIDATION_CHECK(this, - (b_shape == Shape{get_hidden_size()}), - "Input tensor B must have shape (", - get_hidden_size(), - "). Actual shape is:", - b_shape, - "."); -} - -void op::RNNCell::validate_and_infer_types() -{ - std::vector input_param{}; - auto merged_batch_size = Dimension::dynamic(); auto merged_hidden_size = Dimension::dynamic(); auto result_et = element::dynamic; - // Copy all inputs for further validation - for (size_t i = 0; i < get_input_size(); i++) - { - input_param.push_back(get_input_partial_shape(i)); - } - // Get input partial shape for all inputs const auto& x_pshape = get_input_partial_shape(0); const auto& ht_pshape = get_input_partial_shape(1); @@ -171,7 +94,7 @@ void op::RNNCell::validate_and_infer_types() const auto& r_pshape = get_input_partial_shape(3); const auto& b_pshape = get_input_partial_shape(4); - validate_input_rank_dimension(input_param); + validate_input_rank_dimension({x_pshape, ht_pshape, w_pshape, r_pshape, b_pshape}); // Validate input types and save result for output type NODE_VALIDATION_CHECK( @@ -238,72 +161,23 @@ void op::RNNCell::validate_and_infer_types() } // Mark inputs which are relevant to output parameters - set_input_is_relevant_to_shape(0); - set_input_is_relevant_to_shape(1); - set_input_is_relevant_to_shape(2); - set_input_is_relevant_to_shape(3); - set_input_is_relevant_to_shape(4); + for (size_t i = 0; i <= 4; ++i) + set_input_is_relevant_to_shape(i); // Set output size, type and shape set_output_size(1); set_output_type(0, result_et, {merged_batch_size, merged_hidden_size}); } -OutputVector op::RNNCell::decompose_op() const -{ - // ------ VARIABLE'S NAMES AND ACRONYM DEFINITIONS ------ - // The names used below are analogous to the one used in ONNX documentation. - // - // ------ ACRONYMS ------ - // i_t - input gate at current time step - // t - time step (t-1 means previous time step) - // X - The input data tensor. Shape: [batch_size, input_size]. - // W - The weight tensor for input gate. Shape: [hidden_size, input_size]. - // R - The recurrence weight tensor for input gate. Shape: [hidden_size, hidden_size]. - // H_t - The hidden state tensor at current time step. Shape: [batch_size, hidden_size]. - // B - The bias tensor for the input gate. Shape: [hidden_size]. - // Wb - W bias vectors for input gate. - // Rb - R bias vectors for input gate. - // ------ VARIABLE NAMES ------ - // Xt_W - Input sequence multiplied by weights tensor at current time step. - // Ht_R - Hidden state multiplied by weights tensor at current time step. - - // (.) - Denotes element-wise multiplication. - // * - Denotes dot product. - - // ---- Equations ---- - // f - is activation functions. - // Ht = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Wbi + Rbi) - // -------------------- - - Output X = input_value(0); - Output H_t = input_value(1); - Output W = input_value(2); - Output R = input_value(3); - Output bias = input_value(4); - - // Xt*(W^T) - auto Xt_W = std::make_shared(X, builder::opset1::transpose(W)); - // Ht-1*(R^T) - auto Ht_R = std::make_shared(H_t, builder::opset1::transpose(R)); - // Xt*(W^T) + Ht-1*(R^T) + Wb + Rb - auto i_t = add(Xt_W, add(Ht_R, bias)); - - // f(Xt*(Wi^T) + Ht-1*(Ri^T) + Wbi + Rbi) - i_t = m_activation_f(clip(i_t)); - - return {i_t}; -} - -Output op::RNNCell::get_default_bias_input() const +Output op::v0::RNNCell::get_default_bias_input() const { return Output{ - op::Constant::create(get_input_element_type(0), - Shape{s_gates_count * get_hidden_size()}, - vector(s_gates_count * get_hidden_size(), 0.f))}; + op::v0::Constant::create(get_input_element_type(0), + Shape{s_gates_count * get_hidden_size()}, + vector(s_gates_count * get_hidden_size(), 0.f))}; } -shared_ptr op::RNNCell::clone_with_new_inputs(const OutputVector& new_args) const +shared_ptr op::v0::RNNCell::clone_with_new_inputs(const OutputVector& new_args) const { check_new_args_count(this, new_args); if (new_args.size() == 4) diff --git a/ngraph/core/src/op/split.cpp b/ngraph/core/src/op/split.cpp index e8ac326..ecf1160 100644 --- a/ngraph/core/src/op/split.cpp +++ b/ngraph/core/src/op/split.cpp @@ -13,8 +13,8 @@ // See the License for the specific language governing permissions and // limitations under the License. //***************************************************************************** +#include "ngraph/runtime/reference/split.hpp" #include - #include "ngraph/attribute_visitor.hpp" #include "ngraph/builder/split.hpp" #include "ngraph/op/constant.hpp" @@ -23,8 +23,6 @@ #include "ngraph/validation_util.hpp" #include "ngraph/runtime/host_tensor.hpp" -#include "ngraph/runtime/reference/slice.hpp" - NGRAPH_SUPPRESS_DEPRECATED_START using namespace std; @@ -196,20 +194,25 @@ shared_ptr op::v1::Split::clone_with_new_inputs(const OutputVector& new_ar namespace { - inline bool evaluate(const HostTensorPtr& in, - const HostTensorPtr& out, - const Coordinate& lower_bounds, - const Coordinate& upper_bounds) + inline bool evaluate(const HostTensorPtr& data_tensor, + const HostTensorVector& outputs, + const int64_t axis, + const int64_t num_splits) { - runtime::reference::slice(in->get_data_ptr(), - out->get_data_ptr(), - in->get_shape(), - lower_bounds, - upper_bounds, - Strides(lower_bounds.size(), 1), - out->get_shape(), - in->get_element_type().size()); - + Shape output_shape = data_tensor->get_shape(); + std::vector outputs_data(num_splits); + output_shape.at(axis) /= num_splits; + for (size_t i = 0; i < outputs.size(); ++i) + { + outputs[i]->set_shape(output_shape); + outputs_data[i] = outputs[i]->get_data_ptr(); + } + ngraph::runtime::reference::split(data_tensor->get_data_ptr(), + data_tensor->get_shape(), + data_tensor->get_element_type().size(), + axis, + num_splits, + outputs_data.data()); return true; } @@ -236,26 +239,7 @@ namespace break; } axis = ngraph::normalize_axis(split_node, axis, data_tensor->get_partial_shape().rank()); - - const auto data_shape = data_tensor->get_shape(); - const size_t axis_dim_length = data_shape.at(axis); - const size_t part_length = axis_dim_length / num_splits; - - Shape output_shape = data_shape; - output_shape.at(axis) = part_length; - - std::vector lower_bounds(data_shape.size(), 0); - std::vector upper_bounds = data_shape; - upper_bounds.at(axis) = part_length; - - for (const auto& output : outputs) - { - output->set_shape(output_shape); - evaluate(data_tensor, output, lower_bounds, upper_bounds); - lower_bounds.at(axis) += part_length; - upper_bounds.at(axis) += part_length; - } - + evaluate(data_tensor, outputs, axis, num_splits); return true; } } diff --git a/ngraph/core/src/op/util/rnn_cell_base.cpp b/ngraph/core/src/op/util/rnn_cell_base.cpp index 202cd31..1683ed8 100644 --- a/ngraph/core/src/op/util/rnn_cell_base.cpp +++ b/ngraph/core/src/op/util/rnn_cell_base.cpp @@ -24,11 +24,38 @@ #include "ngraph/op/multiply.hpp" #include "ngraph/op/subtract.hpp" #include "ngraph/op/util/rnn_cell_base.hpp" +#include "ngraph/opsets/opset4.hpp" #include "ngraph/util.hpp" using namespace std; using namespace ngraph; +std::shared_ptr ngraph::op::util::convert_lstm_node_format(const Output& node, + LSTMWeightsFormat from_format, + LSTMWeightsFormat to_format) +{ + static const std::map> gate_order_map{ + {op::util::LSTMWeightsFormat::FICO, {0, 1, 2, 3}}, + {op::util::LSTMWeightsFormat::ICOF, {1, 2, 3, 0}}, + {op::util::LSTMWeightsFormat::IFOC, {1, 0, 3, 2}}, + {op::util::LSTMWeightsFormat::IOFC, {1, 3, 0, 2}}, + {op::util::LSTMWeightsFormat::IFCO, {1, 0, 2, 3}}, + }; + const auto& from = gate_order_map.at(from_format); + const auto& to = gate_order_map.at(to_format); + size_t num_gates = 4; + + auto axis_const = std::make_shared(element::i64, Shape{}, 0); + OutputVector splitted_node = + std::make_shared(node, axis_const, num_gates)->outputs(); + OutputVector nodes_in_new_format(num_gates); + for (size_t i = 0; i < num_gates; ++i) + { + nodes_in_new_format[to[from[i]]] = splitted_node[i]; + } + return std::make_shared(nodes_in_new_format, 0); +} + // Modify input vector in-place and return reference to modified vector. static vector to_lower_case(const vector& vs) { @@ -43,12 +70,14 @@ op::util::RNNCellBase::RNNCellBase() { } -op::util::RNNCellBase::RNNCellBase(size_t hidden_size, +op::util::RNNCellBase::RNNCellBase(const OutputVector& args, + size_t hidden_size, float clip, const vector& activations, const vector& activations_alpha, const vector& activations_beta) - : m_hidden_size(hidden_size) + : Op(args) + , m_hidden_size(hidden_size) , m_clip(clip) , m_activations(to_lower_case(activations)) , m_activations_alpha(activations_alpha) diff --git a/ngraph/frontend/onnx_import/src/op/lstm.cpp b/ngraph/frontend/onnx_import/src/op/lstm.cpp index e575de1..ed07732 100644 --- a/ngraph/frontend/onnx_import/src/op/lstm.cpp +++ b/ngraph/frontend/onnx_import/src/op/lstm.cpp @@ -29,6 +29,7 @@ #include "ngraph/op/constant.hpp" #include "ngraph/op/lstm_sequence.hpp" #include "ngraph/op/util/attr_types.hpp" +#include "ngraph/opsets/opset3.hpp" #include "ngraph/shape.hpp" #include "ngraph/type/element_type.hpp" #include "onnx_import/core/null_node.hpp" @@ -212,7 +213,10 @@ namespace ngraph LSTMNgInputMap input_map{node}; LSTMAttributes attributes{node}; - auto lstmSequence = std::make_shared( + // LSTMSequence is not fully supported in OpenVINO and is excluded from + // opset4 (current the latest opset version), use one of the previous + // opsets instead of default + auto lstmSequence = std::make_shared( input_map.at(LSTMInput::LSTM_INPUT_X), input_map.at(LSTMInput::LSTM_INPUT_INIT_H), input_map.at(LSTMInput::LSTM_INPUT_INIT_C), diff --git a/ngraph/python/src/ngraph/opset4/__init__.py b/ngraph/python/src/ngraph/opset4/__init__.py index 07d2c07..b6a179c 100644 --- a/ngraph/python/src/ngraph/opset4/__init__.py +++ b/ngraph/python/src/ngraph/opset4/__init__.py @@ -82,7 +82,7 @@ from ngraph.opset1.ops import logical_not from ngraph.opset1.ops import logical_or from ngraph.opset1.ops import logical_xor from ngraph.opset1.ops import lrn -from ngraph.opset1.ops import lstm_cell +from ngraph.opset4.ops import lstm_cell from ngraph.opset1.ops import lstm_sequence from ngraph.opset1.ops import matmul from ngraph.opset1.ops import max_pool diff --git a/ngraph/python/src/ngraph/opset4/ops.py b/ngraph/python/src/ngraph/opset4/ops.py index 8149a32..badc360 100644 --- a/ngraph/python/src/ngraph/opset4/ops.py +++ b/ngraph/python/src/ngraph/opset4/ops.py @@ -367,3 +367,54 @@ def reduce_l2( return _get_node_factory_opset4().create( "ReduceL2", as_nodes(node, reduction_axes), {"keep_dims": keep_dims} ) + + +@nameable_op +def lstm_cell( + X: NodeInput, + initial_hidden_state: NodeInput, + initial_cell_state: NodeInput, + W: NodeInput, + R: NodeInput, + B: NodeInput, + hidden_size: int, + activations: List[str] = None, + activations_alpha: List[float] = None, + activations_beta: List[float] = None, + clip: float = 0.0, + name: Optional[str] = None, +) -> Node: + """Return a node which performs LSTMCell operation. + + :param X: The input tensor with shape: [batch_size, input_size]. + :param initial_hidden_state: The hidden state tensor with shape: [batch_size, hidden_size]. + :param initial_cell_state: The cell state tensor with shape: [batch_size, hidden_size]. + :param W: The weight tensor with shape: [4*hidden_size, input_size]. + :param R: The recurrence weight tensor with shape: [4*hidden_size, hidden_size]. + :param B: The bias tensor for gates with shape: [4*hidden_size]. + :param hidden_size: Specifies hidden state size. + :param activations: The list of three activation functions for gates. + :param activations_alpha: The list of alpha parameters for activation functions. + :param activations_beta: The list of beta parameters for activation functions. + :param clip: Specifies bound values [-C, C] for tensor clipping performed before activations. + :param name: An optional name of the output node. + + :return: The new node represents LSTMCell. Node outputs count: 2. + """ + if activations is None: + activations = ["sigmoid", "tanh", "tanh"] + if activations_alpha is None: + activations_alpha = [] + if activations_beta is None: + activations_beta = [] + + node_inputs = as_nodes(X, initial_hidden_state, initial_cell_state, W, R, B) + + attributes = { + "hidden_size": hidden_size, + "activations": activations, + "activations_alpha": activations_alpha, + "activations_beta": activations_beta, + "clip": clip, + } + return _get_node_factory_opset4().create("LSTMCell", node_inputs, attributes) diff --git a/ngraph/python/tests/test_ngraph/test_create_op.py b/ngraph/python/tests/test_ngraph/test_create_op.py index 5d1eae3..674b8d0 100644 --- a/ngraph/python/tests/test_ngraph/test_create_op.py +++ b/ngraph/python/tests/test_ngraph/test_create_op.py @@ -18,6 +18,7 @@ import pytest from _pyngraph import PartialShape import ngraph as ng +import ngraph.opset1 as ng_opset1 from ngraph.impl import Type np_types = [np.float32, np.int32] @@ -231,6 +232,62 @@ def test_lstm_cell_operator(dtype): @pytest.mark.parametrize("dtype", [np.float32, np.float64]) +def test_lstm_cell_operator_opset1(dtype): + batch_size = 1 + input_size = 16 + hidden_size = 128 + + X_shape = [batch_size, input_size] + H_t_shape = [batch_size, hidden_size] + C_t_shape = [batch_size, hidden_size] + W_shape = [4 * hidden_size, input_size] + R_shape = [4 * hidden_size, hidden_size] + B_shape = [4 * hidden_size] + + parameter_X = ng.parameter(X_shape, name="X", dtype=dtype) + parameter_H_t = ng.parameter(H_t_shape, name="H_t", dtype=dtype) + parameter_C_t = ng.parameter(C_t_shape, name="C_t", dtype=dtype) + parameter_W = ng.parameter(W_shape, name="W", dtype=dtype) + parameter_R = ng.parameter(R_shape, name="R", dtype=dtype) + parameter_B = ng.parameter(B_shape, name="B", dtype=dtype) + + expected_shape = [1, 128] + + node_default = ng_opset1.lstm_cell( + parameter_X, parameter_H_t, parameter_C_t, parameter_W, parameter_R, parameter_B, hidden_size, + ) + + assert node_default.get_type_name() == "LSTMCell" + assert node_default.get_output_size() == 2 + assert list(node_default.get_output_shape(0)) == expected_shape + assert list(node_default.get_output_shape(1)) == expected_shape + + activations = ["tanh", "Sigmoid", "RELU"] + activation_alpha = [1.0, 2.0, 3.0] + activation_beta = [3.0, 2.0, 1.0] + clip = 0.5 + + node_param = ng_opset1.lstm_cell( + parameter_X, + parameter_H_t, + parameter_C_t, + parameter_W, + parameter_R, + parameter_B, + hidden_size, + activations, + activation_alpha, + activation_beta, + clip, + ) + + assert node_param.get_type_name() == "LSTMCell" + assert node_param.get_output_size() == 2 + assert list(node_param.get_output_shape(0)) == expected_shape + assert list(node_param.get_output_shape(1)) == expected_shape + + +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) def test_lstm_sequence_operator_bidirectional(dtype): batch_size = 1 input_size = 16 @@ -255,7 +312,7 @@ def test_lstm_sequence_operator_bidirectional(dtype): parameter_B = ng.parameter(B_shape, name="B", dtype=dtype) direction = "BIDIRECTIONAL" - node = ng.lstm_sequence( + node = ng_opset1.lstm_sequence( parameter_X, parameter_H_t, parameter_C_t, @@ -275,7 +332,7 @@ def test_lstm_sequence_operator_bidirectional(dtype): activation_beta = [3.0, 2.0, 1.0] clip = 1.22 - node_param = ng.lstm_sequence( + node_param = ng_opset1.lstm_sequence( parameter_X, parameter_H_t, parameter_C_t, @@ -321,7 +378,7 @@ def test_lstm_sequence_operator_reverse(dtype): direction = "REVERSE" - node_default = ng.lstm_sequence( + node_default = ng_opset1.lstm_sequence( parameter_X, parameter_H_t, parameter_C_t, @@ -341,7 +398,7 @@ def test_lstm_sequence_operator_reverse(dtype): activation_beta = [3.0, 2.0, 1.0] clip = 1.22 - node_param = ng.lstm_sequence( + node_param = ng_opset1.lstm_sequence( parameter_X, parameter_H_t, parameter_C_t, @@ -387,7 +444,7 @@ def test_lstm_sequence_operator_forward(dtype): direction = "forward" - node_default = ng.lstm_sequence( + node_default = ng_opset1.lstm_sequence( parameter_X, parameter_H_t, parameter_C_t, @@ -407,7 +464,7 @@ def test_lstm_sequence_operator_forward(dtype): activation_beta = [1.0] clip = 0.5 - node = ng.lstm_sequence( + node = ng_opset1.lstm_sequence( parameter_X, parameter_H_t, parameter_C_t, diff --git a/ngraph/test/attributes.cpp b/ngraph/test/attributes.cpp index 03972a4..093cfbf 100644 --- a/ngraph/test/attributes.cpp +++ b/ngraph/test/attributes.cpp @@ -20,6 +20,7 @@ #include "ngraph/op/util/attr_types.hpp" #include "ngraph/opsets/opset1.hpp" #include "ngraph/opsets/opset3.hpp" +#include "ngraph/opsets/opset4.hpp" #include "util/visitor.hpp" @@ -1063,7 +1064,7 @@ TEST(attributes, lrn_op) TEST(attributes, lstm_cell_op) { - FactoryRegistry::get().register_factory(); + FactoryRegistry::get().register_factory(); auto X = make_shared(element::f32, Shape{2, 3}); auto H = make_shared(element::f32, Shape{2, 3}); auto W = make_shared(element::f32, Shape{12, 3}); @@ -1072,40 +1073,33 @@ TEST(attributes, lstm_cell_op) const auto initial_cell_state = make_shared(element::f32, Shape{2, 3}); const auto hidden_size = 3; - const auto weights_format = op::LSTMWeightsFormat::ICOF; const std::vector activations = {"tanh", "sigmoid", "tanh"}; auto activations_alpha = std::vector{1.0, 1.5}; auto activations_beta = std::vector{2.0, 1.0}; const float clip = 0.5f; - bool input_forget = true; - - const auto lstm_cell = make_shared(X, + const auto lstm_cell = make_shared(X, initial_hidden_state, initial_cell_state, W, R, hidden_size, - weights_format, activations, activations_alpha, activations_beta, - clip, - input_forget); + clip); NodeBuilder builder(lstm_cell); - auto g_lstm_cell = as_type_ptr(builder.create()); + auto g_lstm_cell = as_type_ptr(builder.create()); EXPECT_EQ(g_lstm_cell->get_hidden_size(), lstm_cell->get_hidden_size()); EXPECT_EQ(g_lstm_cell->get_activations(), lstm_cell->get_activations()); EXPECT_EQ(g_lstm_cell->get_activations_alpha(), lstm_cell->get_activations_alpha()); EXPECT_EQ(g_lstm_cell->get_activations_beta(), lstm_cell->get_activations_beta()); EXPECT_EQ(g_lstm_cell->get_clip(), lstm_cell->get_clip()); - EXPECT_EQ(g_lstm_cell->get_input_forget(), lstm_cell->get_input_forget()); - EXPECT_EQ(g_lstm_cell->get_weights_format(), lstm_cell->get_weights_format()); } TEST(attributes, lstm_sequence_op) { - FactoryRegistry::get().register_factory(); + FactoryRegistry::get().register_factory(); const size_t batch_size = 4; const size_t num_directions = 2; @@ -1127,14 +1121,12 @@ TEST(attributes, lstm_sequence_op) const auto B = make_shared(element::f32, Shape{num_directions, 4 * hidden_size}); const auto lstm_direction = op::RecurrentSequenceDirection::BIDIRECTIONAL; - const auto weights_format = op::LSTMWeightsFormat::ICOF; const std::vector activations_alpha = {1, 2, 3}; const std::vector activations_beta = {4, 5, 6}; const std::vector activations = {"tanh", "sigmoid", "tanh"}; const float clip_threshold = 0.5f; - const bool input_forget = true; - const auto lstm_sequence = make_shared(X, + const auto lstm_sequence = make_shared(X, initial_hidden_state, initial_cell_state, sequence_lengths, @@ -1143,23 +1135,19 @@ TEST(attributes, lstm_sequence_op) B, hidden_size, lstm_direction, - weights_format, activations_alpha, activations_beta, activations, - clip_threshold, - input_forget); + clip_threshold); NodeBuilder builder(lstm_sequence); - auto g_lstm_sequence = as_type_ptr(builder.create()); + auto g_lstm_sequence = as_type_ptr(builder.create()); EXPECT_EQ(g_lstm_sequence->get_hidden_size(), lstm_sequence->get_hidden_size()); EXPECT_EQ(g_lstm_sequence->get_activations(), lstm_sequence->get_activations()); EXPECT_EQ(g_lstm_sequence->get_activations_alpha(), lstm_sequence->get_activations_alpha()); EXPECT_EQ(g_lstm_sequence->get_activations_beta(), lstm_sequence->get_activations_beta()); - EXPECT_EQ(g_lstm_sequence->get_clip_threshold(), lstm_sequence->get_clip_threshold()); + EXPECT_EQ(g_lstm_sequence->get_clip(), lstm_sequence->get_clip()); EXPECT_EQ(g_lstm_sequence->get_direction(), lstm_sequence->get_direction()); - EXPECT_EQ(g_lstm_sequence->get_input_forget(), lstm_sequence->get_input_forget()); - EXPECT_EQ(g_lstm_sequence->get_weights_format(), lstm_sequence->get_weights_format()); } TEST(attributes, shuffle_channels_op) diff --git a/ngraph/test/backend/fused_op.in.cpp b/ngraph/test/backend/fused_op.in.cpp index ca84700..7544b61 100644 --- a/ngraph/test/backend/fused_op.in.cpp +++ b/ngraph/test/backend/fused_op.in.cpp @@ -33,7 +33,9 @@ #include "gtest/gtest.h" #include "ngraph/check.hpp" #include "ngraph/ngraph.hpp" +#include "ngraph/opsets/opset4.hpp" #include "ngraph/op/util/attr_types.hpp" +#include "ngraph/op/util/rnn_cell_base.hpp" #include "op/group_conv.hpp" #include "util/all_close.hpp" #include "util/all_close_f.hpp" @@ -1629,11 +1631,17 @@ NGRAPH_TEST(${BACKEND_NAME}, lstm_cell_zero_bias_peepholes) const auto B = make_shared(element::f32, Shape{gates_count * hidden_size}); const auto P = make_shared(element::f32, Shape{3 * hidden_size}); - const auto lstm_cell = make_shared( - X, H_t, C_t, W, R, B, P, hidden_size, op::LSTMWeightsFormat::IOFC); + const auto lstm_cell = make_shared( + X, + H_t, + C_t, + op::util::convert_lstm_node_format(W, op::util::LSTMWeightsFormat::IOFC), + op::util::convert_lstm_node_format(R, op::util::LSTMWeightsFormat::IOFC), + op::util::convert_lstm_node_format(B, op::util::LSTMWeightsFormat::IOFC), + hidden_size); auto ht_function = make_shared(OutputVector{lstm_cell->output(0)}, - ParameterVector{X, H_t, C_t, W, R, B, P}); + ParameterVector{X, H_t, C_t, W, R, B}); auto ht_test_case = test::TestCase(ht_function); // X @@ -1665,18 +1673,16 @@ NGRAPH_TEST(${BACKEND_NAME}, lstm_cell_zero_bias_peepholes) // P vector in_P(3 * hidden_size, 0.f); - ht_test_case.add_multiple_inputs( - vector>{in_X, in_Ht, in_Ct, in_W, in_R, in_B, in_P}); + ht_test_case.add_multiple_inputs(vector>{in_X, in_Ht, in_Ct, in_W, in_R, in_B}); ht_test_case.add_expected_output( Shape{batch_size, hidden_size}, {0.81457126f, 0.61109227f, 0.769522f, 0.52239674f, 0.4324641f, 0.63183f}); ht_test_case.run(); auto ct_function = make_shared(OutputVector{lstm_cell->output(1)}, - ParameterVector{X, H_t, C_t, W, R, B, P}); + ParameterVector{X, H_t, C_t, W, R, B}); auto ct_test_case = test::TestCase(ct_function); - ct_test_case.add_multiple_inputs( - vector>{in_X, in_Ht, in_Ct, in_W, in_R, in_B, in_P}); + ct_test_case.add_multiple_inputs(vector>{in_X, in_Ht, in_Ct, in_W, in_R, in_B}); ct_test_case.add_expected_output( Shape{batch_size, hidden_size}, {1.4444952f, 0.9635685f, 1.2875274f, 0.8053419f, 0.7184521f, 0.95803297f}); @@ -1700,11 +1706,10 @@ NGRAPH_TEST(${BACKEND_NAME}, lstm_cell_bias_peepholes) const auto B = make_shared(element::f32, Shape{gates_count * hidden_size}); const auto P = make_shared(element::f32, Shape{3 * hidden_size}); - const auto lstm_cell = make_shared( - X, H_t, C_t, W, R, B, P, hidden_size, op::LSTMWeightsFormat::IOFC); + const auto lstm_cell = make_shared(X, H_t, C_t, W, R, B, hidden_size); auto ht_function = make_shared(OutputVector{lstm_cell->output(0)}, - ParameterVector{X, H_t, C_t, W, R, B, P}); + ParameterVector{X, H_t, C_t, W, R, B}); auto ht_test_case = test::TestCase(ht_function); // X @@ -1755,18 +1760,16 @@ NGRAPH_TEST(${BACKEND_NAME}, lstm_cell_bias_peepholes) 0.13840231f, 0.24175227f}; - ht_test_case.add_multiple_inputs( - vector>{in_X, in_Ht, in_Ct, in_W, in_R, in_B, in_P}); + ht_test_case.add_multiple_inputs(vector>{in_X, in_Ht, in_Ct, in_W, in_R, in_B}); ht_test_case.add_expected_output( Shape{batch_size, hidden_size}, {0.9218244f, 0.78787273f, 0.8754273f, 0.7361462f, 0.70927656f, 0.83522964f}); ht_test_case.run(); auto ct_function = make_shared(OutputVector{lstm_cell->output(1)}, - ParameterVector{X, H_t, C_t, W, R, B, P}); + ParameterVector{X, H_t, C_t, W, R, B}); auto ct_test_case = test::TestCase(ct_function); - ct_test_case.add_multiple_inputs( - vector>{in_X, in_Ht, in_Ct, in_W, in_R, in_B, in_P}); + ct_test_case.add_multiple_inputs(vector>{in_X, in_Ht, in_Ct, in_W, in_R, in_B}); ct_test_case.add_expected_output( Shape{batch_size, hidden_size}, {1.7094649f, 1.1259761f, 1.444019f, 1.086587f, 0.9762144f, 1.3066899f}); @@ -1792,22 +1795,19 @@ NGRAPH_TEST(${BACKEND_NAME}, lstm_cell_bias_peepholes_clip_input_forget) const auto B = make_shared(element::f32, Shape{gates_count * hidden_size}); const auto P = make_shared(element::f32, Shape{3 * hidden_size}); - const auto lstm_cell = make_shared(X, - H_t, - C_t, - W, - R, - B, - P, - hidden_size, - op::LSTMWeightsFormat::IOFC, - vector{"sigmoid", "tanh", "tanh"}, - vector{}, - vector{}, - clip_threshold, - input_forget); + const auto lstm_cell = make_shared(X, + H_t, + C_t, + W, + R, + B, + hidden_size, + vector{"sigmoid", "tanh", "tanh"}, + vector{}, + vector{}, + clip_threshold); auto ht_function = make_shared(OutputVector{lstm_cell->output(0)}, - ParameterVector{X, H_t, C_t, W, R, B, P}); + ParameterVector{X, H_t, C_t, W, R, B}); auto ht_test_case = test::TestCase(ht_function); // X @@ -1858,18 +1858,16 @@ NGRAPH_TEST(${BACKEND_NAME}, lstm_cell_bias_peepholes_clip_input_forget) 0.13840231f, 0.24175227f}; - ht_test_case.add_multiple_inputs( - vector>{in_X, in_Ht, in_Ct, in_W, in_R, in_B, in_P}); + ht_test_case.add_multiple_inputs(vector>{in_X, in_Ht, in_Ct, in_W, in_R, in_B}); ht_test_case.add_expected_output( Shape{batch_size, hidden_size}, {0.71485436f, 0.71844107f, 0.72704613f, 0.6235602f, 0.68306124f, 0.6978715f}); ht_test_case.run(); auto ct_function = make_shared(OutputVector{lstm_cell->output(1)}, - ParameterVector{X, H_t, C_t, W, R, B, P}); + ParameterVector{X, H_t, C_t, W, R, B}); auto ct_test_case = test::TestCase(ct_function); - ct_test_case.add_multiple_inputs( - vector>{in_X, in_Ht, in_Ct, in_W, in_R, in_B, in_P}); + ct_test_case.add_multiple_inputs(vector>{in_X, in_Ht, in_Ct, in_W, in_R, in_B}); ct_test_case.add_expected_output( Shape{batch_size, hidden_size}, {0.94656503f, 0.9527454f, 0.9706756f, 0.84206575f, 0.91898793f, 0.9127192f}); @@ -1898,22 +1896,19 @@ NGRAPH_TEST(${BACKEND_NAME}, lstm_cell_activaction_functions) const auto B = make_shared(element::f32, Shape{gates_count * hidden_size}); const auto P = make_shared(element::f32, Shape{3 * hidden_size}); - const auto lstm_cell = make_shared(X, - H_t, - C_t, - W, - R, - B, - P, - hidden_size, - op::LSTMWeightsFormat::IOFC, - activations, - activation_alpha, - activation_beta, - clip_threshold, - input_forget); + const auto lstm_cell = make_shared(X, + H_t, + C_t, + W, + R, + B, + hidden_size, + activations, + activation_alpha, + activation_beta, + clip_threshold); auto ht_function = make_shared(OutputVector{lstm_cell->output(0)}, - ParameterVector{X, H_t, C_t, W, R, B, P}); + ParameterVector{X, H_t, C_t, W, R, B}); auto ht_test_case = test::TestCase(ht_function); // X @@ -1964,18 +1959,16 @@ NGRAPH_TEST(${BACKEND_NAME}, lstm_cell_activaction_functions) 0.13840231f, 0.24175227f}; - ht_test_case.add_multiple_inputs( - vector>{in_X, in_Ht, in_Ct, in_W, in_R, in_B, in_P}); + ht_test_case.add_multiple_inputs(vector>{in_X, in_Ht, in_Ct, in_W, in_R, in_B}); ht_test_case.add_expected_output( Shape{batch_size, hidden_size}, {0.96834344f, 0.9695254f, 0.97068775f, 0.9077866f, 0.94161016f, 0.96599925f}); ht_test_case.run(); auto ct_function = make_shared(OutputVector{lstm_cell->output(1)}, - ParameterVector{X, H_t, C_t, W, R, B, P}); + ParameterVector{X, H_t, C_t, W, R, B}); auto ct_test_case = test::TestCase(ct_function); - ct_test_case.add_multiple_inputs( - vector>{in_X, in_Ht, in_Ct, in_W, in_R, in_B, in_P}); + ct_test_case.add_multiple_inputs(vector>{in_X, in_Ht, in_Ct, in_W, in_R, in_B}); ct_test_case.add_expected_output( Shape{batch_size, hidden_size}, {0.94656503f, 0.9527454f, 0.9706756f, 0.84206575f, 0.91898793f, 0.9127192f}); @@ -2168,7 +2161,7 @@ NGRAPH_TEST(${BACKEND_NAME}, rnn_cell_no_bias) const auto W = make_shared(element::f32, Shape{hidden_size, input_size}); const auto R = make_shared(element::f32, Shape{hidden_size, hidden_size}); - const auto rnn_cell = make_shared(X, H_t, W, R, hidden_size); + const auto rnn_cell = make_shared(X, H_t, W, R, hidden_size); auto function = make_shared(rnn_cell, ParameterVector{X, H_t, W, R}); auto test_case = test::TestCase(function); @@ -2219,16 +2212,16 @@ NGRAPH_TEST(${BACKEND_NAME}, rnn_cell_bias_clip) const auto R = make_shared(element::f32, Shape{hidden_size, hidden_size}); const auto B = make_shared(element::f32, Shape{hidden_size}); - const auto rnn_cell = make_shared(X, - H_t, - W, - R, - B, - hidden_size, - vector{"tanh"}, - vector{}, - vector{}, - clip); + const auto rnn_cell = make_shared(X, + H_t, + W, + R, + B, + hidden_size, + vector{"tanh"}, + vector{}, + vector{}, + clip); auto function = make_shared(rnn_cell, ParameterVector{X, H_t, W, R, B}); auto test_case = test::TestCase(function); @@ -2281,16 +2274,16 @@ NGRAPH_TEST(${BACKEND_NAME}, rnn_cell_activation_function) const auto R = make_shared(element::f32, Shape{hidden_size, hidden_size}); const auto B = make_shared(element::f32, Shape{hidden_size}); - const auto rnn_cell = make_shared(X, - H_t, - W, - R, - B, - hidden_size, - vector{"sigmoid"}, - vector{}, - vector{}, - clip); + const auto rnn_cell = make_shared(X, + H_t, + W, + R, + B, + hidden_size, + vector{"sigmoid"}, + vector{}, + vector{}, + clip); auto function = make_shared(rnn_cell, ParameterVector{X, H_t, W, R, B}); auto test_case = test::TestCase(function); @@ -2347,17 +2340,17 @@ NGRAPH_TEST(${BACKEND_NAME}, gru_cell_bias_clip) const auto H_t = make_shared(element::f32, Shape{batch_size, hidden_size}); const auto B = make_shared(element::f32, Shape{gates_count * hidden_size}); - const auto gru_cell = make_shared(X, - H_t, - W, - R, - B, - hidden_size, - vector{"sigmoid", "tanh"}, - vector{}, - vector{}, - clip, - linear_before_reset); + const auto gru_cell = make_shared(X, + H_t, + W, + R, + B, + hidden_size, + vector{"sigmoid", "tanh"}, + vector{}, + vector{}, + clip, + linear_before_reset); auto function = make_shared(gru_cell, ParameterVector{X, H_t, W, R, B}); auto test_case = test::TestCase(function); @@ -2420,17 +2413,17 @@ NGRAPH_TEST(${BACKEND_NAME}, gru_cell_linear_before_reset) const auto H_t = make_shared(element::f32, Shape{batch_size, hidden_size}); const auto B = make_shared(element::f32, Shape{(gates_count + 1) * hidden_size}); - const auto gru_cell = make_shared(X, - H_t, - W, - R, - B, - hidden_size, - vector{"sigmoid", "tanh"}, - vector{}, - vector{}, - clip, - linear_before_reset); + const auto gru_cell = make_shared(X, + H_t, + W, + R, + B, + hidden_size, + vector{"sigmoid", "tanh"}, + vector{}, + vector{}, + clip, + linear_before_reset); auto function = make_shared(gru_cell, ParameterVector{X, H_t, W, R, B}); auto test_case = test::TestCase(function); @@ -2492,17 +2485,17 @@ NGRAPH_TEST(${BACKEND_NAME}, gru_cell_activation_function) const auto H_t = make_shared(element::f32, Shape{batch_size, hidden_size}); const auto B = make_shared(element::f32, Shape{(gates_count + 1) * hidden_size}); - const auto gru_cell = make_shared(X, - H_t, - W, - R, - B, - hidden_size, - vector{"hardsigmoid", "hardsigmoid"}, - vector{1.8345f, 1.8345f}, - vector{3.05f, 3.05f}, - clip, - linear_before_reset); + const auto gru_cell = make_shared(X, + H_t, + W, + R, + B, + hidden_size, + vector{"hardsigmoid", "hardsigmoid"}, + vector{1.8345f, 1.8345f}, + vector{3.05f, 3.05f}, + clip, + linear_before_reset); auto function = make_shared(gru_cell, ParameterVector{X, H_t, W, R, B}); auto test_case = test::TestCase(function); diff --git a/ngraph/test/op_is.cpp b/ngraph/test/op_is.cpp index b4ee8e5..059f1de 100644 --- a/ngraph/test/op_is.cpp +++ b/ngraph/test/op_is.cpp @@ -346,7 +346,7 @@ namespace void op_is_GRUCell() { - op::GRUCell node; + op::v3::GRUCell node; EXPECT_FALSE(op::is_unary_elementwise_arithmetic(&node)); EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node)); EXPECT_FALSE(op::is_binary_elementwise_comparison(&node)); @@ -472,7 +472,7 @@ namespace void op_is_LSTMCell() { - op::LSTMCell node; + op::v4::LSTMCell node; EXPECT_FALSE(op::is_unary_elementwise_arithmetic(&node)); EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node)); EXPECT_FALSE(op::is_binary_elementwise_comparison(&node)); @@ -481,7 +481,7 @@ namespace void op_is_LSTMSequence() { - op::LSTMSequence node; + op::v0::LSTMSequence node; EXPECT_FALSE(op::is_unary_elementwise_arithmetic(&node)); EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node)); EXPECT_FALSE(op::is_binary_elementwise_comparison(&node)); @@ -733,7 +733,7 @@ namespace void op_is_RNNCell() { - op::RNNCell node; + op::v0::RNNCell node; EXPECT_FALSE(op::is_unary_elementwise_arithmetic(&node)); EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node)); EXPECT_FALSE(op::is_binary_elementwise_comparison(&node)); diff --git a/ngraph/test/runtime/ie/unit_test.manifest b/ngraph/test/runtime/ie/unit_test.manifest index d3858e9..783fe8c 100644 --- a/ngraph/test/runtime/ie/unit_test.manifest +++ b/ngraph/test/runtime/ie/unit_test.manifest @@ -1085,14 +1085,14 @@ IE_CPU.builder_opset1_collapse_dyn_shape # IE_CPU.interpolate_down_scales_const_linear # GRUCell operation has a form that is not supported -IE_CPU.onnx_model_gru_defaults_fwd -IE_CPU.onnx_model_gru_fwd_activations -IE_CPU.onnx_model_gru_fwd_mixed_seq_len -IE_CPU.onnx_model_gru_rev_clip -IE_CPU.onnx_model_gru_reverse -IE_CPU.onnx_model_gru_fwd_bias_initial_h -IE_CPU.onnx_model_gru_bidirectional -IE_CPU.onnx_model_gru_fwd_linear_before_reset +onnx_model_gru_defaults_fwd +onnx_model_gru_fwd_activations +onnx_model_gru_fwd_mixed_seq_len +onnx_model_gru_rev_clip +onnx_model_gru_reverse +onnx_model_gru_fwd_bias_initial_h +onnx_model_gru_bidirectional +onnx_model_gru_fwd_linear_before_reset # Not implemented Interpolate-4: IE_CPU.onnx_model_resize10_import_only diff --git a/ngraph/test/runtime/interpreter/int_executable.hpp b/ngraph/test/runtime/interpreter/int_executable.hpp index 20152a8..0f4c4ea 100644 --- a/ngraph/test/runtime/interpreter/int_executable.hpp +++ b/ngraph/test/runtime/interpreter/int_executable.hpp @@ -59,8 +59,10 @@ #include "ngraph/runtime/reference/floor.hpp" #include "ngraph/runtime/reference/gather.hpp" #include "ngraph/runtime/reference/gather_nd.hpp" +#include "ngraph/runtime/reference/gru_cell.hpp" #include "ngraph/runtime/reference/log.hpp" #include "ngraph/runtime/reference/lrn.hpp" +#include "ngraph/runtime/reference/lstm_cell.hpp" #include "ngraph/runtime/reference/matmul.hpp" #include "ngraph/runtime/reference/max.hpp" #include "ngraph/runtime/reference/max_pool.hpp" @@ -77,6 +79,7 @@ #include "ngraph/runtime/reference/result.hpp" #include "ngraph/runtime/reference/reverse.hpp" #include "ngraph/runtime/reference/reverse_sequence.hpp" +#include "ngraph/runtime/reference/rnn_cell.hpp" #include "ngraph/runtime/reference/round.hpp" #include "ngraph/runtime/reference/scatter_nd_update.hpp" #include "ngraph/runtime/reference/select.hpp" @@ -692,6 +695,67 @@ protected: } break; } + case OP_TYPEID::GRUCell_v3: + { + const op::v3::GRUCell* gru_cell = static_cast(&node); + runtime::reference::gru_cell(args[0]->get_data_ptr(), + args[0]->get_shape(), + args[1]->get_data_ptr(), + args[1]->get_shape(), + args[2]->get_data_ptr(), + args[2]->get_shape(), + args[3]->get_data_ptr(), + args[3]->get_shape(), + args[4]->get_data_ptr(), + args[4]->get_shape(), + out[0]->get_data_ptr(), + gru_cell->get_activations()[0], + gru_cell->get_activations()[1], + gru_cell->get_clip(), + gru_cell->get_linear_before_reset()); + break; + } + case OP_TYPEID::LSTMCell_v4: + { + const op::v4::LSTMCell* lstm_cell = static_cast(&node); + runtime::reference::lstm_cell(args[0]->get_data_ptr(), + args[0]->get_shape(), + args[1]->get_data_ptr(), + args[1]->get_shape(), + args[2]->get_data_ptr(), + args[2]->get_shape(), + args[3]->get_data_ptr(), + args[3]->get_shape(), + args[4]->get_data_ptr(), + args[4]->get_shape(), + args[5]->get_data_ptr(), + args[5]->get_shape(), + out[0]->get_data_ptr(), + out[1]->get_data_ptr(), + lstm_cell->get_activations()[0], + lstm_cell->get_activations()[1], + lstm_cell->get_activations()[2], + lstm_cell->get_clip()); + break; + } + case OP_TYPEID::RNNCell_v0: + { + const op::v0::RNNCell* rnn_cell = static_cast(&node); + runtime::reference::rnn_cell(args[0]->get_data_ptr(), + args[0]->get_shape(), + args[1]->get_data_ptr(), + args[1]->get_shape(), + args[2]->get_data_ptr(), + args[2]->get_shape(), + args[3]->get_data_ptr(), + args[3]->get_shape(), + args[4]->get_data_ptr(), + args[4]->get_shape(), + out[0]->get_data_ptr(), + rnn_cell->get_activations()[0], + rnn_cell->get_clip()); + break; + } case OP_TYPEID::Log: { size_t element_count = shape_size(node.get_output_shape(0)); @@ -1203,15 +1267,12 @@ protected: case OP_TYPEID::GRN: case OP_TYPEID::GroupConvolution: case OP_TYPEID::GroupConvolutionBackpropData: - case OP_TYPEID::GRUCell: case OP_TYPEID::HardSigmoid: case OP_TYPEID::Interpolate: - case OP_TYPEID::LSTMCell: case OP_TYPEID::LSTMSequence: case OP_TYPEID::MVN: case OP_TYPEID::NormalizeL2: case OP_TYPEID::PRelu: - case OP_TYPEID::RNNCell: case OP_TYPEID::ScatterUpdate_v3: case OP_TYPEID::Selu: case OP_TYPEID::ShuffleChannels: diff --git a/ngraph/test/runtime/interpreter/opset_int_tbl.hpp b/ngraph/test/runtime/interpreter/opset_int_tbl.hpp index 7badf0a..1dadbfa 100644 --- a/ngraph/test/runtime/interpreter/opset_int_tbl.hpp +++ b/ngraph/test/runtime/interpreter/opset_int_tbl.hpp @@ -20,6 +20,7 @@ #define ID_SUFFIX(NAME) NAME##_v0 NGRAPH_OP(DetectionOutput, op::v0) +NGRAPH_OP(RNNCell, op::v0) #undef ID_SUFFIX #define ID_SUFFIX(NAME) NAME##_v1 @@ -31,6 +32,7 @@ NGRAPH_OP(LogicalNot, op::v1) #undef ID_SUFFIX #define ID_SUFFIX(NAME) NAME##_v3 +NGRAPH_OP(GRUCell, op::v3) NGRAPH_OP(EmbeddingBagOffsetsSum, op::v3) NGRAPH_OP(EmbeddingBagPackedSum, op::v3) NGRAPH_OP(EmbeddingSegmentsSum, op::v3) @@ -43,4 +45,5 @@ NGRAPH_OP(ScatterUpdate, op::v3) #define ID_SUFFIX(NAME) NAME##_v4 NGRAPH_OP(CTCLoss, op::v4) +NGRAPH_OP(LSTMCell, op::v4) #undef ID_SUFFIX diff --git a/ngraph/test/runtime/interpreter/unit_test.manifest b/ngraph/test/runtime/interpreter/unit_test.manifest index cb84aca..2c864ef 100644 --- a/ngraph/test/runtime/interpreter/unit_test.manifest +++ b/ngraph/test/runtime/interpreter/unit_test.manifest @@ -105,3 +105,20 @@ INTERPRETER.onnx_model_gatherND_float # Round op doesn't support some specific cases of rounding onnx_model_round_half_nearest_even + +# Unsupported op 'LSTMSequence': not FusedOp anymore, no reference implementation yet +onnx_model_lstm_fwd_with_clip +onnx_model_lstm_fwd_mixed_seq +onnx_model_lstm_fwd_hardsigmoid_activation +onnx_model_lstm_fwd_large_batch_no_clip +onnx_model_lstm_bdir_short_input_seq +onnx_model_lstm_mixed_seq_reverse + +# Activation function hardsigmoid is not supported. +gru_cell_activation_function +lstm_cell_activaction_functions +onnx_model_gru_fwd_activations + +# Peepholes, input_forget are not supported +lstm_cell_bias_peepholes +lstm_cell_bias_peepholes_clip_input_forget diff --git a/ngraph/test/runtime/opset0_tbl.hpp b/ngraph/test/runtime/opset0_tbl.hpp index 6a24d2f..ec14923 100644 --- a/ngraph/test/runtime/opset0_tbl.hpp +++ b/ngraph/test/runtime/opset0_tbl.hpp @@ -81,7 +81,6 @@ NGRAPH_OP(Exp, ngraph::op) NGRAPH_OP(FakeQuantize, ngraph::op) NGRAPH_OP(Floor, ngraph::op) NGRAPH_OP(GRN, ngraph::op) -NGRAPH_OP(GRUCell, ngraph::op) NGRAPH_OP(Gather, ngraph::op) NGRAPH_OP(GatherND, ngraph::op) NGRAPH_OP(Gelu, ngraph::op) @@ -95,8 +94,7 @@ NGRAPH_OP(Less, ngraph::op) NGRAPH_OP(LessEq, ngraph::op) NGRAPH_OP(Log, ngraph::op) NGRAPH_OP(LRN, ngraph::op) -NGRAPH_OP(LSTMCell, ngraph::op) -NGRAPH_OP(LSTMSequence, ngraph::op) +NGRAPH_OP(LSTMSequence, ngraph::op::v0) NGRAPH_OP(MatMul, ngraph::op) NGRAPH_OP(NormalizeL2, ngraph::op) NGRAPH_OP(Max, ngraph::op) @@ -124,7 +122,6 @@ NGRAPH_OP(Reshape, ngraph::op) NGRAPH_OP(Result, ngraph::op) NGRAPH_OP(Reverse, ngraph::op) NGRAPH_OP(ReverseSequence, ngraph::op) -NGRAPH_OP(RNNCell, ngraph::op) NGRAPH_OP(Round, ngraph::op) NGRAPH_OP(Select, ngraph::op) NGRAPH_OP(Selu, ngraph::op) diff --git a/ngraph/test/type_prop/gru_cell.cpp b/ngraph/test/type_prop/gru_cell.cpp index f673fca..9ed5729 100644 --- a/ngraph/test/type_prop/gru_cell.cpp +++ b/ngraph/test/type_prop/gru_cell.cpp @@ -16,6 +16,7 @@ #include "gtest/gtest.h" #include "ngraph/ngraph.hpp" +#include "ngraph/opsets/opset4.hpp" #include "util/type_prop.hpp" using namespace std; @@ -35,7 +36,7 @@ TEST(type_prop, gru_cell) make_shared(element::f32, Shape{gates_count * hidden_size, hidden_size}); const auto H_t = make_shared(element::f32, Shape{batch_size, hidden_size}); - const auto gru_cell = make_shared(X, H_t, W, R, hidden_size); + const auto gru_cell = make_shared(X, H_t, W, R, hidden_size); EXPECT_EQ(gru_cell->get_output_element_type(0), element::f32); EXPECT_EQ(gru_cell->get_output_shape(0), (Shape{batch_size, hidden_size})); } @@ -56,7 +57,7 @@ TEST(type_prop, gru_cell_invalid_input) auto W = make_shared(element::f32, Shape{hidden_size, input_size}); try { - const auto gru_cell = make_shared(X, H_t, W, R, hidden_size); + const auto gru_cell = make_shared(X, H_t, W, R, hidden_size); FAIL() << "GRUCell node was created with invalid data."; } catch (const NodeValidationFailure& error) @@ -70,7 +71,7 @@ TEST(type_prop, gru_cell_invalid_input) R = make_shared(element::f32, Shape{hidden_size, 1}); try { - const auto gru_cell = make_shared(X, H_t, W, R, hidden_size); + const auto gru_cell = make_shared(X, H_t, W, R, hidden_size); FAIL() << "GRUCell node was created with invalid data."; } catch (const NodeValidationFailure& error) @@ -86,7 +87,7 @@ TEST(type_prop, gru_cell_invalid_input) H_t = make_shared(element::f32, Shape{4, hidden_size}); try { - const auto gru_cell = make_shared(X, H_t, W, R, hidden_size); + const auto gru_cell = make_shared(X, H_t, W, R, hidden_size); FAIL() << "GRUCell node was created with invalid data."; } catch (const NodeValidationFailure& error) @@ -101,7 +102,7 @@ TEST(type_prop, gru_cell_invalid_input) auto B = make_shared(element::f32, Shape{hidden_size}); try { - const auto gru_cell = make_shared(X, H_t, W, R, B, hidden_size); + const auto gru_cell = make_shared(X, H_t, W, R, B, hidden_size); FAIL() << "GRUCell node was created with invalid data."; } catch (const NodeValidationFailure& error) @@ -126,7 +127,7 @@ TEST(type_prop, gru_cell_dynamic_batch_size) const auto H_t = make_shared(element::f32, PartialShape{batch_size, hidden_size}); - const auto gru_cell = make_shared(X, H_t, W, R, hidden_size); + const auto gru_cell = make_shared(X, H_t, W, R, hidden_size); EXPECT_EQ(gru_cell->get_output_element_type(0), element::f32); EXPECT_EQ(gru_cell->get_output_partial_shape(0), (PartialShape{batch_size, hidden_size})); } @@ -146,7 +147,7 @@ TEST(type_prop, gru_cell_dynamic_hidden_size) const auto H_t = make_shared(element::f32, PartialShape{batch_size, hidden_size}); - const auto gru_cell = make_shared(X, H_t, W, R, 3); + const auto gru_cell = make_shared(X, H_t, W, R, 3); EXPECT_EQ(gru_cell->get_output_element_type(0), element::f32); EXPECT_EQ(gru_cell->get_output_partial_shape(0), (PartialShape{batch_size, hidden_size})); } @@ -163,7 +164,7 @@ TEST(type_prop, gru_cell_dynamic_inputs) const auto H_t = make_shared(element::f32, PartialShape{batch_size, hidden_size}); - const auto gru_cell = make_shared(X, H_t, W, R, 2); + const auto gru_cell = make_shared(X, H_t, W, R, 2); EXPECT_EQ(gru_cell->get_output_partial_shape(0), (PartialShape{batch_size, hidden_size})); EXPECT_EQ(gru_cell->get_output_element_type(0), element::f32); @@ -183,33 +184,37 @@ TEST(type_prop, gru_cell_invalid_input_rank0) // Invalid rank0 for W tensor. auto W = make_shared(element::f32, PartialShape{}); - ASSERT_THROW(make_shared(X, H_t, W, R, hidden_size), ngraph::NodeValidationFailure) + ASSERT_THROW(make_shared(X, H_t, W, R, hidden_size), + ngraph::NodeValidationFailure) << "GRUCell node was created with invalid data."; // Invalid rank0 for X tensor. W = make_shared(element::f32, PartialShape{gates_count * hidden_size, input_size}); X = make_shared(element::f32, PartialShape{}); - ASSERT_THROW(make_shared(X, H_t, W, R, hidden_size), ngraph::NodeValidationFailure) + ASSERT_THROW(make_shared(X, H_t, W, R, hidden_size), + ngraph::NodeValidationFailure) << "GRUCell node was created with invalid data."; // Invalid rank0 for H_t tensor. X = make_shared(element::f32, PartialShape{batch_size, input_size}); H_t = make_shared(element::f32, PartialShape{}); - ASSERT_THROW(make_shared(X, H_t, W, R, hidden_size), ngraph::NodeValidationFailure) + ASSERT_THROW(make_shared(X, H_t, W, R, hidden_size), + ngraph::NodeValidationFailure) << "GRUCell node was created with invalid data."; // Invalid rank0 for R tensor. H_t = make_shared(element::f32, PartialShape{batch_size, hidden_size}); R = make_shared(element::f32, PartialShape{}); - ASSERT_THROW(make_shared(X, H_t, W, R, hidden_size), ngraph::NodeValidationFailure) + ASSERT_THROW(make_shared(X, H_t, W, R, hidden_size), + ngraph::NodeValidationFailure) << "GRUCell node was created with invalid data."; // Invalid rank0 for B tensor. R = make_shared(element::f32, PartialShape{gates_count * hidden_size, input_size}); auto B = make_shared(element::f32, PartialShape{}); - ASSERT_THROW(make_shared(X, H_t, W, R, B, hidden_size), + ASSERT_THROW(make_shared(X, H_t, W, R, B, hidden_size), ngraph::NodeValidationFailure) << "GRUCell node was created with invalid data."; } @@ -228,32 +233,36 @@ TEST(type_prop, gru_cell_invalid_input_dynamic_rank) // Invalid dynamic rank for W tensor. auto W = make_shared(element::f32, PartialShape::dynamic(Rank::dynamic())); - ASSERT_THROW(make_shared(X, H_t, W, R, hidden_size), ngraph::NodeValidationFailure) + ASSERT_THROW(make_shared(X, H_t, W, R, hidden_size), + ngraph::NodeValidationFailure) << "GRUCell node was created with invalid data."; // Invalid dynamic rank for X tensor. W = make_shared(element::f32, PartialShape{hidden_size, input_size}); X = make_shared(element::f32, PartialShape::dynamic(Rank::dynamic())); - ASSERT_THROW(make_shared(X, H_t, W, R, hidden_size), ngraph::NodeValidationFailure) + ASSERT_THROW(make_shared(X, H_t, W, R, hidden_size), + ngraph::NodeValidationFailure) << "GRUCell node was created with invalid data."; // Invalid dynamic rank for H_t tensor. X = make_shared(element::f32, PartialShape{batch_size, input_size}); H_t = make_shared(element::f32, PartialShape::dynamic(Rank::dynamic())); - ASSERT_THROW(make_shared(X, H_t, W, R, hidden_size), ngraph::NodeValidationFailure) + ASSERT_THROW(make_shared(X, H_t, W, R, hidden_size), + ngraph::NodeValidationFailure) << "GRUCell node was created with invalid data."; // Invalid dynamic rank for R tensor. H_t = make_shared(element::f32, PartialShape{batch_size, hidden_size}); R = make_shared(element::f32, PartialShape::dynamic(Rank::dynamic())); - ASSERT_THROW(make_shared(X, H_t, W, R, hidden_size), ngraph::NodeValidationFailure) + ASSERT_THROW(make_shared(X, H_t, W, R, hidden_size), + ngraph::NodeValidationFailure) << "GRUCell node was created with invalid data."; // Invalid dynamic rank for B tensor. R = make_shared(element::f32, PartialShape{gates_count * hidden_size, hidden_size}); auto B = make_shared(element::f32, PartialShape::dynamic(Rank::dynamic())); - ASSERT_THROW(make_shared(X, H_t, W, R, B, hidden_size), + ASSERT_THROW(make_shared(X, H_t, W, R, B, hidden_size), ngraph::NodeValidationFailure) << "GRUCell node was created with invalid data."; } diff --git a/ngraph/test/type_prop/lstm_cell.cpp b/ngraph/test/type_prop/lstm_cell.cpp index 2cee103..48b89cd 100644 --- a/ngraph/test/type_prop/lstm_cell.cpp +++ b/ngraph/test/type_prop/lstm_cell.cpp @@ -16,6 +16,7 @@ #include "gtest/gtest.h" #include "ngraph/ngraph.hpp" +#include "ngraph/opsets/opset4.hpp" #include "util/type_prop.hpp" using namespace std; @@ -28,15 +29,15 @@ TEST(type_prop, lstm_cell) const size_t hidden_size = 3; const size_t gates_count = 4; - const auto X = make_shared(element::f32, Shape{batch_size, input_size}); + const auto X = make_shared(element::f32, Shape{batch_size, input_size}); const auto W = - make_shared(element::f32, Shape{gates_count * hidden_size, input_size}); + make_shared(element::f32, Shape{gates_count * hidden_size, input_size}); const auto R = - make_shared(element::f32, Shape{gates_count * hidden_size, hidden_size}); - const auto H_t = make_shared(element::f32, Shape{batch_size, hidden_size}); - const auto C_t = make_shared(element::f32, Shape{batch_size, hidden_size}); + make_shared(element::f32, Shape{gates_count * hidden_size, hidden_size}); + const auto H_t = make_shared(element::f32, Shape{batch_size, hidden_size}); + const auto C_t = make_shared(element::f32, Shape{batch_size, hidden_size}); - const auto lstm_cell = make_shared(X, H_t, C_t, W, R, hidden_size); + const auto lstm_cell = make_shared(X, H_t, C_t, W, R, hidden_size); EXPECT_EQ(lstm_cell->get_hidden_size(), hidden_size); EXPECT_EQ(lstm_cell->get_clip(), 0.f); EXPECT_TRUE(lstm_cell->get_activations_alpha().empty()); @@ -44,8 +45,6 @@ TEST(type_prop, lstm_cell) EXPECT_EQ(lstm_cell->get_activations()[0], "sigmoid"); EXPECT_EQ(lstm_cell->get_activations()[1], "tanh"); EXPECT_EQ(lstm_cell->get_activations()[2], "tanh"); - EXPECT_EQ(lstm_cell->get_weights_format(), op::LSTMWeightsFormat::IFCO); - EXPECT_FALSE(lstm_cell->get_input_forget()); EXPECT_EQ(lstm_cell->get_output_element_type(0), element::f32); EXPECT_EQ(lstm_cell->get_output_shape(0), (Shape{batch_size, hidden_size})); EXPECT_EQ(lstm_cell->get_output_element_type(1), element::f32); @@ -59,17 +58,17 @@ TEST(type_prop, lstm_cell_invalid_input) const size_t hidden_size = 3; const size_t gates_count = 4; - auto X = make_shared(element::f32, Shape{batch_size, input_size}); + auto X = make_shared(element::f32, Shape{batch_size, input_size}); auto R = - make_shared(element::f32, Shape{gates_count * hidden_size, hidden_size}); - auto H_t = make_shared(element::f32, Shape{batch_size, hidden_size}); - auto C_t = make_shared(element::f32, Shape{batch_size, hidden_size}); + make_shared(element::f32, Shape{gates_count * hidden_size, hidden_size}); + auto H_t = make_shared(element::f32, Shape{batch_size, hidden_size}); + auto C_t = make_shared(element::f32, Shape{batch_size, hidden_size}); // Invalid W tensor shape. - auto W = make_shared(element::f32, Shape{1 * hidden_size, input_size}); + auto W = make_shared(element::f32, Shape{1 * hidden_size, input_size}); try { - const auto lstm_cell = make_shared(X, H_t, C_t, W, R, hidden_size); + const auto lstm_cell = make_shared(X, H_t, C_t, W, R, hidden_size); FAIL() << "LSTMCell node was created with invalid data."; } catch (const NodeValidationFailure& error) @@ -79,11 +78,11 @@ TEST(type_prop, lstm_cell_invalid_input) } // Invalid R tensor shape. - W = make_shared(element::f32, Shape{gates_count * hidden_size, input_size}); - R = make_shared(element::f32, Shape{gates_count * hidden_size, 1}); + W = make_shared(element::f32, Shape{gates_count * hidden_size, input_size}); + R = make_shared(element::f32, Shape{gates_count * hidden_size, 1}); try { - const auto lstm_cell = make_shared(X, H_t, C_t, W, R, hidden_size); + const auto lstm_cell = make_shared(X, H_t, C_t, W, R, hidden_size); FAIL() << "LSTMCell node was created with invalid data."; } catch (const NodeValidationFailure& error) @@ -94,11 +93,11 @@ TEST(type_prop, lstm_cell_invalid_input) } // Invalid H_t tensor shape. - R = make_shared(element::f32, Shape{gates_count * hidden_size, hidden_size}); - H_t = make_shared(element::f32, Shape{4, hidden_size}); + R = make_shared(element::f32, Shape{gates_count * hidden_size, hidden_size}); + H_t = make_shared(element::f32, Shape{4, hidden_size}); try { - const auto lstm_cell = make_shared(X, H_t, C_t, W, R, hidden_size); + const auto lstm_cell = make_shared(X, H_t, C_t, W, R, hidden_size); FAIL() << "LSTMCell node was created with invalid data."; } catch (const NodeValidationFailure& error) @@ -109,11 +108,11 @@ TEST(type_prop, lstm_cell_invalid_input) } // Invalid C_t tensor shape. - H_t = make_shared(element::f32, Shape{batch_size, hidden_size}); - C_t = make_shared(element::f32, Shape{4, hidden_size}); + H_t = make_shared(element::f32, Shape{batch_size, hidden_size}); + C_t = make_shared(element::f32, Shape{4, hidden_size}); try { - const auto lstm_cell = make_shared(X, H_t, C_t, W, R, hidden_size); + const auto lstm_cell = make_shared(X, H_t, C_t, W, R, hidden_size); FAIL() << "LSTMCell node was created with invalid data."; } catch (const NodeValidationFailure& error) @@ -124,12 +123,12 @@ TEST(type_prop, lstm_cell_invalid_input) } // Invalid B tensor shape. - C_t = make_shared(element::f32, Shape{batch_size, hidden_size}); - auto B = make_shared(element::f32, Shape{2 * gates_count * hidden_size}); - auto P = make_shared(element::f32, Shape{3 * hidden_size}); + C_t = make_shared(element::f32, Shape{batch_size, hidden_size}); + auto B = make_shared(element::f32, Shape{2 * gates_count * hidden_size}); + auto P = make_shared(element::f32, Shape{3 * hidden_size}); try { - const auto lstm_cell = make_shared(X, H_t, C_t, W, R, B, P, hidden_size); + const auto lstm_cell = make_shared(X, H_t, C_t, W, R, B, hidden_size); FAIL() << "LSTMCell node was created with invalid data."; } catch (const NodeValidationFailure& error) @@ -137,20 +136,6 @@ TEST(type_prop, lstm_cell_invalid_input) EXPECT_HAS_SUBSTRING(error.what(), std::string("Parameter hidden_size mistmatched in B input.")); } - - // Invalid P tensor shape. - B = make_shared(element::f32, Shape{gates_count * hidden_size}); - P = make_shared(element::f32, Shape{hidden_size}); - try - { - const auto lstm_cell = make_shared(X, H_t, C_t, W, R, B, P, hidden_size); - FAIL() << "LSTMCell node was created with invalid data."; - } - catch (const NodeValidationFailure& error) - { - EXPECT_HAS_SUBSTRING(error.what(), - std::string("Parameter hidden_size mistmatched in P input.")); - } } TEST(type_prop, lstm_cell_dynamic_batch_size) @@ -160,17 +145,18 @@ TEST(type_prop, lstm_cell_dynamic_batch_size) const size_t hidden_size = 3; const size_t gates_count = 4; - const auto X = make_shared(element::f32, PartialShape{batch_size, input_size}); - const auto W = make_shared(element::f32, - PartialShape{gates_count * hidden_size, input_size}); - const auto R = make_shared(element::f32, - PartialShape{gates_count * hidden_size, hidden_size}); + const auto X = + make_shared(element::f32, PartialShape{batch_size, input_size}); + const auto W = make_shared( + element::f32, PartialShape{gates_count * hidden_size, input_size}); + const auto R = make_shared( + element::f32, PartialShape{gates_count * hidden_size, hidden_size}); const auto H_t = - make_shared(element::f32, PartialShape{batch_size, hidden_size}); + make_shared(element::f32, PartialShape{batch_size, hidden_size}); const auto C_t = - make_shared(element::f32, PartialShape{batch_size, hidden_size}); + make_shared(element::f32, PartialShape{batch_size, hidden_size}); - const auto lstm_cell = make_shared(X, H_t, C_t, W, R, hidden_size); + const auto lstm_cell = make_shared(X, H_t, C_t, W, R, hidden_size); EXPECT_EQ(lstm_cell->get_output_partial_shape(0), (PartialShape{batch_size, hidden_size})); EXPECT_EQ(lstm_cell->get_output_partial_shape(1), (PartialShape{batch_size, hidden_size})); @@ -185,17 +171,18 @@ TEST(type_prop, lstm_cell_dynamic_hidden_size) const auto hidden_size = Dimension::dynamic(); const size_t gates_count = 4; - const auto X = make_shared(element::f32, PartialShape{batch_size, input_size}); - const auto W = make_shared(element::f32, - PartialShape{hidden_size * gates_count, input_size}); - const auto R = make_shared(element::f32, - PartialShape{hidden_size * gates_count, hidden_size}); + const auto X = + make_shared(element::f32, PartialShape{batch_size, input_size}); + const auto W = make_shared( + element::f32, PartialShape{hidden_size * gates_count, input_size}); + const auto R = make_shared( + element::f32, PartialShape{hidden_size * gates_count, hidden_size}); const auto H_t = - make_shared(element::f32, PartialShape{batch_size, hidden_size}); + make_shared(element::f32, PartialShape{batch_size, hidden_size}); const auto C_t = - make_shared(element::f32, PartialShape{batch_size, hidden_size}); + make_shared(element::f32, PartialShape{batch_size, hidden_size}); - const auto lstm_cell = make_shared(X, H_t, C_t, W, R, 3); + const auto lstm_cell = make_shared(X, H_t, C_t, W, R, 3); EXPECT_EQ(lstm_cell->get_output_partial_shape(0), (PartialShape{batch_size, hidden_size})); EXPECT_EQ(lstm_cell->get_output_partial_shape(1), (PartialShape{batch_size, hidden_size})); @@ -210,17 +197,18 @@ TEST(type_prop, lstm_cell_dynamic_inputs) const auto hidden_size = Dimension::dynamic(); const size_t gates_count = 4; - const auto X = make_shared(element::f32, PartialShape{batch_size, input_size}); - const auto W = make_shared(element::f32, - PartialShape{hidden_size * gates_count, input_size}); - const auto R = make_shared(element::f32, - PartialShape{hidden_size * gates_count, hidden_size}); + const auto X = + make_shared(element::f32, PartialShape{batch_size, input_size}); + const auto W = make_shared( + element::f32, PartialShape{hidden_size * gates_count, input_size}); + const auto R = make_shared( + element::f32, PartialShape{hidden_size * gates_count, hidden_size}); const auto H_t = - make_shared(element::f32, PartialShape{batch_size, hidden_size}); + make_shared(element::f32, PartialShape{batch_size, hidden_size}); const auto C_t = - make_shared(element::f32, PartialShape{batch_size, hidden_size}); + make_shared(element::f32, PartialShape{batch_size, hidden_size}); - const auto lstm_cell = make_shared(X, H_t, C_t, W, R, 3); + const auto lstm_cell = make_shared(X, H_t, C_t, W, R, 3); EXPECT_EQ(lstm_cell->get_output_partial_shape(0), (PartialShape{batch_size, hidden_size})); EXPECT_EQ(lstm_cell->get_output_partial_shape(1), (PartialShape{batch_size, hidden_size})); @@ -235,62 +223,54 @@ TEST(type_prop, lstm_cell_invalid_input_rank0) const size_t hidden_size = 3; const size_t gates_count = 4; - auto X = make_shared(element::f32, PartialShape{batch_size, input_size}); - auto W = make_shared(element::f32, - PartialShape{gates_count * hidden_size, input_size}); - auto R = make_shared(element::f32, - PartialShape{gates_count * hidden_size, hidden_size}); - auto H_t = make_shared(element::f32, PartialShape{batch_size, hidden_size}); - auto C_t = make_shared(element::f32, PartialShape{batch_size, hidden_size}); + auto X = make_shared(element::f32, PartialShape{batch_size, input_size}); + auto W = make_shared(element::f32, + PartialShape{gates_count * hidden_size, input_size}); + auto R = make_shared(element::f32, + PartialShape{gates_count * hidden_size, hidden_size}); + auto H_t = make_shared(element::f32, PartialShape{batch_size, hidden_size}); + auto C_t = make_shared(element::f32, PartialShape{batch_size, hidden_size}); // Invalid rank0 for W tensor. - W = make_shared(element::f32, PartialShape{}); - ASSERT_THROW(make_shared(X, H_t, C_t, W, R, hidden_size), + W = make_shared(element::f32, PartialShape{}); + ASSERT_THROW(make_shared(X, H_t, C_t, W, R, hidden_size), ngraph::NodeValidationFailure) << "LSTMCell node was created with invalid data."; // Invalid rank0 for X tensor. - W = make_shared(element::f32, - PartialShape{gates_count * hidden_size, input_size}); - X = make_shared(element::f32, PartialShape{}); - ASSERT_THROW(make_shared(X, H_t, C_t, W, R, hidden_size), + W = make_shared(element::f32, + PartialShape{gates_count * hidden_size, input_size}); + X = make_shared(element::f32, PartialShape{}); + ASSERT_THROW(make_shared(X, H_t, C_t, W, R, hidden_size), ngraph::NodeValidationFailure) << "LSTMCell node was created with invalid data."; // Invalid rank0 for H_t tensor. - X = make_shared(element::f32, PartialShape{batch_size, input_size}); - H_t = make_shared(element::f32, PartialShape{}); - ASSERT_THROW(make_shared(X, H_t, C_t, W, R, hidden_size), + X = make_shared(element::f32, PartialShape{batch_size, input_size}); + H_t = make_shared(element::f32, PartialShape{}); + ASSERT_THROW(make_shared(X, H_t, C_t, W, R, hidden_size), ngraph::NodeValidationFailure) << "LSTMCell node was created with invalid data."; // Invalid rank0 for C_t tensor. - H_t = make_shared(element::f32, PartialShape{batch_size, hidden_size}); - C_t = make_shared(element::f32, PartialShape{}); - ASSERT_THROW(make_shared(X, H_t, C_t, W, R, hidden_size), + H_t = make_shared(element::f32, PartialShape{batch_size, hidden_size}); + C_t = make_shared(element::f32, PartialShape{}); + ASSERT_THROW(make_shared(X, H_t, C_t, W, R, hidden_size), ngraph::NodeValidationFailure) << "LSTMCell node was created with invalid data."; // Invalid rank0 for R tensor. - C_t = make_shared(element::f32, PartialShape{batch_size, hidden_size}); - R = make_shared(element::f32, PartialShape{}); - ASSERT_THROW(make_shared(X, H_t, C_t, W, R, hidden_size), + C_t = make_shared(element::f32, PartialShape{batch_size, hidden_size}); + R = make_shared(element::f32, PartialShape{}); + ASSERT_THROW(make_shared(X, H_t, C_t, W, R, hidden_size), ngraph::NodeValidationFailure) << "LSTMCell node was created with invalid data."; // Invalid rank0 for B tensor. - R = make_shared(element::f32, - PartialShape{gates_count * hidden_size, hidden_size}); - auto B = make_shared(element::f32, PartialShape{}); - auto P = make_shared(element::f32, PartialShape{3 * hidden_size}); - ASSERT_THROW(make_shared(X, H_t, C_t, W, R, B, P, hidden_size), - ngraph::NodeValidationFailure) - << "LSTMCell node was created with invalid data."; - - // Invalid rank0 for P tensor. - B = make_shared(element::f32, PartialShape{gates_count * hidden_size}); - P = make_shared(element::f32, PartialShape{}); - ASSERT_THROW(make_shared(X, H_t, C_t, W, R, B, P, hidden_size), + R = make_shared(element::f32, + PartialShape{gates_count * hidden_size, hidden_size}); + auto B = make_shared(element::f32, PartialShape{}); + ASSERT_THROW(make_shared(X, H_t, C_t, W, R, B, hidden_size), ngraph::NodeValidationFailure) << "LSTMCell node was created with invalid data."; } @@ -302,62 +282,54 @@ TEST(type_prop, lstm_cell_invalid_input_dynamic_rank) const size_t hidden_size = 3; const size_t gates_count = 4; - auto X = make_shared(element::f32, PartialShape{batch_size, input_size}); - auto W = make_shared(element::f32, - PartialShape{gates_count * hidden_size, input_size}); - auto R = make_shared(element::f32, - PartialShape{gates_count * hidden_size, hidden_size}); - auto H_t = make_shared(element::f32, PartialShape{batch_size, hidden_size}); - auto C_t = make_shared(element::f32, PartialShape{batch_size, hidden_size}); + auto X = make_shared(element::f32, PartialShape{batch_size, input_size}); + auto W = make_shared(element::f32, + PartialShape{gates_count * hidden_size, input_size}); + auto R = make_shared(element::f32, + PartialShape{gates_count * hidden_size, hidden_size}); + auto H_t = make_shared(element::f32, PartialShape{batch_size, hidden_size}); + auto C_t = make_shared(element::f32, PartialShape{batch_size, hidden_size}); // Invalid dynamic rank for W tensor. - W = make_shared(element::f32, PartialShape::dynamic(Rank::dynamic())); - ASSERT_THROW(make_shared(X, H_t, C_t, W, R, hidden_size), + W = make_shared(element::f32, PartialShape::dynamic(Rank::dynamic())); + ASSERT_THROW(make_shared(X, H_t, C_t, W, R, hidden_size), ngraph::NodeValidationFailure) << "LSTMCell node was created with invalid data."; // Invalid dynamic rank for X tensor. - W = make_shared(element::f32, - PartialShape{gates_count * hidden_size, input_size}); - X = make_shared(element::f32, PartialShape::dynamic(Rank::dynamic())); - ASSERT_THROW(make_shared(X, H_t, C_t, W, R, hidden_size), + W = make_shared(element::f32, + PartialShape{gates_count * hidden_size, input_size}); + X = make_shared(element::f32, PartialShape::dynamic(Rank::dynamic())); + ASSERT_THROW(make_shared(X, H_t, C_t, W, R, hidden_size), ngraph::NodeValidationFailure) << "LSTMCell node was created with invalid data."; // Invalid dynamic rank for H_t tensor. - X = make_shared(element::f32, PartialShape{batch_size, input_size}); - H_t = make_shared(element::f32, PartialShape::dynamic(Rank::dynamic())); - ASSERT_THROW(make_shared(X, H_t, C_t, W, R, hidden_size), + X = make_shared(element::f32, PartialShape{batch_size, input_size}); + H_t = make_shared(element::f32, PartialShape::dynamic(Rank::dynamic())); + ASSERT_THROW(make_shared(X, H_t, C_t, W, R, hidden_size), ngraph::NodeValidationFailure) << "LSTMCell node was created with invalid data."; // Invalid dynamic rank for C_t tensor. - H_t = make_shared(element::f32, PartialShape{batch_size, hidden_size}); - C_t = make_shared(element::f32, PartialShape::dynamic(Rank::dynamic())); - ASSERT_THROW(make_shared(X, H_t, C_t, W, R, hidden_size), + H_t = make_shared(element::f32, PartialShape{batch_size, hidden_size}); + C_t = make_shared(element::f32, PartialShape::dynamic(Rank::dynamic())); + ASSERT_THROW(make_shared(X, H_t, C_t, W, R, hidden_size), ngraph::NodeValidationFailure) << "LSTMCell node was created with invalid data."; // Invalid dynamic rank for R tensor. - C_t = make_shared(element::f32, PartialShape{batch_size, hidden_size}); - R = make_shared(element::f32, PartialShape::dynamic(Rank::dynamic())); - ASSERT_THROW(make_shared(X, H_t, C_t, W, R, hidden_size), + C_t = make_shared(element::f32, PartialShape{batch_size, hidden_size}); + R = make_shared(element::f32, PartialShape::dynamic(Rank::dynamic())); + ASSERT_THROW(make_shared(X, H_t, C_t, W, R, hidden_size), ngraph::NodeValidationFailure) << "LSTMCell node was created with invalid data."; // Invalid dynamic rank for B tensor. - R = make_shared(element::f32, - PartialShape{gates_count * hidden_size, hidden_size}); - auto B = make_shared(element::f32, PartialShape::dynamic(Rank::dynamic())); - auto P = make_shared(element::f32, PartialShape{3 * hidden_size}); - ASSERT_THROW(make_shared(X, H_t, C_t, W, R, B, P, hidden_size), - ngraph::NodeValidationFailure) - << "LSTMCell node was created with invalid data."; - - // Invalid dynamic rank for P tensor. - B = make_shared(element::f32, PartialShape{gates_count * hidden_size}); - P = make_shared(element::f32, PartialShape::dynamic(Rank::dynamic())); - ASSERT_THROW(make_shared(X, H_t, C_t, W, R, B, P, hidden_size), + R = make_shared(element::f32, + PartialShape{gates_count * hidden_size, hidden_size}); + auto B = make_shared(element::f32, PartialShape::dynamic(Rank::dynamic())); + ASSERT_THROW(make_shared(X, H_t, C_t, W, R, B, hidden_size), ngraph::NodeValidationFailure) << "LSTMCell node was created with invalid data."; } diff --git a/ngraph/test/type_prop/lstm_sequence.cpp b/ngraph/test/type_prop/lstm_sequence.cpp index 2515f5d..491712a 100644 --- a/ngraph/test/type_prop/lstm_sequence.cpp +++ b/ngraph/test/type_prop/lstm_sequence.cpp @@ -16,6 +16,7 @@ #include "gtest/gtest.h" #include "ngraph/ngraph.hpp" +#include "ngraph/opsets/opset4.hpp" #include "util/type_prop.hpp" // suppress FusedOp deprecation warnings @@ -40,7 +41,7 @@ struct recurrent_sequence_parameters // // Create and initialize default input test tensors. // -shared_ptr +shared_ptr lstm_seq_tensor_initialization(const recurrent_sequence_parameters& param) { auto batch_size = param.batch_size; @@ -50,20 +51,21 @@ shared_ptr auto hidden_size = param.hidden_size; auto et = param.et; - const auto X = make_shared(et, PartialShape{batch_size, seq_length, input_size}); + const auto X = + make_shared(et, PartialShape{batch_size, seq_length, input_size}); const auto initial_hidden_state = - make_shared(et, PartialShape{batch_size, num_directions, hidden_size}); + make_shared(et, PartialShape{batch_size, num_directions, hidden_size}); const auto initial_cell_state = - make_shared(et, PartialShape{batch_size, num_directions, hidden_size}); - const auto sequence_lengths = make_shared(et, PartialShape{batch_size}); - const auto W = - make_shared(et, PartialShape{num_directions, hidden_size * 4, input_size}); - const auto R = - make_shared(et, PartialShape{num_directions, hidden_size * 4, hidden_size}); - const auto B = make_shared(et, PartialShape{num_directions, hidden_size * 4}); - const auto P = make_shared(et, PartialShape{num_directions, hidden_size * 3}); + make_shared(et, PartialShape{batch_size, num_directions, hidden_size}); + const auto sequence_lengths = make_shared(et, PartialShape{batch_size}); + const auto W = make_shared( + et, PartialShape{num_directions, hidden_size * 4, input_size}); + const auto R = make_shared( + et, PartialShape{num_directions, hidden_size * 4, hidden_size}); + const auto B = + make_shared(et, PartialShape{num_directions, hidden_size * 4}); - const auto lstm_sequence = make_shared(); + const auto lstm_sequence = make_shared(); lstm_sequence->set_argument(0, X); lstm_sequence->set_argument(1, initial_hidden_state); @@ -72,7 +74,6 @@ shared_ptr lstm_sequence->set_argument(4, W); lstm_sequence->set_argument(5, R); lstm_sequence->set_argument(6, B); - lstm_sequence->set_argument(7, P); return lstm_sequence; } @@ -86,40 +87,39 @@ TEST(type_prop, lstm_sequence_forward) const size_t hidden_size = 128; const auto X = - make_shared(element::f32, Shape{batch_size, seq_length, input_size}); - const auto initial_hidden_state = - make_shared(element::f32, Shape{batch_size, num_directions, hidden_size}); - const auto initial_cell_state = - make_shared(element::f32, Shape{batch_size, num_directions, hidden_size}); - const auto sequence_lengths = make_shared(element::i32, Shape{batch_size}); - const auto W = make_shared(element::f32, - Shape{num_directions, 4 * hidden_size, input_size}); - const auto R = make_shared(element::f32, - Shape{num_directions, 4 * hidden_size, hidden_size}); - const auto B = make_shared(element::f32, Shape{num_directions, 4 * hidden_size}); + make_shared(element::f32, Shape{batch_size, seq_length, input_size}); + const auto initial_hidden_state = make_shared( + element::f32, Shape{batch_size, num_directions, hidden_size}); + const auto initial_cell_state = make_shared( + element::f32, Shape{batch_size, num_directions, hidden_size}); + const auto sequence_lengths = make_shared(element::i32, Shape{batch_size}); + const auto W = make_shared( + element::f32, Shape{num_directions, 4 * hidden_size, input_size}); + const auto R = make_shared( + element::f32, Shape{num_directions, 4 * hidden_size, hidden_size}); + const auto B = + make_shared(element::f32, Shape{num_directions, 4 * hidden_size}); const auto lstm_direction = op::RecurrentSequenceDirection::FORWARD; - const auto lstm_sequence = make_shared(X, - initial_hidden_state, - initial_cell_state, - sequence_lengths, - W, - R, - B, - hidden_size, - lstm_direction); + const auto lstm_sequence = make_shared(X, + initial_hidden_state, + initial_cell_state, + sequence_lengths, + W, + R, + B, + hidden_size, + lstm_direction); EXPECT_EQ(lstm_sequence->get_hidden_size(), hidden_size); EXPECT_EQ(lstm_sequence->get_direction(), op::RecurrentSequenceDirection::FORWARD); - EXPECT_EQ(lstm_sequence->get_weights_format(), op::LSTMWeightsFormat::IFCO); EXPECT_TRUE(lstm_sequence->get_activations_alpha().empty()); EXPECT_TRUE(lstm_sequence->get_activations_beta().empty()); EXPECT_EQ(lstm_sequence->get_activations()[0], "sigmoid"); EXPECT_EQ(lstm_sequence->get_activations()[1], "tanh"); EXPECT_EQ(lstm_sequence->get_activations()[2], "tanh"); - EXPECT_EQ(lstm_sequence->get_clip_threshold(), 0.f); - EXPECT_FALSE(lstm_sequence->get_input_forget()); + EXPECT_EQ(lstm_sequence->get_clip(), 0.f); EXPECT_EQ(lstm_sequence->get_output_element_type(0), element::f32); EXPECT_EQ(lstm_sequence->get_output_shape(0), (Shape{batch_size, num_directions, seq_length, hidden_size})); @@ -138,47 +138,44 @@ TEST(type_prop, lstm_sequence_bidirectional) const size_t hidden_size = 256; const auto X = - make_shared(element::f32, Shape{batch_size, seq_length, input_size}); - const auto initial_hidden_state = - make_shared(element::f32, Shape{batch_size, num_directions, hidden_size}); - const auto initial_cell_state = - make_shared(element::f32, Shape{batch_size, num_directions, hidden_size}); - const auto sequence_lengths = make_shared(element::i32, Shape{batch_size}); - const auto W = make_shared(element::f32, - Shape{num_directions, 4 * hidden_size, input_size}); - const auto R = make_shared(element::f32, - Shape{num_directions, 4 * hidden_size, hidden_size}); - const auto B = make_shared(element::f32, Shape{num_directions, 4 * hidden_size}); - - const auto weights_format = op::LSTMWeightsFormat::FICO; - const auto lstm_direction = op::LSTMSequence::direction::BIDIRECTIONAL; + make_shared(element::f32, Shape{batch_size, seq_length, input_size}); + const auto initial_hidden_state = make_shared( + element::f32, Shape{batch_size, num_directions, hidden_size}); + const auto initial_cell_state = make_shared( + element::f32, Shape{batch_size, num_directions, hidden_size}); + const auto sequence_lengths = make_shared(element::i32, Shape{batch_size}); + const auto W = make_shared( + element::f32, Shape{num_directions, 4 * hidden_size, input_size}); + const auto R = make_shared( + element::f32, Shape{num_directions, 4 * hidden_size, hidden_size}); + const auto B = + make_shared(element::f32, Shape{num_directions, 4 * hidden_size}); + + const auto lstm_direction = op::v1::LSTMSequence::direction::BIDIRECTIONAL; const std::vector activations_alpha = {2.7, 7.0, 32.367}; const std::vector activations_beta = {0.0, 5.49, 6.0}; const std::vector activations = {"tanh", "sigmoid", "sigmoid"}; - const auto lstm_sequence = make_shared(X, - initial_hidden_state, - initial_cell_state, - sequence_lengths, - W, - R, - B, - hidden_size, - lstm_direction, - weights_format, - activations_alpha, - activations_beta, - activations); + const auto lstm_sequence = make_shared(X, + initial_hidden_state, + initial_cell_state, + sequence_lengths, + W, + R, + B, + hidden_size, + lstm_direction, + activations_alpha, + activations_beta, + activations); EXPECT_EQ(lstm_sequence->get_hidden_size(), hidden_size); - EXPECT_EQ(lstm_sequence->get_direction(), op::LSTMSequence::direction::BIDIRECTIONAL); - EXPECT_EQ(lstm_sequence->get_weights_format(), op::LSTMWeightsFormat::FICO); + EXPECT_EQ(lstm_sequence->get_direction(), op::v1::LSTMSequence::direction::BIDIRECTIONAL); EXPECT_EQ(lstm_sequence->get_activations_alpha(), activations_alpha); EXPECT_EQ(lstm_sequence->get_activations_beta(), activations_beta); EXPECT_EQ(lstm_sequence->get_activations()[0], "tanh"); EXPECT_EQ(lstm_sequence->get_activations()[1], "sigmoid"); EXPECT_EQ(lstm_sequence->get_activations()[2], "sigmoid"); - EXPECT_EQ(lstm_sequence->get_clip_threshold(), 0.f); - EXPECT_FALSE(lstm_sequence->get_input_forget()); + EXPECT_EQ(lstm_sequence->get_clip(), 0.f); EXPECT_EQ(lstm_sequence->get_output_element_type(0), element::f32); EXPECT_EQ(lstm_sequence->get_output_shape(0), (Shape{batch_size, num_directions, seq_length, hidden_size})); @@ -330,15 +327,14 @@ TEST(type_prop, lstm_sequence_invalid_input_dimension) param.et = element::f32; auto lstm_sequence = lstm_seq_tensor_initialization(param); - auto invalid_rank0_tensor = make_shared(param.et, PartialShape{}); + auto invalid_rank0_tensor = make_shared(param.et, PartialShape{}); // Validate invalid rank0 tensor for all inputs: X, initial_hidden_state, initial_cell_state W, - // R, B and P + // R, B for (auto i = 0; i < lstm_sequence->get_input_size(); i++) { lstm_sequence = lstm_seq_tensor_initialization(param); lstm_sequence->set_argument(i, invalid_rank0_tensor); - ASSERT_THROW(lstm_sequence->validate_and_infer_types(), ngraph::CheckFailure) << "LSTMSequence node was created with invalid data."; } @@ -357,15 +353,14 @@ TEST(type_prop, lstm_sequence_invalid_input_dynamic_rank) auto lstm_sequence = lstm_seq_tensor_initialization(param); auto invalid_dynamic_tensor = - make_shared(param.et, PartialShape::dynamic(Rank::dynamic())); + make_shared(param.et, PartialShape::dynamic(Rank::dynamic())); // Validate invalid dynamic tensor for all inputs: X, initial_hidden_state, initial_cell_state - // W, R, B and P + // W, R, B for (auto i = 0; i < lstm_sequence->get_input_size(); i++) { lstm_sequence = lstm_seq_tensor_initialization(param); lstm_sequence->set_argument(i, invalid_dynamic_tensor); - ASSERT_THROW(lstm_sequence->validate_and_infer_types(), ngraph::CheckFailure) << "LSTMSequence node was created with invalid data."; } diff --git a/ngraph/test/type_prop/rnn_cell.cpp b/ngraph/test/type_prop/rnn_cell.cpp index c2a1e63..5db9e15 100644 --- a/ngraph/test/type_prop/rnn_cell.cpp +++ b/ngraph/test/type_prop/rnn_cell.cpp @@ -16,6 +16,7 @@ #include "gtest/gtest.h" #include "ngraph/ngraph.hpp" +#include "ngraph/opsets/opset4.hpp" #include "util/type_prop.hpp" using namespace std; @@ -27,12 +28,12 @@ TEST(type_prop, rnn_cell) const size_t input_size = 3; const size_t hidden_size = 3; - const auto X = make_shared(element::f32, Shape{batch_size, input_size}); - const auto H_t = make_shared(element::f32, Shape{batch_size, hidden_size}); - const auto W = make_shared(element::f32, Shape{hidden_size, input_size}); - const auto R = make_shared(element::f32, Shape{hidden_size, hidden_size}); + const auto X = make_shared(element::f32, Shape{batch_size, input_size}); + const auto H_t = make_shared(element::f32, Shape{batch_size, hidden_size}); + const auto W = make_shared(element::f32, Shape{hidden_size, input_size}); + const auto R = make_shared(element::f32, Shape{hidden_size, hidden_size}); - const auto rnn_cell = make_shared(X, H_t, W, R, hidden_size); + const auto rnn_cell = make_shared(X, H_t, W, R, hidden_size); EXPECT_EQ(rnn_cell->get_output_element_type(0), element::f32); EXPECT_EQ(rnn_cell->get_output_shape(0), (Shape{batch_size, hidden_size})); } @@ -43,15 +44,15 @@ TEST(type_prop, rnn_cell_invalid_input) const size_t input_size = 3; const size_t hidden_size = 3; - auto X = make_shared(element::f32, Shape{batch_size, input_size}); - auto R = make_shared(element::f32, Shape{hidden_size, hidden_size}); - auto H_t = make_shared(element::f32, Shape{batch_size, hidden_size}); + auto X = make_shared(element::f32, Shape{batch_size, input_size}); + auto R = make_shared(element::f32, Shape{hidden_size, hidden_size}); + auto H_t = make_shared(element::f32, Shape{batch_size, hidden_size}); // Invalid W tensor shape. - auto W = make_shared(element::f32, Shape{2 * hidden_size, input_size}); + auto W = make_shared(element::f32, Shape{2 * hidden_size, input_size}); try { - const auto rnn_cell = make_shared(X, H_t, W, R, hidden_size); + const auto rnn_cell = make_shared(X, H_t, W, R, hidden_size); FAIL() << "RNNCell node was created with invalid data."; } catch (const NodeValidationFailure& error) @@ -61,11 +62,11 @@ TEST(type_prop, rnn_cell_invalid_input) } // Invalid R tensor shape. - W = make_shared(element::f32, Shape{hidden_size, input_size}); - R = make_shared(element::f32, Shape{hidden_size, 1}); + W = make_shared(element::f32, Shape{hidden_size, input_size}); + R = make_shared(element::f32, Shape{hidden_size, 1}); try { - const auto rnn_cell = make_shared(X, H_t, W, R, hidden_size); + const auto rnn_cell = make_shared(X, H_t, W, R, hidden_size); FAIL() << "RNNCell node was created with invalid data."; } catch (const NodeValidationFailure& error) @@ -77,11 +78,11 @@ TEST(type_prop, rnn_cell_invalid_input) } // Invalid H_t tensor shape. - R = make_shared(element::f32, Shape{hidden_size, hidden_size}); - H_t = make_shared(element::f32, Shape{4, hidden_size}); + R = make_shared(element::f32, Shape{hidden_size, hidden_size}); + H_t = make_shared(element::f32, Shape{4, hidden_size}); try { - const auto rnn_cell = make_shared(X, H_t, W, R, hidden_size); + const auto rnn_cell = make_shared(X, H_t, W, R, hidden_size); FAIL() << "RNNCell node was created with invalid data."; } catch (const NodeValidationFailure& error) @@ -92,11 +93,11 @@ TEST(type_prop, rnn_cell_invalid_input) } // Invalid B tensor shape. - H_t = make_shared(element::f32, Shape{batch_size, hidden_size}); - auto B = make_shared(element::f32, Shape{2 * hidden_size}); + H_t = make_shared(element::f32, Shape{batch_size, hidden_size}); + auto B = make_shared(element::f32, Shape{2 * hidden_size}); try { - const auto rnn_cell = make_shared(X, H_t, W, R, B, hidden_size); + const auto rnn_cell = make_shared(X, H_t, W, R, B, hidden_size); FAIL() << "RNNCell node was created with invalid data."; } catch (const NodeValidationFailure& error) @@ -112,13 +113,16 @@ TEST(type_prop, rnn_cell_dynamic_batch_size) const size_t input_size = 3; const size_t hidden_size = 3; - const auto X = make_shared(element::f32, PartialShape{batch_size, input_size}); + const auto X = + make_shared(element::f32, PartialShape{batch_size, input_size}); const auto H_t = - make_shared(element::f32, PartialShape{batch_size, hidden_size}); - const auto W = make_shared(element::f32, PartialShape{hidden_size, input_size}); - const auto R = make_shared(element::f32, PartialShape{hidden_size, hidden_size}); + make_shared(element::f32, PartialShape{batch_size, hidden_size}); + const auto W = + make_shared(element::f32, PartialShape{hidden_size, input_size}); + const auto R = + make_shared(element::f32, PartialShape{hidden_size, hidden_size}); - const auto rnn_cell = make_shared(X, H_t, W, R, hidden_size); + const auto rnn_cell = make_shared(X, H_t, W, R, hidden_size); EXPECT_EQ(rnn_cell->get_output_element_type(0), element::f32); EXPECT_EQ(rnn_cell->get_output_partial_shape(0), (PartialShape{batch_size, hidden_size})); } @@ -129,13 +133,16 @@ TEST(type_prop, rnn_cell_dynamic_hidden_size) const size_t input_size = 3; const auto hidden_size = Dimension::dynamic(); - const auto X = make_shared(element::f32, PartialShape{batch_size, input_size}); + const auto X = + make_shared(element::f32, PartialShape{batch_size, input_size}); const auto H_t = - make_shared(element::f32, PartialShape{batch_size, hidden_size}); - const auto W = make_shared(element::f32, PartialShape{hidden_size, input_size}); - const auto R = make_shared(element::f32, PartialShape{hidden_size, hidden_size}); + make_shared(element::f32, PartialShape{batch_size, hidden_size}); + const auto W = + make_shared(element::f32, PartialShape{hidden_size, input_size}); + const auto R = + make_shared(element::f32, PartialShape{hidden_size, hidden_size}); - const auto rnn_cell = make_shared(X, H_t, W, R, 3); + const auto rnn_cell = make_shared(X, H_t, W, R, 3); EXPECT_EQ(rnn_cell->get_output_element_type(0), element::f32); EXPECT_EQ(rnn_cell->get_output_partial_shape(0), (PartialShape{batch_size, hidden_size})); } @@ -146,13 +153,16 @@ TEST(type_prop, rnn_cell_dynamic_inputs) const auto input_size = Dimension::dynamic(); const auto hidden_size = Dimension::dynamic(); - const auto X = make_shared(element::f32, PartialShape{batch_size, input_size}); - const auto R = make_shared(element::f32, PartialShape{hidden_size, hidden_size}); - const auto W = make_shared(element::f32, PartialShape{hidden_size, input_size}); + const auto X = + make_shared(element::f32, PartialShape{batch_size, input_size}); + const auto R = + make_shared(element::f32, PartialShape{hidden_size, hidden_size}); + const auto W = + make_shared(element::f32, PartialShape{hidden_size, input_size}); const auto H_t = - make_shared(element::f32, PartialShape{batch_size, hidden_size}); + make_shared(element::f32, PartialShape{batch_size, hidden_size}); - const auto rnn_cell = make_shared(X, H_t, W, R, 2); + const auto rnn_cell = make_shared(X, H_t, W, R, 2); EXPECT_EQ(rnn_cell->get_output_partial_shape(0), (PartialShape{batch_size, hidden_size})); EXPECT_EQ(rnn_cell->get_output_element_type(0), element::f32); @@ -164,37 +174,41 @@ TEST(type_prop, rnn_cell_invalid_input_rank0) const size_t input_size = 3; const size_t hidden_size = 3; - auto X = make_shared(element::f32, Shape{batch_size, input_size}); - auto R = make_shared(element::f32, Shape{hidden_size, hidden_size}); - auto H_t = make_shared(element::f32, Shape{batch_size, hidden_size}); + auto X = make_shared(element::f32, Shape{batch_size, input_size}); + auto R = make_shared(element::f32, Shape{hidden_size, hidden_size}); + auto H_t = make_shared(element::f32, Shape{batch_size, hidden_size}); // Invalid rank0 for W tensor. - auto W = make_shared(element::f32, PartialShape{}); - ASSERT_THROW(make_shared(X, H_t, W, R, hidden_size), ngraph::NodeValidationFailure) + auto W = make_shared(element::f32, PartialShape{}); + ASSERT_THROW(make_shared(X, H_t, W, R, hidden_size), + ngraph::NodeValidationFailure) << "RNNCell node was created with invalid data."; // Invalid rank0 for X tensor. - W = make_shared(element::f32, PartialShape{hidden_size, input_size}); - X = make_shared(element::f32, PartialShape{}); - ASSERT_THROW(make_shared(X, H_t, W, R, hidden_size), ngraph::NodeValidationFailure) + W = make_shared(element::f32, PartialShape{hidden_size, input_size}); + X = make_shared(element::f32, PartialShape{}); + ASSERT_THROW(make_shared(X, H_t, W, R, hidden_size), + ngraph::NodeValidationFailure) << "RNNCell node was created with invalid data."; // Invalid rank0 for H_t tensor. - X = make_shared(element::f32, Shape{batch_size, input_size}); - H_t = make_shared(element::f32, PartialShape{}); - ASSERT_THROW(make_shared(X, H_t, W, R, hidden_size), ngraph::NodeValidationFailure) + X = make_shared(element::f32, Shape{batch_size, input_size}); + H_t = make_shared(element::f32, PartialShape{}); + ASSERT_THROW(make_shared(X, H_t, W, R, hidden_size), + ngraph::NodeValidationFailure) << "RNNCell node was created with invalid data."; // Invalid rank0 for R tensor. - H_t = make_shared(element::f32, Shape{batch_size, hidden_size}); - R = make_shared(element::f32, PartialShape{}); - ASSERT_THROW(make_shared(X, H_t, W, R, hidden_size), ngraph::NodeValidationFailure) + H_t = make_shared(element::f32, Shape{batch_size, hidden_size}); + R = make_shared(element::f32, PartialShape{}); + ASSERT_THROW(make_shared(X, H_t, W, R, hidden_size), + ngraph::NodeValidationFailure) << "RNNCell node was created with invalid data."; // Invalid rank0 for B tensor. - R = make_shared(element::f32, PartialShape{hidden_size, hidden_size}); - auto B = make_shared(element::f32, PartialShape{}); - ASSERT_THROW(make_shared(X, H_t, W, R, B, hidden_size), + R = make_shared(element::f32, PartialShape{hidden_size, hidden_size}); + auto B = make_shared(element::f32, PartialShape{}); + ASSERT_THROW(make_shared(X, H_t, W, R, B, hidden_size), ngraph::NodeValidationFailure) << "RNNCell node was created with invalid data."; } @@ -205,37 +219,41 @@ TEST(type_prop, rnn_cell_invalid_input_dynamic_rank) const size_t input_size = 3; const size_t hidden_size = 3; - auto X = make_shared(element::f32, Shape{batch_size, input_size}); - auto R = make_shared(element::f32, Shape{hidden_size, hidden_size}); - auto H_t = make_shared(element::f32, Shape{batch_size, hidden_size}); + auto X = make_shared(element::f32, Shape{batch_size, input_size}); + auto R = make_shared(element::f32, Shape{hidden_size, hidden_size}); + auto H_t = make_shared(element::f32, Shape{batch_size, hidden_size}); // Invalid dynamic rank for W tensor. - auto W = make_shared(element::f32, PartialShape::dynamic(Rank::dynamic())); - ASSERT_THROW(make_shared(X, H_t, W, R, hidden_size), ngraph::NodeValidationFailure) + auto W = make_shared(element::f32, PartialShape::dynamic(Rank::dynamic())); + ASSERT_THROW(make_shared(X, H_t, W, R, hidden_size), + ngraph::NodeValidationFailure) << "RNNCell node was created with invalid data."; // Invalid dynamic rank for X tensor. - W = make_shared(element::f32, PartialShape{hidden_size, input_size}); - X = make_shared(element::f32, PartialShape::dynamic(Rank::dynamic())); - ASSERT_THROW(make_shared(X, H_t, W, R, hidden_size), ngraph::NodeValidationFailure) + W = make_shared(element::f32, PartialShape{hidden_size, input_size}); + X = make_shared(element::f32, PartialShape::dynamic(Rank::dynamic())); + ASSERT_THROW(make_shared(X, H_t, W, R, hidden_size), + ngraph::NodeValidationFailure) << "RNNCell node was created with invalid data."; // Invalid dynamic rank for H_t tensor. - X = make_shared(element::f32, Shape{batch_size, input_size}); - H_t = make_shared(element::f32, PartialShape::dynamic(Rank::dynamic())); - ASSERT_THROW(make_shared(X, H_t, W, R, hidden_size), ngraph::NodeValidationFailure) + X = make_shared(element::f32, Shape{batch_size, input_size}); + H_t = make_shared(element::f32, PartialShape::dynamic(Rank::dynamic())); + ASSERT_THROW(make_shared(X, H_t, W, R, hidden_size), + ngraph::NodeValidationFailure) << "RNNCell node was created with invalid data."; // Invalid dynamic rank for R tensor. - H_t = make_shared(element::f32, Shape{batch_size, hidden_size}); - R = make_shared(element::f32, PartialShape::dynamic(Rank::dynamic())); - ASSERT_THROW(make_shared(X, H_t, W, R, hidden_size), ngraph::NodeValidationFailure) + H_t = make_shared(element::f32, Shape{batch_size, hidden_size}); + R = make_shared(element::f32, PartialShape::dynamic(Rank::dynamic())); + ASSERT_THROW(make_shared(X, H_t, W, R, hidden_size), + ngraph::NodeValidationFailure) << "RNNCell node was created with invalid data."; // Invalid dynamic rank for B tensor. - R = make_shared(element::f32, PartialShape{hidden_size, hidden_size}); - auto B = make_shared(element::f32, PartialShape::dynamic(Rank::dynamic())); - ASSERT_THROW(make_shared(X, H_t, W, R, B, hidden_size), + R = make_shared(element::f32, PartialShape{hidden_size, hidden_size}); + auto B = make_shared(element::f32, PartialShape::dynamic(Rank::dynamic())); + ASSERT_THROW(make_shared(X, H_t, W, R, B, hidden_size), ngraph::NodeValidationFailure) << "RNNCell node was created with invalid data."; } -- 2.7.4