Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / extensions / front / onnx / conv_ext.py
1 """
2  Copyright (c) 2018-2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16
17 import numpy as np
18
19 from mo.front.common.extractors.utils import layout_attrs
20 from mo.front.extractor import FrontExtractorOp
21 from mo.front.onnx.extractors.utils import onnx_attr, get_onnx_autopad
22 from mo.ops.convolution import Convolution
23 from mo.utils.error import Error
24 from mo.front.common.partial_infer.utils import int64_array
25
26
27 class ConvFrontExtractor(FrontExtractorOp):
28     op = 'Conv'
29     enabled = True
30
31     @staticmethod
32     def extract(node):
33         # Extract pads attribute
34         # In case if pads is not specified it will be set in default (1) in infer function
35         pads = onnx_attr(node, 'pads', 'ints', default=None, dst_type=lambda x: np.array(x, dtype=np.int64))
36         assert pads is None or len(pads) % 2 == 0
37         final_pad = None
38         if pads is not None:
39             pads = pads.reshape([2, -1])
40             pads = np.transpose(pads)
41             final_pad = np.array([[0, 0], [0, 0], *pads], dtype=np.int64)
42
43         # Extract dilations attribute
44         # In case if dilations is not specified it will be set in default (1) in infer function
45         dilations = onnx_attr(node, 'dilations', 'ints', default=None, dst_type=lambda x: np.array(x, dtype=np.int64))
46         final_dilations = np.array([1, 1, *dilations], dtype=np.int64) if dilations is not None else None
47
48         # Extract dilations attribute
49         # In case if dilations is not specified it will be set in default (1) in infer function
50         strides = onnx_attr(node, 'strides', 'ints', default=None, dst_type=lambda x: np.array(x, dtype=np.int64))
51         final_strides = np.array([1, 1, *strides], dtype=np.int64) if strides is not None else None
52
53         kernel_shape = onnx_attr(node, 'kernel_shape', 'ints', default=None)
54         auto_pad = onnx_attr(node, 'auto_pad', 's', default=None, dst_type=get_onnx_autopad)
55         group = onnx_attr(node, 'group', 'i', default=1, dst_type=lambda x: np.array(x, dtype=np.int64))
56
57         attrs = {
58             'op': __class__.op,
59             'auto_pad': auto_pad,
60             'bias_addable': True,
61             'bias_term': None,
62             'pad': final_pad,
63             'pad_spatial_shape': np.array(pads, dtype=np.int64) if pads is not None else None,
64             'dilation': final_dilations,
65             'output_spatial_shape': None,
66             'output_shape': None,
67             'stride': final_strides,
68             'group': group,
69             'output': None,
70             'kernel_spatial': np.array(kernel_shape, dtype=np.int64) if kernel_shape is not None else None,
71
72             'input_feature_channel': 1,
73             'output_feature_channel': 0,
74             'kernel_spatial_idx': None,  # Will be calculated in infer function (np.array([2, 3]))
75
76             'spatial_dims': None,  # Will be calculated in infer function
77             'channel_dims': np.array([1], dtype=np.int64),
78             'batch_dims': np.array([0], dtype=np.int64),
79             'layout': 'NCHW'
80         }
81
82         # update the attributes of the node
83         Convolution.update_node_stat(node, attrs)
84         return __class__.enabled
85
86
87 class ConvTransposeFrontExtractor(FrontExtractorOp):
88     op = 'ConvTranspose'
89     enabled = True
90
91     @staticmethod
92     def get_pad(node, input_shape, kernel_shape):
93         # Reference: https://github.com/onnx/onnx/blob/master/docs/Operators.md#ConvTranspose
94         input_shape = node.in_node(0).shape
95         pad = np.zeros((len(input_shape), 2), dtype=np.int64)
96         total_padding = int64_array([node.stride[node.spatial_dims][x] *
97                                      (input_shape[node.spatial_dims][x] - 1) +
98                                      node.output_padding[node.spatial_dims][x] +
99                                      kernel_shape[node.kernel_spatial_idx][x] -
100                                      node.output_spatial_shape[x] for x in range(len(node.spatial_dims))])
101         if node.has_valid('auto_pad') and node.auto_pad != 'same_upper':
102             pad[node.spatial_dims] = int64_array(
103                 [[total_padding[x] / 2, total_padding[x] - (total_padding[x] // 2)] for x in
104                  range(len(node.spatial_dims))])
105         else:
106             pad[node.spatial_dims] = int64_array(
107                 [[total_padding[x] - (total_padding[x] // 2), total_padding[x] / 2] for x in
108                  range(len(node.spatial_dims))])
109         return pad
110
111     @staticmethod
112     def extract(node):
113         pads = onnx_attr(node, 'pads', 'ints', dst_type=int64_array)
114         auto_pad = onnx_attr(node, 'auto_pad', 's', default=None, dst_type=get_onnx_autopad)
115
116         if pads is not None:
117             if len(pads) % 2 != 0:
118                 raise Error(
119                     'ConvTranspose node {} specifies pads = {} which has odd number of elements. The model is not correct.',
120                     node.soft_get('name'),
121                     pads
122                 )
123             pads = pads.reshape([2, -1])
124             pads = np.transpose(pads)
125
126         final_pads = int64_array([[0, 0], [0, 0], *pads]) if pads is not None else None
127
128         dilations = onnx_attr(node, 'dilations', 'ints', default=None)
129         final_dilations = int64_array([1, 1, *dilations]) if dilations is not None else None
130
131         strides = onnx_attr(node, 'strides', 'ints', default=None)
132         final_strides = int64_array([1, 1, *strides]) if strides is not None else None
133
134         kernel_shape = onnx_attr(node, 'kernel_shape', 'ints', dst_type=int64_array)
135
136         if kernel_shape is None:
137             raise Error(
138                 'ConvTranspose node {} doesn\'t have explicitly defined kernel_shape. It is not supported.',
139                 node.soft_get('name')
140             )
141
142         output_padding = onnx_attr(node, 'output_padding', 'ints', default=None)
143         final_output_padding = int64_array([0, 0, *output_padding]) if output_padding is not None else None
144
145         output_shape = onnx_attr(node, 'output_shape', 'ints', default=None, dst_type=int64_array)
146
147         attrs = {
148             'type': 'Deconvolution',
149             'op': 'Deconv2D',
150             'auto_pad': auto_pad,
151             'bias_addable': True,
152             'bias_term': None,  # will be deduced later; not really needed
153             'pad': final_pads,
154             'dilation': final_dilations,
155             'output_spatial_shape': output_shape,
156             'output_shape': None,
157             'output_padding': final_output_padding,
158             'stride': final_strides,
159             'group': onnx_attr(node, 'group', 'i', default=1),
160             'output': None,
161
162             'spatial_dims': None,  # Will be calculated in infer function
163             'channel_dims': int64_array([1]),
164             'batch_dims': int64_array([0]),
165             'layout': 'NCHW',
166
167             'input_feature_channel': 0,
168             'output_feature_channel': 1,
169             'get_pad': ConvTransposeFrontExtractor.get_pad
170         }
171
172         # update the attributes of the node
173         Convolution.update_node_stat(node, attrs)
174         return __class__.enabled