"""
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
def val(self, key, valtype, default=None):
attr = self.str(key, default)
+ attr = None if attr == 'None' else attr
if valtype is None:
return attr
else:
- if not isinstance(attr, valtype):
+ if not isinstance(attr, valtype) and attr is not None:
return valtype(attr)
else:
return attr
model_params._param_names = arg_keys
model_params._aux_names = aux_keys
return model_params
+
+
+def init_rnn_states(model_nodes):
+ states = {}
+ for i, node in enumerate(model_nodes):
+ if node['op'] == 'RNN':
+ for i in node['inputs'][2:]:
+ attrs = get_mxnet_layer_attrs(model_nodes[i[0]])
+ shape = attrs.tuple('__shape__', int, None)
+ if shape:
+ states.update({model_nodes[i[0]]['name']: shape})
+ return states
\ No newline at end of file