Publishing R5 content (#72)
[platform/upstream/dldt.git] / model-optimizer / mo / front / mxnet / extractor.py
1 """
2  Copyright (c) 2018 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16
17 from mo.front.mxnet.extractors.batchnorm import batch_norm_ext
18 from mo.front.mxnet.extractors.concat import concat_ext
19 from mo.front.mxnet.extractors.crop import crop_ext
20 from mo.front.mxnet.extractors.eltwise import eltwise_ext
21 from mo.front.mxnet.extractors.fully_connected import fully_connected_ext
22 from mo.front.mxnet.extractors.l2_normalization import l2_normalization_ext
23 from mo.front.mxnet.extractors.lrn import lrn_ext
24 from mo.front.mxnet.extractors.multibox_detection import multi_box_detection_ext
25 from mo.front.mxnet.extractors.multibox_prior import multi_box_prior_ext
26 from mo.front.mxnet.extractors.null import null_ext
27 from mo.front.mxnet.extractors.scaleshift import scale_shift_ext
28 from mo.front.mxnet.extractors.slice_axis import slice_axis_ext
29 from mo.front.mxnet.extractors.transpose import transpose_ext
30 from mo.front.mxnet.extractors.utils import get_mxnet_layer_attrs
31 from mo.graph.graph import Node
32 from mo.utils.error import Error
33 from mo.utils.utils import refer_to_faq_msg
34
35
36 def extractor_wrapper(mxnet_extractor):
37     return lambda node: mxnet_extractor(get_mxnet_layer_attrs(node.symbol_dict))
38
39
40 mxnet_op_extractors = {
41     'BatchNorm': extractor_wrapper(batch_norm_ext),
42     'Crop': extractor_wrapper(crop_ext),
43     'ScaleShift': extractor_wrapper(scale_shift_ext),
44     'slice_axis': extractor_wrapper(slice_axis_ext),
45     'null': lambda node: null_ext(node.symbol_dict),
46     'Concat': extractor_wrapper(concat_ext),
47     'elemwise_add': extractor_wrapper(lambda attrs: eltwise_ext(attrs, infer=lambda a, b: a + b, op_type="sum")),
48     'elemwise_mul': extractor_wrapper(lambda attrs: eltwise_ext(attrs, infer=lambda a, b: a * b, op_type="mul")),
49     '_Plus': extractor_wrapper(lambda attrs: eltwise_ext(attrs, infer=lambda a, b: a + b, op_type="sum")),
50     'FullyConnected': extractor_wrapper(fully_connected_ext),
51     'transpose': extractor_wrapper(transpose_ext),
52     'LRN': extractor_wrapper(lrn_ext),
53     'L2Normalization': extractor_wrapper(l2_normalization_ext),
54     '_contrib_MultiBoxPrior': extractor_wrapper(multi_box_prior_ext),
55     '_contrib_MultiBoxDetection': extractor_wrapper(multi_box_detection_ext),
56     'broadcast_add': extractor_wrapper(lambda attrs: eltwise_ext(attrs, infer=lambda a, b: a + b, op_type="sum")),
57 }
58
59
60 def common_mxnet_fields(node: Node):
61     return {
62         'kind': 'op',
63         'name': node['symbol_dict']['name'],
64         'type': node['symbol_dict']['op'],
65         'op': node['symbol_dict']['op'],
66         'infer': None,
67         'precision': 'FP32'
68     }
69
70
71 def mxnet_op_extractor(node: Node):
72     result = common_mxnet_fields(node)
73     op = result['op']
74     if op not in mxnet_op_extractors:
75         raise Error(
76             "Operation '{}' not supported. Please register it as custom op. " +
77             refer_to_faq_msg(86),
78             op)
79     result_attr = mxnet_op_extractors[op](node)
80
81     if result_attr is None:
82         raise Error('Model Optimizer does not support layer "{}". Please, implement extension. '.format(node.name) +
83                     refer_to_faq_msg(45))
84
85     result.update(result_attr)
86     supported = bool(result_attr)
87     return supported, result