Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / mo / ops / slice.py
1 """
2  Copyright (c) 2018-2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16
17 import logging as log
18
19 import numpy as np
20
21 from mo.graph.graph import Node, Graph
22 from mo.ops.op import Op
23
24
25 class Slice(Op):
26     op = 'Slice'
27     enabled = True
28
29     def __init__(self, graph: Graph, attrs: dict):
30         super().__init__(graph, {
31             'type': __class__.op,
32             'op': 'Slice',
33             'in_ports_count': 3,
34             'out_ports_count': 1,
35             'infer': __class__.infer
36         }, attrs)
37
38     def supported_attrs(self):
39         return ['start', 'end', 'axis']
40
41     @staticmethod
42     def infer(node: Node):
43         if len(node.in_nodes()) == 1:
44             # Caffe or ONNX
45             if node.has('start') and node.has('end') and node.has('axis'):
46                 # ONNX case
47                 if node.has_valid('start') and node.has_valid('end') and node.has('axis'):
48                     start = node.start
49                     end = node.end
50                     axis = node.axis
51                 else:
52                     log.warning('Incorrect slice operation: no starts or end attr')
53                     return
54             else:
55                 # Caffe case
56                 from mo.front.common.partial_infer.slice import caffe_slice_infer
57                 caffe_slice_infer(node)
58         elif len(node.in_nodes()) == 3:
59             # TF case
60             start_node = node.in_node(1)
61             size_node = node.in_node(2)
62             if start_node.has_valid('value') and size_node.has_valid('value'):
63                 start = np.array(node.in_node(1).value, dtype=np.int64)
64                 size = np.array(node.in_node(2).value, dtype=np.int64)
65                 end = start + size
66                 axis = None
67
68                 # Delete edges to start, size nodes
69                 node.graph.remove_edge(node.in_node(1).id, node.id)
70                 node.graph.remove_edge(node.in_node(2).id, node.id)
71
72                 node['start'] = start
73                 node['end'] = end
74                 node['axis'] = None
75             else:
76                 log.warning('Incorrect slice operation: no starts or end attr')
77                 return
78         else:
79             log.warning('Incorrect number of input nodes in slice operation')
80             return
81
82         input_shape = node.in_node(0).shape
83         # Check for situation when size[i] == -1 in TF
84         for i in range(start.size):
85             if end[i] < start[i]:
86                 end[i] = input_shape[i]
87         # Update end param
88         node.end = end
89         value = node.in_node(0).value
90
91         # If value is None create dummy vaue for shape propogation
92         if value is None:
93             value = np.zeros(input_shape)
94
95         # Following ONNX and TF specification, in case of unknown axis, axises should be in greater order
96         if axis is None:
97             axis = [x for x in range(len(start))]
98
99         # Calculate output value for slice operation
100         slice_idx = [None for x in range(len(node.in_node().shape))]
101         shrink_axis_mask = [False for x in range(len(node.in_node().shape))]
102         for id in range(len(axis)):
103             # Ranged for output value for specified axis
104             slice_idx[axis[id]] = slice(start[id], end[id], 1)
105
106         # TODO: check whether this check is really important
107         for axis, s in enumerate(slice_idx):
108             if s is None:
109                 slice_idx[axis] = slice(0, input_shape[axis], 1)
110
111         # Add new parameters to node
112         node['slices'] = np.array(slice_idx)
113         node['shrink_axis_mask'] = np.array(shrink_axis_mask)
114
115         value = value[tuple(slice_idx)]
116         node.out_node().value = np.array(value) if node.in_node(0).value is not None else None
117         node.out_node().shape = np.array(value.shape)