From 80e09dfa3b8b5d224d4506732dbf2ed8287c7280 Mon Sep 17 00:00:00 2001 From: Evgenya Stepyreva Date: Fri, 21 Aug 2020 16:27:02 +0300 Subject: [PATCH] [ MO ] Change layout of Shape sub-graphs once (#1875) * [ MO ] Change layout of Shape sub-graphs once --- .../extensions/middle/InsertLayoutPropagationTransposes.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/model-optimizer/extensions/middle/InsertLayoutPropagationTransposes.py b/model-optimizer/extensions/middle/InsertLayoutPropagationTransposes.py index eea7557..50ec7b5 100644 --- a/model-optimizer/extensions/middle/InsertLayoutPropagationTransposes.py +++ b/model-optimizer/extensions/middle/InsertLayoutPropagationTransposes.py @@ -46,12 +46,15 @@ class InsertLayoutPropagationTranspose(MiddleReplacementPattern): 1. The node is marked as 'reinterp_shape' attribute 2. The node is *not* marked as getting input in correct layout (implicitly imply that the input is on port 0) 3. The input shape rank is not less than 4 + 4. Node is not a part of shape sub-graph (layout permutation is handled separately for such a sub-graph) + :param node: node to check :return: result of the check """ return node.has_and_set('reinterp_shape') and \ not is_input_data_in_correct_layout(node, 0) and \ - len(node.in_port(0).data.get_shape()) >= 4 + len(node.in_port(0).data.get_shape()) >= 4 and \ + all([port.data.get_value() is None for port in node.out_ports().values() if not port.disconnected()]) @staticmethod def is_nhwc_to_nchw_transpose_needed(node: Node): @@ -61,12 +64,14 @@ class InsertLayoutPropagationTranspose(MiddleReplacementPattern): 1. The node is marked as 'reinterp_shape' attribute 2. The node is *not* marked as generating output in correct layout (implicitly imply that the output port is 0) 3. The output shape rank is not less than 4 + 4. Node is not a part of shape sub-graph (layout permutation is handled separately for such a sub-graph) :param node: node to check :return: result of the check """ return node.has_and_set('reinterp_shape') and \ not is_output_data_in_correct_layout(node, 0) and \ - len(node.out_port(0).data.get_shape()) >= 4 + len(node.out_port(0).data.get_shape()) >= 4 and \ + all([port.data.get_value() is None for port in node.out_ports().values() if not port.disconnected()]) def find_and_replace_pattern(self, graph: Graph): -- 2.7.4