2 Copyright (c) 2018-2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
19 from mo.utils.error import Error
20 from mo.utils.str_to import StrTo
21 from mo.utils.utils import refer_to_faq_msg
24 class AttrDictionary(object):
25 def __init__(self, dict):
29 return not self._dict is None
34 def add_dict(self, dict):
35 self._dict.update(dict)
37 def set(self, key, value):
38 self._dict[key] = value
40 def remove(self, key):
44 def str(self, key, default=None):
47 raise ValueError("Missing required parameter: " + key)
49 return self._dict[key]
52 def bool(self, key, default=None):
53 attr = self.str(key, default)
54 if isinstance(attr, str):
56 return bool(int(attr))
57 return StrTo.bool(attr)
61 def float(self, key, default=None):
62 return self.val(key, float, default)
64 def int(self, key, default=None):
65 return self.val(key, int, default)
67 def tuple(self, key, valtype=str, default=None):
68 attr = self.str(key, default)
71 if isinstance(attr, str):
72 if (not '(' in attr and not ')' in attr) and (not '[' in attr and not ']' in attr):
73 return (valtype(attr),)
74 if (not attr) or (not attr[1:-1].split(',')[0]):
75 return tuple([valtype(x) for x in default])
76 return StrTo.tuple(valtype, attr)
78 return tuple([valtype(x) for x in attr])
80 def list(self, key, valtype, default=None, sep=","):
81 attr = self.str(key, default)
82 if isinstance(attr, list):
83 attr = [valtype(x) for x in attr]
86 return StrTo.list(attr, valtype, sep)
88 def val(self, key, valtype, default=None):
89 attr = self.str(key, default)
90 attr = None if attr == 'None' else attr
94 if not isinstance(attr, valtype) and attr is not None:
100 if not self.is_valid:
103 return key in self._dict
106 def get_mxnet_node_edges(node: dict, node_id: [int, str], nodes_list: list, index_node_key: dict):
108 for in_port, src_node_id in enumerate(node['inputs']):
109 src_node = src_node_id[0]
110 dest_port = src_node_id[1]
114 # debug anchor for name of tensor consumed at this input port
115 'fw_tensor_debug_info': [(nodes_list[src_node]['name'], src_node_id[1])],
117 'out_attrs': ['out'],
118 'data_attrs': ['fw_tensor_debug_info']
120 edge = (index_node_key[src_node], index_node_key[node_id], edge_attrs)
121 edge_list.append(edge)
125 def get_mxnet_layer_attrs(json_dic: dict):
127 if 'attr' in json_dic:
129 elif 'attrs' in json_dic:
131 return AttrDictionary(json_dic[attr] if attr in json_dic else {})
134 def get_json_layer_attrs(json_dic):
136 if 'attr' in json_dic:
138 elif 'attrs' in json_dic:
140 return json_dic[attr]
143 def load_params(input_model, data_names = ('data',)):
148 file_format = input_model.split('.')[-1]
149 loaded_weight = mx.nd.load(input_model)
150 if file_format == 'params':
151 for key in loaded_weight:
152 keys = key.split(':')
153 if len(keys)>1 and 'aux' == keys[0]:
154 aux_keys.append(keys[1])
155 aux_params[keys[1]] = loaded_weight[key]
156 elif len(keys)>1 and 'arg' == keys[0]:
157 arg_keys.append(keys[1])
158 arg_params[keys[1]] = loaded_weight[key]
161 arg_params[key] = loaded_weight[key]
162 elif file_format == 'nd':
163 for key in loaded_weight:
164 if 'auxs' in input_model:
166 aux_params[key] = loaded_weight[key]
167 elif 'args' in input_model:
169 arg_params[key] = loaded_weight[key]
172 'Unsupported Input model file type {}. Model Optimizer support only .params and .nd files format. ' +
173 refer_to_faq_msg(85), file_format)
175 data = mx.sym.Variable(data_names[0])
176 model_params = mx.mod.Module(data, data_names=(data_names[0],), label_names=(data_names[0],))
177 model_params._arg_params = arg_params
178 model_params._aux_params = aux_params
179 model_params._param_names = arg_keys
180 model_params._aux_names = aux_keys
184 def init_rnn_states(model_nodes):
186 for i, node in enumerate(model_nodes):
187 if node['op'] == 'RNN':
188 for i in node['inputs'][2:]:
189 attrs = get_mxnet_layer_attrs(model_nodes[i[0]])
190 shape = attrs.tuple('__shape__', int, None)
192 states.update({model_nodes[i[0]]['name']: shape})