Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / mo / ops / crop.py
1 """
2  Copyright (c) 2018-2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16
17 import logging as log
18
19 import numpy as np
20
21 from mo.front.caffe.extractors.utils import get_canonical_axis_index
22 from mo.graph.graph import Node, Graph
23 from mo.ops.op import Op, PermuteAttrs
24
25
26 class Crop(Op):
27     op = 'Crop'
28
29     def __init__(self, graph: Graph, attrs: dict):
30         super().__init__(graph, {
31             'kind': 'op',
32             'type': __class__.op,
33             'op': __class__.op,
34             'infer': __class__.infer,
35             'in_ports_count': 2,
36             'out_ports_count': 1,
37         }, attrs)
38
39     def backend_attrs(self):
40         return [
41             ('axis', lambda node: None if not node.has_valid('axis') else ','.join(map(str, node.axis))),
42             ('offset', lambda node: None if not node.has_valid('offset') else ','.join(map(str, node.offset))),
43
44             ('dim', lambda node: None if not node.has_valid('dim') else ','.join(map(str, node.dim))),
45
46             ('crop_begin', lambda node: None if not node.has_valid('crop_begin') else ','.join(map(str, node.crop_begin))),
47             ('crop_end', lambda node: None if not node.has_valid('crop_end') else ','.join(map(str, node.crop_end))),
48         ]
49
50     @staticmethod
51     def infer(node: Node):
52         """
53         Crops the shape of the output blob according to input ones be specified params.
54         Detailed Crop description can be found in IR Catalog specification.
55         In short: crop layer can be represented in three ways:
56             1. Two inputs, where the shape of the second input is crop dim (axis and offset attrs)
57             2. One input and dim, axis and offset attributes.
58             3. Ont input and axis, crop_begin and crop_end attributes
59         """
60
61         input_count = len(node.in_nodes())
62
63         if input_count == 2:
64             Crop._two_inputs_infer(node)
65         elif input_count == 1:
66             Crop._one_input_infer(node)
67         else:
68             log.error('Wrong number of input tensors ({}) in {}'.format(input_count, node.name))
69             return
70
71     @staticmethod
72     def _one_input_infer(node: Node):
73         input_shape = np.array(node.in_node().shape)
74
75         if input_shape is None:
76             log.error('input_shape is none for {} node'.format(node.name))
77             return
78
79         if not node.has_valid('axis'):
80             log.error('axis attribute is missing for {} node. should be set in crop extractor'.format(node.name))
81             return
82
83         output_shape = input_shape
84         if node.has_valid('dim'):
85             if len(node.dim) != len(node.axis):
86                 log.error('number of axis should match number of dim')
87                 return
88             output_shape[node.axis] = node.dim
89         elif node.has_valid('crop_begin') and node.has_valid('crop_end'):
90             if len(node.crop_begin) != len(node.axis) or len(node.crop_end) != len(node.axis):
91                 log.error('number of crop_begin/crop_end should match number of axis')
92                 return
93             output_shape[node.axis] = output_shape[node.axis] - node.crop_begin - node.crop_end
94         else:
95             log.error('Crop node {} should have either dim or crop_begin and crop_end attributes'.format(node.name))
96             return
97
98         node.out_node().shape = np.array(output_shape)
99         PermuteAttrs.create_permute_attrs(node, attrs=[('axis', 'input:0')])
100
101     @staticmethod
102     def _two_inputs_infer(node: Node):
103         N = len(node.in_nodes())
104
105         shapes = [node.in_node(i).shape for i in range(N)]
106         if any(s is None for s in shapes):
107             log.error('Not all input shapes were defined for {} node'.format(node.name))
108             return
109
110         if not node.has_valid('axis'):
111             log.error('axis attribute is missing for {} node. should be set in crop extractor'.format(node.name))
112             return
113
114         if not node.has_valid('offset'):
115             log.error('offset attribute is missing for {} node. should be set in crop extractor'.format(node.name))
116             return
117
118         input_shape = np.array(shapes[0])
119         start_axis = get_canonical_axis_index(input_shape, node.axis)
120         node.axis = start_axis
121
122         reference_shape = np.array(shapes[1])
123         input_dim = input_shape.size
124
125         # set new shape to current shape
126         new_shape = input_shape.copy()
127         ir_axis = []
128         ir_offset = []
129         dim = []
130
131         for i in range(0, input_dim):
132             if i < start_axis:
133                 new_shape[i] = input_shape[i]
134                 continue
135
136             crop_offset = 0
137             if len(node.offset) == 1:
138                 crop_offset = node.offset[0]
139             elif len(node.offset) > 1:
140                 crop_offset = node.offset[i - start_axis]
141
142             if input_shape[i] - crop_offset < reference_shape[i]:
143                 log.error('The crop for dimension is out of bounds in ' + node.node)
144                 return
145
146             dim.append(reference_shape[i])
147             ir_axis.append(i)
148             ir_offset.append(crop_offset)
149             new_shape[i] = reference_shape[i]
150
151         node.axis = ir_axis
152         node.offset = ir_offset
153         node['dim'] = dim
154         node.out_node().shape = new_shape
155         PermuteAttrs.create_permute_attrs(node, attrs=[('axis', 'input:0')])