"""
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
limitations under the License.
"""
-import numpy as np
-import networkx as nx
+from copy import deepcopy
+
import logging as log
+import numpy as np
+from extensions.middle.SliceConverter import ConvertSlice
from extensions.ops.splitv import SplitV
-from mo.graph.graph import Node
+from mo.front.common.partial_infer.utils import int64_array
+from mo.graph.graph import Node, Graph, add_opoutput
+from mo.middle.replacement import MiddleReplacementPattern
from mo.ops.op import Op
from mo.ops.reshape import Reshape
-from mo.middle.replacement import MiddleReplacementPattern
-from extensions.middle.SliceConverter import ConvertSlice
+
class ConvertGroupedStridedSlice(MiddleReplacementPattern):
"""
def run_after(self):
return [ConvertSlice]
- def find_and_replace_pattern(self, graph: nx.MultiDiGraph):
+ def run_before(self):
+ from extensions.middle.pass_separator import MiddleFinish
+ return [MiddleFinish]
+
+ def find_and_replace_pattern(self, graph: Graph):
# Iterate over all data nodes and find all with >= 1 consumers
data_nodes = [Node(graph, node) for node in graph.node if Node(graph, node).kind == 'data']
for input_data in data_nodes:
input_shape = np.array(input_data.shape)
# Get all StridedSlice consumers
- out_nodes = [node for node in input_data.out_nodes() if node.op == 'StridedSlice']
+ out_nodes = [node for node in input_data.out_nodes() if node.op == 'StridedSlice' and node.in_node(0).name == input_data.name]
if len(out_nodes) < 1:
continue
valid_for_replacement = True
+ for node in out_nodes:
+ if len(node.slices) != len(out_nodes[0].slices):
+ valid_for_replacement = False
+
# Detect dimension for splitting
split_channel_dim = None
for dim_id, s in enumerate(out_nodes[0].slices):
# split_dims contains tuples with split range and output data node
split_dims = []
for out_id, node in enumerate(out_nodes):
- # Check that StridedSlice op has no shrink_axis_mask attribute
- if not np.all([x == False for x in node.shrink_axis_mask]):
- valid_for_replacement = False
# Check that StridedSlice op has stride eq 1 and splits only feature channel
for id, s in enumerate(node.slices):
l, r, stride = s.start, s.stop, s.step
# Check feature split intersection
final_data_nodes_list = []
- sorted_split_dims = sorted(split_dims)
+ sorted_split_dims = sorted(split_dims, key=lambda item: (item[0], item[1]))
+
+ # check if we have similar StridedSlice operations with different outputs
+ prev_sd = sorted_split_dims[0]
+ to_remove = []
+ for i in range(1, len(sorted_split_dims)):
+ if sorted_split_dims[i][0] == prev_sd[0] and sorted_split_dims[i][1] == prev_sd[1] and sorted_split_dims[i][2].name != prev_sd[2].name:
+ cur_node = sorted_split_dims[i][2]
+ for out in cur_node.out_nodes():
+ attrs = deepcopy(graph.get_edge_data(cur_node.id, out.id)[0])
+ graph.remove_edge(cur_node.id, out.id)
+ graph.add_edge(prev_sd[2].id, out.id, **attrs)
+ to_remove.append(i)
+
+ for ind in reversed(to_remove):
+ sorted_split_dims.pop(ind)
+
size_splits = []
prev_r = 0
for l, r, out in sorted_split_dims:
shape = np.array(input_shape)
size_splits.append(l - prev_r)
shape[split_channel_dim] = l - prev_r
- data_node = Op._create_data_node(graph, 'fake_data', {'shape': shape, 'is_output': True})
+ data_node = Op._create_data_node(graph, 'fake_data', {'shape': shape})
+ add_opoutput(graph, data_node.id, 0, False)
final_data_nodes_list.append(data_node)
-
prev_r = r
size_splits.append(r - l)
final_data_nodes_list.append(out)
shape = input_shape.copy()
shape[split_channel_dim] = input_shape[split_channel_dim] - prev_r
size_splits.append(input_shape[split_channel_dim] - prev_r)
- data_node = Op._create_data_node(graph, 'fake_data', {'shape': shape, 'is_output': True})
+ data_node = Op._create_data_node(graph, 'fake_data', {'shape': shape})
+ add_opoutput(graph, data_node.id, 0, False)
final_data_nodes_list.append(data_node)
if not valid_for_replacement:
continue
+ for node in out_nodes:
+ if not np.all([x == 0 for x in node.shrink_axis_mask]):
+ out_node = node.out_node()
+ if np.any(node['shrink_axis_mask']):
+ self.add_reshape_for_shrink(graph, node)
+ if np.any(node['new_axis_mask']):
+ self.add_reshape_for_new(graph, node)
+
+ for i in range(len(final_data_nodes_list)):
+ if final_data_nodes_list[i].name == out_node.name:
+ final_data_nodes_list[i] = node.out_node()
+ break
+
# Insert Split layer and remove old StridedSlice layers
# 1. Remove connections from input_data to StridedSlice ops
out_data_nodes = []
# 2. Create Split layer and reorder outputs
split = SplitV(graph, dict(name=name_for_future_split + "/Split", axis=split_channel_dim,
- size_splits=size_splits))
+ size_splits=size_splits, out_ports_count=len(size_splits)))
split.create_node_with_data(inputs=[input_data], data_nodes=final_data_nodes_list)
+
+ @staticmethod
+ def add_reshape_for_shrink(graph: Graph, ss_node):
+ # add Reshape for shrink_axis_mask
+ log.info("StridedSlice op with shrink mask '{}' has been detected".format(ss_node.id))
+ node = ss_node
+
+ if len(node.in_nodes()) != 4 or len(node.out_nodes()) != 1:
+ return
+
+ shape_out = node.out_node().shape
+ dim = shape_out.copy()
+ ss_shape = []
+ k = 0
+
+ # Don't permute reshape if channels were squeezed
+ dont_permute = False
+ if graph.graph['layout'] == 'NHWC' and node['shrink_axis_mask'][-1] == 1:
+ dont_permute = True
+
+ for i in range(0, len(node['shrink_axis_mask'])):
+ if not node['shrink_axis_mask'][i]:
+ ss_shape.append(shape_out[k])
+ k = k + 1
+ else:
+ node['shrink_axis_mask'][i] = 0
+ ss_shape.append(1)
+
+ out_node = node.out_node(0)
+
+ # insert data node for StridedSlice
+ data_node = Op._create_data_node(graph, node.name + "/Reshape_shrink_data", {'shape': int64_array(ss_shape)})
+ attrs = deepcopy(graph.get_edge_data(node.id, out_node.id)[0])
+ graph.remove_edge(node.id, out_node.id)
+ graph.add_edge(node.id, data_node.id, **attrs)
+
+ # insert Reshape
+ if dont_permute:
+ reshape = Reshape(graph, dict(name=node.name + "/Reshape_shrink",
+ dim=np.array(dim, dtype=np.int64), nchw_layout=True))
+ reshape_data_node = reshape.create_node_with_data([data_node], reshape.attrs,
+ data_nodes=[out_node])
+ reshape_data_node['nchw_layout'] = True
+ else:
+ reshape = Reshape(graph, dict(name=node.name + "/Reshape_shrink",
+ dim=np.array(dim, dtype=np.int64)))
+ reshape_data_node = reshape.create_node_with_data([data_node], reshape.attrs,
+ data_nodes=[out_node])
+
+ @staticmethod
+ def add_reshape_for_new(graph: Graph, ss_node):
+ log.info("StridedSlice op with new axis mask '{}' has been detected".format(ss_node.id))
+ node = ss_node
+
+ if len(node.in_nodes()) != 4 or len(node.out_nodes()) != 1:
+ return
+
+ shape_out = node.out_node().shape
+ dim = shape_out.copy()
+ ss_shape = []
+ for i in range(0, len(node['new_axis_mask'])):
+ if not node['new_axis_mask'][i]:
+ ss_shape.append(shape_out[i])
+ else:
+ node['new_axis_mask'][i] = 0
+
+ out_node = node.out_node(0)
+ # insert data node for StridedSlice
+ data_node = Op._create_data_node(graph, node.name + "/Reshape_new_data", {'shape': ss_shape})
+ attrs = deepcopy(graph.get_edge_data(node.id, out_node.id)[0])
+ graph.remove_edge(node.id, out_node.id)
+ graph.add_edge(node.id, data_node.id, **attrs)
+
+ # insert Reshape
+ reshape = Reshape(graph, dict(name=node.name + "/Reshape_new",
+ dim=np.array(dim, dtype=np.int64)))
+ reshape.create_node_with_data([data_node], reshape.attrs, data_nodes=[out_node])