2 Copyright (c) 2018 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
19 from mo.utils.error import Error
20 from mo.utils.str_to import StrTo
21 from mo.utils.utils import refer_to_faq_msg
24 class AttrDictionary(object):
25 def __init__(self, dict):
29 return not self._dict is None
34 def add_dict(self, dict):
35 self._dict.update(dict)
37 def set(self, key, value):
38 self._dict[key] = value
40 def remove(self, key):
44 def str(self, key, default=None):
47 raise ValueError("Missing required parameter: " + key)
49 return self._dict[key]
52 def bool(self, key, default=None):
53 attr = self.str(key, default)
54 if isinstance(attr, str):
56 return bool(int(attr))
57 return StrTo.bool(attr)
61 def float(self, key, default=None):
62 return self.val(key, float, default)
64 def int(self, key, default=None):
65 return self.val(key, int, default)
67 def tuple(self, key, valtype=str, default=None):
68 attr = self.str(key, default)
71 if isinstance(attr, str):
72 if (not '(' in attr and not ')' in attr) and (not '[' in attr and not ']' in attr):
73 return (valtype(attr),)
74 if (not attr) or (not attr[1:-1].split(',')[0]):
75 return tuple([valtype(x) for x in default])
76 return StrTo.tuple(valtype, attr)
78 return tuple([valtype(x) for x in attr])
80 def list(self, key, valtype, default=None, sep=","):
81 attr = self.str(key, default)
82 if isinstance(attr, list):
83 attr = [valtype(x) for x in attr]
86 return StrTo.list(attr, valtype, sep)
88 def val(self, key, valtype, default=None):
89 attr = self.str(key, default)
93 if not isinstance(attr, valtype):
102 return key in self._dict
105 def get_mxnet_node_edges(node: dict, node_id: [int, str], nodes_list: list, index_node_key: dict):
107 for in_port, src_node_id in enumerate(node['inputs']):
108 src_node = src_node_id[0]
109 dest_port = src_node_id[1]
113 # debug anchor for name of tensor consumed at this input port
114 'fw_tensor_debug_info': [(nodes_list[src_node]['name'], src_node_id[1])],
116 'out_attrs': ['out'],
117 'data_attrs': ['fw_tensor_debug_info']
119 edge = (index_node_key[src_node], index_node_key[node_id], edge_attrs)
120 edge_list.append(edge)
124 def get_mxnet_layer_attrs(json_dic: dict):
126 if 'attr' in json_dic:
128 elif 'attrs' in json_dic:
130 return AttrDictionary(json_dic[attr] if attr in json_dic else {})
133 def get_json_layer_attrs(json_dic):
135 if 'attr' in json_dic:
137 elif 'attrs' in json_dic:
139 return json_dic[attr]
142 def load_params(input_model, data_names = ('data',)):
147 file_format = input_model.split('.')[-1]
148 loaded_weight = mx.nd.load(input_model)
149 if file_format == 'params':
150 for key in loaded_weight:
151 keys = key.split(':')
152 if len(keys)>1 and 'aux' == keys[0]:
153 aux_keys.append(keys[1])
154 aux_params[keys[1]] = loaded_weight[key]
155 elif len(keys)>1 and 'arg' == keys[0]:
156 arg_keys.append(keys[1])
157 arg_params[keys[1]] = loaded_weight[key]
160 arg_params[key] = loaded_weight[key]
161 elif file_format == 'nd':
162 for key in loaded_weight:
163 if 'auxs' in input_model:
165 aux_params[key] = loaded_weight[key]
166 elif 'args' in input_model:
168 arg_params[key] = loaded_weight[key]
171 'Unsupported Input model file type {}. Model Optimizer support only .params and .nd files format. ' +
172 refer_to_faq_msg(85), file_format)
174 data = mx.sym.Variable(data_names[0])
175 model_params = mx.mod.Module(data, data_names=(data_names[0],), label_names=(data_names[0],))
176 model_params._arg_params = arg_params
177 model_params._aux_params = aux_params
178 model_params._param_names = arg_keys
179 model_params._aux_names = aux_keys