Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / mo / front / mxnet / extractors / utils.py
1 """
2  Copyright (c) 2018-2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16
17 import mxnet as mx
18
19 from mo.utils.error import Error
20 from mo.utils.str_to import StrTo
21 from mo.utils.utils import refer_to_faq_msg
22
23
24 class AttrDictionary(object):
25     def __init__(self, dict):
26         self._dict = dict
27
28     def is_valid(self):
29         return not self._dict is None
30
31     def dict(self):
32         return self._dict
33
34     def add_dict(self, dict):
35         self._dict.update(dict)
36
37     def set(self, key, value):
38         self._dict[key] = value
39
40     def remove(self, key):
41         if key in self._dict:
42             del self._dict[key]
43
44     def str(self, key, default=None):
45         if not self.is_valid:
46             if default is None:
47                 raise ValueError("Missing required parameter: " + key)
48         if key in self._dict:
49             return self._dict[key]
50         return default
51
52     def bool(self, key, default=None):
53         attr = self.str(key, default)
54         if isinstance(attr, str):
55             if attr.isdigit():
56                 return bool(int(attr))
57             return StrTo.bool(attr)
58         else:
59             return attr
60
61     def float(self, key, default=None):
62         return self.val(key, float, default)
63
64     def int(self, key, default=None):
65         return self.val(key, int, default)
66
67     def tuple(self, key, valtype=str, default=None):
68         attr = self.str(key, default)
69         if attr is None:
70             return default
71         if isinstance(attr, str):
72             if (not '(' in attr and not ')' in attr) and (not '[' in attr and not ']' in attr):
73                 return (valtype(attr),)
74             if (not attr) or (not attr[1:-1].split(',')[0]):
75                 return tuple([valtype(x) for x in default])
76             return StrTo.tuple(valtype, attr)
77         else:
78             return tuple([valtype(x) for x in attr])
79
80     def list(self, key, valtype, default=None, sep=","):
81         attr = self.str(key, default)
82         if isinstance(attr, list):
83             attr = [valtype(x) for x in attr]
84             return attr
85         else:
86             return StrTo.list(attr, valtype, sep)
87
88     def val(self, key, valtype, default=None):
89         attr = self.str(key, default)
90         attr = None if attr == 'None' else attr
91         if valtype is None:
92             return attr
93         else:
94             if not isinstance(attr, valtype) and attr is not None:
95                 return valtype(attr)
96             else:
97                 return attr
98
99     def has(self, key):
100         if not self.is_valid:
101             return False
102         else:
103             return key in self._dict
104
105
106 def get_mxnet_node_edges(node: dict, node_id: [int, str], nodes_list: list, index_node_key: dict):
107     edge_list = []
108     for in_port, src_node_id in enumerate(node['inputs']):
109         src_node = src_node_id[0]
110         dest_port = src_node_id[1]
111         edge_attrs = {
112             'in': in_port,
113             'out': dest_port,
114             # debug anchor for name of tensor consumed at this input port
115             'fw_tensor_debug_info': [(nodes_list[src_node]['name'], src_node_id[1])],
116             'in_attrs': ['in'],
117             'out_attrs': ['out'],
118             'data_attrs': ['fw_tensor_debug_info']
119         }
120         edge = (index_node_key[src_node], index_node_key[node_id], edge_attrs)
121         edge_list.append(edge)
122     return edge_list
123
124
125 def get_mxnet_layer_attrs(json_dic: dict):
126     attr = 'param'
127     if 'attr' in json_dic:
128         attr = 'attr'
129     elif 'attrs' in json_dic:
130         attr = 'attrs'
131     return AttrDictionary(json_dic[attr] if attr in json_dic else {})
132
133
134 def get_json_layer_attrs(json_dic):
135     attr = 'param'
136     if 'attr' in json_dic:
137         attr = 'attr'
138     elif 'attrs' in json_dic:
139         attr = 'attrs'
140     return json_dic[attr]
141
142
143 def load_params(input_model, data_names = ('data',)):
144     arg_params = {}
145     aux_params = {}
146     arg_keys = []
147     aux_keys = []
148     file_format = input_model.split('.')[-1]
149     loaded_weight = mx.nd.load(input_model)
150     if file_format == 'params':
151         for key in loaded_weight:
152             keys = key.split(':')
153             if len(keys)>1 and 'aux' == keys[0]:
154                 aux_keys.append(keys[1])
155                 aux_params[keys[1]] = loaded_weight[key]
156             elif len(keys)>1 and 'arg' == keys[0]:
157                 arg_keys.append(keys[1])
158                 arg_params[keys[1]] = loaded_weight[key]
159             else:
160                 arg_keys.append(key)
161                 arg_params[key] = loaded_weight[key]
162     elif file_format == 'nd':
163         for key in loaded_weight:
164             if 'auxs' in input_model:
165                 aux_keys.append(key)
166                 aux_params[key] = loaded_weight[key]
167             elif 'args' in input_model:
168                 arg_keys.append(key)
169                 arg_params[key] = loaded_weight[key]
170     else:
171         raise Error(
172             'Unsupported Input model file type {}. Model Optimizer support only .params and .nd files format. ' +
173             refer_to_faq_msg(85), file_format)
174
175     data = mx.sym.Variable(data_names[0])
176     model_params = mx.mod.Module(data, data_names=(data_names[0],), label_names=(data_names[0],))
177     model_params._arg_params = arg_params
178     model_params._aux_params = aux_params
179     model_params._param_names = arg_keys
180     model_params._aux_names = aux_keys
181     return model_params
182
183
184 def init_rnn_states(model_nodes):
185     states = {}
186     for i, node in enumerate(model_nodes):
187         if node['op'] == 'RNN':
188             for i in node['inputs'][2:]:
189                 attrs = get_mxnet_layer_attrs(model_nodes[i[0]])
190                 shape = attrs.tuple('__shape__', int, None)
191                 if shape:
192                     states.update({model_nodes[i[0]]['name']: shape})
193     return states