2 Copyright (c) 2018-2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
21 from mo.graph.graph import Node, Graph
22 from mo.ops.op import Op
29 def __init__(self, graph: Graph, attrs: dict):
30 super().__init__(graph, {
35 'infer': __class__.infer
38 def supported_attrs(self):
39 return ['start', 'end', 'axis']
42 def infer(node: Node):
43 if len(node.in_nodes()) == 1:
45 if node.has('start') and node.has('end') and node.has('axis'):
47 if node.has_valid('start') and node.has_valid('end') and node.has('axis'):
52 log.warning('Incorrect slice operation: no starts or end attr')
56 from mo.front.common.partial_infer.slice import caffe_slice_infer
57 caffe_slice_infer(node)
58 elif len(node.in_nodes()) == 3:
60 start_node = node.in_node(1)
61 size_node = node.in_node(2)
62 if start_node.has_valid('value') and size_node.has_valid('value'):
63 start = np.array(node.in_node(1).value, dtype=np.int64)
64 size = np.array(node.in_node(2).value, dtype=np.int64)
68 # Delete edges to start, size nodes
69 node.graph.remove_edge(node.in_node(1).id, node.id)
70 node.graph.remove_edge(node.in_node(2).id, node.id)
76 log.warning('Incorrect slice operation: no starts or end attr')
79 log.warning('Incorrect number of input nodes in slice operation')
82 input_shape = node.in_node(0).shape
83 # Check for situation when size[i] == -1 in TF
84 for i in range(start.size):
86 end[i] = input_shape[i]
89 value = node.in_node(0).value
91 # If value is None create dummy vaue for shape propogation
93 value = np.zeros(input_shape)
95 # Following ONNX and TF specification, in case of unknown axis, axises should be in greater order
97 axis = [x for x in range(len(start))]
99 # Calculate output value for slice operation
100 slice_idx = [None for x in range(len(node.in_node().shape))]
101 shrink_axis_mask = [False for x in range(len(node.in_node().shape))]
102 for id in range(len(axis)):
103 # Ranged for output value for specified axis
104 slice_idx[axis[id]] = slice(start[id], end[id], 1)
106 # TODO: check whether this check is really important
107 for axis, s in enumerate(slice_idx):
109 slice_idx[axis] = slice(0, input_shape[axis], 1)
111 # Add new parameters to node
112 node['slices'] = np.array(slice_idx)
113 node['shrink_axis_mask'] = np.array(shrink_axis_mask)
115 value = value[tuple(slice_idx)]
116 node.out_node().value = np.array(value) if node.in_node(0).value is not None else None
117 node.out_node().shape = np.array(value.shape)