Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / extensions / middle / SliceConverter.py
index f6e925b..e4c0266 100644 (file)
@@ -1,5 +1,5 @@
 """
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
 
  Licensed under the Apache License, Version 2.0 (the "License");
  you may not use this file except in compliance with the License.
  limitations under the License.
 """
 
-import networkx as nx
 import numpy as np
 
+from mo.graph.graph import Graph
 from mo.middle.replacement import MiddleReplacementPattern
+from mo.ops.const import Const
 from mo.ops.crop import Crop
-from mo.ops.op import Op
+from mo.ops.strided_slice import StridedSlice
 
 
 def convert_negative_indices(indices: np.array, shape: np.array):
@@ -36,6 +37,10 @@ class ConvertSlice(MiddleReplacementPattern):
     enabled = True
     op = "Slice"
 
+    def run_after(self):
+        from extensions.middle.pass_separator import MiddleStart
+        return [MiddleStart]
+
     def pattern(self):
         return dict(
             nodes=[
@@ -44,7 +49,7 @@ class ConvertSlice(MiddleReplacementPattern):
             edges=[]
         )
 
-    def replace_pattern(self, graph: nx.MultiDiGraph, match: dict):
+    def replace_pattern(self, graph: Graph, match: dict):
         node = match['slice']
         # Caffe case
         if not node.has_valid('start') or not node.has_valid('end'):
@@ -52,31 +57,50 @@ class ConvertSlice(MiddleReplacementPattern):
 
         begin = node.start
         end = node.end
+        axis = node.axis if node.has_valid('axis') else range(begin.size)
+        
 
         input = node.in_node(0)
         output_data = node.out_node()
 
         # Check whether operation use only one axis or not
+        axes_begin = np.zeros(len(input.shape), dtype=np.int32)
+        axes_end = np.zeros(len(input.shape), dtype=np.int32)
+        begin_ext = np.zeros(len(input.shape), dtype=np.int32)
+        end_ext = np.zeros(len(input.shape), dtype=np.int32)
         dims = 0
         axes = np.zeros(begin.size)
-        for i in range(begin.size):
-            if begin[i] != 0 or end[i] != input.shape[i]:
+        for i in range(len(axis)):
+            if begin[i] != 0 or end[i] < input.shape[i]:
                 dims += 1
                 axes[i] = 1
+                if begin[i] != 0:
+                    axes_begin[axis[i]] = 1
+                    begin_ext[axis[i]] = begin[i]
+                if end[i] < input.shape[i]:
+                    axes_end[axis[i]] = 1
+                    end_ext[axis[i]] = end[i]
         axes = np.array(axes, dtype=bool)
-        if dims == 0:
-            return
-        elif dims == 1:
-            # If Slice use only one axis, than
+
+        if dims == 1 or dims == 0:
+            # If Slice use only one axis or no axis, than
             # convert Slice to StridedSlice
+            ss = StridedSlice(graph, dict(new_axis_mask=np.zeros(len(output_data.shape), dtype=np.int32),
+                                          shrink_axis_mask=np.zeros(len(output_data.shape), dtype=np.int32),
+                                          ellipsis_mask=np.zeros(len(output_data.shape), dtype=np.int32),
+                                          begin_mask=axes_begin,
+                                          end_mask=axes_end))
+
+            convert_negative_indices(begin_ext, input.shape)
+            convert_negative_indices(end_ext, input.shape)
 
-            node['op'] = 'StridedSlice'
-            node['type'] = 'StridedSlice'
-            node['new_axis_mask'] = np.zeros(len(output_data.shape), dtype=np.bool)
-            node['shrink_axis_mask'] = np.zeros(len(output_data.shape), dtype=np.bool)
+            begin_node = Const(graph, {'name': 'begin', 'value': begin_ext, 'force_precision': 'I32'}).create_node_with_data()
+            end_node = Const(graph, {'name': 'end', 'value': end_ext, 'force_precision': 'I32'}).create_node_with_data()
 
-            convert_negative_indices(begin, input.shape)
-            convert_negative_indices(end, input.shape)
+            ss.create_node_with_data(inputs=[input, begin_node, end_node], data_nodes=[output_data])
+            # Remove unnecessary edges from and to to Slice vertex
+            graph.remove_edge(input.id, node.id)
+            graph.remove_edge(node.id, output_data.id)
         else:
             # If Slice use more than one axis use Crop layer
             crop = Crop(graph, dict(axis=np.arange(begin.size)[axes],