From cd391389ceb677f250009163ed16def52f2540eb Mon Sep 17 00:00:00 2001 From: Evgenya Stepyreva Date: Fri, 18 Sep 2020 14:42:16 +0300 Subject: [PATCH] [ MO ] Complete weights layout permutation (#2299) * MO TF: FQPerChannel extractor * [ MO ] Complete weights layout permutation * removed deleted file out of BOM * Bring back stashed changes * Skip if no weights permutation * Conditional permutation * Comments --- model-optimizer/automation/package_BOM.txt | 1 - .../front/tf/FakeQuantWithMinMaxVars_ext.py | 18 ++++ .../extensions/middle/ApplyPermutations.py | 14 ++- .../middle/LayoutChangeForConstantShapePaths.py | 10 +- .../middle/MarkSubgraphsWithCorrectLayout.py | 89 ++++++++++++++++ .../extensions/middle/quantize_fuses.py | 15 +-- .../middle/weights_permute_normalizer_test.py | 118 --------------------- .../extensions/middle/wights_permute_normalizer.py | 51 --------- model-optimizer/extensions/ops/gather.py | 2 +- model-optimizer/extensions/ops/transpose.py | 3 +- model-optimizer/mo/graph/perm_inputs.py | 19 ++++ model-optimizer/mo/ops/convolution.py | 6 +- model-optimizer/mo/ops/deconvolution.py | 4 +- 13 files changed, 146 insertions(+), 204 deletions(-) delete mode 100644 model-optimizer/extensions/middle/weights_permute_normalizer_test.py delete mode 100644 model-optimizer/extensions/middle/wights_permute_normalizer.py diff --git a/model-optimizer/automation/package_BOM.txt b/model-optimizer/automation/package_BOM.txt index 9301c42..ae7f4de 100644 --- a/model-optimizer/automation/package_BOM.txt +++ b/model-optimizer/automation/package_BOM.txt @@ -582,7 +582,6 @@ extensions/middle/UnsqueezeTileReshapeBlockToInterpolate.py extensions/middle/UpsampleToResample.py extensions/middle/UselessMerge.py extensions/middle/UselessSplitEraser.py -extensions/middle/wights_permute_normalizer.py extensions/ops/__init__.py extensions/ops/accum.py extensions/ops/activation_ops.py diff --git a/model-optimizer/extensions/front/tf/FakeQuantWithMinMaxVars_ext.py b/model-optimizer/extensions/front/tf/FakeQuantWithMinMaxVars_ext.py index 302b3a8..a9995a9 100644 --- a/model-optimizer/extensions/front/tf/FakeQuantWithMinMaxVars_ext.py +++ b/model-optimizer/extensions/front/tf/FakeQuantWithMinMaxVars_ext.py @@ -34,3 +34,21 @@ class FakeQuantWithMinMaxVarsExtractor(FrontExtractorOp): 'narrow_range': narrow_range, 'num_bits': num_bits}) return cls.enabled + + +class FakeQuantWithMinMaxVarsPerChannelExtractor(FrontExtractorOp): + op = 'FakeQuantWithMinMaxVarsPerChannel' + enabled = True + + @classmethod + def extract(cls, node): + narrow_range = node.pb.attr['narrow_range'].b + num_bits = node.pb.attr['num_bits'].i + levels = 2 ** num_bits - int(narrow_range) + + # we prepare this operation to be converted to FakeQuantize op, + # but input reconnection is needed, so we don't set infer function and type attribute + Op.update_node_stat(node, {'op': 'FakeQuantWithMinMaxVars', 'levels': levels, + 'narrow_range': narrow_range, 'num_bits': num_bits}) + + return cls.enabled diff --git a/model-optimizer/extensions/middle/ApplyPermutations.py b/model-optimizer/extensions/middle/ApplyPermutations.py index a065ea4..afa623f 100644 --- a/model-optimizer/extensions/middle/ApplyPermutations.py +++ b/model-optimizer/extensions/middle/ApplyPermutations.py @@ -24,6 +24,7 @@ from extensions.middle.LayoutChangeForConstantShapePaths import LayoutChangeForC from extensions.middle.pass_separator import PostMiddleStart from mo.front.common.partial_infer.utils import int64_array from mo.graph.graph import Graph, Node +from mo.graph.perm_inputs import get_node_with_permutation from mo.graph.port import Port from mo.middle.replacement import MiddleReplacementPattern from mo.utils.error import Error @@ -47,6 +48,7 @@ class ApplyPermutation(MiddleReplacementPattern): self.permute_op_nodes_attrs(graph) self.shape_of_sub_graph_reinference(graph) self.permute_input_data(graph) + graph.graph['layout'] = 'NCHW' @staticmethod def merge_nodes_permutations(graph: Graph): @@ -94,7 +96,8 @@ class ApplyPermutation(MiddleReplacementPattern): def permute_data_nodes_attrs(graph: Graph): # Iterate over all data nodes and apply permutation if exists for node in graph.get_data_nodes(): - if not node.has_valid('permutation'): + if not node.has_valid('permutation') or \ + all([attrs.get('input_permutation', False) for u, v, attrs in graph.out_edges(node.id, data=True)]): continue if len( @@ -126,8 +129,6 @@ class ApplyPermutation(MiddleReplacementPattern): @staticmethod def permute_input_data(graph: Graph): - if graph.graph['layout'] != 'NHWC': - return for node in graph.get_op_nodes(): input_permutations = [(in_port, edge_attrs['input_permutation']) for in_port, edge_attrs in node.in_edges().items() if edge_attrs.get('input_permutation') is not None] @@ -136,9 +137,12 @@ class ApplyPermutation(MiddleReplacementPattern): direction, port = port_info.split(':') port = int(port) port_to_check = node.in_port(port) if direction == 'input' else node.out_port(port) - if not is_input_data_in_correct_layout(node, in_port) and len(port_to_check.data.get_shape()) >= 4: + permutation_data_node = get_node_with_permutation(node, port_info) + + if permutation_data_node.has_and_set('permutation') and \ + not is_input_data_in_correct_layout(node, in_port) and \ + len(port_to_check.data.get_shape()) >= 4: permutation(node, port_info, in_port) - graph.graph['layout'] = 'NCHW' @staticmethod def shape_of_sub_graph_reinference(graph: Graph): diff --git a/model-optimizer/extensions/middle/LayoutChangeForConstantShapePaths.py b/model-optimizer/extensions/middle/LayoutChangeForConstantShapePaths.py index 9748acc..a92e74d 100644 --- a/model-optimizer/extensions/middle/LayoutChangeForConstantShapePaths.py +++ b/model-optimizer/extensions/middle/LayoutChangeForConstantShapePaths.py @@ -46,13 +46,6 @@ class LayoutChangeForConstantShapePaths(MiddleReplacementPattern): next_in_ports.update(out_port.get_destinations()) return next_in_ports - def mark_node_as_in_correct_layout_by_in_port(self, in_port): - next_in_ports = self.get_next_in_ports(in_port) - in_port.__setattr__('input_permutation', None) - mark_input_as_in_correct_layout(in_port.node, in_port.idx) - for port in next_in_ports: - mark_output_as_in_correct_layout(port.get_source().node, port.get_source().idx) - def find_shape_subgraph_endpoints(self, out_ports: List[Port], visited: set = None, action: callable = None) -> Set[Port]: """ @@ -108,8 +101,7 @@ class LayoutChangeForConstantShapePaths(MiddleReplacementPattern): shape.out_port(0).get_connection().insert_node(gather) # 2. Inserting Gather/Transpose to NC* format - shape_sub_graph_end_points = self.find_shape_subgraph_endpoints( - [shape.out_port(0) for shape in shape_ops], None, self.mark_node_as_in_correct_layout_by_in_port) + shape_sub_graph_end_points = self.find_shape_subgraph_endpoints([shape.out_port(0) for shape in shape_ops]) for in_port in shape_sub_graph_end_points: name = in_port.node.soft_get('name', in_port.node.id) shape = in_port.data.get_shape() diff --git a/model-optimizer/extensions/middle/MarkSubgraphsWithCorrectLayout.py b/model-optimizer/extensions/middle/MarkSubgraphsWithCorrectLayout.py index 34c4bff..719fc72 100644 --- a/model-optimizer/extensions/middle/MarkSubgraphsWithCorrectLayout.py +++ b/model-optimizer/extensions/middle/MarkSubgraphsWithCorrectLayout.py @@ -16,10 +16,14 @@ import logging as log from collections import deque +from typing import Set + from extensions.middle.InsertLayoutPropagationTransposes import InsertLayoutPropagationTranspose, \ mark_as_correct_data_layout +from extensions.middle.LayoutChangeForConstantShapePaths import LayoutChangeForConstantShapePaths from extensions.middle.pass_separator import PostMiddleStart from mo.graph.graph import Graph, Node +from mo.graph.port import Port from mo.middle.replacement import MiddleReplacementPattern @@ -121,3 +125,88 @@ class MarkSubGraphsWithCorrectLayout(MiddleReplacementPattern): for visited_node in marked_nodes: mark_as_correct_data_layout(visited_node) visited_node['nchw_layout'] = True + + for node in self.get_ports_and_nodes_on_weights(graph)[1]: + mark_as_correct_data_layout(node) + node['nchw_layout'] = True + if node.soft_get('type') == 'Const': # WA for Const op deletion during clean_up + node.out_node()['nchw_layout'] = True + + for node in self.get_ports_and_nodes_on_shape_subgraphs(graph)[1]: + mark_as_correct_data_layout(node) + node['nchw_layout'] = True + if node.soft_get('type') == 'Const': # WA for Const op deletion during clean_up + node.out_node()['nchw_layout'] = True + + @staticmethod + def walk_up_from_in_ports_to_out_ports(in_ports: Set[Port], out_ports: Set[Port], + visited_ports: Set[Port] = None, visited_nodes: Set[Node] = None): + """" + Returns all intermediate ports and nodes of such a sub-graph: + + out_ports + | | + \/ \/ + . . . + | | + \/ \/ + in_ports + """ + if visited_ports is None: + visited_ports = set() + if visited_nodes is None: + visited_nodes = set() + + deque_of_in_ports = deque(in_ports) + while len(deque_of_in_ports): + in_port = deque_of_in_ports.popleft() + source_node = in_port.get_source().node + if in_port in visited_ports: # do not check visited_nodes as search is based on ports + continue + visited_ports.update({in_port, in_port.get_source()}) + if in_port.get_source() in out_ports: # reached source marked to stop the search + if not len(in_port.get_source().node.in_ports()): # for Constants and Parameters to be visited + visited_nodes.add(in_port.get_source().node) + continue + deque_of_in_ports.extend([port for port in source_node.in_ports().values() if not port.disconnected()]) + visited_nodes.add(source_node) + return visited_ports, visited_nodes + + @staticmethod + def get_ports_and_nodes_on_weights(graph): + get_weights_port_index = lambda node: node.weights_index if node.has_valid('weights_index') else 1 + weighted_layer_type_to_in_weights_port = { + 'Convolution': get_weights_port_index, + 'DeformableConvolution': get_weights_port_index, + 'Deconvolution': get_weights_port_index, + 'BinaryConvolution': get_weights_port_index, + } + nodes = graph.get_op_nodes() + weighted_types = list(weighted_layer_type_to_in_weights_port.keys()) + + # collect all input ports with weights + weight_ports = set() + start_ports = set() + for node in nodes: + node_type = node.soft_get('type', 'unknown') + if node_type not in weighted_types: + if node_type in ['Const', 'Parameter', 'ShapeOf']: + start_ports.add(node.out_port(0)) + continue + weight_port_idx = weighted_layer_type_to_in_weights_port[node_type](node) + assert node.is_in_port_connected(weight_port_idx), \ + 'Unexpected port configuration of {} node with name=`{}`'.format(node_type, + node.soft_get('name', node.id)) + weight_ports.add(node.in_port(weight_port_idx)) + + # collect all sub-graphs that start with Constant/Parameter/ShapeOf and end at in_port as weights + ports, nodes = MarkSubGraphsWithCorrectLayout.walk_up_from_in_ports_to_out_ports(weight_ports, start_ports) + return ports, nodes + + @staticmethod + def get_ports_and_nodes_on_shape_subgraphs(graph): + shape_sources = {shape_of.out_port(0) for shape_of in graph.get_op_nodes(type='ShapeOf')} + end_points = LayoutChangeForConstantShapePaths().find_shape_subgraph_endpoints( + [shape.out_port(0) for shape in graph.get_op_nodes(type='ShapeOf')]) + ports, nodes = MarkSubGraphsWithCorrectLayout.walk_up_from_in_ports_to_out_ports(end_points, shape_sources) + return ports, nodes diff --git a/model-optimizer/extensions/middle/quantize_fuses.py b/model-optimizer/extensions/middle/quantize_fuses.py index d479bc1..26f8876 100644 --- a/model-optimizer/extensions/middle/quantize_fuses.py +++ b/model-optimizer/extensions/middle/quantize_fuses.py @@ -124,17 +124,4 @@ class FakeQuantizeFuse(MiddleReplacementPattern): port.get_source().connect(fuse_node_duplicate.in_port(idx)) fuse_node_duplicate.infer(fuse_node_duplicate) - first_port_fusion = False - - if 'permutation' in quantize_node.in_edge(0): - permutation = quantize_node.in_edge(0)['permutation'] - if permutation is None: - continue - - perm_rank = permutation.perm.size - - if not all([quantize_node.in_port(i).data.get_shape().size == perm_rank for i in range(1, 5)]): - continue - - for i in range(1, 5): - quantize_node.in_edge(i)['permutation'] = permutation + first_port_fusion = False \ No newline at end of file diff --git a/model-optimizer/extensions/middle/weights_permute_normalizer_test.py b/model-optimizer/extensions/middle/weights_permute_normalizer_test.py deleted file mode 100644 index bfd753a..0000000 --- a/model-optimizer/extensions/middle/weights_permute_normalizer_test.py +++ /dev/null @@ -1,118 +0,0 @@ -""" - Copyright (C) 2018-2020 Intel Corporation - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -""" -import unittest - -from extensions.middle.wights_permute_normalizer import WeightsPermuteNormalizer -from mo.graph.graph import Node -from mo.utils.unittest.graph import build_graph - -nodes_attributes = { - 'placeholder': {'type': 'Placeholder', 'kind': 'op', 'op': 'Placeholder'}, - 'placeholder_data': {'value': None, 'shape': None, 'kind': 'data'}, - - 'quantize': {'type': 'FakeQuantize', 'kind': 'op', 'op': 'FakeQuantize'}, - 'quantize_data': {'value': None, 'shape': None, 'kind': 'data'}, - - 'const_1': {'type': 'Const', 'kind': 'op', 'op': 'Const'}, - 'const_1_data': {'value': None, 'shape': None, 'kind': 'data'}, - - 'const_2': {'type': 'Const', 'kind': 'op', 'op': 'Const'}, - 'const_2_data': {'value': None, 'shape': None, 'kind': 'data'}, - - 'const_3': {'type': 'Const', 'kind': 'op', 'op': 'Const'}, - 'const_3_data': {'value': None, 'shape': None, 'kind': 'data'}, - - 'const_4': {'type': 'Const', 'kind': 'op', 'op': 'Const'}, - 'const_4_data': {'value': None, 'shape': None, 'kind': 'data'}, - - 'const_5': {'type': 'Const', 'kind': 'op', 'op': 'Const'}, - 'const_5_data': {'value': None, 'shape': None, 'kind': 'data'}, - - 'conv_1': {'type': 'Convolution', 'kind': 'op', 'op': 'Conv2D', 'layout': 'NHWC'}, - 'conv_1_w': {'value': None, 'shape': None, 'kind': 'data'}, - 'conv_1_b': {'value': None, 'shape': None, 'kind': 'data'}, - 'const_conv_1_w': {'value': None, 'shape': None, 'kind': 'op'}, - 'const_conv_1_b': {'value': None, 'shape': None, 'kind': 'op'}, - 'conv_1_data': {'value': None, 'shape': None, 'kind': 'data'}, -} - - -class WeightNormalizationTests(unittest.TestCase): - def test_normalize_weights_test1(self): - # FakeQuantize---,->Conv - # Placeholder--' - graph = build_graph(nodes_attributes, - [('placeholder', 'placeholder_data'), - ('const_1', 'const_1_data'), - ('const_2', 'const_2_data'), - ('const_3', 'const_3_data'), - ('const_4', 'const_4_data'), - ('const_5', 'const_5_data'), - ('quantize', 'quantize_data'), - ('conv_1', 'conv_1_data'), - ('const_1_data', 'quantize'), - ('const_2_data', 'quantize'), - ('const_3_data', 'quantize'), - ('const_4_data', 'quantize'), - ('const_5_data', 'quantize'), - ('placeholder_data', 'conv_1'), - ('quantize_data', 'conv_1', {'in': 1, 'permutation': "[3, 2, 0, 1]"}), - ], - {}, - nodes_with_edges_only=True - ) - - pattern = WeightsPermuteNormalizer() - pattern.find_and_replace_pattern(graph) - - conv = Node(graph, 'conv_1') - quantize = Node(graph, 'quantize') - - self.assertTrue('permutation' in conv.in_edge(1) and conv.in_edge(1)['permutation'] == "[3, 2, 0, 1]") - self.assertTrue('permutation' in quantize.in_edge(0) and quantize.in_edge(0)['permutation'] == "[3, 2, 0, 1]") - - def test_normalize_weights_test2(self): - # Quantize---,->Conv - # Placeholder--' - graph = build_graph(nodes_attributes, - [('placeholder', 'placeholder_data'), - ('const_1', 'const_1_data'), - ('const_2', 'const_2_data'), - ('const_3', 'const_3_data'), - ('const_4', 'const_4_data'), - ('const_5', 'const_5_data'), - ('quantize', 'quantize_data'), - ('conv_1', 'conv_1_data'), - ('const_1_data', 'quantize'), - ('const_2_data', 'quantize'), - ('const_3_data', 'quantize'), - ('const_4_data', 'quantize'), - ('const_5_data', 'quantize'), - ('quantize_data', 'conv_1', {'in': 0}), - ('conv_1_w', 'conv_1', {'in': 1}), - ], - {}, - nodes_with_edges_only=True - ) - - pattern = WeightsPermuteNormalizer() - pattern.find_and_replace_pattern(graph) - - conv = Node(graph, 'conv_1') - quantize = Node(graph, 'quantize') - - self.assertTrue('permutation' not in conv.in_edge(1)) - self.assertTrue('permutation' not in quantize.in_edge(0)) diff --git a/model-optimizer/extensions/middle/wights_permute_normalizer.py b/model-optimizer/extensions/middle/wights_permute_normalizer.py deleted file mode 100644 index a9dd2b0..0000000 --- a/model-optimizer/extensions/middle/wights_permute_normalizer.py +++ /dev/null @@ -1,51 +0,0 @@ -""" - Copyright (C) 2018-2020 Intel Corporation - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -""" - -from mo.graph.graph import Graph -from mo.middle.replacement import MiddleReplacementPattern - - -class WeightsPermuteNormalizer(MiddleReplacementPattern): - """ - We propagate PermuteAttr from weights port of Convolution and MatMul to real Const that contains it - """ - enabled = True - - @staticmethod - def pattern(): - return dict( - nodes=[ - ('const_data', dict(kind='data')), - ('const', dict(type='Const')), - ('quantize', dict(type='FakeQuantize')), - ('quantize_data', dict(kind='data')), - ('conv', dict(type=lambda type: type in ['Convolution', 'MatMul'])), - ], - edges=[ - ('const', 'const_data'), - ('const_data', 'quantize', {'in': 0}), - ('quantize', 'quantize_data'), - ('quantize_data', 'conv', {'in': 1}), - ] - ) - - def replace_pattern(self, graph: Graph, match: dict): - conv = match['conv'] - if 1 not in conv.in_edges() or 'permutation' not in conv.in_edge(1): - return - - perm = conv.in_edge(1)['permutation'] - match['quantize'].in_port(0).permutation = perm diff --git a/model-optimizer/extensions/ops/gather.py b/model-optimizer/extensions/ops/gather.py index 6c444cb..ea5410e 100644 --- a/model-optimizer/extensions/ops/gather.py +++ b/model-optimizer/extensions/ops/gather.py @@ -60,7 +60,7 @@ class Gather(Op): assert axis is not None axis = get_canonical_axis_index(data_shape, axis) - # we import PermuteInputs because it uses Gather inside and we have recursive imports + # we import PermuteInputs locally because it uses Gather inside and we have recursive imports from mo.graph.perm_inputs import PermuteInputs PermuteInputs().set_input_permutation(node.in_node(1), node, 'input:0', 'axis') diff --git a/model-optimizer/extensions/ops/transpose.py b/model-optimizer/extensions/ops/transpose.py index 930a57b..6cfe3a4 100644 --- a/model-optimizer/extensions/ops/transpose.py +++ b/model-optimizer/extensions/ops/transpose.py @@ -17,7 +17,6 @@ import numpy as np from mo.graph.graph import Graph -from mo.graph.perm_inputs import PermuteInputs from mo.ops.op import Op @@ -48,6 +47,8 @@ class Transpose(Op): 'Cannot infer `{}` due to both order and reverse_order was set'.format(node.soft_get('name')) order = np.arange(len(input_shape))[::-1] # Reverse order else: + # we import PermuteInputs locally because it uses Transpose inside and we have recursive imports + from mo.graph.perm_inputs import PermuteInputs assert len(connected_ports) == 2 and 0 in in_ports and 1 in in_ports, \ "{} node `{}` should have 2 input ports, where 0-input is a data input and 1-input represents " \ "Transpose `order`".format(node.op, node.id) diff --git a/model-optimizer/mo/graph/perm_inputs.py b/model-optimizer/mo/graph/perm_inputs.py index 1645799..b2983ce 100644 --- a/model-optimizer/mo/graph/perm_inputs.py +++ b/model-optimizer/mo/graph/perm_inputs.py @@ -16,6 +16,7 @@ import networkx as nx from extensions.ops.gather import Gather +from extensions.ops.transpose import Transpose from mo.front.common.partial_infer.utils import int64_array from mo.graph.graph import Node from mo.ops.const import Const @@ -153,6 +154,23 @@ def shape(op_node: Node, port_info: str, input_port: int): op_node.infer(op_node) +def transpose(op_node: Node, port_info: str, input_port: int): + graph = op_node.graph + permutation_data_node = get_node_with_permutation(op_node, port_info) + assert permutation_data_node.has_and_set('permutation'), \ + 'Data node "{}" does not have permutation for node {}, port_info "{}".'.format( + permutation_data_node.id, op_node.id, port_info) + permutation = permutation_data_node.permutation + if len(permutation.perm) == 0: + return + + transpose_name = op_node.soft_get('name', op_node.id) + '/Transpose' + from mo.front.tf.graph_utils import create_op_with_const_inputs # avoiding recursive imports + transpose = create_op_with_const_inputs( + graph, Transpose, {1: permutation.perm}, {'name': transpose_name, 'override_output_shape': True}) + op_node.in_port(input_port).get_connection().insert_node(transpose) + + class PermuteInputs: common_inv_permutation = lambda node, port_info, input_port: axis(node, port_info, input_port) @@ -160,6 +178,7 @@ class PermuteInputs: 'axis': common_inv_permutation, 'order': lambda node, port_info, input_port: order(node, port_info, input_port), 'shape': lambda node, port_info, input_port: shape(node, port_info, input_port), + 'transpose': lambda node, port_info, input_port: transpose(node, port_info, input_port), } def set_input_permutation(self, node1: Node, node2: Node, port_info: str, permutation_rule: str): diff --git a/model-optimizer/mo/ops/convolution.py b/model-optimizer/mo/ops/convolution.py index 95471b0..55823c5 100644 --- a/model-optimizer/mo/ops/convolution.py +++ b/model-optimizer/mo/ops/convolution.py @@ -22,6 +22,7 @@ from mo.front.common.partial_infer.utils import int64_array, float_array, mark_i tf_window_op_pad_infer from mo.front.onnx.extractors.utils import get_backend_pad from mo.graph.graph import Node, Graph +from mo.graph.perm_inputs import PermuteInputs from mo.ops.op import Op, PermuteAttrs from mo.utils.error import Error @@ -264,5 +265,6 @@ class Convolution(Op): ('output_feature_channel', 'input:{}'.format(weights_index)), ]) - PermuteAttrs.set_permutation(node.in_node(weights_index), node, - node.get_weights_permute if node.has_valid('get_weights_permute') else None) + PermuteAttrs.set_permutation(node.in_node(weights_index), node, node.soft_get('get_weights_permute', None)) + PermuteInputs().set_input_permutation( + node.in_node(weights_index), node, 'input:{}'.format(weights_index), 'transpose') diff --git a/model-optimizer/mo/ops/deconvolution.py b/model-optimizer/mo/ops/deconvolution.py index c4982b3..f598fd9 100644 --- a/model-optimizer/mo/ops/deconvolution.py +++ b/model-optimizer/mo/ops/deconvolution.py @@ -113,8 +113,8 @@ class Deconvolution(Op): ('output_feature_channel', 'input:1'), ]) - PermuteAttrs.set_permutation(node.in_node(1), node, - node.get_weights_permute if node.has_valid('get_weights_permute') else None) + PermuteAttrs.set_permutation(node.in_node(1), node, node.soft_get('get_weights_permute', None)) + PermuteInputs().set_input_permutation(node.in_node(1), node, 'input:1', 'transpose') PermuteInputs().set_input_permutation(node.in_node(2), node, 'input:0', 'shape') node['force_precision_in_ports'] = {2: 'int64'} -- 2.7.4