2 Copyright (c) 2018-2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
21 from mo.ops.op import PermuteAttrs
22 from mo.graph.graph import Node
25 def part_sizes_to_indices(part_sizes: list):
27 Calculates indices of splits in the array based on part sizes for the split.
28 Output list can be used as the second argument for np.split function.
32 for part_size in part_sizes:
35 # the last element should equal to the size of original array and it is redundant to numpy
36 log.debug("part_sizes: {} --> indices: {}".format(part_sizes, indices))
38 log.debug("part_sizes: {} --> indices: {}".format(part_sizes, indices))
39 return np.array(indices)
42 def split(input_data_node: Node, node: Node, axis: int, part_sizes: list):
44 Partial inference of generic split node.
47 @input: input tensor node, subject to split
48 @node: node of one of the Split types
49 @axis: split dimension index
50 @part_sizes: a NumPy array with sizes of all pieces that we split to
53 int: normalized axis index
57 if input_data_node.shape is None:
62 axis = input_data_node.shape.size + axis
64 if axis < 0 or axis >= input_data_node.shape.size:
65 log.error('Model is incorrect: axis for split node is out of range')
68 undef_indices = np.argwhere(part_sizes == -1)
69 if undef_indices.size > 1:
70 log.error('Desired split part sizes have more than one -1 element -- cannot deduce real sizes for them')
73 if undef_indices.size == 1:
74 undef_index = undef_indices[0]
75 part_sizes[undef_index] = 0
76 deduced_dim = input_data_node.shape[axis] - np.add.reduce(part_sizes)
78 log.error('Just deduced dimension for the split has negative value that means that split input shape and '
79 'desired parts are not compatible')
82 all_parts_size = np.add.reduce(part_sizes)
83 if all_parts_size != input_data_node.shape[axis]:
84 log.error("input.shape[{}] = {} != {} = sum of all parts in part_sizes".format(axis,
85 input_data_node.shape[axis],
90 if input_data_node.value is not None:
91 splitted = np.split(input_data_node.value, part_sizes_to_indices(part_sizes), axis)
93 # not all outputs from the split could be used so it is necessary to iterate over output edges and infer shape for
94 # necessary nodes only
95 for _, dst, edge_attrs in node.graph.out_edges(node.id, data=True):
96 out_port = edge_attrs['out']
97 out_node = node.out_node(out_port)
99 new_out_shape = input_data_node.shape.copy()
100 new_out_shape[axis] = part_sizes[out_port]
101 node.out_node(out_port).shape = new_out_shape
102 if splitted is not None:
103 out_node.value = splitted[out_port]
104 assert all(out_node.value.shape == out_node.shape)
106 assert not node.has_valid('axis') or node.axis == axis
108 # WARNING: != 4 is supposed to work for NHWC to NCHW translation only.
109 # if other global permutations happen this will fail
110 # TODO: redesign it to have this logic built in NHWC to NCHW translation pass; it requires
111 # additional attributes with layout to be propagated through the network
112 if len(input_data_node.shape) != 4 and node.has_valid('dim_attrs') and 'axis' in node.dim_attrs:
113 log.warning('Removed "axis" attribute from the scope of the model relayout pass because len(input.shape) == {} '
114 '!= 4 for node {}'.format(len(input_data_node.shape), node.soft_get('name')))
115 node.dim_attrs.remove('axis')
116 assert 'axis' not in node.dim_attrs
117 log.debug('output shapes after split: {}'.format([v.shape for k, v in node.out_nodes().items()]))
120 def tf_split_infer(node):
122 Partial infer of split node similar to Split op of TF.
124 # Two inputs: [split_dim, input]
125 assert len(node.in_nodes()) == 2, 'Node "{}" must have exactly two inputs'.format(node.soft_get('name'))
126 split_dim = node.in_node(0).value
127 if split_dim is None:
128 log.error('split_dim value for node {} is None. Cannot do shape inference.')
131 assert split_dim.ndim == 0, 'The split dimension for node "{}" must be a scalar.'.format(node.soft_get('name'))
132 split_dim = split_dim.item()
133 input = node.in_node(1)
135 if input.shape is None:
136 log.error('Input shape for node {} is not defined'.format(node.soft_get('name')))
139 log.debug('input shape for split: {}, should be split along {} dim'.format(input.shape, split_dim))
140 split_dim_size = input.shape[split_dim]
141 log.debug('split_dim_size type = {}'.format(type(split_dim_size)))
143 if split_dim_size % node.num_split != 0:
144 log.error("split_dim cannot be evenly divided by a given number of parts")
147 # split_dim is a numpy array, axis is split_dim[0]
148 log.debug('split_dim_size = {}, node.num_split = {}, div = {}, typeof div = {}'.format(
149 split_dim_size, node.num_split, split_dim_size / node.num_split, type(split_dim_size / node.num_split)))
150 split(input, node, split_dim, [int(split_dim_size / node.num_split)] * node.num_split)
151 node.graph.remove_edge(node.in_node(0).id, node.id)
152 node['input_port'] = 1
154 PermuteAttrs.create_permute_attrs(node, attrs=[('axis', 'input:1')])
157 def tf_split_v_infer(node: Node):
159 Partial infer of split node similar to SplitV op of TF.
162 if len(node.in_nodes()) == 1 and not (node.has_valid('axis') and node.has_valid('size_splits')):
165 if len(node.in_nodes()) == 3 and (node.has_valid('axis') or node.has_valid('size_splits')):
168 # Three inputs: [input, size_splits, split_dim)
169 if len(node.in_nodes()) == 3:
170 split_dim = node.in_node(2).value
171 assert split_dim.ndim == 0
172 split_dim = split_dim.item()
173 size_splits = node.in_node(1).value
174 node.graph.remove_edge(node.in_node(1).id, node.id)
175 node.graph.remove_edge(node.in_node(2).id, node.id)
177 split_dim = node.axis
178 size_splits = node.size_splits
180 if split_dim is None:
181 log.error('split_dim value for node {} is None. Cannot do shape inference.')
184 input = node.in_node(0)
185 if input.shape is None or size_splits is None:
186 log.error('input shape or size of splits are not defined for node {}'.format(node.soft_get('name')))
189 log.debug('split_dim = {}, input.shape = {}, size_splits.value = {}'.format(split_dim, input.shape, size_splits))
191 # split_dim is a numpy array, axis is split_dim
192 split(input, node, split_dim, size_splits)
194 PermuteAttrs.create_permute_attrs(node, attrs=[('axis', 'input:0')])
197 def tf_unpack_infer(node: Node):
198 if len(node.in_nodes()) != 1:
199 log.debug('Unpack node "{}" must have one input.'.format(node.name))
202 in_shape = node.in_node().shape
204 log.debug('Unpack node "{}" input node shape is not defined.'.format(node.name))
207 split_dim = node.axis
208 log.debug('input shape for unpack: {}, should be split along {} dim'.format(in_shape, split_dim))
209 split_dim_size = in_shape[split_dim]
210 log.debug('split_dim_size type = {}'.format(type(split_dim_size)))
212 if node.num_split is not None and node.num_split != split_dim_size:
213 log.debug('The unpack where num to unpack is not equal to the size of the dimension to unpack is not supported')
216 if node.num_split is None:
217 node.num_split = split_dim_size
219 if split_dim_size % node.num_split != 0:
220 log.error("split_dim cannot be evenly divided by a given number of parts")
223 split(node.in_node(), node, split_dim, [int(split_dim_size / node.num_split)] * node.num_split)
224 # node shapes will be squeezed in the separate pass