sys.tracebacklimit = 0
+# Class to rename input/output to prevent issues while import ONNX models
+class TidyIONames:
+ def __init__(self, onnx_model):
+ self.input_nodes = []
+ self.output_nodes = []
+ self.remap_inputs = []
+ self.remap_outputs = []
+ self.initializers = []
+ self.onnx_model = onnx_model
+ # some models may have initializers as inputs. ignore them.
+ for initializer in onnx_model.graph.initializer:
+ self.initializers.append(initializer.name)
+
+ def order(self):
+ for idx in range(0, len(self.onnx_model.graph.input)):
+ name = self.onnx_model.graph.input[idx].name
+ if not name in self.initializers:
+ self.input_nodes.append(name)
+ self.remap_inputs.append('i_' + format(idx + 1, '04d') + '_' + name)
+ for idx in range(0, len(self.onnx_model.graph.output)):
+ name = self.onnx_model.graph.output[idx].name
+ self.output_nodes.append(name)
+ self.remap_outputs.append('o_' + format(idx + 1, '04d') + '_' + name)
+
+ # exclude special characters in names
+ def sanitize(self):
+ for idx in range(0, len(self.onnx_model.graph.input)):
+ name = self.onnx_model.graph.input[idx].name
+ if not name in self.initializers:
+ if '.' in name or ':' in name or name[:1].isdigit():
+ self.input_nodes.append(name)
+ name_alt = name.replace('.', '_')
+ name_alt = name_alt.replace(':', '_')
+ if name_alt[:1].isdigit():
+ name_alt = 'a_' + name_alt
+ self.remap_inputs.append(name_alt)
+ for idx in range(0, len(self.onnx_model.graph.output)):
+ name = self.onnx_model.graph.output[idx].name
+ if '.' in name or ':' in name or name[:1].isdigit():
+ self.output_nodes.append(name)
+ name_alt = name.replace('.', '_')
+ name_alt = name_alt.replace(':', '_')
+ if name_alt[:1].isdigit():
+ name_alt = 'a_' + name_alt
+ self.remap_outputs.append(name_alt)
+
+ def update(self):
+ # change names for graph input
+ for i in range(len(self.onnx_model.graph.input)):
+ if self.onnx_model.graph.input[i].name in self.input_nodes:
+ to_rename = self.onnx_model.graph.input[i].name
+ idx = self.input_nodes.index(to_rename)
+ self.onnx_model.graph.input[i].name = self.remap_inputs[idx]
+ # change names of all nodes in the graph
+ for i in range(len(self.onnx_model.graph.node)):
+ # check node.input is to change to remap_inputs or remap_outputs
+ for j in range(len(self.onnx_model.graph.node[i].input)):
+ if self.onnx_model.graph.node[i].input[j] in self.input_nodes:
+ to_rename = self.onnx_model.graph.node[i].input[j]
+ idx = self.input_nodes.index(to_rename)
+ self.onnx_model.graph.node[i].input[j] = self.remap_inputs[idx]
+ if self.onnx_model.graph.node[i].input[j] in self.output_nodes:
+ to_rename = self.onnx_model.graph.node[i].input[j]
+ idx = self.output_nodes.index(to_rename)
+ self.onnx_model.graph.node[i].input[j] = self.remap_outputs[idx]
+ # check node.output is to change to remap_inputs or remap_outputs
+ for j in range(len(self.onnx_model.graph.node[i].output)):
+ if self.onnx_model.graph.node[i].output[j] in self.output_nodes:
+ to_rename = self.onnx_model.graph.node[i].output[j]
+ idx = self.output_nodes.index(to_rename)
+ self.onnx_model.graph.node[i].output[j] = self.remap_outputs[idx]
+ if self.onnx_model.graph.node[i].output[j] in self.input_nodes:
+ to_rename = self.onnx_model.graph.node[i].output[j]
+ idx = self.input_nodes.index(to_rename)
+ self.onnx_model.graph.node[i].output[j] = self.remap_inputs[idx]
+ # change names for graph output
+ for i in range(len(self.onnx_model.graph.output)):
+ if self.onnx_model.graph.output[i].name in self.output_nodes:
+ to_rename = self.onnx_model.graph.output[i].name
+ idx = self.output_nodes.index(to_rename)
+ self.onnx_model.graph.output[i].name = self.remap_outputs[idx]
+
+
def get_driver_cfg_section():
return "one-import-onnx"
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
+# TF2.12.1 tries to sanitize special characters, '.:' and maybe others and then fails with
+# 'IndexError: tuple index out of range' error from somewhere else.
+# This method is to prevent this IndexError.
+def _sanitize_io_names(onnx_model):
+ sanitizer = TidyIONames(onnx_model)
+ sanitizer.sanitize()
+ sanitizer.update()
+
+
# The index of input/output is added in front of the name. For example,
# Original input names: 'a', 'c', 'b'
-# Renamed: '0001_a', '0002_c', '0003_b'
+# Renamed: 'i_0001_a', 'i_0002_c', 'i_0003_b'
# This will preserve I/O order after import.
def _remap_io_names(onnx_model):
# gather existing name of I/O and generate new name of I/O in sort order
- input_nodes = []
- output_nodes = []
- remap_inputs = []
- remap_outputs = []
- initializers = []
- # some models may have initializers as inputs. ignore them.
- for initializer in onnx_model.graph.initializer:
- initializers.append(initializer.name)
- for idx in range(0, len(onnx_model.graph.input)):
- name = onnx_model.graph.input[idx].name
- if not name in initializers:
- input_nodes.append(name)
- remap_inputs.append(format(idx + 1, '04d') + '_' + name)
- for idx in range(0, len(onnx_model.graph.output)):
- name = onnx_model.graph.output[idx].name
- output_nodes.append(name)
- remap_outputs.append(format(idx + 1, '04d') + '_' + name)
- # change names for graph input
- for i in range(len(onnx_model.graph.input)):
- if onnx_model.graph.input[i].name in input_nodes:
- to_rename = onnx_model.graph.input[i].name
- idx = input_nodes.index(to_rename)
- onnx_model.graph.input[i].name = remap_inputs[idx]
- # change names of all nodes in the graph
- for i in range(len(onnx_model.graph.node)):
- # check node.input is to change to remap_inputs or remap_outputs
- for j in range(len(onnx_model.graph.node[i].input)):
- if onnx_model.graph.node[i].input[j] in input_nodes:
- to_rename = onnx_model.graph.node[i].input[j]
- idx = input_nodes.index(to_rename)
- onnx_model.graph.node[i].input[j] = remap_inputs[idx]
- if onnx_model.graph.node[i].input[j] in output_nodes:
- to_rename = onnx_model.graph.node[i].input[j]
- idx = output_nodes.index(to_rename)
- onnx_model.graph.node[i].input[j] = remap_outputs[idx]
- # check node.output is to change to remap_inputs or remap_outputs
- for j in range(len(onnx_model.graph.node[i].output)):
- if onnx_model.graph.node[i].output[j] in output_nodes:
- to_rename = onnx_model.graph.node[i].output[j]
- idx = output_nodes.index(to_rename)
- onnx_model.graph.node[i].output[j] = remap_outputs[idx]
- if onnx_model.graph.node[i].output[j] in input_nodes:
- to_rename = onnx_model.graph.node[i].output[j]
- idx = input_nodes.index(to_rename)
- onnx_model.graph.node[i].output[j] = remap_inputs[idx]
- # change names for graph output
- for i in range(len(onnx_model.graph.output)):
- if onnx_model.graph.output[i].name in output_nodes:
- to_rename = onnx_model.graph.output[i].name
- idx = output_nodes.index(to_rename)
- onnx_model.graph.output[i].name = remap_outputs[idx]
+ remapper = TidyIONames(onnx_model)
+ remapper.order()
+ remapper.update()
+
+
+def _check_ext():
+ dir_path = os.path.dirname(os.path.realpath(__file__))
+ ext_path = os.path.join(dir_path, 'one-import-onnx-ext')
+ if (os.path.isfile(ext_path)):
+ return ext_path
+ return None
def _convert(args):
# get file path to log
dir_path = os.path.dirname(os.path.realpath(__file__))
logfile_path = os.path.realpath(args.output_path) + '.log'
+ ext_path = _check_ext()
with open(logfile_path, 'wb') as f, tempfile.TemporaryDirectory() as tmpdir:
# save intermediate
tmpdir = os.path.dirname(logfile_path)
# convert onnx to tf saved model
onnx_model = onnx.load(getattr(args, 'input_path'))
+ _sanitize_io_names(onnx_model)
if _onnx_legalizer_enabled:
options = onnx_legalizer.LegalizeOptions
options.unroll_rnn = oneutils.is_valid_attr(args, 'unroll_rnn')
fixed_path = os.path.join(tmpdir,
os.path.splitext(basename)[0] + '~.onnx')
onnx.save(onnx_model, fixed_path)
+
+ if ext_path:
+ # save onnx_model to temporary alt file
+ basename = os.path.basename(getattr(args, 'input_path'))
+ alt_path = os.path.join(tmpdir, os.path.splitext(basename)[0] + '-alt.onnx')
+ onnx.save(onnx_model, alt_path)
+
+ # call extension with options
+ ext_cmd = [ext_path]
+ if oneutils.is_valid_attr(args, 'unroll_rnn'):
+ ext_cmd.append('--unroll_rnn')
+ if oneutils.is_valid_attr(args, 'unroll_lstm'):
+ ext_cmd.append('--unroll_lstm')
+ if oneutils.is_valid_attr(args, 'experimental_disable_batchmatmul_unfold'):
+ ext_cmd.append('--experimental_disable_batchmatmul_unfold')
+ if oneutils.is_valid_attr(args, 'save_intermediate'):
+ ext_cmd.append('--save_intermediate')
+ if oneutils.is_valid_attr(args, 'keep_io_order'):
+ ext_cmd.append('--keep_io_order')
+ ext_cmd.append(alt_path)
+ ext_cmd.append(getattr(args, 'output_path'))
+ oneutils.run(ext_cmd, logfile=f)
+ return
+
tf_savedmodel = onnx_tf.backend.prepare(onnx_model)
savedmodel_name = os.path.splitext(os.path.basename(