2 Copyright (c) 2018-2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
17 from mo.front.caffe.extractors.batchnorm import batch_norm_ext
18 from mo.front.caffe.extractors.concat import concat_ext
19 from mo.front.caffe.extractors.eltwise import eltwise_ext
20 from mo.front.caffe.extractors.inner_product import inner_product_ext
21 from mo.front.caffe.extractors.input import global_input_ext, input_ext
22 from mo.front.caffe.extractors.lrn import lrn_ext
23 from mo.front.caffe.extractors.native_caffe import native_caffe_node_extractor
24 from mo.front.caffe.extractors.permute import permute_ext
25 from mo.front.caffe.extractors.power import power_ext
26 from mo.front.caffe.extractors.relu import relu_ext
27 from mo.front.caffe.extractors.reshape import reshape_ext
28 from mo.front.caffe.extractors.roipooling import roipooling_ext
29 from mo.front.caffe.extractors.scale import scale_ext
30 from mo.front.caffe.extractors.slice import slice_ext
31 from mo.front.common.partial_infer.elemental import copy_shape_infer
32 from mo.front.common.register_custom_ops import extension_op_extractor
33 from mo.front.extractor import CaffePythonFrontExtractorOp
34 from mo.graph.graph import Node
35 from mo.ops.op import Op
36 from mo.utils.error import Error
37 from mo.utils.utils import refer_to_faq_msg
40 def node_pb_arg(pb_extractor):
41 return lambda node: pb_extractor(node.pb, node.model_pb)
45 Keys are names that appear as layer names in .prototxt.
46 Full list is available here: http://caffe.berkeleyvision.org/tutorial/layers.html
48 caffe_type_extractors = {
50 'input': node_pb_arg(input_ext),
51 'globalinput': node_pb_arg(global_input_ext),
54 'innerproduct': node_pb_arg(inner_product_ext),
55 'inner_product': node_pb_arg(inner_product_ext),
56 'dropout': node_pb_arg(lambda _, __: dict(op='Dropout', infer=copy_shape_infer)),
58 # Normalization Layers
59 'batchnorm': node_pb_arg(batch_norm_ext),
60 'lrn': node_pb_arg(lrn_ext),
63 'power': node_pb_arg(power_ext),
64 'relu': node_pb_arg(relu_ext),
65 'scale': node_pb_arg(scale_ext),
68 'concat': node_pb_arg(concat_ext),
69 'eltwise': node_pb_arg(eltwise_ext),
70 'reshape': node_pb_arg(reshape_ext),
71 'slice': node_pb_arg(slice_ext),
73 # Custom, implemented in IE, SSD-specific
74 'permute': node_pb_arg(permute_ext),
76 # Custom, implemented in IE, Fast-RCNN-specific
77 'roipooling': node_pb_arg(roipooling_ext),
81 def common_caffe_fields(node: Node) -> dict:
82 if node.has_valid('op') and node.op == 'Identity':
84 pb = node.pb if node.pb else node
86 if isinstance(layer_type, int):
87 layer_type = pb.LayerType.DESCRIPTOR.values_by_number[layer_type].name
88 layer_type = str(layer_type)
94 # generic code relies on op; it should be overridden by specific op extractor
96 'precision': 'FP32' # TODO use real precision derived from the model
100 def caffe_extractor(node: Node, lowered_keys_map: dict) -> (bool, dict):
101 if node.has_valid('op') and node.op == 'Identity':
103 result = common_caffe_fields(node)
107 layer_type = result['type'].lower()
108 if layer_type in lowered_keys_map:
109 layer_type = lowered_keys_map[layer_type]
110 assert layer_type in caffe_type_extractors
113 if name: # it is either standard or registered via CustomLayersMapping.xml
114 attrs = caffe_type_extractors[name](node)
115 # intentionally as Python registry if not found returns None
116 if attrs is not None:
121 raise Error('Found custom layer "{}". Model Optimizer does not support this layer. '.format(node.id) +
122 'Please, implement extension. ' +
123 refer_to_faq_msg(45))
125 if 'infer' not in result or not result['infer']:
126 result.update(native_caffe_node_extractor(node))
128 phase_attr = check_phase(node)
129 result.update(phase_attr)
130 return supported, result
133 def check_phase(node: Node):
134 if node.has_valid('pb') and hasattr(node.pb, 'include'):
135 for i in node.pb.include:
136 if hasattr(i, 'phase'):
137 return {'phase': i.phase}
141 def register_caffe_python_extractor(op: Op, name: str = None):
142 if not name and hasattr(op, 'op'):
145 raise Error("Can not register Op {}. Please, call function 'register_caffe_python_extractor'"
146 "with parameter 'name' .".format(op),
147 refer_to_faq_msg(87))
148 CaffePythonFrontExtractorOp.registered_ops[name] = lambda node: extension_op_extractor(node, op)