2 Copyright (c) 2018-2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
21 from mo.front.caffe.extractors.utils import get_canonical_axis_index
22 from mo.graph.graph import Node, Graph
23 from mo.ops.op import Op, PermuteAttrs
29 def __init__(self, graph: Graph, attrs: dict):
30 super().__init__(graph, {
34 'infer': __class__.infer,
39 def backend_attrs(self):
41 ('axis', lambda node: None if not node.has_valid('axis') else ','.join(map(str, node.axis))),
42 ('offset', lambda node: None if not node.has_valid('offset') else ','.join(map(str, node.offset))),
44 ('dim', lambda node: None if not node.has_valid('dim') else ','.join(map(str, node.dim))),
46 ('crop_begin', lambda node: None if not node.has_valid('crop_begin') else ','.join(map(str, node.crop_begin))),
47 ('crop_end', lambda node: None if not node.has_valid('crop_end') else ','.join(map(str, node.crop_end))),
51 def infer(node: Node):
53 Crops the shape of the output blob according to input ones be specified params.
54 Detailed Crop description can be found in IR Catalog specification.
55 In short: crop layer can be represented in three ways:
56 1. Two inputs, where the shape of the second input is crop dim (axis and offset attrs)
57 2. One input and dim, axis and offset attributes.
58 3. Ont input and axis, crop_begin and crop_end attributes
61 input_count = len(node.in_nodes())
64 Crop._two_inputs_infer(node)
65 elif input_count == 1:
66 Crop._one_input_infer(node)
68 log.error('Wrong number of input tensors ({}) in {}'.format(input_count, node.name))
72 def _one_input_infer(node: Node):
73 input_shape = np.array(node.in_node().shape)
75 if input_shape is None:
76 log.error('input_shape is none for {} node'.format(node.name))
79 if not node.has_valid('axis'):
80 log.error('axis attribute is missing for {} node. should be set in crop extractor'.format(node.name))
83 output_shape = input_shape
84 if node.has_valid('dim'):
85 if len(node.dim) != len(node.axis):
86 log.error('number of axis should match number of dim')
88 output_shape[node.axis] = node.dim
89 elif node.has_valid('crop_begin') and node.has_valid('crop_end'):
90 if len(node.crop_begin) != len(node.axis) or len(node.crop_end) != len(node.axis):
91 log.error('number of crop_begin/crop_end should match number of axis')
93 output_shape[node.axis] = output_shape[node.axis] - node.crop_begin - node.crop_end
95 log.error('Crop node {} should have either dim or crop_begin and crop_end attributes'.format(node.name))
98 node.out_node().shape = np.array(output_shape)
99 PermuteAttrs.create_permute_attrs(node, attrs=[('axis', 'input:0')])
102 def _two_inputs_infer(node: Node):
103 N = len(node.in_nodes())
105 shapes = [node.in_node(i).shape for i in range(N)]
106 if any(s is None for s in shapes):
107 log.error('Not all input shapes were defined for {} node'.format(node.name))
110 if not node.has_valid('axis'):
111 log.error('axis attribute is missing for {} node. should be set in crop extractor'.format(node.name))
114 if not node.has_valid('offset'):
115 log.error('offset attribute is missing for {} node. should be set in crop extractor'.format(node.name))
118 input_shape = np.array(shapes[0])
119 start_axis = get_canonical_axis_index(input_shape, node.axis)
120 node.axis = start_axis
122 reference_shape = np.array(shapes[1])
123 input_dim = input_shape.size
125 # set new shape to current shape
126 new_shape = input_shape.copy()
131 for i in range(0, input_dim):
133 new_shape[i] = input_shape[i]
137 if len(node.offset) == 1:
138 crop_offset = node.offset[0]
139 elif len(node.offset) > 1:
140 crop_offset = node.offset[i - start_axis]
142 if input_shape[i] - crop_offset < reference_shape[i]:
143 log.error('The crop for dimension is out of bounds in ' + node.node)
146 dim.append(reference_shape[i])
148 ir_offset.append(crop_offset)
149 new_shape[i] = reference_shape[i]
152 node.offset = ir_offset
154 node.out_node().shape = new_shape
155 PermuteAttrs.create_permute_attrs(node, attrs=[('axis', 'input:0')])