Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / extensions / front / tf / conv_ext.py
1 """
2  Copyright (c) 2018-2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16 import numpy as np
17
18 from mo.front.common.partial_infer.utils import convert_tf_padding_to_str, int64_array
19 from mo.front.extractor import FrontExtractorOp
20 from mo.front.tf.extractors.utils import tf_data_format_spatial, tf_data_format_channel, tf_data_format_batch, \
21     tf_int_list
22 from mo.ops.convolution import Convolution
23 from mo.ops.op import PermuteAttrs
24
25
26 class Conv2DFrontExtractor(FrontExtractorOp):
27     op = 'Conv2D'
28     enabled = True
29
30     @staticmethod
31     def extract(node):
32         attrs = tf_create_attrs(node, 2, 3)
33         attrs.update({'op': __class__.op,
34                       'get_group': lambda node: 1,
35                       'get_output_feature_dim': lambda node: node.kernel_shape[node.output_feature_channel],
36                       'get_weights_permute': PermuteAttrs.Permutation(perm=int64_array([3, 2, 0, 1]),
37                                                                       inv=int64_array([2, 3, 1, 0]))
38                       })
39
40         # update the attributes of the node
41         Convolution.update_node_stat(node, attrs)
42         return __class__.enabled
43
44
45 class DepthwiseConv2dNativeFrontExtractor(FrontExtractorOp):
46     op = 'DepthwiseConv2dNative'
47     enabled = True
48
49     @staticmethod
50     def extract(node):
51         attrs = tf_create_attrs(node, 2, 2)
52         attrs.update({'op': __class__.op,
53                       'kernel_spatial_idx': np.array([0, 1], dtype=np.int64),
54                       'get_group': lambda node: node.kernel_shape[node.output_feature_channel],
55                       'get_output_feature_dim': lambda node: node.kernel_shape[-1] * node.kernel_shape[-2],
56                       'get_weights_permute': PermuteAttrs.Permutation(perm=int64_array([2, 3, 0, 1]),
57                                                                       inv=int64_array([2, 3, 0, 1]))
58                       })
59
60         # update the attributes of the node
61         Convolution.update_node_stat(node, attrs)
62         return __class__.enabled
63
64
65 class Conv3DFrontExtractor(FrontExtractorOp):
66     op = 'Conv3D'
67     enabled = True
68
69     @staticmethod
70     def extract(node):
71         attrs = tf_create_attrs(node, 3, 4)
72         attrs.update({'op': __class__.op,
73                       'get_group': lambda node: 1,
74                       'get_output_feature_dim': lambda node: node.kernel_shape[node.output_feature_channel],
75                       'get_weights_permute': PermuteAttrs.Permutation(perm=int64_array([4, 3, 0, 1, 2]),
76                                                                       inv=int64_array([2, 3, 4, 1, 0]))
77                       })
78
79         # update the attributes of the node
80         Convolution.update_node_stat(node, attrs)
81         return __class__.enabled
82
83
84 def tf_create_attrs(node, input_feature_channel, output_feature_channel):
85     data_format = node.pb.attr["data_format"]
86     dilations = tf_int_list(node.pb.attr["dilations"].list)
87     if len(dilations) == 0:
88         dilations = None
89
90     attrs = {
91         'type': 'Convolution',
92         'auto_pad': convert_tf_padding_to_str(node.pb.attr['padding']),
93         'bias_addable': True,
94         'bias_term': False,
95         'dilation': dilations,
96         'stride': tf_int_list(node.pb.attr["strides"].list),
97
98         'channel_dims': tf_data_format_channel(data_format),
99         'batch_dims': tf_data_format_batch(data_format),
100
101         'input_feature_channel': input_feature_channel,
102         'output_feature_channel': output_feature_channel,
103         'layout': data_format.s.decode(),
104
105         # get_group and get_output_feature_dim are special attrs that stores lambdas ( lambda node, kernel_shape:...)
106         # this attrs calls in infer function to calculate output feature dimension and group attr
107         'get_group': None,  # lambda should return group attr for given node
108         'get_output_feature_dim': None,  # lamda should return output feature dimension
109     }
110
111     return attrs