Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / mo / front / common / partial_infer / split.py
1 """
2  Copyright (c) 2018-2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16
17 import logging as log
18
19 import numpy as np
20
21 from mo.ops.op import PermuteAttrs
22 from mo.graph.graph import Node
23
24
25 def part_sizes_to_indices(part_sizes: list):
26     """
27     Calculates indices of splits in the array based on part sizes for the split.
28     Output list can be used as the second argument for np.split function.
29     """
30     idx = 0
31     indices = []
32     for part_size in part_sizes:
33         idx += part_size
34         indices.append(idx)
35     # the last element should equal to the size of original array and it is redundant to numpy
36     log.debug("part_sizes: {}   -->   indices: {}".format(part_sizes, indices))
37     del indices[-1]
38     log.debug("part_sizes: {}   -->   indices: {}".format(part_sizes, indices))
39     return np.array(indices)
40
41
42 def split(input_data_node: Node, node: Node, axis: int, part_sizes: list):
43     """
44     Partial inference of generic split node.
45
46     Args:
47         @input: input tensor node, subject to split
48         @node: node of one of the Split types
49         @axis: split dimension index
50         @part_sizes: a NumPy array with sizes of all pieces that we split to
51
52     Returns:
53         int: normalized axis index
54
55     """
56
57     if input_data_node.shape is None:
58         return
59
60     # normalize axis
61     if axis < 0:
62         axis = input_data_node.shape.size + axis
63
64     if axis < 0 or axis >= input_data_node.shape.size:
65         log.error('Model is incorrect: axis for split node is out of range')
66         return
67
68     undef_indices = np.argwhere(part_sizes == -1)
69     if undef_indices.size > 1:
70         log.error('Desired split part sizes have more than one -1 element -- cannot deduce real sizes for them')
71         return
72
73     if undef_indices.size == 1:
74         undef_index = undef_indices[0]
75         part_sizes[undef_index] = 0
76         deduced_dim = input_data_node.shape[axis] - np.add.reduce(part_sizes)
77         if deduced_dim < 0:
78             log.error('Just deduced dimension for the split has negative value that means that split input shape and '
79                       'desired parts are not compatible')
80             return
81
82     all_parts_size = np.add.reduce(part_sizes)
83     if all_parts_size != input_data_node.shape[axis]:
84         log.error("input.shape[{}] = {}  !=  {} = sum of all parts in part_sizes".format(axis,
85                                                                                          input_data_node.shape[axis],
86                                                                                          all_parts_size))
87         return
88
89     splitted = None
90     if input_data_node.value is not None:
91         splitted = np.split(input_data_node.value, part_sizes_to_indices(part_sizes), axis)
92
93     # not all outputs from the split could be used so it is necessary to iterate over output edges and infer shape for
94     # necessary nodes only
95     for _, dst, edge_attrs in node.graph.out_edges(node.id, data=True):
96         out_port = edge_attrs['out']
97         out_node = node.out_node(out_port)
98
99         new_out_shape = input_data_node.shape.copy()
100         new_out_shape[axis] = part_sizes[out_port]
101         node.out_node(out_port).shape = new_out_shape
102         if splitted is not None:
103             out_node.value = splitted[out_port]
104             assert all(out_node.value.shape == out_node.shape)
105
106     assert not node.has_valid('axis') or node.axis == axis
107     node.axis = axis
108     # WARNING: != 4 is supposed to work for NHWC to NCHW translation only.
109     # if other global permutations happen this will fail
110     # TODO: redesign it to have this logic built in NHWC to NCHW translation pass; it requires
111     #       additional attributes with layout to be propagated through the network
112     if len(input_data_node.shape) != 4 and node.has_valid('dim_attrs') and 'axis' in node.dim_attrs:
113         log.warning('Removed "axis" attribute from the scope of the model relayout pass because len(input.shape) == {} '
114                     '!= 4 for node {}'.format(len(input_data_node.shape), node.soft_get('name')))
115         node.dim_attrs.remove('axis')
116         assert 'axis' not in node.dim_attrs
117     log.debug('output shapes after split: {}'.format([v.shape for k, v in node.out_nodes().items()]))
118
119
120 def tf_split_infer(node):
121     """
122     Partial infer of split node similar to Split op of TF.
123     """
124     # Two inputs: [split_dim, input]
125     assert len(node.in_nodes()) == 2, 'Node "{}" must have exactly two inputs'.format(node.soft_get('name'))
126     split_dim = node.in_node(0).value
127     if split_dim is None:
128         log.error('split_dim value for node {} is None. Cannot do shape inference.')
129         return
130
131     assert split_dim.ndim == 0, 'The split dimension for node "{}" must be a scalar.'.format(node.soft_get('name'))
132     split_dim = split_dim.item()
133     input = node.in_node(1)
134
135     if input.shape is None:
136         log.error('Input shape for node {} is not defined'.format(node.soft_get('name')))
137         return
138
139     log.debug('input shape for split: {}, should be split along {} dim'.format(input.shape, split_dim))
140     split_dim_size = input.shape[split_dim]
141     log.debug('split_dim_size type = {}'.format(type(split_dim_size)))
142
143     if split_dim_size % node.num_split != 0:
144         log.error("split_dim cannot be evenly divided by a given number of parts")
145         return
146
147     # split_dim is a numpy array, axis is split_dim[0]
148     log.debug('split_dim_size = {}, node.num_split = {}, div = {}, typeof div = {}'.format(
149         split_dim_size, node.num_split, split_dim_size / node.num_split, type(split_dim_size / node.num_split)))
150     split(input, node, split_dim, [int(split_dim_size / node.num_split)] * node.num_split)
151     node.graph.remove_edge(node.in_node(0).id, node.id)
152     node['input_port'] = 1
153
154     PermuteAttrs.create_permute_attrs(node, attrs=[('axis', 'input:1')])
155
156
157 def tf_split_v_infer(node: Node):
158     """
159     Partial infer of split node similar to SplitV op of TF.
160     """
161
162     if len(node.in_nodes()) == 1 and not (node.has_valid('axis') and node.has_valid('size_splits')):
163         return
164
165     if len(node.in_nodes()) == 3 and (node.has_valid('axis') or node.has_valid('size_splits')):
166         return
167
168     # Three inputs: [input, size_splits, split_dim)
169     if len(node.in_nodes()) == 3:
170         split_dim = node.in_node(2).value
171         assert split_dim.ndim == 0
172         split_dim = split_dim.item()
173         size_splits = node.in_node(1).value
174         node.graph.remove_edge(node.in_node(1).id, node.id)
175         node.graph.remove_edge(node.in_node(2).id, node.id)
176     else:
177         split_dim = node.axis
178         size_splits = node.size_splits
179    
180     if split_dim is None:
181         log.error('split_dim value for node {} is None. Cannot do shape inference.')
182         return
183     
184     input = node.in_node(0)
185     if input.shape is None or size_splits is None:
186         log.error('input shape or size of splits are not defined for node {}'.format(node.soft_get('name')))
187         return
188
189     log.debug('split_dim = {}, input.shape = {}, size_splits.value = {}'.format(split_dim, input.shape, size_splits))
190
191     # split_dim is a numpy array, axis is split_dim
192     split(input, node, split_dim, size_splits)
193
194     PermuteAttrs.create_permute_attrs(node, attrs=[('axis', 'input:0')])
195
196
197 def tf_unpack_infer(node: Node):
198     if len(node.in_nodes()) != 1:
199         log.debug('Unpack node "{}" must have one input.'.format(node.name))
200         return
201
202     in_shape = node.in_node().shape
203     if in_shape is None:
204         log.debug('Unpack node "{}" input node shape is not defined.'.format(node.name))
205         return
206
207     split_dim = node.axis
208     log.debug('input shape for unpack: {}, should be split along {} dim'.format(in_shape, split_dim))
209     split_dim_size = in_shape[split_dim]
210     log.debug('split_dim_size type = {}'.format(type(split_dim_size)))
211
212     if node.num_split is not None and node.num_split != split_dim_size:
213         log.debug('The unpack where num to unpack is not equal to the size of the dimension to unpack is not supported')
214         return
215
216     if node.num_split is None:
217         node.num_split = split_dim_size
218
219     if split_dim_size % node.num_split != 0:
220         log.error("split_dim cannot be evenly divided by a given number of parts")
221         return
222
223     split(node.in_node(), node, split_dim, [int(split_dim_size / node.num_split)] * node.num_split)
224     # node shapes will be squeezed in the separate pass