Publishing R5 content (#72)
[platform/upstream/dldt.git] / model-optimizer / mo / front / mxnet / extractors / utils.py
1 """
2  Copyright (c) 2018 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16
17 import mxnet as mx
18
19 from mo.utils.error import Error
20 from mo.utils.str_to import StrTo
21 from mo.utils.utils import refer_to_faq_msg
22
23
24 class AttrDictionary(object):
25     def __init__(self, dict):
26         self._dict = dict
27
28     def is_valid(self):
29         return not self._dict is None
30
31     def dict(self):
32         return self._dict
33
34     def add_dict(self, dict):
35         self._dict.update(dict)
36
37     def set(self, key, value):
38         self._dict[key] = value
39
40     def remove(self, key):
41         if key in self._dict:
42             del self._dict[key]
43
44     def str(self, key, default=None):
45         if not self.is_valid:
46             if default is None:
47                 raise ValueError("Missing required parameter: " + key)
48         if key in self._dict:
49             return self._dict[key]
50         return default
51
52     def bool(self, key, default=None):
53         attr = self.str(key, default)
54         if isinstance(attr, str):
55             if attr.isdigit():
56                 return bool(int(attr))
57             return StrTo.bool(attr)
58         else:
59             return attr
60
61     def float(self, key, default=None):
62         return self.val(key, float, default)
63
64     def int(self, key, default=None):
65         return self.val(key, int, default)
66
67     def tuple(self, key, valtype=str, default=None):
68         attr = self.str(key, default)
69         if attr is None:
70             return default
71         if isinstance(attr, str):
72             if (not '(' in attr and not ')' in attr) and (not '[' in attr and not ']' in attr):
73                 return (valtype(attr),)
74             if (not attr) or (not attr[1:-1].split(',')[0]):
75                 return tuple([valtype(x) for x in default])
76             return StrTo.tuple(valtype, attr)
77         else:
78             return tuple([valtype(x) for x in attr])
79
80     def list(self, key, valtype, default=None, sep=","):
81         attr = self.str(key, default)
82         if isinstance(attr, list):
83             attr = [valtype(x) for x in attr]
84             return attr
85         else:
86             return StrTo.list(attr, valtype, sep)
87
88     def val(self, key, valtype, default=None):
89         attr = self.str(key, default)
90         if valtype is None:
91             return attr
92         else:
93             if not isinstance(attr, valtype):
94                 return valtype(attr)
95             else:
96                 return attr
97
98     def has(self, key):
99         if not self.is_valid:
100             return False
101         else:
102             return key in self._dict
103
104
105 def get_mxnet_node_edges(node: dict, node_id: [int, str], nodes_list: list, index_node_key: dict):
106     edge_list = []
107     for in_port, src_node_id in enumerate(node['inputs']):
108         src_node = src_node_id[0]
109         dest_port = src_node_id[1]
110         edge_attrs = {
111             'in': in_port,
112             'out': dest_port,
113             # debug anchor for name of tensor consumed at this input port
114             'fw_tensor_debug_info': [(nodes_list[src_node]['name'], src_node_id[1])],
115             'in_attrs': ['in'],
116             'out_attrs': ['out'],
117             'data_attrs': ['fw_tensor_debug_info']
118         }
119         edge = (index_node_key[src_node], index_node_key[node_id], edge_attrs)
120         edge_list.append(edge)
121     return edge_list
122
123
124 def get_mxnet_layer_attrs(json_dic: dict):
125     attr = 'param'
126     if 'attr' in json_dic:
127         attr = 'attr'
128     elif 'attrs' in json_dic:
129         attr = 'attrs'
130     return AttrDictionary(json_dic[attr] if attr in json_dic else {})
131
132
133 def get_json_layer_attrs(json_dic):
134     attr = 'param'
135     if 'attr' in json_dic:
136         attr = 'attr'
137     elif 'attrs' in json_dic:
138         attr = 'attrs'
139     return json_dic[attr]
140
141
142 def load_params(input_model, data_names = ('data',)):
143     arg_params = {}
144     aux_params = {}
145     arg_keys = []
146     aux_keys = []
147     file_format = input_model.split('.')[-1]
148     loaded_weight = mx.nd.load(input_model)
149     if file_format == 'params':
150         for key in loaded_weight:
151             keys = key.split(':')
152             if len(keys)>1 and 'aux' == keys[0]:
153                 aux_keys.append(keys[1])
154                 aux_params[keys[1]] = loaded_weight[key]
155             elif len(keys)>1 and 'arg' == keys[0]:
156                 arg_keys.append(keys[1])
157                 arg_params[keys[1]] = loaded_weight[key]
158             else:
159                 arg_keys.append(key)
160                 arg_params[key] = loaded_weight[key]
161     elif file_format == 'nd':
162         for key in loaded_weight:
163             if 'auxs' in input_model:
164                 aux_keys.append(key)
165                 aux_params[key] = loaded_weight[key]
166             elif 'args' in input_model:
167                 arg_keys.append(key)
168                 arg_params[key] = loaded_weight[key]
169     else:
170         raise Error(
171             'Unsupported Input model file type {}. Model Optimizer support only .params and .nd files format. ' +
172             refer_to_faq_msg(85), file_format)
173
174     data = mx.sym.Variable(data_names[0])
175     model_params = mx.mod.Module(data, data_names=(data_names[0],), label_names=(data_names[0],))
176     model_params._arg_params = arg_params
177     model_params._aux_params = aux_params
178     model_params._param_names = arg_keys
179     model_params._aux_names = aux_keys
180     return model_params