Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / extensions / front / caffe / binary_conv_ext.py
1 """
2  Copyright (c) 2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16
17 from extensions.front.caffe.conv_ext import conv_create_attrs, conv_set_params
18 from mo.front.caffe.extractors.utils import weights_biases
19 from mo.front.common.extractors.utils import layout_attrs
20 from mo.front.extractor import FrontExtractorOp
21 from mo.ops.convolution import Convolution
22 from mo.utils.error import Error
23
24
25 class ConvFrontExtractor(FrontExtractorOp):
26     op = 'ConvolutionBinary'
27     enabled = True
28
29     @staticmethod
30     def extract(node):
31         proto_layer, model_layer = node.pb, node.model_pb
32
33         if not proto_layer:
34             raise Error('Protobuf layer can not be empty')
35
36         conv_param = proto_layer.convolution_param
37         conv_type = 'ConvND' if len(proto_layer.bottom) > 1 else 'Conv2D'
38
39         params = conv_set_params(conv_param, conv_type)
40         attrs = conv_create_attrs(params)
41         attrs.update({'op': __class__.op,
42                       'get_group': lambda node: node.group,
43                       'get_output_feature_dim': lambda node: node.output
44                       })
45
46         # Embed weights and biases as attributes
47         # It will be moved to a separate nodes in special pass
48         attrs.update(
49             weights_biases(conv_param.bias_term, model_layer, start_index=len(proto_layer.bottom), proto=conv_param))
50         attrs.update(layout_attrs())
51
52         # update the attributes of the node
53         Convolution.update_node_stat(node, attrs)
54         return __class__.enabled
55