Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / mo / front / common / partial_infer / slice.py
index bf23763..a63658a 100644 (file)
@@ -1,5 +1,5 @@
 """
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
 
  Licensed under the Apache License, Version 2.0 (the "License");
  you may not use this file except in compliance with the License.
 
 import numpy as np
 
-from mo.graph.graph import erase_node
 from mo.utils.error import Error
 
+
 def tf_strided_slice_infer(node):
+    if node.in_node(1).value is None or node.in_node(2).value is None:
+        raise Error('Strided slice layer supports only constant begin and end inputs')
     begin_id = node.in_node(1).value
     end_id = node.in_node(2).value
-    stride = node.in_node(3).value
+    if len(node.in_nodes()) > 3:
+        if node.in_node(3).value is None:
+            raise Error('Strided slice layer supports only constant stride input')
+        stride = node.in_node(3).value
+    else:
+        stride = []
 
     shape = node.in_node(0).shape
 
@@ -32,63 +39,79 @@ def tf_strided_slice_infer(node):
     convert_negative_indices(begin_id, shape)
     convert_negative_indices(end_id, shape)
 
-    test_bit = lambda val, offset: ((1 << offset) & val != 0)
-
     slice_idx = []
-    shrink_axis_mask = []
-    ellipsis_mask = []
-    new_axis_mask = []
-    dims = len(begin_id)
-
+    dims = np.amax(np.array([len(begin_id), len(end_id), len(stride),
+                             len(node.shrink_axis_mask), len(node.new_axis_mask), len(node.ellipsis_mask),
+                             len(node.begin_mask), len(node.end_mask)]))
+
+    # make mask correct length
+    def extend_mask(in_mask, fin_len, zeros=True):
+        mask = list(in_mask)
+        if len(mask) < fin_len:
+            if zeros:
+                mask.extend(np.zeros(dims-len(mask), dtype=np.int32))
+            else:
+                mask.extend(np.ones(dims-len(mask), dtype=np.int32))
+        return np.array(mask, dtype=np.int32)
+
+    for mask in {'new_axis_mask', 'shrink_axis_mask', 'ellipsis_mask'}:
+        node[mask] = extend_mask(node[mask], dims)
+    node.begin_mask = extend_mask(node.begin_mask, dims, False)
+    node.end_mask = extend_mask(node.end_mask, dims, False)
+
+    old_idx = 0
+    ellips_ext = 0
+    id_em = 0
     for idx in range(dims):
-        def_beg = 0 if stride[idx] > 0 else -1
-        def_end = shape[idx] if stride[idx] > 0 else -shape[idx]-1
-        l = begin_id[idx] if not test_bit(node.begin_mask, idx) else def_beg
-        r = end_id[idx] if not test_bit(node.end_mask, idx) else def_end
-
-        # Check shrink_axis_mask
-        shrink_axis_mask.append(test_bit(node.shrink_axis_mask, idx))
-        if shrink_axis_mask[idx]:
-            l, r = l, l + 1
-
-        # Check new_axis_mask
-        new_axis_mask.append(test_bit(node.new_axis_mask, idx))
-        if new_axis_mask[idx]:
+        if node.new_axis_mask[idx]:
             slice_idx.append(np.newaxis)
-
-        # Check ellipsis_mask
-        ellipsis_mask.append(test_bit(node.ellipsis_mask, idx))
-        if ellipsis_mask[idx]:
-            shrink_axis_mask[idx] = False
-            l, r = 0, shape[idx]
-
-        slice_idx.append(slice(l, r, stride[idx]))
-    
-    # if masks length are less than input dims length than add slices and masks for such dims
-    for idx in range(dims, len(shape)):
-        slice_idx.append(slice(0, shape[idx], 1))
-        shrink_axis_mask.append(False)
-        new_axis_mask.append(False)
+        elif node.ellipsis_mask[idx]:
+            ellips_ext = len(shape) - (dims - np.count_nonzero(node.new_axis_mask) - 1)
+            id_em = idx
+            for i in range(0, ellips_ext):
+                slice_idx.append(slice(0, shape[old_idx], 1))
+                old_idx = old_idx + 1
+        else:
+            s = stride[idx] if len(stride) > idx else 1
+            def_beg = 0 if s > 0 else -1
+            def_end = shape[old_idx] if s > 0 else -shape[old_idx]-1
+            l = begin_id[idx] if node.begin_mask[idx] and idx < len(begin_id) else def_beg
+            r = end_id[idx] if node.end_mask[idx] and idx < len(end_id) else def_end
+
+            # Check shrink_axis_mask
+            if node.shrink_axis_mask[idx] and idx < len(shape):
+                slice_idx.append(slice(l, l+1, s))
+            else:
+                slice_idx.append(slice(l, r, s))
+            old_idx = old_idx + 1
 
     value = node.in_node(0).value if node.in_node(0).value is not None else np.zeros(shape)
     # fix for the warning: "FutureWarning: Using a non-tuple sequence for multidimensional indexing is deprecated use
     # `arr[tuple(seq)]` instead of `arr[seq]`"
     value = value[tuple(slice_idx)]
 
-    for idx, flag in reversed(list(enumerate(shrink_axis_mask))):
+    for idx, flag in reversed(list(enumerate(node.shrink_axis_mask))):
         if flag:
-            value = np.squeeze(value, idx)
+            if ellips_ext > 0 and idx > id_em:
+                idx = idx + ellips_ext - 1
+            try:
+                value = np.squeeze(value, idx)
+            except ValueError:
+                # ignore this error
+                continue
 
     node['slices'] = np.array(slice_idx)
-    node['shrink_axis_mask'] = np.array(shrink_axis_mask)
-    node['new_axis_mask'] = np.array(new_axis_mask)
+    for attr in ('shrink_axis_mask', 'new_axis_mask', 'ellipsis_mask', 'begin_mask', 'end_mask'):
+        node[attr] = np.array(node[attr], dtype=np.int32)
 
     node.out_node().value = np.array(value) if node.in_node(0).value is not None else None
-    node.out_node().shape = np.array(value.shape)
+    node.out_node().shape = np.array(value.shape, dtype=np.int64)
+
+    # change precision to I32 for begin, end, stride inputs
+    for i in range(1, len(node.in_nodes())):
+        inp = node.in_node(i)
+        inp["force_precision"] = "I32"
 
-    #remove inputs converted in attributes
-    #for i in range(1,4):
-    #    node.graph.remove_edge(node.in_node(i).id, node.id)
 
 def convert_negative_indices(indices: np.array, shape: np.array):
     for ind, value in enumerate(indices):