Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / mo / front / mxnet / extractors / utils.py
index 8c8d23d..3358ccd 100644 (file)
@@ -1,5 +1,5 @@
 """
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
 
  Licensed under the Apache License, Version 2.0 (the "License");
  you may not use this file except in compliance with the License.
@@ -87,10 +87,11 @@ class AttrDictionary(object):
 
     def val(self, key, valtype, default=None):
         attr = self.str(key, default)
+        attr = None if attr == 'None' else attr
         if valtype is None:
             return attr
         else:
-            if not isinstance(attr, valtype):
+            if not isinstance(attr, valtype) and attr is not None:
                 return valtype(attr)
             else:
                 return attr
@@ -178,3 +179,15 @@ def load_params(input_model, data_names = ('data',)):
     model_params._param_names = arg_keys
     model_params._aux_names = aux_keys
     return model_params
+
+
+def init_rnn_states(model_nodes):
+    states = {}
+    for i, node in enumerate(model_nodes):
+        if node['op'] == 'RNN':
+            for i in node['inputs'][2:]:
+                attrs = get_mxnet_layer_attrs(model_nodes[i[0]])
+                shape = attrs.tuple('__shape__', int, None)
+                if shape:
+                    states.update({model_nodes[i[0]]['name']: shape})
+    return states
\ No newline at end of file