"""
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
import numpy as np
from extensions.middle.ConvertGroupedStridedSlice import ConvertGroupedStridedSlice
+from mo.graph.graph import Node
from mo.utils.unittest.graph import build_graph, compare_graphs
nodes_attributes = {
'placeholder_1': {'type': 'Placeholder', 'kind': 'op', 'op': 'Placeholder'},
'placeholder_1_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
+ 'placeholder_2': {'type': 'Placeholder', 'kind': 'op', 'op': 'Placeholder'},
+ 'placeholder_2_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
+ 'placeholder_begin_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
+ 'placeholder_end_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
+ 'placeholder_stride_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
# StridedSlice layers
'sslice_1': {'type': None, 'kind': 'op', 'op': 'StridedSlice', 'slices': None,
- 'shrink_axis_mask': np.array([False, False, False, False])},
+ 'shrink_axis_mask': np.array([0, 0, 0, 0])},
'sslice_1_data': {'value': None, 'shape': None, 'kind': 'data'},
'sslice_2': {'type': None, 'kind': 'op', 'op': 'StridedSlice', 'slices': None,
- 'shrink_axis_mask': np.array([False, False, False, False])},
+ 'shrink_axis_mask': np.array([0, 0, 0, 0])},
'sslice_2_data': {'value': None, 'shape': None, 'kind': 'data'},
'sslice_3': {'type': None, 'kind': 'op', 'op': 'StridedSlice', 'slices': None,
- 'shrink_axis_mask': np.array([False, False, False, False])},
+ 'shrink_axis_mask': np.array([0, 0, 0, 0])},
'sslice_3_data': {'value': None, 'shape': None, 'kind': 'data'},
# Split layer
'split_1': {'type': 'Split', 'kind': 'op', 'op': 'SplitV'},
# Concat1 operation
'concat_1': {'type': 'Concat', 'kind': 'op', 'op': 'Concat'},
'concat_1_data': {'value': None, 'shape': None, 'kind': 'data'},
+ 'op_output': {'kind': 'op', 'op': 'OpOutput'},
+ 'op_output_1': {'kind': 'op', 'op': 'OpOutput'},
+ 'op_output_2': {'kind': 'op', 'op': 'OpOutput'},
+ # Reshape layer
+ 'sslice_1/Reshape_shrink': {'type': 'Reshape', 'value': None, 'kind': 'op', 'op': 'Reshape'},
+ 'sslice_1/Reshape_shrink_data': {'value': None, 'shape': None, 'kind': 'data'},
+ 'sslice_2/Reshape_shrink': {'type': 'Reshape', 'value': None, 'kind': 'op', 'op': 'Reshape'},
+ 'sslice_2/Reshape_shrink_data': {'value': None, 'shape': None, 'kind': 'data'},
+ 'sslice_2/Reshape_new': {'type': 'Reshape', 'value': None, 'kind': 'op', 'op': 'Reshape'},
+ 'sslice_2/Reshape_new_data': {'value': None, 'shape': None, 'kind': 'data'},
}
('sslice_1_data', 'concat_1'),
('sslice_2_data', 'concat_1'),
('sslice_3_data', 'concat_1'),
- ('concat_1', 'concat_1_data')
+ ('concat_1', 'concat_1_data'),
+ ('concat_1_data', 'op_output')
],
{'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
[slice(0, 1, 1), slice(0, 227, 1), slice(0, 227, 1), slice(36, 54, 1)])},
'sslice_3_data': {'shape': np.array([1, 227, 227, 18])},
- 'concat_1_data': {'shape': np.array([1, 227, 227, 54]), 'is_output': True},
+ 'concat_1_data': {'shape': np.array([1, 227, 227, 54])},
})
graph.graph['layout'] = 'NHWC'
('split_1_data', 'concat_1'),
('split_2_data', 'concat_1'),
('split_3_data', 'concat_1'),
- ('concat_1', 'concat_1_data')
+ ('concat_1', 'concat_1_data'),
+ ('concat_1_data', 'op_output')
+
],
{'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
'split_1': {'axis': 3},
'split_1_data': {'shape': np.array([1, 227, 227, 18])},
'split_2_data': {'shape': np.array([1, 227, 227, 18])},
'split_3_data': {'shape': np.array([1, 227, 227, 18])},
- 'concat_1_data': {'shape': np.array([1, 227, 227, 54]), 'is_output': True},
+ 'concat_1_data': {'shape': np.array([1, 227, 227, 54])},
})
pattern = ConvertGroupedStridedSlice()
('sslice_1_data', 'concat_1'),
('sslice_2_data', 'concat_1'),
('sslice_3_data', 'concat_1'),
- ('concat_1', 'concat_1_data')
+ ('concat_1', 'concat_1_data'),
+ ('concat_1_data', 'op_output')
],
{'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
[slice(0, 1, 1), slice(0, 227, 1), slice(0, 227, 1), slice(0, 19, 1)])},
'sslice_3_data': {'shape': np.array([1, 227, 227, 19])},
- 'concat_1_data': {'shape': np.array([1, 227, 227, 54]), 'is_output': True},
+ 'concat_1_data': {'shape': np.array([1, 227, 227, 54])},
})
graph.graph['layout'] = 'NHWC'
('split_1_data', 'concat_1'),
('split_2_data', 'concat_1'),
('split_3_data', 'concat_1'),
- ('concat_1', 'concat_1_data')
+ ('concat_1', 'concat_1_data'),
+ ('concat_1_data', 'op_output')
],
{'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
'split_1': {'axis': 3},
'split_1_data': {'shape': np.array([1, 227, 227, 18])},
'split_2_data': {'shape': np.array([1, 227, 227, 17])},
'split_3_data': {'shape': np.array([1, 227, 227, 19])},
- 'concat_1_data': {'shape': np.array([1, 227, 227, 54]), 'is_output': True},
+ 'concat_1_data': {'shape': np.array([1, 227, 227, 54])},
})
pattern = ConvertGroupedStridedSlice()
('sslice_1_data', 'concat_1'),
('sslice_2_data', 'concat_1'),
('sslice_3_data', 'concat_1'),
- ('concat_1', 'concat_1_data')
+ ('concat_1', 'concat_1_data'),
+ ('concat_1_data', 'op_output')
],
{'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
[slice(0, 1, 1), slice(0, 227, 1), slice(0, 227, 1), slice(0, 19, 1)])},
'sslice_3_data': {'shape': np.array([1, 227, 227, 19])},
- 'concat_1_data': {'shape': np.array([1, 227, 227, 54]), 'is_output': True},
+ 'concat_1_data': {'shape': np.array([1, 227, 227, 54])},
})
graph.graph['layout'] = 'NHWC'
('sslice_1_data', 'concat_1'),
('sslice_2_data', 'concat_1'),
('sslice_3_data', 'concat_1'),
- ('concat_1', 'concat_1_data')
+ ('concat_1', 'concat_1_data'),
+ ('concat_1_data', 'op_output')
],
{'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
[slice(0, 1, 1), slice(0, 227, 1), slice(0, 227, 1), slice(0, 19, 1)])},
'sslice_3_data': {'shape': np.array([1, 227, 227, 19])},
- 'concat_1_data': {'shape': np.array([1, 227, 227, 54]), 'is_output': True},
+ 'concat_1_data': {'shape': np.array([1, 227, 227, 54])},
})
pattern = ConvertGroupedStridedSlice()
('sslice_1_data', 'concat_1'),
('sslice_2_data', 'concat_1'),
('sslice_3_data', 'concat_1'),
- ('concat_1', 'concat_1_data')
+ ('concat_1', 'concat_1_data'),
+ ('concat_1_data', 'op_output')
],
{'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
[slice(0, 1, 1), slice(0, 227, 1), slice(0, 227, 1), slice(0, 19, 1)])},
'sslice_3_data': {'shape': np.array([1, 227, 227, 19])},
- 'concat_1_data': {'shape': np.array([1, 227, 227, 54]), 'is_output': True},
+ 'concat_1_data': {'shape': np.array([1, 227, 227, 54])},
})
graph.graph['layout'] = 'NHWC'
('sslice_1_data', 'concat_1'),
('sslice_2_data', 'concat_1'),
('sslice_3_data', 'concat_1'),
- ('concat_1', 'concat_1_data')
+ ('concat_1', 'concat_1_data'),
+ ('concat_1_data', 'op_output')
],
{'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
[slice(0, 1, 1), slice(0, 227, 1), slice(0, 227, 1), slice(0, 19, 1)])},
'sslice_3_data': {'shape': np.array([1, 227, 227, 19])},
- 'concat_1_data': {'shape': np.array([1, 227, 227, 54]), 'is_output': True},
+ 'concat_1_data': {'shape': np.array([1, 227, 227, 54])},
})
pattern = ConvertGroupedStridedSlice()
('sslice_1_data', 'concat_1'),
('sslice_2_data', 'concat_1'),
('sslice_3_data', 'concat_1'),
- ('concat_1', 'concat_1_data')
+ ('concat_1', 'concat_1_data'),
+ ('concat_1_data', 'op_output'),
],
{'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
[slice(0, 1, 1), slice(0, 227, 1), slice(0, 227, 1), slice(1, 19, 1)])},
'sslice_3_data': {'shape': np.array([1, 227, 227, 18])},
- 'concat_1_data': {'shape': np.array([1, 227, 227, 54]), 'is_output': True},
+ 'concat_1_data': {'shape': np.array([1, 227, 227, 54])},
})
graph.graph['layout'] = 'NHWC'
('split_2_data', 'concat_1'),
('split_3_data', 'concat_1'),
('split_4_data', 'concat_1'),
- ('concat_1', 'concat_1_data')
+ ('concat_1', 'concat_1_data'),
+ ('concat_1_data', 'op_output'),
+ ('split_1_data', 'op_output_1')
],
{'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
'split_1': {'axis': 3},
'split_2_data': {'shape': np.array([1, 227, 227, 18])},
'split_3_data': {'shape': np.array([1, 227, 227, 17])},
'split_4_data': {'shape': np.array([1, 227, 227, 18])},
- 'concat_1_data': {'shape': np.array([1, 227, 227, 54]), 'is_output': True},
+ 'concat_1_data': {'shape': np.array([1, 227, 227, 54])},
})
pattern = ConvertGroupedStridedSlice()
('sslice_2', 'sslice_2_data'),
('sslice_1_data', 'concat_1'),
('sslice_2_data', 'concat_1'),
- ('concat_1', 'concat_1_data')
+ ('concat_1', 'concat_1_data'),
+ ('concat_1_data', 'op_output')
],
{'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
[slice(0, 1, 1), slice(0, 227, 1), slice(0, 227, 1), slice(27, 45, 1)])},
'sslice_2_data': {'shape': np.array([1, 227, 227, 18])},
- 'concat_1_data': {'shape': np.array([1, 227, 227, 54]), 'is_output': True},
+ 'concat_1_data': {'shape': np.array([1, 227, 227, 54])},
})
graph.graph['layout'] = 'NHWC'
('split_1', 'split_4_data'),
('split_1_data', 'concat_1'),
('split_3_data', 'concat_1'),
- ('concat_1', 'concat_1_data')
+ ('concat_1', 'concat_1_data'),
+ ('concat_1_data', 'op_output'),
+ ('split_2_data', 'op_output_1'),
+ ('split_4_data', 'op_output_2'),
],
{'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
'split_1': {'axis': 3},
'split_2_data': {'shape': np.array([1, 227, 227, 9])},
'split_3_data': {'shape': np.array([1, 227, 227, 18])},
'split_4_data': {'shape': np.array([1, 227, 227, 9])},
- 'concat_1_data': {'shape': np.array([1, 227, 227, 54]), 'is_output': True},
+ 'concat_1_data': {'shape': np.array([1, 227, 227, 54])},
})
pattern = ConvertGroupedStridedSlice()
('sslice_2', 'sslice_2_data'),
('sslice_1_data', 'concat_1'),
('sslice_2_data', 'concat_1'),
- ('concat_1', 'concat_1_data')
+ ('concat_1', 'concat_1_data'),
+ ('concat_1_data', 'op_output')
],
{'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
[slice(0, 1, 1), slice(10, 227, 1), slice(0, 227, 1), slice(27, 45, 1)])},
'sslice_2_data': {'shape': np.array([1, 217, 227, 18])},
- 'concat_1_data': {'shape': np.array([1, 227, 227, 54]), 'is_output': True},
+ 'concat_1_data': {'shape': np.array([1, 227, 227, 54])},
})
graph.graph['layout'] = 'NHWC'
('sslice_2', 'sslice_2_data'),
('sslice_1_data', 'concat_1'),
('sslice_2_data', 'concat_1'),
- ('concat_1', 'concat_1_data')
+ ('concat_1', 'concat_1_data'),
+ ('concat_1_data', 'op_output')
],
{'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
[slice(0, 1, 1), slice(10, 227, 1), slice(0, 227, 1), slice(27, 45, 1)])},
'sslice_2_data': {'shape': np.array([1, 217, 227, 18])},
- 'concat_1_data': {'shape': np.array([1, 227, 227, 54]), 'is_output': True},
+ 'concat_1_data': {'shape': np.array([1, 227, 227, 54])},
})
pattern = ConvertGroupedStridedSlice()
('sslice_2', 'sslice_2_data'),
('sslice_1_data', 'concat_1'),
('sslice_2_data', 'concat_1'),
- ('concat_1', 'concat_1_data')
+ ('concat_1', 'concat_1_data'),
+ ('concat_1_data', 'op_output')
],
{'placeholder_1_data': {'shape': np.array([1, 54, 54, 3])},
[slice(0, 1, 1), slice(18, 36, 1), slice(0, 54, 1), slice(0, 3, 1)])},
'sslice_2_data': {'shape': np.array([1, 18, 54, 3])},
- 'concat_1_data': {'shape': np.array([1, 54, 54, 3]), 'is_output': True},
+ 'concat_1_data': {'shape': np.array([1, 54, 54, 3])},
})
graph.graph['layout'] = 'NHWC'
('split_1', 'split_3_data'),
('split_1_data', 'concat_1'),
('split_3_data', 'concat_1'),
- ('concat_1', 'concat_1_data')
+ ('concat_1', 'concat_1_data'),
+ ('concat_1_data', 'op_output'),
+ ('split_2_data', 'op_output_1')
],
{'placeholder_1_data': {'shape': np.array([1, 54, 54, 3])},
'split_1': {'axis': 1},
'split_1_data': {'shape': np.array([1, 18, 54, 3])},
'split_2_data': {'shape': np.array([1, 18, 54, 3])},
'split_3_data': {'shape': np.array([1, 18, 54, 3])},
- 'concat_1_data': {'shape': np.array([1, 54, 54, 3]), 'is_output': True},
+ 'concat_1_data': {'shape': np.array([1, 54, 54, 3])},
+ })
+
+ pattern = ConvertGroupedStridedSlice()
+ pattern.find_and_replace_pattern(graph)
+
+ (flag, resp) = compare_graphs(graph, graph_ref, 'concat_1_data', check_op_attrs=True)
+ self.assertTrue(flag, resp)
+
+
+class AddReshapeAfterStridedSliceTests(unittest.TestCase):
+ def test_ss_1_shrink_last(self):
+ graph = build_graph(nodes_attributes,
+ [('placeholder_1', 'placeholder_1_data'),
+ ('placeholder_1_data', 'sslice_1'),
+ ('placeholder_begin_data', 'sslice_1'),
+ ('placeholder_end_data', 'sslice_1'),
+ ('placeholder_stride_data', 'sslice_1'),
+ ('sslice_1', 'sslice_1_data'),
+ ('sslice_1_data', 'op_output')
+ ],
+ {'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
+ 'sslice_1': {'slices': np.array([slice(0, 1, 1), slice(0, 227, 1), slice(0, 1, 1), slice(0, 54, 1)]),
+ 'shrink_axis_mask': [0, 0, 1, 0],
+ 'new_axis_mask': np.array([0, 0, 0, 0])},
+ 'sslice_1_data': {'shape': np.array([1, 227, 54])},
+ })
+ graph.graph['layout'] = 'NHWC'
+
+ graph_ref = build_graph(nodes_attributes,
+ [('placeholder_1', 'placeholder_1_data'),
+ ('placeholder_1_data', 'sslice_1'),
+ ('placeholder_begin_data', 'sslice_1'),
+ ('placeholder_end_data', 'sslice_1'),
+ ('placeholder_stride_data', 'sslice_1'),
+ ('sslice_1', 'sslice_1/Reshape_shrink_data'),
+ ('sslice_1/Reshape_shrink_data', 'sslice_1/Reshape_shrink'),
+ ('sslice_1/Reshape_shrink', 'sslice_1_data'),
+ ('sslice_1_data', 'op_output')
+ ],
+ {'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
+ 'sslice_1': {'slices': np.array(
+ [slice(0, 1, 1), slice(0, 227, 1), slice(0, 1, 1), slice(0, 54, 1)]),
+ 'shrink_axis_mask': np.array([0, 0, 0, 0]),
+ 'new_axis_mask': np.array([0, 0, 0, 0])},
+ 'sslice_1_data': {'shape': np.array([1, 227, 54])},
+ 'sslice_1/Reshape_shrink': {'dim': np.array([1, 227, 54])},
+ 'sslice_1/Reshape_shrink_data': {'shape': np.array([1, 227, 1, 54])}
+ })
+
+ pattern = ConvertGroupedStridedSlice()
+ pattern.add_reshape_for_shrink(graph, Node(graph, 'sslice_1'))
+
+ (flag, resp) = compare_graphs(graph, graph_ref, 'sslice_1_data', check_op_attrs=True)
+ graph.clear()
+ graph_ref.clear()
+ self.assertTrue(flag, resp)
+
+ def test_ss_1_shrink(self):
+ graph = build_graph(nodes_attributes,
+ [('placeholder_1', 'placeholder_1_data'),
+ ('placeholder_1_data', 'sslice_2'),
+ ('placeholder_begin_data', 'sslice_2'),
+ ('placeholder_end_data', 'sslice_2'),
+ ('placeholder_stride_data', 'sslice_2'),
+ ('sslice_2', 'sslice_2_data'),
+ ('sslice_2_data', 'placeholder_2'),
+ ('placeholder_2', 'placeholder_2_data'),
+ ('sslice_2_data', 'op_output')
+ ],
+ {'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
+ 'sslice_2': {'slices': np.array([slice(0, 1, 1), slice(0, 227, 1), slice(0, 1, 1), slice(0, 54, 1)]),
+ 'shrink_axis_mask': [0, 0, 1, 0],
+ 'new_axis_mask': np.array([0, 0, 0, 0])},
+ 'sslice_2_data': {'shape': np.array([1, 227, 54])}
+ })
+ graph.graph['layout'] = 'NHWC'
+
+ graph_ref = build_graph(nodes_attributes,
+ [('placeholder_1', 'placeholder_1_data'),
+ ('placeholder_1_data', 'sslice_2'),
+ ('placeholder_begin_data', 'sslice_2'),
+ ('placeholder_end_data', 'sslice_2'),
+ ('placeholder_stride_data', 'sslice_2'),
+ ('sslice_2', 'sslice_2/Reshape_shrink_data'),
+ ('sslice_2/Reshape_shrink_data', 'sslice_2/Reshape_shrink'),
+ ('sslice_2/Reshape_shrink', 'sslice_2_data'),
+ ('sslice_2_data', 'placeholder_2'),
+ ('placeholder_2', 'placeholder_2_data'),
+ ('sslice_2_data', 'op_output')
+ ],
+ {'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
+ 'sslice_2': {'slices': np.array([slice(0, 1, 1), slice(0, 227, 1), slice(0, 1, 1), slice(0, 54, 1)]),
+ 'shrink_axis_mask': np.array([0, 0, 0, 0]),
+ 'new_axis_mask': np.array([0, 0, 0, 0])},
+ 'sslice_2_data': {'shape': np.array([1, 227, 54])},
+ 'sslice_2/Reshape_shrink': {'dim': np.array([1, 227, 54])},
+ 'sslice_2/Reshape_shrink_data': {'shape': np.array([1, 227, 1, 54])},
+ })
+
+ pattern = ConvertGroupedStridedSlice()
+ pattern.add_reshape_for_shrink(graph, Node(graph, 'sslice_2'))
+
+ (flag, resp) = compare_graphs(graph, graph_ref, 'sslice_2_data', check_op_attrs=True)
+ graph.clear()
+ graph_ref.clear()
+ self.assertTrue(flag, resp)
+
+ def test_ss_2_shrink(self):
+ graph = build_graph(nodes_attributes,
+ [('placeholder_1', 'placeholder_1_data'),
+ ('placeholder_1_data', 'sslice_2'),
+ ('placeholder_begin_data', 'sslice_2'),
+ ('placeholder_end_data', 'sslice_2'),
+ ('placeholder_stride_data', 'sslice_2'),
+ ('sslice_2', 'sslice_2_data'),
+ ('sslice_2_data', 'placeholder_2'),
+ ('placeholder_2', 'placeholder_2_data'),
+ ('sslice_2_data', 'op_output')
+ ],
+ {'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
+ 'sslice_2': {
+ 'slices': np.array([slice(0, 1, 1), slice(0, 1, 1), slice(0, 227, 1), slice(0, 1, 1)]),
+ 'shrink_axis_mask': np.array([0, 1, 0, 1]),
+ 'new_axis_mask': np.array([0, 0, 0, 0])},
+ 'sslice_2_data': {'shape': np.array([1, 227])}
+ })
+ graph.graph['layout'] = 'NHWC'
+
+ graph_ref = build_graph(nodes_attributes,
+ [('placeholder_1', 'placeholder_1_data'),
+ ('placeholder_1_data', 'sslice_2'),
+ ('placeholder_begin_data', 'sslice_2'),
+ ('placeholder_end_data', 'sslice_2'),
+ ('placeholder_stride_data', 'sslice_2'),
+ ('sslice_2', 'sslice_2/Reshape_shrink_data'),
+ ('sslice_2/Reshape_shrink_data', 'sslice_2/Reshape_shrink'),
+ ('sslice_2/Reshape_shrink', 'sslice_2_data'),
+ ('sslice_2_data', 'placeholder_2'),
+ ('placeholder_2', 'placeholder_2_data'),
+ ('sslice_2_data', 'op_output')
+ ],
+ {'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
+ 'sslice_2': {'slices': np.array(
+ [slice(0, 1, 1), slice(0, 1, 1), slice(0, 227, 1), slice(0, 1, 1)]),
+ 'shrink_axis_mask': np.array([0, 0, 0, 0]),
+ 'new_axis_mask': np.array([0, 0, 0, 0])},
+ 'sslice_2_data': {'shape': np.array([1, 227])},
+ 'sslice_2/Reshape_shrink': {'dim': np.array([1, 227])},
+ 'sslice_2/Reshape_shrink_data': {'shape': np.array([1, 1, 227, 1])},
+ })
+
+ pattern = ConvertGroupedStridedSlice()
+ pattern.add_reshape_for_shrink(graph, Node(graph, 'sslice_2'))
+
+ (flag, resp) = compare_graphs(graph, graph_ref, 'sslice_2_data', check_op_attrs=True)
+ graph.clear()
+ graph_ref.clear()
+ self.assertTrue(flag, resp)
+
+ def test_ss_1_new(self):
+ graph = build_graph(nodes_attributes,
+ [('placeholder_1', 'placeholder_1_data'),
+ ('placeholder_1_data', 'sslice_2'),
+ ('placeholder_begin_data', 'sslice_2'),
+ ('placeholder_end_data', 'sslice_2'),
+ ('placeholder_stride_data', 'sslice_2'),
+ ('sslice_2', 'sslice_2_data'),
+ ('sslice_2_data', 'placeholder_2'),
+ ('placeholder_2', 'placeholder_2_data'), ],
+ {'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
+ 'sslice_2': {'slices': np.array(
+ [slice(0, 1, 1), slice(0, 1, 1), slice(0, 227, 1), slice(0, 227, 1), slice(0, 54, 1)]),
+ 'shrink_axis_mask': np.array([0, 0, 0, 0, 0]),
+ 'new_axis_mask': np.array([0, 1, 0, 0, 0])},
+ 'sslice_2_data': {'shape': np.array([1, 1, 227, 227, 54])}
+ })
+ graph.graph['layout'] = 'NHWC'
+
+ graph_ref = build_graph(nodes_attributes,
+ [('placeholder_1', 'placeholder_1_data'),
+ ('placeholder_1_data', 'sslice_2'),
+ ('placeholder_begin_data', 'sslice_2'),
+ ('placeholder_end_data', 'sslice_2'),
+ ('placeholder_stride_data', 'sslice_2'),
+ ('sslice_2', 'sslice_2/Reshape_new_data'),
+ ('sslice_2/Reshape_new_data', 'sslice_2/Reshape_new'),
+ ('sslice_2/Reshape_new', 'sslice_2_data'),
+ ('sslice_2_data', 'placeholder_2'),
+ ('placeholder_2', 'placeholder_2_data')],
+ {'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
+ 'sslice_2': {'slices': np.array(
+ [slice(0, 1, 1), slice(0, 1, 1), slice(0, 227, 1), slice(0, 227, 1),
+ slice(0, 54, 1)]),
+ 'shrink_axis_mask': np.array([0, 0, 0, 0, 0]),
+ 'new_axis_mask': np.array([0, 0, 0, 0, 0])},
+ 'sslice_2_data': {'shape': np.array([1, 1, 227, 227, 54])},
+ 'sslice_2/Reshape_new': {'dim': np.array([1, 1, 227, 227, 54])},
+ 'sslice_2/Reshape_new_data': {'shape': np.array([1, 227, 227, 54])},
+ })
+
+ pattern = ConvertGroupedStridedSlice()
+ pattern.add_reshape_for_new(graph, Node(graph, 'sslice_2'))
+
+ (flag, resp) = compare_graphs(graph, graph_ref, 'sslice_2_data', check_op_attrs=True)
+ graph.clear()
+ graph_ref.clear()
+ self.assertTrue(flag, resp)
+
+ def test_ss_shrink_new(self):
+ graph = build_graph(nodes_attributes,
+ [('placeholder_1', 'placeholder_1_data'),
+ ('placeholder_1_data', 'sslice_2'),
+ ('placeholder_begin_data', 'sslice_2'),
+ ('placeholder_end_data', 'sslice_2'),
+ ('placeholder_stride_data', 'sslice_2'),
+ ('sslice_2', 'sslice_2_data'),
+ ('sslice_2_data', 'placeholder_2'),
+ ('placeholder_2', 'placeholder_2_data'),
+ ('sslice_2_data', 'op_output')
+ ],
+ {'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
+ 'sslice_2': {'slices': np.array(
+ [slice(0, 1, 1), slice(0, 1, 1), slice(0, 227, 1), slice(0, 1, 1), slice(0, 54, 1)]),
+ 'shrink_axis_mask': np.array([0, 0, 0, 1, 0]),
+ 'new_axis_mask': np.array([0, 1, 0, 0, 0])},
+ 'sslice_2_data': {'shape': np.array([1, 1, 227, 54])}
+ })
+ graph.graph['layout'] = 'NHWC'
+
+ graph_ref = build_graph(nodes_attributes,
+ [('placeholder_1', 'placeholder_1_data'),
+ ('placeholder_1_data', 'sslice_2'),
+ ('placeholder_begin_data', 'sslice_2'),
+ ('placeholder_end_data', 'sslice_2'),
+ ('placeholder_stride_data', 'sslice_2'),
+ ('sslice_2', 'sslice_2/Reshape_new_data'),
+ ('sslice_2/Reshape_new_data', 'sslice_2/Reshape_new'),
+ ('sslice_2/Reshape_new', 'sslice_2/Reshape_shrink_data'),
+ ('sslice_2/Reshape_shrink_data', 'sslice_2/Reshape_shrink'),
+ ('sslice_2/Reshape_shrink', 'sslice_2_data'),
+ ('sslice_2_data', 'placeholder_2'),
+ ('placeholder_2', 'placeholder_2_data'),
+ ('sslice_2_data', 'op_output')
+ ],
+ {'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
+ 'sslice_2': {'slices': np.array(
+ [slice(0, 1, 1), slice(0, 1, 1), slice(0, 227, 1), slice(0, 1, 1),
+ slice(0, 54, 1)]),
+ 'shrink_axis_mask': np.array([0, 0, 0, 0, 0]),
+ 'new_axis_mask': np.array([0, 0, 0, 0, 0])},
+ 'sslice_2_data': {'shape': np.array([1, 1, 227, 54])},
+ 'sslice_2/Reshape_new': {'dim': np.array([1, 1, 227, 1, 54])},
+ 'sslice_2/Reshape_new_data': {'shape': np.array([1, 227, 1, 54])},
+ 'sslice_2/Reshape_shrink': {'dim': np.array([1, 1, 227, 54])},
+ 'sslice_2/Reshape_shrink_data': {'shape': np.array([1, 1, 227, 1, 54])},
+ })
+
+ pattern = ConvertGroupedStridedSlice()
+ pattern.add_reshape_for_shrink(graph, Node(graph, 'sslice_2'))
+ pattern.add_reshape_for_new(graph, Node(graph, 'sslice_2'))
+
+ (flag, resp) = compare_graphs(graph, graph_ref, 'sslice_2_data', check_op_attrs=True)
+ graph.clear()
+ graph_ref.clear()
+ self.assertTrue(flag, resp)
+
+ # test case with 2 strided slices with the same parameters but different outputs
+ def test_1(self):
+ graph = build_graph(nodes_attributes,
+ [('placeholder_1', 'placeholder_1_data'),
+ ('placeholder_1_data', 'sslice_1'),
+ ('sslice_1', 'sslice_1_data'),
+ ('placeholder_1_data', 'sslice_2'),
+ ('sslice_2', 'sslice_2_data'),
+ ('placeholder_1_data', 'sslice_3'),
+ ('sslice_3', 'sslice_3_data'),
+ ('sslice_1_data', 'concat_1'),
+ ('sslice_2_data', 'concat_1'),
+ ('sslice_3_data', 'placeholder_2'),
+ ('placeholder_2', 'placeholder_2_data'),
+ ('concat_1', 'concat_1_data'),
+ ('concat_1_data', 'op_output'),
+ ('placeholder_2_data', 'op_output')
+ ],
+ {'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
+
+ 'sslice_1': {'slices': np.array(
+ [slice(0, 1, 1), slice(0, 227, 1), slice(0, 227, 1), slice(0, 27, 1)])},
+ 'sslice_1_data': {'shape': np.array([1, 227, 227, 27])},
+
+ 'sslice_2': {'slices': np.array(
+ [slice(0, 1, 1), slice(0, 227, 1), slice(0, 227, 1), slice(27, 54, 1)])},
+ 'sslice_2_data': {'shape': np.array([1, 227, 227, 27])},
+
+ 'sslice_3': {'slices': np.array(
+ [slice(0, 1, 1), slice(0, 227, 1), slice(0, 227, 1), slice(0, 27, 1)])},
+ 'sslice_3_data': {'shape': np.array([1, 227, 227, 27])},
+
+ 'concat_1_data': {'shape': np.array([1, 227, 227, 54])},
+ })
+ graph.graph['layout'] = 'NHWC'
+
+ graph_ref = build_graph(nodes_attributes,
+ [('placeholder_1', 'placeholder_1_data'),
+ ('placeholder_1_data', 'split_1'),
+ ('split_1', 'split_1_data'),
+ ('split_1', 'split_2_data'),
+ ('split_1_data', 'concat_1'),
+ ('split_2_data', 'concat_1'),
+ ('split_1_data', 'placeholder_2'),
+ ('placeholder_2', 'placeholder_2_data'),
+ ('concat_1', 'concat_1_data'),
+ ('concat_1_data', 'op_output'),
+ ('placeholder_2_data', 'op_output')
+ ],
+ {'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
+ 'split_1': {'axis': 3},
+ 'split_1_data': {'shape': np.array([1, 227, 227, 27])},
+ 'split_2_data': {'shape': np.array([1, 227, 227, 27])},
+ 'concat_1_data': {'shape': np.array([1, 227, 227, 54])},
})
pattern = ConvertGroupedStridedSlice()