From 10b18a00c6e5c05ed7b8b68ecf41ffb9e469ee3a Mon Sep 17 00:00:00 2001 From: Roman Kazantsev Date: Mon, 16 Nov 2020 09:50:41 +0300 Subject: [PATCH] Improve reshapeability of models with eltwise nodes influencing shapes (#2767) * Fix ElementwiseInputReshape transformation Reshape node always needs to be inserted in order to preserve ShapeOf nodes (reshapability of a model) that can potentially be above elementwise node. Refactor EltwiseInputReshape_test and EltwiseInputNormalization_test since the logic of maintaining reshape for eltwise has been changed. Signed-off-by: Roman Kazantsev * Merge EltwiseInputNormalization and EltwiseInputReshape transformations Signed-off-by: Roman Kazantsev * Remove Unsqueeze from Fused_op Signed-off-by: Roman Kazantsev * Fix code after code review #1 Signed-off-by: Roman Kazantsev * Fix code after review #2 Signed-off-by: Roman Kazantsev * Fix code review #4 Signed-off-by: Roman Kazantsev * Perform full normalization based on shapes of all inputs to eltwise Signed-off-by: Roman Kazantsev * Refactor much to avoid old API and edges with unsqueeze_dims attribute Signed-off-by: Roman Kazantsev * Fix code after review Signed-off-by: Roman Kazantsev --- .../legacy_api/src/ie_cnn_layer_builder_ngraph.cpp | 4 +- .../convert_nms_to_nms_ie.cpp | 6 +- .../src/subgraph_tests/matmul_squeeze_add.cpp | 2 +- .../shared/src/subgraph_tests/memory_LSTMCell.cpp | 8 +- .../src/subgraph_tests/multiple_LSTMCell.cpp | 12 +- .../unsqueeze_function.cpp | 2 +- model-optimizer/automation/package_BOM.txt | 1 - .../extensions/middle/EltwiseInputNormalization.py | 48 -- .../middle/EltwiseInputNormalization_test.py | 202 -------- .../extensions/middle/EltwiseInputReshape.py | 145 +++--- .../extensions/middle/EltwiseInputReshape_test.py | 511 ++++++++++++++++----- model-optimizer/extensions/middle/fusings.py | 6 +- .../mo/front/common/partial_infer/eltwise.py | 11 - model-optimizer/mo/utils/shape.py | 10 +- ngraph/core/include/ngraph/op/unsqueeze.hpp | 11 +- ngraph/core/src/op/unsqueeze.cpp | 24 +- ngraph/test/backend/fused_op.in.cpp | 2 +- ngraph/test/constant_folding.cpp | 4 +- ngraph/test/op_is.cpp | 2 +- ngraph/test/runtime/opset0_tbl.hpp | 2 +- ngraph/test/type_prop/unsqueeze.cpp | 4 +- 21 files changed, 527 insertions(+), 490 deletions(-) delete mode 100644 model-optimizer/extensions/middle/EltwiseInputNormalization.py delete mode 100644 model-optimizer/extensions/middle/EltwiseInputNormalization_test.py diff --git a/inference-engine/src/legacy_api/src/ie_cnn_layer_builder_ngraph.cpp b/inference-engine/src/legacy_api/src/ie_cnn_layer_builder_ngraph.cpp index 50c6810..1444603 100644 --- a/inference-engine/src/legacy_api/src/ie_cnn_layer_builder_ngraph.cpp +++ b/inference-engine/src/legacy_api/src/ie_cnn_layer_builder_ngraph.cpp @@ -651,11 +651,11 @@ CNNLayer::Ptr NodeConverter::createLayer(const std::shared_ } template <> -CNNLayer::Ptr NodeConverter::createLayer(const std::shared_ptr& layer) const { +CNNLayer::Ptr NodeConverter::createLayer(const std::shared_ptr& layer) const { LayerParams params = {layer->get_friendly_name(), "Unsqueeze", details::convertPrecision(layer->get_output_element_type(0))}; auto res = std::make_shared(params); - auto castedLayer = ngraph::as_type_ptr(layer); + auto castedLayer = ngraph::as_type_ptr(layer); if (castedLayer == nullptr) THROW_IE_EXCEPTION << "Cannot get " << params.type << " layer " << params.name; return res; diff --git a/inference-engine/src/legacy_api/src/transformations/convert_opset1_to_legacy/convert_nms_to_nms_ie.cpp b/inference-engine/src/legacy_api/src/transformations/convert_opset1_to_legacy/convert_nms_to_nms_ie.cpp index 589860a..5403c5f 100644 --- a/inference-engine/src/legacy_api/src/transformations/convert_opset1_to_legacy/convert_nms_to_nms_ie.cpp +++ b/inference-engine/src/legacy_api/src/transformations/convert_opset1_to_legacy/convert_nms_to_nms_ie.cpp @@ -52,7 +52,7 @@ ngraph::pass::ConvertNMSToNMSIEMatcher::ConvertNMSToNMSIEMatcher() { if (auto new_max_per_class_const = std::dynamic_pointer_cast(new_max_per_class.get_node_shared_ptr())) { new_max_per_class = opset1::Constant::create(element::i64, Shape{1}, new_max_per_class_const->cast_vector()); } else { - new_max_per_class = std::make_shared( + new_max_per_class = std::make_shared( nms->input_value(2), opset1::Constant::create(element::i64, Shape{1}, {0})); new_ops.push_back(new_max_per_class.get_node_shared_ptr()); @@ -60,14 +60,14 @@ ngraph::pass::ConvertNMSToNMSIEMatcher::ConvertNMSToNMSIEMatcher() { } auto new_iou_threshold = nms->input_value(3); if (iou_threshold_rank.get_length() == 0) { - new_iou_threshold = std::make_shared( + new_iou_threshold = std::make_shared( nms->input_value(3), opset1::Constant::create(element::i64, Shape{1}, {0})); new_ops.push_back(new_iou_threshold.get_node_shared_ptr()); } auto new_score_threshold = nms->input_value(4); if (score_threshold_rank.get_length() == 0) { - new_score_threshold = std::make_shared( + new_score_threshold = std::make_shared( nms->input_value(4), opset1::Constant::create(element::i64, Shape{1}, {0})); new_ops.push_back(new_score_threshold.get_node_shared_ptr()); diff --git a/inference-engine/tests/functional/plugin/shared/src/subgraph_tests/matmul_squeeze_add.cpp b/inference-engine/tests/functional/plugin/shared/src/subgraph_tests/matmul_squeeze_add.cpp index afeb81b..05f65eb 100644 --- a/inference-engine/tests/functional/plugin/shared/src/subgraph_tests/matmul_squeeze_add.cpp +++ b/inference-engine/tests/functional/plugin/shared/src/subgraph_tests/matmul_squeeze_add.cpp @@ -58,7 +58,7 @@ void MatmulSqueezeAddTest::SetUp() { auto matmul_0 = std::make_shared(params[0], constant_0, false, true); auto constant_1 = std::make_shared(ngraph::element::Type_t::i64, ngraph::Shape{ 1 }, std::vector{0}); - auto unsqueeze_0 = std::make_shared(matmul_0, constant_1); + auto unsqueeze_0 = std::make_shared(matmul_0, constant_1); auto constant_2 = ngraph::builder::makeConstant(ngPrc, { 1, inputShape[0], outputSize }, CommonTestUtils::generate_float_numbers(inputShape[0] * outputSize, 0, 1, seed), false); diff --git a/inference-engine/tests/functional/plugin/shared/src/subgraph_tests/memory_LSTMCell.cpp b/inference-engine/tests/functional/plugin/shared/src/subgraph_tests/memory_LSTMCell.cpp index 4542c8f..1f3dc9f 100644 --- a/inference-engine/tests/functional/plugin/shared/src/subgraph_tests/memory_LSTMCell.cpp +++ b/inference-engine/tests/functional/plugin/shared/src/subgraph_tests/memory_LSTMCell.cpp @@ -75,7 +75,7 @@ namespace SubgraphTestsDefinitions { auto mul = ngraph::builder::makeEltwise(add, input_mul_const, ngraph::helpers::EltwiseTypes::MULTIPLY); auto unsqueeze_input_const = std::make_shared(ngraph::element::i64, ngraph::Shape{1}, squeeze_axes); - auto unsqueeze_input = std::make_shared(mul, unsqueeze_input_const); + auto unsqueeze_input = std::make_shared(mul, unsqueeze_input_const); auto permute_in_params = std::make_shared(ngraph::element::i64, ngraph::Shape{3}, ngraph::Shape{{1, 0, 2}}); auto permute_in = std::make_shared(unsqueeze_input, permute_in_params); @@ -100,7 +100,7 @@ namespace SubgraphTestsDefinitions { auto lstm = std::make_shared(squeeze, H_t, C_t, weightsNode, reccurrenceWeightsNode, biasNode, hiddenSize); auto unsqueeze_const = std::make_shared(ngraph::element::i64, ngraph::Shape{1}, squeeze_axes); - auto unsqueeze = std::make_shared(lstm->output(0), unsqueeze_const); + auto unsqueeze = std::make_shared(lstm->output(0), unsqueeze_const); // body - outputs auto H_o = lstm->output(0); auto C_o = lstm->output(1); @@ -158,7 +158,7 @@ namespace SubgraphTestsDefinitions { auto mul = ngraph::builder::makeEltwise(add, input_mul_const, ngraph::helpers::EltwiseTypes::MULTIPLY); auto unsqueeze_input_const = std::make_shared(ngraph::element::i64, ngraph::Shape{1}, squeeze_axes); - auto unsqueeze_input = std::make_shared(mul, unsqueeze_input_const); + auto unsqueeze_input = std::make_shared(mul, unsqueeze_input_const); auto cell_memory_constant = ngraph::builder::makeConstant(ngPrc, cell_memory_dims, cell_memory_init); @@ -175,7 +175,7 @@ namespace SubgraphTestsDefinitions { reccurrenceWeightsNode, biasNode, hiddenSize); auto unsqueeze_const = std::make_shared(ngraph::element::i64, ngraph::Shape{1}, squeeze_axes); - auto unsqueeze = std::make_shared(lstm->output(0), unsqueeze_const); + auto unsqueeze = std::make_shared(lstm->output(0), unsqueeze_const); auto final_reshape_pattern = std::make_shared(ngraph::element::i64, ngraph::Shape{4}, std::vector({1, 1, 1, hiddenSize})); diff --git a/inference-engine/tests/functional/plugin/shared/src/subgraph_tests/multiple_LSTMCell.cpp b/inference-engine/tests/functional/plugin/shared/src/subgraph_tests/multiple_LSTMCell.cpp index 0197077..f52d1a9 100644 --- a/inference-engine/tests/functional/plugin/shared/src/subgraph_tests/multiple_LSTMCell.cpp +++ b/inference-engine/tests/functional/plugin/shared/src/subgraph_tests/multiple_LSTMCell.cpp @@ -73,7 +73,7 @@ void MultipleLSTMCellTest::SetUp() { auto mul = ngraph::builder::makeEltwise(add, input_mul_const, ngraph::helpers::EltwiseTypes::MULTIPLY); auto unsqueeze_input_const = std::make_shared(ngraph::element::i64, ngraph::Shape{1}, squeeze_axes); - auto unsqueeze_input = std::make_shared(mul, unsqueeze_input_const); + auto unsqueeze_input = std::make_shared(mul, unsqueeze_input_const); auto permute_in_params = std::make_shared(ngraph::element::i64, ngraph::Shape{3}, ngraph::Shape{{1, 0, 2}}); auto permute_in = std::make_shared(unsqueeze_input, permute_in_params); @@ -100,7 +100,7 @@ void MultipleLSTMCellTest::SetUp() { auto lstm = std::make_shared(squeeze, H_t, C_t, weightsNode, reccurrenceWeightsNode, biasNode, hiddenSize); auto unsqueeze_const = std::make_shared(ngraph::element::i64, ngraph::Shape{1}, squeeze_axes); - auto unsqueeze = std::make_shared(lstm->output(0), unsqueeze_const); + auto unsqueeze = std::make_shared(lstm->output(0), unsqueeze_const); // body - outputs auto H_o = lstm->output(0); auto C_o = lstm->output(1); @@ -155,7 +155,7 @@ void MultipleLSTMCellTest::SetUp() { auto lstm_2 = std::make_shared(squeeze_2, H_t_2, C_t_2, weightsNode_2, reccurrenceWeightsNode_2, biasNode_2, hiddenSize); auto unsqueeze_2_const = std::make_shared(ngraph::element::i64, ngraph::Shape{1}, squeeze_axes); - auto unsqueeze_2 = std::make_shared(lstm_2->output(0), unsqueeze_2_const); + auto unsqueeze_2 = std::make_shared(lstm_2->output(0), unsqueeze_2_const); // body - outputs auto H_o_2 = lstm_2->output(0); auto C_o_2 = lstm_2->output(1); @@ -219,7 +219,7 @@ void MultipleLSTMCellTest::switchToNgraphFriendlyModel() { auto mul = ngraph::builder::makeEltwise(add, input_mul_const, ngraph::helpers::EltwiseTypes::MULTIPLY); auto unsqueeze_input_const = std::make_shared(ngraph::element::i64, ngraph::Shape{1}, squeeze_axes); - auto unsqueeze_input = std::make_shared(mul, unsqueeze_input_const); + auto unsqueeze_input = std::make_shared(mul, unsqueeze_input_const); // Body 1 - layers auto cell_memory_constant = ngraph::builder::makeConstant(ngPrc, cell_memory_dims, cell_memory_init); @@ -236,7 +236,7 @@ void MultipleLSTMCellTest::switchToNgraphFriendlyModel() { reccurrenceWeightsNode, biasNode, hiddenSize); auto unsqueeze_const = std::make_shared(ngraph::element::i64, ngraph::Shape{1}, squeeze_axes); - auto unsqueeze = std::make_shared(lstm->output(0), unsqueeze_const); + auto unsqueeze = std::make_shared(lstm->output(0), unsqueeze_const); auto first_reshape_pattern = std::make_shared(ngraph::element::i64, ngraph::Shape{4}, std::vector({1, 1, 1, hiddenSize})); @@ -261,7 +261,7 @@ void MultipleLSTMCellTest::switchToNgraphFriendlyModel() { reccurrenceWeightsNode_2, biasNode_2, hiddenSize); auto unsqueeze_2_const = std::make_shared(ngraph::element::i64, ngraph::Shape{1}, squeeze_axes); - auto unsqueeze_2 = std::make_shared(lstm_2->output(0), unsqueeze_2_const); + auto unsqueeze_2 = std::make_shared(lstm_2->output(0), unsqueeze_2_const); auto final_reshape_pattern = std::make_shared(ngraph::element::i64, ngraph::Shape{4}, std::vector({1, 1, 1, hiddenSize})); diff --git a/inference-engine/tests/ngraph_functions/src/low_precision_transformations/unsqueeze_function.cpp b/inference-engine/tests/ngraph_functions/src/low_precision_transformations/unsqueeze_function.cpp index 49cba6c..5e72f96 100644 --- a/inference-engine/tests/ngraph_functions/src/low_precision_transformations/unsqueeze_function.cpp +++ b/inference-engine/tests/ngraph_functions/src/low_precision_transformations/unsqueeze_function.cpp @@ -67,7 +67,7 @@ std::shared_ptr UnsqueezeFunction::getReference( const std::shared_ptr dequantizationOpBefore = makeDequantization(input, dequantizationBefore); const auto unsqueeze = std::make_shared>( - op::Unsqueeze(dequantizationOpBefore, std::make_shared(element::i64, Shape{ axes.size() }, axes)), + op::v0::Unsqueeze(dequantizationOpBefore, std::make_shared(element::i64, Shape{ axes.size() }, axes)), precisionAfterOperation); const std::shared_ptr dequantizationOpAfter = makeDequantization(unsqueeze, dequantizationAfter); dequantizationOpAfter->set_friendly_name("output"); diff --git a/model-optimizer/automation/package_BOM.txt b/model-optimizer/automation/package_BOM.txt index 925d2a4..80b1aa2 100644 --- a/model-optimizer/automation/package_BOM.txt +++ b/model-optimizer/automation/package_BOM.txt @@ -526,7 +526,6 @@ extensions/middle/DeleteControlFlowEdges.py extensions/middle/DeleteNotExecutable.py extensions/middle/DilatedConvolution.py extensions/middle/EltwiseChecker.py -extensions/middle/EltwiseInputNormalization.py extensions/middle/EltwiseInputReshape.py extensions/middle/FakeSplitOutputs.py extensions/middle/FusedBatchNormNonConstant.py diff --git a/model-optimizer/extensions/middle/EltwiseInputNormalization.py b/model-optimizer/extensions/middle/EltwiseInputNormalization.py deleted file mode 100644 index e07fe0e..0000000 --- a/model-optimizer/extensions/middle/EltwiseInputNormalization.py +++ /dev/null @@ -1,48 +0,0 @@ -""" - Copyright (C) 2018-2020 Intel Corporation - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -""" - -import networkx as nx -import numpy as np - -from extensions.middle.EltwiseInputReshape import EltwiseInputReshape -from mo.graph.graph import Graph -from mo.middle.replacement import MiddleReplacementPattern - - -class EltwiseInputNormalize(EltwiseInputReshape, MiddleReplacementPattern): - # This pass should be called directly from pipeline before layout change and other permutations - enabled = False - - def find_and_replace_pattern(self, graph: Graph): - eltwise_nodes = graph.get_op_nodes(is_eltwise=True) - # Iterating over all Eltwise operations and check that every input has similar shape - # in case of different shapes, we inserts new_shape attribute and then call EltwiseInputReshape extension - # that insert reshapes (in case of not constant nodes) or directly reshapes values in data nodes for specified - # shape - for node in eltwise_nodes: - output_shape = node.out_node().shape - for in_node in node.in_nodes().values(): - if len(in_node.shape) != len(output_shape): - # Set edge attribute new_shape for further transformation pass - new_shape = in_node.shape - for x in range(len(output_shape) - len(in_node.shape)): - new_shape = np.insert(new_shape, 0, 1) - - nx.set_edge_attributes(G=node.graph, - values={(in_node.id, node.id, 0): new_shape}, - name='new_shape') - - super().find_and_replace_pattern(graph) diff --git a/model-optimizer/extensions/middle/EltwiseInputNormalization_test.py b/model-optimizer/extensions/middle/EltwiseInputNormalization_test.py deleted file mode 100644 index 579a29c..0000000 --- a/model-optimizer/extensions/middle/EltwiseInputNormalization_test.py +++ /dev/null @@ -1,202 +0,0 @@ -""" - Copyright (C) 2018-2020 Intel Corporation - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -""" - -import unittest - -import numpy as np - -from extensions.middle.EltwiseInputNormalization import EltwiseInputNormalize -from mo.front.common.partial_infer.utils import int64_array -from mo.middle.passes.eliminate_test import build_graph -from mo.utils.ir_engine.compare_graphs import compare_graphs - -# The dictionary with nodes attributes used to build various graphs. A key is the name of the node and the value is the -# dictionary with node attributes. -nodes_attributes = { - # Placeholder layers - 'placeholder_1': {'value': None, 'shape': None, 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'}, - 'placeholder_1_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None}, - 'placeholder_2_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None}, - 'placeholder_3_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None}, - 'placeholder_4_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None}, - - # Reshape layers - 'reshape_1': {'type': 'Reshape', 'value': None, 'kind': 'op', 'op': 'Reshape'}, - 'reshape_1_data': {'value': None, 'shape': None, 'kind': 'data'}, - 'reshape_1_const': {'type': 'Const', 'kind': 'op', 'op': 'Const', 'value': None}, - 'reshape_1_const_data': {'kind': 'data', 'value': None, 'shape': None}, - - 'reshape_2': {'type': 'Reshape', 'value': None, 'kind': 'op', 'op': 'Reshape'}, - 'reshape_2_data': {'value': None, 'shape': None, 'kind': 'data'}, - 'reshape_2_const': {'type': 'Const', 'kind': 'op', 'op': 'Const', 'value': None}, - 'reshape_2_const_data': {'kind': 'data', 'value': None, 'shape': None}, - - # Eltwise consumes layers - 'eltwise_1': {'kind': 'op', 'is_eltwise': True}, - 'eltwise_1_data': {'value': None, 'shape': None, 'kind': 'data'}, - - 'eltwise_2': {'kind': 'op', 'is_eltwise': True}, - 'eltwise_2_data': {'value': None, 'shape': None, 'kind': 'data'}, - - 'eltwise_3': {'kind': 'op', 'is_eltwise': True}, - 'eltwise_3_data': {'value': None, 'shape': None, 'kind': 'data'}, - - 'eltwise_4': {'kind': 'op', 'is_eltwise': True}, - 'eltwise_4_data': {'value': None, 'shape': None, 'kind': 'data'}, - - # Concat - 'concat': {'type': 'Concat', 'kind': 'op', 'op': 'Concat'}, -} - - -class EltwiseInputNormalizationTest(unittest.TestCase): - def test1_not_constant(self): - # - # data1(1,3,64,64)----. data(1,3,64,64)-------. - # data2(1,64,1)-------->Eltwise-->data(1,3,64,64) => data(1,64,1)->Reshape->data(1,1,64,1)-->Eltwise->... - # data3(64,1)------' data(64,1)->Reshape->data(1,1,64,1)-' - # - graph = build_graph(nodes_attributes, [ - ('placeholder_1', 'placeholder_1_data'), - ('placeholder_1', 'placeholder_2_data'), - ('placeholder_1', 'placeholder_3_data'), - ('placeholder_1_data', 'eltwise_1'), - ('placeholder_2_data', 'eltwise_1'), - ('placeholder_3_data', 'eltwise_1'), - ('eltwise_1', 'eltwise_1_data') - ], - {'placeholder_1_data': {'shape': np.array([1, 3, 64, 64])}, - 'placeholder_2_data': {'shape': np.array([1, 64, 1])}, - 'placeholder_3_data': {'shape': np.array([64, 1])}, - 'eltwise_1_data': {'shape': np.array([1, 3, 64, 64])} - }, nodes_with_edges_only=True) - - graph_ref = build_graph(nodes_attributes, - [ - ('placeholder_1', 'placeholder_1_data'), - ('placeholder_1', 'placeholder_2_data'), - ('placeholder_1', 'placeholder_3_data'), - ('placeholder_1_data', 'eltwise_1'), - ('placeholder_2_data', 'reshape_1'), - ('reshape_1_const', 'reshape_1_const_data'), - ('reshape_1_const_data', 'reshape_1'), - ('placeholder_3_data', 'reshape_2'), - ('reshape_2_const', 'reshape_2_const_data'), - ('reshape_2_const_data', 'reshape_2'), - ('reshape_1', 'reshape_1_data'), - ('reshape_2', 'reshape_2_data'), - ('reshape_1_data', 'eltwise_1'), - ('reshape_2_data', 'eltwise_1'), - ('eltwise_1', 'eltwise_1_data') - ], - {'placeholder_1_data': {'shape': np.array([1, 3, 64, 64])}, - 'reshape_1_const': {'value': int64_array([1, 1, 64, 1]), 'shape': int64_array([4])}, - 'reshape_1_const_data': {'value': int64_array([1, 1, 64, 1]), - 'shape': int64_array([4])}, - 'reshape_1_data': {'shape': np.array([1, 1, 64, 1])}, - 'reshape_2_const': {'value': int64_array([1, 1, 64, 1]), 'shape': int64_array([4])}, - 'reshape_2_const_data': {'value': int64_array([1, 1, 64, 1]), - 'shape': int64_array([4])}, - 'reshape_2_data': {'shape': np.array([1, 1, 64, 1])}, - 'eltwise_1_data': {'shape': np.array([1, 3, 64, 64])} - }, nodes_with_edges_only=True) - - pattern = EltwiseInputNormalize() - pattern.find_and_replace_pattern(graph) - - (flag, resp) = compare_graphs(graph, graph_ref, 'eltwise_1', check_op_attrs=True) - self.assertTrue(flag, resp) - - def test_mega_hardcore(self): - # ORIGINAL GRAPH - # - # data1(1,3,64,64)---,->Eltwise1->data(1,3,64,64)-----,->Eltwise2->data(1,3,64,64)---,->Eltwise4->data(1,3,64,64) - # /\ /\ /\ - # data2(64,1)-----,-'--------------------------------'------------------------------' - # \/ / - # data3(64,1)----`-->Eltwise3->data(64,1)----------' - # - # REFERENCE GRAPH AFTER TRANSFORMATION - # - # data1(1,3,64,64)---,->Eltwise1->data(1,3,64,64)-----,->Eltwise2->data(1,3,64,64)---,->Eltwise4->data(1,3,64,64) - # /\ /\ /\ - # data2(1,1,64,1)---'--------------------------------'-------------------------------' - # / - # data4(64,1)-------, Reshape(1,1,64,1) - # \/ | - # data3(64,1)------`---->Eltwise3->data(64,1)---' - # - graph = build_graph(nodes_attributes, - [('placeholder_1_data', 'eltwise_1'), - ('placeholder_2_data', 'eltwise_1'), - ('eltwise_1', 'eltwise_1_data'), - ('eltwise_1_data', 'eltwise_2'), - ('placeholder_2_data', 'eltwise_3'), - ('placeholder_3_data', 'eltwise_3'), - ('eltwise_3', 'eltwise_3_data'), - ('eltwise_3_data', 'eltwise_2'), - ('eltwise_2', 'eltwise_2_data'), - ('eltwise_2_data', 'eltwise_4'), - ('placeholder_2_data', 'eltwise_4'), - ('eltwise_4', 'eltwise_4_data'), - ], - {'placeholder_1_data': {'shape': np.array([1, 3, 64, 64])}, - 'placeholder_2_data': {'shape': np.array([64, 1]), 'value': np.ones([64, 1])}, - 'placeholder_3_data': {'shape': np.array([64, 1])}, - 'eltwise_1_data': {'shape': np.array([1, 3, 64, 64])}, - 'eltwise_2_data': {'shape': np.array([1, 3, 64, 64])}, - 'eltwise_3_data': {'shape': np.array([64, 1])}, - 'eltwise_4_data': {'shape': np.array([1, 3, 64, 64])} - }, nodes_with_edges_only=True) - - graph_ref = build_graph(nodes_attributes, - [('placeholder_1_data', 'eltwise_1'), - ('placeholder_2_data', 'eltwise_1'), - ('eltwise_1', 'eltwise_1_data'), - ('eltwise_1_data', 'eltwise_2'), - ('placeholder_4_data', 'eltwise_3'), - ('placeholder_3_data', 'eltwise_3'), - ('eltwise_3', 'eltwise_3_data'), - ('eltwise_3_data', 'reshape_1'), - ('reshape_1_const', 'reshape_1_const_data'), - ('reshape_1_const_data', 'reshape_1'), - ('reshape_1', 'reshape_1_data'), - ('reshape_1_data', 'eltwise_2'), - ('eltwise_2', 'eltwise_2_data'), - ('eltwise_2_data', 'eltwise_4'), - ('placeholder_2_data', 'eltwise_4'), - ('eltwise_4', 'eltwise_4_data'), - ], - {'placeholder_1_data': {'shape': np.array([1, 3, 64, 64])}, - 'placeholder_2_data': {'shape': np.array([1, 1, 64, 1]), - 'value': np.ones([1, 1, 64, 1])}, - 'placeholder_3_data': {'shape': np.array([64, 1])}, - 'placeholder_4_data': {'shape': np.array([64, 1]), 'value': np.ones([64, 1])}, - 'reshape_1_const': {'value': int64_array([1, 1, 64, 1]), 'shape': int64_array([4])}, - 'reshape_1_const_data': {'value': int64_array([1, 1, 64, 1]), - 'shape': int64_array([4])}, - 'reshape_1_data': {'shape': np.array([1, 1, 64, 1])}, - 'eltwise_1_data': {'shape': np.array([1, 3, 64, 64])}, - 'eltwise_2_data': {'shape': np.array([1, 3, 64, 64])}, - 'eltwise_3_data': {'shape': np.array([64, 1])}, - 'eltwise_4_data': {'shape': np.array([1, 3, 64, 64])} - }, nodes_with_edges_only=True) - - pattern = EltwiseInputNormalize() - pattern.find_and_replace_pattern(graph) - - (flag, resp) = compare_graphs(graph, graph_ref, 'eltwise_4', check_op_attrs=True) - self.assertTrue(flag, resp) diff --git a/model-optimizer/extensions/middle/EltwiseInputReshape.py b/model-optimizer/extensions/middle/EltwiseInputReshape.py index 8814df7..84b40d7 100644 --- a/model-optimizer/extensions/middle/EltwiseInputReshape.py +++ b/model-optimizer/extensions/middle/EltwiseInputReshape.py @@ -17,11 +17,13 @@ import numpy as np from mo.front.common.layout import get_features_dim, shape_for_layout -from mo.graph.graph import Graph +from mo.front.common.partial_infer.utils import int64_array +from mo.front.tf.graph_utils import create_op_with_const_inputs +from mo.graph.graph import Graph, Node from mo.middle.replacement import MiddleReplacementPattern from mo.ops.const import Const -from mo.ops.op import Op from mo.ops.reshape import Reshape +from mo.ops.unsqueeze import Unsqueeze class Eltwise1DInputReshape(MiddleReplacementPattern): @@ -42,9 +44,6 @@ class Eltwise1DInputReshape(MiddleReplacementPattern): """ enabled = False - def run_after(self): - return [EltwiseInputReshape] - def find_and_replace_pattern(self, graph: Graph): layout = graph.graph['layout'] for eltwise_op_node in graph.get_op_nodes(is_eltwise=True): @@ -64,59 +63,85 @@ class Eltwise1DInputReshape(MiddleReplacementPattern): reshape_op.out_port(0).connect(eltwise_op_node.in_port(port)) -class EltwiseInputReshape(MiddleReplacementPattern): - enabled = True - force_clean_up = True - - def run_after(self): - from extensions.middle.pass_separator import MiddleStart - return [MiddleStart] - - def find_and_replace_pattern(self, graph: Graph): - for node in graph.get_data_nodes(): - # Get all requested shapes for current node - # This mapping will contain pairs like {shape:[list of consumers nodes]} - mapping = {} - for consumer in node.out_nodes(): - edge_attrs = graph.get_edge_data(node.id, consumer.id)[0] - if 'new_shape' in edge_attrs: - if np.array_equal(edge_attrs['new_shape'], node.shape): - continue - new_shape = tuple([x for x in edge_attrs['new_shape']]) - if not new_shape in mapping: - mapping.update({new_shape: [consumer]}) - else: - mapping[new_shape].append(consumer) - - if node.has_valid('value'): - # Check that requested shape are the same - # In case if they are different, we duplicate them - for shape_key in mapping.keys(): - shape = list(shape_key) - new_value = np.reshape(node.value, shape) - node_copy = Op.create_input_data_node(graph, node.id + '/copy', value=np.array(new_value)) - for consumer in mapping[shape_key]: - edge_attrs = graph.get_edge_data(node.id, consumer.id)[0] - del edge_attrs['new_shape'] - - # Remove edge from previous data node and connect new data node with its consumer - graph.remove_edge(node.id, consumer.id) - graph.add_edge(node_copy.id, consumer.id, **edge_attrs) +def compute_unsqueeze_map_for_eltwise(eltwise_node: Node): + ''' + The function computes a map of unsqueeze_dims for each producer of eltwise node. + These unsqueeze_dims are needed to normalize input shapes of eltwise node. + ''' + eltwise_shape = eltwise_node.out_port(0).data.get_shape() + max_dims = max( + [len(port.data.get_shape()) for port in eltwise_node.in_ports().values() if port.data.get_shape() is not None]) + axis = eltwise_node.soft_get('axis', None) + unsqueeze_dims_map = {} + for consumer_port in eltwise_node.in_ports().values(): + producer_port = consumer_port.get_source() + producer_shape = producer_port.data.get_shape() + unsqueeze_dims = int64_array([]) + + # 1. Compute unsqueeze dimensions in the tail + if len(producer_shape) != max_dims and len(producer_shape) > 0 and axis is not None: + num_unsqueeze_dims = max_dims - axis - len(producer_shape) + if num_unsqueeze_dims > 0: + unsqueeze_dims = np.arange(len(producer_shape), len(producer_shape) + num_unsqueeze_dims, + dtype=np.int64) + + # 2. Compute unsqueeze dimensions in the head + unsqueeze_dims_head = np.arange(len(eltwise_shape) - len(producer_shape) - len(unsqueeze_dims), dtype=np.int64) + + # Pay attention that unsqueeze dims order makes sense + # since shape is normalized in the tail first and after in the head + unsqueeze_dims = np.concatenate((unsqueeze_dims, unsqueeze_dims_head)) + unsqueeze_dims_map[producer_port] = unsqueeze_dims + + return unsqueeze_dims_map + + +def normalize_eltwise_inputs(graph: Graph): + ''' + The function normalizes input shapes for eltwise nodes. + In the first step the function gets to know which shapes/unsqueeze dims for inputs are required for normalization. + In the second step the function inserts Unsqueeze nodes between non-normalized inputs and eltwise nodes. + ''' + # Generate a map for producers of eltwise nodes with non-normalized shapes + # and in this map every producer has another map that reflects normalized shape + # to a list of eltwise consumers + mapping = {} + for eltwise_node in graph.get_op_nodes(is_eltwise=True): + unsqueeze_dims_map = compute_unsqueeze_map_for_eltwise(eltwise_node) + for consumer_port in eltwise_node.in_ports().values(): + producer_port = consumer_port.get_source() + unsqueeze_dims = unsqueeze_dims_map[producer_port] + if unsqueeze_dims is not None and len(unsqueeze_dims) > 0: + unsqueeze_dims = tuple([x for x in unsqueeze_dims]) + if producer_port not in mapping: + mapping.update({producer_port: {unsqueeze_dims: [consumer_port]}}) + elif unsqueeze_dims not in mapping[producer_port]: + mapping[producer_port].update({unsqueeze_dims: [consumer_port]}) + else: + mapping[producer_port][unsqueeze_dims].append(consumer_port) + + # Walk through each produced in the map and insert Unsqueeze nodes between a producer and eltwise nodes + for producer_port in mapping.keys(): + producer_node = producer_port.node + for unsqueeze_dims in mapping[producer_port].keys(): + unsqueeze_name = producer_node.soft_get('name', producer_node.id) + '/EltwiseUnsqueeze' + unsqueeze_node = create_op_with_const_inputs(graph, Unsqueeze, {1: int64_array(list(unsqueeze_dims))}, + {'name': unsqueeze_name}) + unsqueeze_node.in_port(0).connect(producer_port) + + # Insert Unsqueeze with determined unsqueeze dimensions between the current producer and eltwise node + for consumer_port in mapping[producer_port][unsqueeze_dims]: + consumer_port.connect(unsqueeze_node.out_port(0)) + + # The shape and value adjustments must be explicitly done within the transformation + # since the transformation is called from Fusing transformation that excludes + # automatic call of shape inference pass + producer_port_value = producer_port.data.get_value() + producer_port_shape = producer_port.data.get_shape() + new_shape = producer_port_shape.copy() + for unsqueeze_dim in unsqueeze_dims: + new_shape = np.insert(new_shape, unsqueeze_dim, 1) + if producer_port_value is not None: + unsqueeze_node.out_port(0).data.set_value(np.reshape(producer_port_value, new_shape)) else: - # Insert Reshape layer between data node and consumer - for shape_key in mapping.keys(): - shape = list(shape_key) - reshape_name = node.soft_get('name', node.id) + '/EltwiseReshape' - reshape = Reshape(graph, attrs={'name': reshape_name}) - reshape_dim = Const(graph, - {'value': shape, 'name': reshape_name + '/Shape'}).create_node_with_data() - reshape_data = reshape.create_node_with_data(inputs=[node, reshape_dim]) - - # Iterate over consumers and reconnect them to Reshape layer output - for consumer in mapping[shape_key]: - edge_attrs = graph.get_edge_data(node.id, consumer.id)[0] - del edge_attrs['new_shape'] - - # Reconnect edge from original data node to Reshape output datanode - graph.remove_edge(node.id, consumer.id) - graph.add_edge(reshape_data.id, consumer.id, **edge_attrs) + unsqueeze_node.out_port(0).data.set_shape(new_shape) diff --git a/model-optimizer/extensions/middle/EltwiseInputReshape_test.py b/model-optimizer/extensions/middle/EltwiseInputReshape_test.py index 59d24f1..ce45b46 100644 --- a/model-optimizer/extensions/middle/EltwiseInputReshape_test.py +++ b/model-optimizer/extensions/middle/EltwiseInputReshape_test.py @@ -18,7 +18,7 @@ import unittest import numpy as np -from extensions.middle.EltwiseInputReshape import EltwiseInputReshape +from extensions.middle.EltwiseInputReshape import normalize_eltwise_inputs from mo.front.common.partial_infer.utils import int64_array from mo.middle.passes.eliminate_test import build_graph from mo.utils.ir_engine.compare_graphs import compare_graphs @@ -28,47 +28,216 @@ from mo.utils.ir_engine.compare_graphs import compare_graphs nodes_attributes = { # Placeholder layers 'placeholder_1': {'value': None, 'shape': None, 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'}, + 'placeholder_2': {'value': None, 'shape': None, 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'}, + 'placeholder_3': {'value': None, 'shape': None, 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'}, 'placeholder_1_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None}, 'placeholder_2_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None}, 'placeholder_3_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None}, + 'placeholder_4_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None}, # Reshape layers - 'reshape_1': {'type': 'Reshape', 'value': None, 'kind': 'op', 'op': 'Reshape'}, + 'reshape_1': {'type': 'Unsqueeze', 'value': None, 'kind': 'op', 'op': 'Unsqueeze'}, 'reshape_1_data': {'value': None, 'shape': None, 'kind': 'data'}, 'reshape_1_const': {'type': 'Const', 'kind': 'op', 'op': 'Const', 'value': None}, 'reshape_1_const_data': {'kind': 'data', 'value': None, 'shape': None}, - 'reshape_2': {'type': 'Reshape', 'value': None, 'kind': 'op', 'op': 'Reshape'}, + 'reshape_2': {'type': 'Unsqueeze', 'value': None, 'kind': 'op', 'op': 'Unsqueeze'}, 'reshape_2_data': {'value': None, 'shape': None, 'kind': 'data'}, 'reshape_2_const': {'type': 'Const', 'kind': 'op', 'op': 'Const', 'value': None}, 'reshape_2_const_data': {'kind': 'data', 'value': None, 'shape': None}, - # Fake consumes layers - 'consumer_1': {'type': 'Consumer', 'value': None, 'kind': 'op', 'op': 'Consumer'}, - 'consumer_2': {'type': 'Consumer', 'value': None, 'kind': 'op', 'op': 'Consumer'}, - 'consumer_3': {'type': 'Consumer', 'value': None, 'kind': 'op', 'op': 'Consumer'}, + # Eltwise consumes layers + 'eltwise_1': {'kind': 'op', 'is_eltwise': True}, + 'eltwise_1_data': {'value': None, 'shape': None, 'kind': 'data'}, + + 'eltwise_2': {'kind': 'op', 'is_eltwise': True}, + 'eltwise_2_data': {'value': None, 'shape': None, 'kind': 'data'}, + + 'eltwise_3': {'kind': 'op', 'is_eltwise': True}, + 'eltwise_3_data': {'value': None, 'shape': None, 'kind': 'data'}, + + 'eltwise_4': {'kind': 'op', 'is_eltwise': True}, + 'eltwise_4_data': {'value': None, 'shape': None, 'kind': 'data'}, # Concat 'concat': {'type': 'Concat', 'kind': 'op', 'op': 'Concat'}, } -class EltwiseInputReshapeTest(unittest.TestCase): +class EltwiseInputNormalizationTest(unittest.TestCase): def test1_not_constant(self): + # + # data1(1,3,64,64)----. data(1,3,64,64)-------. + # data2(1,64,1)-------->Eltwise-->data(1,3,64,64) => data(1,64,1)->Reshape->data(1,1,64,1)-->Eltwise->... + # data3(64,1)------' data(64,1)->Reshape->data(1,1,64,1)-' + # + graph = build_graph(nodes_attributes, [ + ('placeholder_1', 'placeholder_1_data'), + ('placeholder_1', 'placeholder_2_data'), + ('placeholder_1', 'placeholder_3_data'), + ('placeholder_1_data', 'eltwise_1'), + ('placeholder_2_data', 'eltwise_1'), + ('placeholder_3_data', 'eltwise_1'), + ('eltwise_1', 'eltwise_1_data') + ], + {'placeholder_1_data': {'shape': np.array([1, 3, 64, 64])}, + 'placeholder_2_data': {'shape': np.array([1, 64, 1])}, + 'placeholder_3_data': {'shape': np.array([64, 1])}, + 'eltwise_1_data': {'shape': np.array([1, 3, 64, 64])} + }, nodes_with_edges_only=True) + + graph_ref = build_graph(nodes_attributes, + [ + ('placeholder_1', 'placeholder_1_data'), + ('placeholder_1', 'placeholder_2_data'), + ('placeholder_1', 'placeholder_3_data'), + ('placeholder_1_data', 'eltwise_1'), + ('placeholder_2_data', 'reshape_1'), + ('reshape_1_const', 'reshape_1_const_data'), + ('reshape_1_const_data', 'reshape_1'), + ('placeholder_3_data', 'reshape_2'), + ('reshape_2_const', 'reshape_2_const_data'), + ('reshape_2_const_data', 'reshape_2'), + ('reshape_1', 'reshape_1_data'), + ('reshape_2', 'reshape_2_data'), + ('reshape_1_data', 'eltwise_1'), + ('reshape_2_data', 'eltwise_1'), + ('eltwise_1', 'eltwise_1_data') + ], + {'placeholder_1_data': {'shape': np.array([1, 3, 64, 64])}, + 'reshape_1_const': {'value': int64_array([0]), 'shape': int64_array([1])}, + 'reshape_1_const_data': {'value': int64_array([0]), + 'shape': int64_array([1])}, + 'reshape_1_data': {'shape': np.array([1, 1, 64, 1])}, + 'reshape_2_const': {'value': int64_array([0, 1]), 'shape': int64_array([2])}, + 'reshape_2_const_data': {'value': int64_array([0, 1]), + 'shape': int64_array([2])}, + 'reshape_2_data': {'shape': np.array([1, 1, 64, 1])}, + 'eltwise_1_data': {'shape': np.array([1, 3, 64, 64])} + }, nodes_with_edges_only=True) + + normalize_eltwise_inputs(graph) + + (flag, resp) = compare_graphs(graph, graph_ref, 'eltwise_1', check_op_attrs=True) + self.assertTrue(flag, resp) + + def test_mega_hardcore(self): + # ORIGINAL GRAPH + # + # data1(1,3,64,64)---,->Eltwise1->data(1,3,64,64)-----,->Eltwise2->data(1,3,64,64)---,->Eltwise4->data(1,3,64,64) + # /\ /\ /\ + # data2(64,1)-----,-'--------------------------------'------------------------------' + # \/ / + # data3(64,1)----`-->Eltwise3->data(64,1)----------' + # + # REFERENCE GRAPH AFTER TRANSFORMATION + # + # data1(1,3,64,64)---------------------,->Eltwise1->data(1,3,64,64)-----,->Eltwise2->data(1,3,64,64)---,->Eltwise4->data(1,3,64,64) + # /\ /\ /\ + # data2(64,1)-,- Reshape1(1,1,64,64)--'--------------------------------o-------------------------------' + # | | + # | Reshape(1,1,64,1) + # \/ | + # data3(64,1)----------->Eltwise3->data(64,1)--------------------------' + # + graph = build_graph(nodes_attributes, + [('placeholder_1', 'placeholder_1_data'), + ('placeholder_2', 'placeholder_2_data'), + ('placeholder_3', 'placeholder_3_data'), + ('placeholder_1_data', 'eltwise_1'), + ('placeholder_2_data', 'eltwise_1'), + ('eltwise_1', 'eltwise_1_data'), + ('eltwise_1_data', 'eltwise_2'), + ('placeholder_2_data', 'eltwise_3'), + ('placeholder_3_data', 'eltwise_3'), + ('eltwise_3', 'eltwise_3_data'), + ('eltwise_3_data', 'eltwise_2'), + ('eltwise_2', 'eltwise_2_data'), + ('eltwise_2_data', 'eltwise_4'), + ('placeholder_2_data', 'eltwise_4'), + ('eltwise_4', 'eltwise_4_data'), + ], + {'placeholder_1_data': {'shape': np.array([1, 3, 64, 64])}, + 'placeholder_2_data': {'shape': np.array([64, 1]), 'value': np.ones([64, 1])}, + 'placeholder_3_data': {'shape': np.array([64, 1])}, + 'eltwise_1_data': {'shape': np.array([1, 3, 64, 64])}, + 'eltwise_2_data': {'shape': np.array([1, 3, 64, 64])}, + 'eltwise_3_data': {'shape': np.array([64, 1])}, + 'eltwise_4_data': {'shape': np.array([1, 3, 64, 64])} + }, nodes_with_edges_only=True) + + graph_ref = build_graph(nodes_attributes, + [('placeholder_1', 'placeholder_1_data'), + ('placeholder_2', 'placeholder_2_data'), + ('placeholder_3', 'placeholder_3_data'), + ('placeholder_1_data', 'eltwise_1'), + ('placeholder_2_data', 'reshape_1'), + ('reshape_1_const', 'reshape_1_const_data'), + ('reshape_1_const_data', 'reshape_1'), + ('reshape_1', 'reshape_1_data'), + ('reshape_1_data', 'eltwise_1'), + ('eltwise_1', 'eltwise_1_data'), + ('eltwise_1_data', 'eltwise_2'), + ('placeholder_2_data', 'eltwise_3'), + ('placeholder_3_data', 'eltwise_3'), + ('eltwise_3', 'eltwise_3_data'), + ('eltwise_3_data', 'reshape_2'), + ('reshape_2_const', 'reshape_2_const_data'), + ('reshape_2_const_data', 'reshape_2'), + ('reshape_2', 'reshape_2_data'), + ('reshape_2_data', 'eltwise_2'), + ('eltwise_2', 'eltwise_2_data'), + ('eltwise_2_data', 'eltwise_4'), + ('reshape_1_data', 'eltwise_4'), + ('eltwise_4', 'eltwise_4_data'), + ], + {'placeholder_1_data': {'shape': np.array([1, 3, 64, 64])}, + 'placeholder_2_data': {'shape': np.array([64, 1]), + 'value': np.ones([64, 1])}, + 'placeholder_3_data': {'shape': np.array([64, 1])}, + 'reshape_1_const': {'value': int64_array([0, 1]), 'shape': int64_array([2])}, + 'reshape_1_const_data': {'value': int64_array([0, 1]), + 'shape': int64_array([2])}, + 'reshape_1_data': {'shape': np.array([1, 1, 64, 1])}, + + 'reshape_2_const': {'value': int64_array([0, 1]), 'shape': int64_array([2])}, + 'reshape_2_const_data': {'value': int64_array([0, 1]), + 'shape': int64_array([2])}, + 'reshape_2_data': {'shape': np.array([1, 1, 64, 1])}, + 'eltwise_1_data': {'shape': np.array([1, 3, 64, 64])}, + 'eltwise_2_data': {'shape': np.array([1, 3, 64, 64])}, + 'eltwise_3_data': {'shape': np.array([64, 1])}, + 'eltwise_4_data': {'shape': np.array([1, 3, 64, 64])} + }, nodes_with_edges_only=True) + + normalize_eltwise_inputs(graph) + + (flag, resp) = compare_graphs(graph, graph_ref, 'eltwise_4', check_op_attrs=True) + self.assertTrue(flag, resp) + + def test2_not_constant(self): # ,-------------->consumer3 ,------------>consumer3 # data---(new_shape1)-->consumer1 => data---->Reshape-->consumer1 # `-(new_shape2)-->consumer2 `-->Reshape-->consumer2 # graph = build_graph(nodes_attributes, [ ('placeholder_1', 'placeholder_1_data'), - ('placeholder_1_data', 'consumer_1', {'new_shape': int64_array([1, 3, 1, 1])}), - ('placeholder_1_data', 'consumer_2', {'new_shape': int64_array([1, 1, 3])}), - ('placeholder_1_data', 'consumer_3'), - ('consumer_1', 'concat'), - ('consumer_2', 'concat'), - ('consumer_3', 'concat'), + ('placeholder_1_data', 'eltwise_1'), + ('placeholder_1_data', 'eltwise_2'), + ('placeholder_1_data', 'eltwise_3'), + ('eltwise_1', 'eltwise_1_data'), + ('eltwise_2', 'eltwise_2_data'), + ('eltwise_3', 'eltwise_3_data'), + ('eltwise_1_data', 'concat'), + ('eltwise_2_data', 'concat'), + ('eltwise_3_data', 'concat'), ], - {'placeholder_1_data': {'shape': int64_array([1, 3])}}, nodes_with_edges_only=True) + {'placeholder_1_data': {'shape': int64_array([1, 3])}, + 'eltwise_1_data': {'shape': int64_array([1, 1, 1, 3])}, + 'eltwise_2_data': {'shape': int64_array([1, 1, 3])}, + 'eltwise_3_data': {'shape': int64_array([1, 3])}, + }, + nodes_with_edges_only=True) graph_ref = build_graph(nodes_attributes, [ @@ -79,32 +248,34 @@ class EltwiseInputReshapeTest(unittest.TestCase): ('placeholder_1_data', 'reshape_2'), ('reshape_2_const', 'reshape_2_const_data'), ('reshape_2_const_data', 'reshape_2'), - ('placeholder_1_data', 'consumer_3'), + ('placeholder_1_data', 'eltwise_3'), ('reshape_1', 'reshape_1_data'), ('reshape_2', 'reshape_2_data'), - ('reshape_1_data', 'consumer_1'), - ('reshape_2_data', 'consumer_2'), - ('consumer_1', 'concat'), - ('consumer_2', 'concat'), - ('consumer_3', 'concat'), + ('reshape_1_data', 'eltwise_1'), + ('reshape_2_data', 'eltwise_2'), + ('eltwise_1', 'eltwise_1_data'), + ('eltwise_2', 'eltwise_2_data'), + ('eltwise_3', 'eltwise_3_data'), + ('eltwise_1_data', 'concat'), + ('eltwise_2_data', 'concat'), + ('eltwise_3_data', 'concat'), ], {'placeholder_1_data': {'shape': int64_array([1, 3])}, - 'reshape_1_const': {'value': int64_array([1, 3, 1, 1]), 'shape': int64_array([4])}, - 'reshape_1_const_data': {'value': int64_array([1, 3, 1, 1]), - 'shape': int64_array([4])}, - 'reshape_1_data': {'shape': int64_array([1, 3, 1, 1])}, - 'reshape_2_const': {'value': int64_array([1, 1, 3]), 'shape': int64_array([3])}, - 'reshape_2_const_data': {'value': int64_array([1, 1, 3]), 'shape': int64_array([3])}, + 'reshape_1_const': {'value': int64_array([0, 1]), 'shape': int64_array([2])}, + 'reshape_1_const_data': {'value': int64_array([0, 1]), + 'shape': int64_array([2])}, + 'reshape_1_data': {'shape': int64_array([1, 1, 1, 3])}, + 'reshape_2_const': {'value': int64_array([0]), 'shape': int64_array([1])}, + 'reshape_2_const_data': {'value': int64_array([0]), 'shape': int64_array([1])}, 'reshape_2_data': {'shape': int64_array([1, 1, 3])}, }, nodes_with_edges_only=True) - pattern = EltwiseInputReshape() - pattern.find_and_replace_pattern(graph) + normalize_eltwise_inputs(graph) (flag, resp) = compare_graphs(graph, graph_ref, 'concat', check_op_attrs=True) self.assertTrue(flag, resp) - def test2_not_constant(self): + def test3_not_constant(self): # ,--------------->consumer3 ,----------->consumer3 # data---(new_shape1)-->consumer1 => data-->Reshape-->consumer1 # `-(new_shape1)-->consumer2 `-->consumer2 @@ -112,14 +283,22 @@ class EltwiseInputReshapeTest(unittest.TestCase): graph = build_graph(nodes_attributes, [ ('placeholder_1', 'placeholder_1_data'), - ('placeholder_1_data', 'consumer_1', {'new_shape': int64_array([1, 3, 1, 1])}), - ('placeholder_1_data', 'consumer_2', {'new_shape': int64_array([1, 3, 1, 1])}), - ('placeholder_1_data', 'consumer_3'), - ('consumer_1', 'concat'), - ('consumer_2', 'concat'), - ('consumer_3', 'concat'), + ('placeholder_1_data', 'eltwise_1'), + ('placeholder_1_data', 'eltwise_2'), + ('placeholder_1_data', 'eltwise_3'), + ('eltwise_1', 'eltwise_1_data'), + ('eltwise_2', 'eltwise_2_data'), + ('eltwise_3', 'eltwise_3_data'), + ('eltwise_1_data', 'concat'), + ('eltwise_2_data', 'concat'), + ('eltwise_3_data', 'concat'), ], - {'placeholder_1_data': {'shape': int64_array([1, 3])}}, nodes_with_edges_only=True) + {'placeholder_1_data': {'shape': int64_array([1, 3])}, + 'eltwise_1_data': {'shape': int64_array([1, 1, 1, 3])}, + 'eltwise_2_data': {'shape': int64_array([1, 1, 1, 3])}, + 'eltwise_3_data': {'shape': int64_array([1, 3])}, + }, + nodes_with_edges_only=True) graph_ref = build_graph(nodes_attributes, [ @@ -127,123 +306,239 @@ class EltwiseInputReshapeTest(unittest.TestCase): ('placeholder_1_data', 'reshape_1'), ('reshape_1_const', 'reshape_1_const_data'), ('reshape_1_const_data', 'reshape_1'), - ('placeholder_1_data', 'consumer_3'), + ('placeholder_1_data', 'eltwise_3'), ('reshape_1', 'reshape_1_data'), - ('reshape_1_data', 'consumer_1'), - ('reshape_1_data', 'consumer_2'), - ('consumer_1', 'concat'), - ('consumer_2', 'concat'), - ('consumer_3', 'concat'), + ('reshape_1_data', 'eltwise_1'), + ('reshape_1_data', 'eltwise_2'), + ('eltwise_1', 'eltwise_1_data'), + ('eltwise_2', 'eltwise_2_data'), + ('eltwise_3', 'eltwise_3_data'), + ('eltwise_1_data', 'concat'), + ('eltwise_2_data', 'concat'), + ('eltwise_3_data', 'concat'), ], {'placeholder_1_data': {'shape': int64_array([1, 3])}, - 'reshape_1_const': {'value': int64_array([1, 3, 1, 1]), 'shape': int64_array([4])}, - 'reshape_1_const_data': {'value': int64_array([1, 3, 1, 1]), - 'shape': int64_array([4])}, - 'reshape_1_data': {'shape': int64_array([1, 3, 1, 1])}, + 'reshape_1_const': {'value': int64_array([0, 1]), 'shape': int64_array([2])}, + 'reshape_1_const_data': {'value': int64_array([0, 1]), + 'shape': int64_array([2])}, + 'reshape_1_data': {'shape': int64_array([1, 1, 1, 3])}, }, nodes_with_edges_only=True) - pattern = EltwiseInputReshape() - pattern.find_and_replace_pattern(graph) + normalize_eltwise_inputs(graph) (flag, resp) = compare_graphs(graph, graph_ref, 'concat', check_op_attrs=True) self.assertTrue(flag, resp) - def test3_constant(self): - # ,--------------->consumer3 data-->consumer3 - # data---(new_shape1)-->consumer1 => data-->consumer1 - # `-(new_shape2)-->consumer2 data-->consumer2 + def test4_constant(self): + # ,--------------->consumer3 ,------------>consumer3 + # data---(new_shape1)-->consumer1 => data--->reshape1-->consumer1 + # `-(new_shape2)-->consumer2 `->reshape2-->consumer2 # graph = build_graph(nodes_attributes, - [('placeholder_1_data', 'consumer_1', {'new_shape': int64_array([1, 3, 1, 1])}), - ('placeholder_1_data', 'consumer_2', {'new_shape': int64_array([1, 1, 3])}), - ('placeholder_1_data', 'consumer_3'), - ('consumer_1', 'concat'), - ('consumer_2', 'concat'), - ('consumer_3', 'concat'), + [('placeholder_1', 'placeholder_1_data'), + ('placeholder_1_data', 'eltwise_1'), + ('placeholder_1_data', 'eltwise_2'), + ('placeholder_1_data', 'eltwise_3'), + ('eltwise_1', 'eltwise_1_data'), + ('eltwise_2', 'eltwise_2_data'), + ('eltwise_3', 'eltwise_3_data'), + ('eltwise_1_data', 'concat'), + ('eltwise_2_data', 'concat'), + ('eltwise_3_data', 'concat'), ], - {'placeholder_1_data': {'shape': int64_array([1, 3]), 'value': np.ones([1, 3])}}, + {'placeholder_1_data': {'shape': int64_array([1, 3]), 'value': np.ones([1, 3])}, + 'eltwise_1_data': {'shape': int64_array([1, 1, 1, 3])}, + 'eltwise_2_data': {'shape': int64_array([1, 1, 3])}, + 'eltwise_3_data': {'shape': int64_array([1, 3])}, + }, nodes_with_edges_only=True) graph_ref = build_graph(nodes_attributes, - [('placeholder_1_data', 'consumer_1'), - ('placeholder_2_data', 'consumer_2'), - ('placeholder_3_data', 'consumer_3'), - ('consumer_1', 'concat'), - ('consumer_2', 'concat'), - ('consumer_3', 'concat'), + [('placeholder_1', 'placeholder_1_data'), + ('placeholder_1_data', 'reshape_1'), + ('reshape_1_const', 'reshape_1_const_data'), + ('reshape_1_const_data', 'reshape_1'), + ('reshape_1', 'reshape_1_data'), + ('reshape_1_data', 'eltwise_1'), + ('placeholder_1_data', 'reshape_2'), + ('reshape_2_const', 'reshape_2_const_data'), + ('reshape_2_const_data', 'reshape_2'), + ('reshape_2', 'reshape_2_data'), + ('reshape_2_data', 'eltwise_2'), + ('placeholder_1_data', 'eltwise_3'), + ('eltwise_1', 'eltwise_1_data'), + ('eltwise_2', 'eltwise_2_data'), + ('eltwise_3', 'eltwise_3_data'), + ('eltwise_1_data', 'concat'), + ('eltwise_2_data', 'concat'), + ('eltwise_3_data', 'concat'), ], - {'placeholder_1_data': {'shape': int64_array([1, 3, 1, 1]), - 'value': np.ones([1, 3, 1, 1])}, - 'placeholder_2_data': {'shape': int64_array([1, 1, 3]), 'value': np.ones([1, 1, 3])}, - 'placeholder_3_data': {'shape': int64_array([1, 3]), 'value': np.ones([1, 3])}, + {'placeholder_1_data': {'shape': int64_array([1, 3]), 'value': np.ones([1, 3])}, + 'reshape_1_const': {'value': int64_array([0, 1]), 'shape': int64_array([2])}, + 'reshape_1_const_data': {'value': int64_array([0, 1]), + 'shape': int64_array([2])}, + 'reshape_1_data': {'shape': int64_array([1, 1, 1, 3])}, + + 'reshape_2_const': {'value': int64_array([0]), 'shape': int64_array([1])}, + 'reshape_2_const_data': {'value': int64_array([0]), + 'shape': int64_array([1])}, + 'reshape_2_data': {'shape': int64_array([1, 1, 3])}, }, nodes_with_edges_only=True) - pattern = EltwiseInputReshape() - pattern.find_and_replace_pattern(graph) + normalize_eltwise_inputs(graph) (flag, resp) = compare_graphs(graph, graph_ref, 'concat', check_op_attrs=True) self.assertTrue(flag, resp) - def test4_constant(self): - # ,--------------->consumer3 ,-->consumer3 - # data---(new_shape1)-->consumer1 => data-->consumer1 - # `-(new_shape2)-->consumer2 `->consumer2 + def test5_constant(self): + # ,-(new_shape)-->consumer3 ,-->consumer3 + # data---(new_shape)-->consumer1 => data-->reshape---->consumer1 + # `-(new_shape)-->consumer2 `-->consumer2 # graph = build_graph(nodes_attributes, - [('placeholder_1_data', 'consumer_1', {'new_shape': int64_array([3, 1, 1])}), - ('placeholder_1_data', 'consumer_2', {'new_shape': int64_array([3, 1, 1])}), - ('placeholder_1_data', 'consumer_3', {'new_shape': int64_array([3, 1, 1])}), - ('consumer_1', 'concat'), - ('consumer_2', 'concat'), - ('consumer_3', 'concat'), + [('placeholder_1', 'placeholder_1_data'), + ('placeholder_1_data', 'eltwise_1'), + ('placeholder_1_data', 'eltwise_2'), + ('placeholder_1_data', 'eltwise_3'), + ('eltwise_1', 'eltwise_1_data'), + ('eltwise_2', 'eltwise_2_data'), + ('eltwise_3', 'eltwise_3_data'), + ('eltwise_1_data', 'concat'), + ('eltwise_2_data', 'concat'), + ('eltwise_3_data', 'concat'), ], - {'placeholder_1_data': {'shape': int64_array([1, 3]), 'value': np.ones([1, 3])}}, + {'placeholder_1_data': {'shape': int64_array([1, 3]), 'value': np.ones([1, 3])}, + 'eltwise_1_data': {'shape': int64_array([1, 1, 3])}, + 'eltwise_2_data': {'shape': int64_array([1, 1, 3])}, + 'eltwise_3_data': {'shape': int64_array([1, 1, 3])}, + }, nodes_with_edges_only=True) graph_ref = build_graph(nodes_attributes, - [('placeholder_1_data', 'consumer_1'), - ('placeholder_1_data', 'consumer_2'), - ('placeholder_1_data', 'consumer_3'), - ('consumer_1', 'concat'), - ('consumer_2', 'concat'), - ('consumer_3', 'concat'), + [('placeholder_1', 'placeholder_1_data'), + ('placeholder_1_data', 'reshape_1'), + ('reshape_1_const', 'reshape_1_const_data'), + ('reshape_1_const_data', 'reshape_1'), + ('reshape_1', 'reshape_1_data'), + ('reshape_1_data', 'eltwise_1'), + ('reshape_1_data', 'eltwise_2'), + ('reshape_1_data', 'eltwise_3'), + ('eltwise_1', 'eltwise_1_data'), + ('eltwise_2', 'eltwise_2_data'), + ('eltwise_3', 'eltwise_3_data'), + ('eltwise_1_data', 'concat'), + ('eltwise_2_data', 'concat'), + ('eltwise_3_data', 'concat'), ], - {'placeholder_1_data': {'shape': int64_array([3, 1, 1]), 'value': np.ones([3, 1, 1])} + {'placeholder_1_data': {'shape': int64_array([1, 3]), 'value': np.ones([1, 3])}, + 'reshape_1_const': {'value': int64_array([0]), 'shape': int64_array([1])}, + 'reshape_1_const_data': {'value': int64_array([0]), + 'shape': int64_array([1])}, + 'reshape_1_data': {'shape': int64_array([1, 1, 3])}, }, nodes_with_edges_only=True) - pattern = EltwiseInputReshape() - pattern.find_and_replace_pattern(graph) + normalize_eltwise_inputs(graph) (flag, resp) = compare_graphs(graph, graph_ref, 'concat', check_op_attrs=True) self.assertTrue(flag, resp) - def test5_not_constant(self): + def test6_not_constant(self): # ,--------------->consumer3 ,->consumer3 # data---(new_shape1)-->consumer1 => data----->consumer1 # `-(new_shape1)-->consumer2 `-->consumer2 # graph = build_graph(nodes_attributes, - [('placeholder_1_data', 'consumer_1', {'new_shape': int64_array([1, 3])}), - ('placeholder_1_data', 'consumer_2', {'new_shape': int64_array([1, 3])}), - ('placeholder_1_data', 'consumer_3'), - ('consumer_1', 'concat'), - ('consumer_2', 'concat'), - ('consumer_3', 'concat'), + [('placeholder_1', 'placeholder_1_data'), + ('placeholder_1_data', 'eltwise_1'), + ('placeholder_1_data', 'eltwise_2'), + ('placeholder_1_data', 'eltwise_3'), + ('eltwise_1', 'eltwise_1_data'), + ('eltwise_2', 'eltwise_2_data'), + ('eltwise_3', 'eltwise_3_data'), + ('eltwise_1_data', 'concat'), + ('eltwise_2_data', 'concat'), + ('eltwise_3_data', 'concat'), ], - {'placeholder_1_data': {'shape': int64_array([1, 3])}}, nodes_with_edges_only=True) + {'placeholder_1_data': {'shape': int64_array([1, 3])}, + 'eltwise_1_data': {'shape': int64_array([1, 3])}, + 'eltwise_2_data': {'shape': int64_array([1, 3])}, + 'eltwise_3_data': {'shape': int64_array([1, 3])}, + }, nodes_with_edges_only=True) graph_ref = build_graph(nodes_attributes, - [('placeholder_1_data', 'consumer_1', {'new_shape': int64_array([1, 3])}), - ('placeholder_1_data', 'consumer_2', {'new_shape': int64_array([1, 3])}), - ('placeholder_1_data', 'consumer_3'), - ('consumer_1', 'concat'), - ('consumer_2', 'concat'), - ('consumer_3', 'concat'), + [('placeholder_1', 'placeholder_1_data'), + ('placeholder_1_data', 'eltwise_1'), + ('placeholder_1_data', 'eltwise_2'), + ('placeholder_1_data', 'eltwise_3'), + ('eltwise_1', 'eltwise_1_data'), + ('eltwise_2', 'eltwise_2_data'), + ('eltwise_3', 'eltwise_3_data'), + ('eltwise_1_data', 'concat'), + ('eltwise_2_data', 'concat'), + ('eltwise_3_data', 'concat'), ], {'placeholder_1_data': {'shape': int64_array([1, 3])}}, nodes_with_edges_only=True) - pattern = EltwiseInputReshape() - pattern.find_and_replace_pattern(graph) + normalize_eltwise_inputs(graph) (flag, resp) = compare_graphs(graph, graph_ref, 'concat', check_op_attrs=True) self.assertTrue(flag, resp) + + def test7_axis1_not_constant(self): + # + # data1(1,3,64,64)----. data(1,3,64,64)-------. + # data2(3,64,1)-------->Eltwise-->data(1,3,64,64)=> data(3,64,1)->Unsqueeze(0)->data(1,3,64,1)-->Eltwise->... + # data3(3,1)------' data(3,1)->Unsqueeze(2, 0)->data(1,3,1,1)-' + # + graph = build_graph(nodes_attributes, [ + ('placeholder_1', 'placeholder_1_data'), + ('placeholder_2', 'placeholder_2_data'), + ('placeholder_3', 'placeholder_3_data'), + ('placeholder_1_data', 'eltwise_1'), + ('placeholder_2_data', 'eltwise_1'), + ('placeholder_3_data', 'eltwise_1'), + ('eltwise_1', 'eltwise_1_data') + ], + {'placeholder_1_data': {'shape': np.array([1, 3, 64, 64])}, + 'placeholder_2_data': {'shape': np.array([3, 64, 1])}, + 'placeholder_3_data': {'shape': np.array([3, 1])}, + 'eltwise_1_data': {'shape': np.array([1, 3, 64, 64])}, + 'eltwise_1' : {'axis': 1} + }, nodes_with_edges_only=True) + + graph_ref = build_graph(nodes_attributes, + [ + ('placeholder_1', 'placeholder_1_data'), + ('placeholder_2', 'placeholder_2_data'), + ('placeholder_3', 'placeholder_3_data'), + ('placeholder_1_data', 'eltwise_1'), + ('placeholder_2_data', 'reshape_1'), + ('reshape_1_const', 'reshape_1_const_data'), + ('reshape_1_const_data', 'reshape_1'), + ('placeholder_3_data', 'reshape_2'), + ('reshape_2_const', 'reshape_2_const_data'), + ('reshape_2_const_data', 'reshape_2'), + ('reshape_1', 'reshape_1_data'), + ('reshape_2', 'reshape_2_data'), + ('reshape_1_data', 'eltwise_1'), + ('reshape_2_data', 'eltwise_1'), + ('eltwise_1', 'eltwise_1_data') + ], + {'placeholder_1_data': {'shape': np.array([1, 3, 64, 64])}, + 'placeholder_2_data': {'shape': np.array([3, 64, 1])}, + 'placeholder_3_data': {'shape': np.array([3, 1])}, + 'reshape_1_const': {'value': int64_array([0]), 'shape': int64_array([1])}, + 'reshape_1_const_data': {'value': int64_array([0]), + 'shape': int64_array([1])}, + 'reshape_1_data': {'shape': np.array([1, 3, 64, 1])}, + 'reshape_2_const': {'value': int64_array([2, 0]), 'shape': int64_array([2])}, + 'reshape_2_const_data': {'value': int64_array([2, 0]), + 'shape': int64_array([2])}, + 'reshape_2_data': {'shape': np.array([1, 3, 1, 1])}, + 'eltwise_1_data': {'shape': np.array([1, 3, 64, 64])} + }, nodes_with_edges_only=True) + + normalize_eltwise_inputs(graph) + + (flag, resp) = compare_graphs(graph, graph_ref, 'eltwise_1', check_op_attrs=True) + self.assertTrue(flag, resp) diff --git a/model-optimizer/extensions/middle/fusings.py b/model-optimizer/extensions/middle/fusings.py index 44f1586..40ce7b2 100644 --- a/model-optimizer/extensions/middle/fusings.py +++ b/model-optimizer/extensions/middle/fusings.py @@ -16,7 +16,7 @@ from extensions.front.div import Div from extensions.front.sub import Sub from extensions.middle.AddFakeQuantizeFuse import AddFakeQuantizeFuse -from extensions.middle.EltwiseInputNormalization import EltwiseInputNormalize +from extensions.middle.EltwiseInputReshape import normalize_eltwise_inputs from extensions.middle.MulFakeQuantizeFuse import MulFakeQuantizeFuse from extensions.middle.RemoveRedundantReshapes import RemoveRedundantReshapes @@ -82,7 +82,7 @@ class Fusing(MiddleReplacementPattern): for_graph_and_each_sub_graph_recursively(graph, fuse_mul_add_sequence) for_graph_and_each_sub_graph_recursively(graph, lambda G: G.clean_up()) - EltwiseInputNormalize().find_and_replace_pattern(graph) + normalize_eltwise_inputs(graph) for_graph_and_each_sub_graph_recursively(graph, lambda G: G.clean_up()) # Fusing linear operation to Convolution @@ -96,7 +96,7 @@ class Fusing(MiddleReplacementPattern): for_graph_and_each_sub_graph_recursively(graph, fuse_linear_ops) for_graph_and_each_sub_graph_recursively(graph, lambda G: G.clean_up()) - EltwiseInputNormalize().find_and_replace_pattern(graph) + normalize_eltwise_inputs(graph) for_graph_and_each_sub_graph_recursively(graph, lambda G: G.clean_up()) MarkNodesToFuseUpToFakeQuantize().find_and_replace_pattern(graph) diff --git a/model-optimizer/mo/front/common/partial_infer/eltwise.py b/model-optimizer/mo/front/common/partial_infer/eltwise.py index c111bba..f0f26e3 100644 --- a/model-optimizer/mo/front/common/partial_infer/eltwise.py +++ b/model-optimizer/mo/front/common/partial_infer/eltwise.py @@ -14,17 +14,14 @@ limitations under the License. """ -import networkx as nx import numpy as np from mo.front.common.partial_infer.utils import int64_array -from mo.graph.graph import Node def eltwise_infer(node, op=None, **kwargs): raw_inputs = [(inp, attr) for inp, attr in node.get_sorted_inputs() if 'control_flow_edge' not in attr or not attr['control_flow_edge']] - inputs = [Node(node.graph, inp) for inp, attr in raw_inputs] shapes = [node.graph.node[inp]['shape'] for inp, attr in raw_inputs] values = [node.graph.node[inp]['value'] for inp, attr in raw_inputs] @@ -53,14 +50,6 @@ def eltwise_infer(node, op=None, **kwargs): shapes[id] = new_shape - # Save shape for further transformation that applies this shapes for input nodes - # We set new_shape attribute on edge for given input node - edge_attrs = node.graph.get_edge_data(inputs[id].id, node.id)[0] - - nx.set_edge_attributes(G=node.graph, - values={(inputs[id].id, node.id, 0): new_shape}, - name='new_shape') - # Reshape value to correctly calculate output shape if values[id] is not None: values[id] = np.reshape(values[id], new_shape) diff --git a/model-optimizer/mo/utils/shape.py b/model-optimizer/mo/utils/shape.py index 90ba349..4329eab 100644 --- a/model-optimizer/mo/utils/shape.py +++ b/model-optimizer/mo/utils/shape.py @@ -38,13 +38,13 @@ def get_canonical_axis_index_node(rank: Node, axis: int) -> Node: graph = rank.graph name = rank.soft_get('name', rank.id) if axis < 0: - axis = Const(graph, {'name': name + '/negative_axis', 'value': int64_array([axis])}).create_node() + axis = Const(graph, {'name': name + '/negative_axis', 'value': int64_array(axis)}).create_node() add = Add(graph, {'name': name + '/positive_axis'}).create_node() rank.out_port(0).connect(add.in_port(0)) axis.out_port(0).connect(add.in_port(1)) return add else: - return Const(graph, {'name': name + '/positive_axis', 'value': int64_array([axis])}).create_node() + return Const(graph, {'name': name + '/positive_axis', 'value': int64_array(axis)}).create_node() def get_range_node_of_idxs(rank: Node, begin: int, end: int, @@ -66,20 +66,20 @@ def get_range_node_of_idxs(rank: Node, begin: int, end: int, end_idx = get_canonical_axis_index_node(rank, end) if not include_begin: - const = Const(graph, {'value': int64_array([1]), 'name': name + '/exclude_begin/value'}).create_node() + const = Const(graph, {'value': int64_array(1), 'name': name + '/exclude_begin/value'}).create_node() add = Add(graph, {'name': name + '/exclude_begin'}).create_node() start_idx.out_port(0).connect(add.in_port(0)) const.out_port(0).connect(add.in_port(1)) start_idx = add if include_end: - const = Const(graph, {'value': int64_array([1]), 'name': name + '/including_end/value'}).create_node() + const = Const(graph, {'value': int64_array(1), 'name': name + '/including_end/value'}).create_node() add = Add(graph, {'name': name + '/including_end'}).create_node() end_idx.out_port(0).connect(add.in_port(0)) const.out_port(0).connect(add.in_port(1)) end_idx = add - delta = Const(graph, {'name': name + '/delta', 'value': int64_array([1])}).create_node() + delta = Const(graph, {'name': name + '/delta', 'value': int64_array(1)}).create_node() range_node = Range(graph, {'name': name + '/range_idxs'}).create_node() start_idx.out_port(0).connect(range_node.in_port(0)) diff --git a/ngraph/core/include/ngraph/op/unsqueeze.hpp b/ngraph/core/include/ngraph/op/unsqueeze.hpp index 35cc051..a87f34c 100644 --- a/ngraph/core/include/ngraph/op/unsqueeze.hpp +++ b/ngraph/core/include/ngraph/op/unsqueeze.hpp @@ -21,9 +21,6 @@ #include "ngraph/axis_vector.hpp" #include "ngraph/node.hpp" #include "ngraph/op/op.hpp" -#include "ngraph/op/util/fused_op.hpp" - -NGRAPH_SUPPRESS_DEPRECATED_START namespace ngraph { @@ -31,7 +28,7 @@ namespace ngraph { namespace v0 { - class NGRAPH_API Unsqueeze : public ngraph::op::util::FusedOp + class NGRAPH_API Unsqueeze : public Op { public: NGRAPH_RTTI_DECLARATION; @@ -39,9 +36,7 @@ namespace ngraph Unsqueeze() = default; Unsqueeze(const Output& data, const Output& axes); - virtual void pre_validate_and_infer_types() override; - virtual OutputVector decompose_op() const override; - + void validate_and_infer_types() override; bool visit_attributes(AttributeVisitor& visitor) override; bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override; @@ -55,5 +50,3 @@ namespace ngraph using v0::Unsqueeze; } } - -NGRAPH_SUPPRESS_DEPRECATED_END diff --git a/ngraph/core/src/op/unsqueeze.cpp b/ngraph/core/src/op/unsqueeze.cpp index 03a5963..bb2fb3a 100644 --- a/ngraph/core/src/op/unsqueeze.cpp +++ b/ngraph/core/src/op/unsqueeze.cpp @@ -29,17 +29,15 @@ using namespace std; using namespace ngraph; -NGRAPH_SUPPRESS_DEPRECATED_START - NGRAPH_RTTI_DEFINITION(op::v0::Unsqueeze, "Unsqueeze", 0); -op::Unsqueeze::Unsqueeze(const Output& data, const Output& axes) - : FusedOp({data, axes}) +op::v0::Unsqueeze::Unsqueeze(const Output& data, const Output& axes) + : Op({data, axes}) { constructor_validate_and_infer_types(); } -void op::Unsqueeze::pre_validate_and_infer_types() +void op::v0::Unsqueeze::validate_and_infer_types() { const auto data = input_value(0); auto data_partial_shape = data.get_partial_shape(); @@ -79,24 +77,12 @@ void op::Unsqueeze::pre_validate_and_infer_types() set_output_type(0, get_input_element_type(0), PartialShape{output_shape}); } -OutputVector op::Unsqueeze::decompose_op() const -{ - NODE_VALIDATION_CHECK( - this, - (get_output_partial_shape(0).is_static()), - "output shape was not calculated during pre_validate_and_infer_types. Can not decompose."); - auto data = input_value(0); - auto data_shape = data.get_shape(); - auto output_shape = get_output_shape(0); - return {builder::opset1::reshape(data, output_shape)}; -} - -bool ngraph::op::v0::Unsqueeze::visit_attributes(AttributeVisitor& visitor) +bool op::v0::Unsqueeze::visit_attributes(AttributeVisitor& visitor) { return true; } -shared_ptr op::Unsqueeze::clone_with_new_inputs(const OutputVector& new_args) const +shared_ptr op::v0::Unsqueeze::clone_with_new_inputs(const OutputVector& new_args) const { if (new_args.size() != 2) { diff --git a/ngraph/test/backend/fused_op.in.cpp b/ngraph/test/backend/fused_op.in.cpp index 8aabaaf..38d451a 100644 --- a/ngraph/test/backend/fused_op.in.cpp +++ b/ngraph/test/backend/fused_op.in.cpp @@ -1487,7 +1487,7 @@ NGRAPH_TEST(${BACKEND_NAME}, unsqueeze) auto data_node = make_shared(element::f32, Shape{4, 2}); auto axes_node = make_shared(element::i64, Shape{2}, vector{1, 2}); - auto squeeze = make_shared(data_node, axes_node); + auto squeeze = make_shared(data_node, axes_node); auto function = make_shared(NodeVector{squeeze}, ParameterVector{data_node}); auto test_case = test::TestCase(function); diff --git a/ngraph/test/constant_folding.cpp b/ngraph/test/constant_folding.cpp index 8d860ae..d563b73 100644 --- a/ngraph/test/constant_folding.cpp +++ b/ngraph/test/constant_folding.cpp @@ -189,7 +189,7 @@ TEST(constant_folding, constant_unsqueeze) auto constant = make_shared(element::f32, shape_in, values_in); vector values_axes{2, 3}; auto constant_axes = op::Constant::create(element::i64, axes_shape, values_axes); - auto unsqueeze = make_shared(constant, constant_axes); + auto unsqueeze = make_shared(constant, constant_axes); unsqueeze->set_friendly_name("test"); auto f = make_shared(unsqueeze, ParameterVector{}); @@ -197,7 +197,7 @@ TEST(constant_folding, constant_unsqueeze) pass_manager.register_pass(); pass_manager.run_passes(f); - ASSERT_EQ(count_ops_of_type(f), 0); + ASSERT_EQ(count_ops_of_type(f), 0); ASSERT_EQ(count_ops_of_type(f), 1); auto new_const = diff --git a/ngraph/test/op_is.cpp b/ngraph/test/op_is.cpp index da0387b..5aedef1 100644 --- a/ngraph/test/op_is.cpp +++ b/ngraph/test/op_is.cpp @@ -877,7 +877,7 @@ namespace void op_is_Unsqueeze() { - op::Unsqueeze node; + op::v0::Unsqueeze node; EXPECT_FALSE(op::is_unary_elementwise_arithmetic(&node)); EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node)); EXPECT_FALSE(op::is_binary_elementwise_comparison(&node)); diff --git a/ngraph/test/runtime/opset0_tbl.hpp b/ngraph/test/runtime/opset0_tbl.hpp index 422d698..934fe41 100644 --- a/ngraph/test/runtime/opset0_tbl.hpp +++ b/ngraph/test/runtime/opset0_tbl.hpp @@ -126,5 +126,5 @@ NGRAPH_OP(Tan, ngraph::op) NGRAPH_OP(Tanh, ngraph::op) NGRAPH_OP(TensorIterator, ngraph::op) NGRAPH_OP(Tile, ngraph::op::v0) -NGRAPH_OP(Unsqueeze, ngraph::op) +NGRAPH_OP(Unsqueeze, ngraph::op::v0) NGRAPH_OP(Xor, ngraph::op) diff --git a/ngraph/test/type_prop/unsqueeze.cpp b/ngraph/test/type_prop/unsqueeze.cpp index e9a8025..484a60b 100644 --- a/ngraph/test/type_prop/unsqueeze.cpp +++ b/ngraph/test/type_prop/unsqueeze.cpp @@ -26,7 +26,7 @@ TEST(type_prop, unsqueeze) auto param = make_shared(element::f32, Shape{4, 1, 4, 1, 8}); auto axes_node = make_shared(element::u64, Shape{2}, vector{1, 2}); - auto unsqueeze = make_shared(param, axes_node); + auto unsqueeze = make_shared(param, axes_node); ASSERT_EQ(unsqueeze->get_element_type(), element::f32); ASSERT_EQ(unsqueeze->get_shape(), (Shape{4, 1, 1, 1, 4, 1, 8})); @@ -37,7 +37,7 @@ TEST(type_prop, unsqueeze_dynamic) auto param = make_shared(element::f32, PartialShape::dynamic(5)); auto axes_node = make_shared(element::u64, Shape{2}, vector{1, 2}); - auto unsqueeze = make_shared(param, axes_node); + auto unsqueeze = make_shared(param, axes_node); ASSERT_EQ(unsqueeze->get_element_type(), element::f32); EXPECT_TRUE( -- 2.7.4