Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / mo / front / common / partial_infer / crop.py
1 """
2  Copyright (c) 2018-2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16
17 import logging as log
18
19 import numpy as np
20
21 from mo.front.caffe.extractors.utils import get_canonical_axis_index
22
23
24 def crop_infer(node):
25     """
26     Crops the shape of the output blob according to input ones be specified params.
27     Node should have 2 input blobs - 1st blob is getting cropped by specified axis according
28     to the the 2nd (reference) blob.
29     The result blob is written to output node shape, and reference blob is removed from graph.
30     In order to save the reference dims, it is written to dims parameter.
31
32     Parameters
33     ----------
34     node
35
36
37     """
38     N = len(node.in_nodes())
39     if N < 2:
40         log.debug('Wrong number of bottom blobs in ' + node.node)
41         return
42
43     shapes = [node.in_node(i).shape for i in range(N)]
44     if any(s is None for s in shapes):
45         return
46
47     input_shape = np.array(shapes[0])
48     start_axis = get_canonical_axis_index(input_shape, node.axis)
49     node.axis = start_axis
50
51     reference_shape = np.array(shapes[1])
52     input_dim = input_shape.size
53
54     # set new shape to current shape
55     new_shape = input_shape.copy()
56     ir_axis = []
57     ir_offset = []
58     dim = []
59
60     for i in range(0, input_dim):
61         if i < start_axis:
62             new_shape[i] = input_shape[i]
63             continue
64
65         crop_offset = 0
66         if len(node.offset) == 1:
67             crop_offset = node.offset[0]
68         elif len(node.offset) > 1:
69             crop_offset = node.offset[i - start_axis]
70
71         if input_shape[i] - crop_offset < reference_shape[i]:
72             log.error('The crop for dimension is out of bounds in ' + node.node)
73             return
74
75         dim.append(reference_shape[i])
76         ir_axis.append(i)
77         ir_offset.append(crop_offset)
78         new_shape[i] = reference_shape[i]
79
80     node.axis = ir_axis
81     node.offset = ir_offset
82     node.dim = dim
83     node.out_node().shape = new_shape