Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / compiler / one-cmds / one-import-onnx
index 24edea6..b9c773b 100644 (file)
@@ -41,6 +41,89 @@ import onelib.utils as oneutils
 sys.tracebacklimit = 0
 
 
+# Class to rename input/output to prevent issues while import ONNX models
+class TidyIONames:
+    def __init__(self, onnx_model):
+        self.input_nodes = []
+        self.output_nodes = []
+        self.remap_inputs = []
+        self.remap_outputs = []
+        self.initializers = []
+        self.onnx_model = onnx_model
+        # some models may have initializers as inputs. ignore them.
+        for initializer in onnx_model.graph.initializer:
+            self.initializers.append(initializer.name)
+
+    def order(self):
+        for idx in range(0, len(self.onnx_model.graph.input)):
+            name = self.onnx_model.graph.input[idx].name
+            if not name in self.initializers:
+                self.input_nodes.append(name)
+                self.remap_inputs.append('i_' + format(idx + 1, '04d') + '_' + name)
+        for idx in range(0, len(self.onnx_model.graph.output)):
+            name = self.onnx_model.graph.output[idx].name
+            self.output_nodes.append(name)
+            self.remap_outputs.append('o_' + format(idx + 1, '04d') + '_' + name)
+
+    # exclude special characters in names
+    def sanitize(self):
+        for idx in range(0, len(self.onnx_model.graph.input)):
+            name = self.onnx_model.graph.input[idx].name
+            if not name in self.initializers:
+                if '.' in name or ':' in name or name[:1].isdigit():
+                    self.input_nodes.append(name)
+                    name_alt = name.replace('.', '_')
+                    name_alt = name_alt.replace(':', '_')
+                    if name_alt[:1].isdigit():
+                        name_alt = 'a_' + name_alt
+                    self.remap_inputs.append(name_alt)
+        for idx in range(0, len(self.onnx_model.graph.output)):
+            name = self.onnx_model.graph.output[idx].name
+            if '.' in name or ':' in name or name[:1].isdigit():
+                self.output_nodes.append(name)
+                name_alt = name.replace('.', '_')
+                name_alt = name_alt.replace(':', '_')
+                if name_alt[:1].isdigit():
+                    name_alt = 'a_' + name_alt
+                self.remap_outputs.append(name_alt)
+
+    def update(self):
+        # change names for graph input
+        for i in range(len(self.onnx_model.graph.input)):
+            if self.onnx_model.graph.input[i].name in self.input_nodes:
+                to_rename = self.onnx_model.graph.input[i].name
+                idx = self.input_nodes.index(to_rename)
+                self.onnx_model.graph.input[i].name = self.remap_inputs[idx]
+        # change names of all nodes in the graph
+        for i in range(len(self.onnx_model.graph.node)):
+            # check node.input is to change to remap_inputs or remap_outputs
+            for j in range(len(self.onnx_model.graph.node[i].input)):
+                if self.onnx_model.graph.node[i].input[j] in self.input_nodes:
+                    to_rename = self.onnx_model.graph.node[i].input[j]
+                    idx = self.input_nodes.index(to_rename)
+                    self.onnx_model.graph.node[i].input[j] = self.remap_inputs[idx]
+                if self.onnx_model.graph.node[i].input[j] in self.output_nodes:
+                    to_rename = self.onnx_model.graph.node[i].input[j]
+                    idx = self.output_nodes.index(to_rename)
+                    self.onnx_model.graph.node[i].input[j] = self.remap_outputs[idx]
+            # check node.output is to change to remap_inputs or remap_outputs
+            for j in range(len(self.onnx_model.graph.node[i].output)):
+                if self.onnx_model.graph.node[i].output[j] in self.output_nodes:
+                    to_rename = self.onnx_model.graph.node[i].output[j]
+                    idx = self.output_nodes.index(to_rename)
+                    self.onnx_model.graph.node[i].output[j] = self.remap_outputs[idx]
+                if self.onnx_model.graph.node[i].output[j] in self.input_nodes:
+                    to_rename = self.onnx_model.graph.node[i].output[j]
+                    idx = self.input_nodes.index(to_rename)
+                    self.onnx_model.graph.node[i].output[j] = self.remap_inputs[idx]
+        # change names for graph output
+        for i in range(len(self.onnx_model.graph.output)):
+            if self.onnx_model.graph.output[i].name in self.output_nodes:
+                to_rename = self.onnx_model.graph.output[i].name
+                idx = self.output_nodes.index(to_rename)
+                self.onnx_model.graph.output[i].name = self.remap_outputs[idx]
+
+
 def get_driver_cfg_section():
     return "one-import-onnx"
 
@@ -135,63 +218,32 @@ def _apply_verbosity(verbosity):
         os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
 
 
+# TF2.12.1 tries to sanitize special characters, '.:' and maybe others and then fails with
+# 'IndexError: tuple index out of range' error from somewhere else.
+# This method is to prevent this IndexError.
+def _sanitize_io_names(onnx_model):
+    sanitizer = TidyIONames(onnx_model)
+    sanitizer.sanitize()
+    sanitizer.update()
+
+
 # The index of input/output is added in front of the name. For example,
 # Original input names: 'a', 'c', 'b'
-# Renamed: '0001_a', '0002_c', '0003_b'
+# Renamed: 'i_0001_a', 'i_0002_c', 'i_0003_b'
 # This will preserve I/O order after import.
 def _remap_io_names(onnx_model):
     # gather existing name of I/O and generate new name of I/O in sort order
-    input_nodes = []
-    output_nodes = []
-    remap_inputs = []
-    remap_outputs = []
-    initializers = []
-    # some models may have initializers as inputs. ignore them.
-    for initializer in onnx_model.graph.initializer:
-        initializers.append(initializer.name)
-    for idx in range(0, len(onnx_model.graph.input)):
-        name = onnx_model.graph.input[idx].name
-        if not name in initializers:
-            input_nodes.append(name)
-            remap_inputs.append(format(idx + 1, '04d') + '_' + name)
-    for idx in range(0, len(onnx_model.graph.output)):
-        name = onnx_model.graph.output[idx].name
-        output_nodes.append(name)
-        remap_outputs.append(format(idx + 1, '04d') + '_' + name)
-    # change names for graph input
-    for i in range(len(onnx_model.graph.input)):
-        if onnx_model.graph.input[i].name in input_nodes:
-            to_rename = onnx_model.graph.input[i].name
-            idx = input_nodes.index(to_rename)
-            onnx_model.graph.input[i].name = remap_inputs[idx]
-    # change names of all nodes in the graph
-    for i in range(len(onnx_model.graph.node)):
-        # check node.input is to change to remap_inputs or remap_outputs
-        for j in range(len(onnx_model.graph.node[i].input)):
-            if onnx_model.graph.node[i].input[j] in input_nodes:
-                to_rename = onnx_model.graph.node[i].input[j]
-                idx = input_nodes.index(to_rename)
-                onnx_model.graph.node[i].input[j] = remap_inputs[idx]
-            if onnx_model.graph.node[i].input[j] in output_nodes:
-                to_rename = onnx_model.graph.node[i].input[j]
-                idx = output_nodes.index(to_rename)
-                onnx_model.graph.node[i].input[j] = remap_outputs[idx]
-        # check node.output is to change to remap_inputs or remap_outputs
-        for j in range(len(onnx_model.graph.node[i].output)):
-            if onnx_model.graph.node[i].output[j] in output_nodes:
-                to_rename = onnx_model.graph.node[i].output[j]
-                idx = output_nodes.index(to_rename)
-                onnx_model.graph.node[i].output[j] = remap_outputs[idx]
-            if onnx_model.graph.node[i].output[j] in input_nodes:
-                to_rename = onnx_model.graph.node[i].output[j]
-                idx = input_nodes.index(to_rename)
-                onnx_model.graph.node[i].output[j] = remap_inputs[idx]
-    # change names for graph output
-    for i in range(len(onnx_model.graph.output)):
-        if onnx_model.graph.output[i].name in output_nodes:
-            to_rename = onnx_model.graph.output[i].name
-            idx = output_nodes.index(to_rename)
-            onnx_model.graph.output[i].name = remap_outputs[idx]
+    remapper = TidyIONames(onnx_model)
+    remapper.order()
+    remapper.update()
+
+
+def _check_ext():
+    dir_path = os.path.dirname(os.path.realpath(__file__))
+    ext_path = os.path.join(dir_path, 'one-import-onnx-ext')
+    if (os.path.isfile(ext_path)):
+        return ext_path
+    return None
 
 
 def _convert(args):
@@ -200,6 +252,7 @@ def _convert(args):
     # get file path to log
     dir_path = os.path.dirname(os.path.realpath(__file__))
     logfile_path = os.path.realpath(args.output_path) + '.log'
+    ext_path = _check_ext()
 
     with open(logfile_path, 'wb') as f, tempfile.TemporaryDirectory() as tmpdir:
         # save intermediate
@@ -207,6 +260,7 @@ def _convert(args):
             tmpdir = os.path.dirname(logfile_path)
         # convert onnx to tf saved model
         onnx_model = onnx.load(getattr(args, 'input_path'))
+        _sanitize_io_names(onnx_model)
         if _onnx_legalizer_enabled:
             options = onnx_legalizer.LegalizeOptions
             options.unroll_rnn = oneutils.is_valid_attr(args, 'unroll_rnn')
@@ -219,6 +273,30 @@ def _convert(args):
                 fixed_path = os.path.join(tmpdir,
                                           os.path.splitext(basename)[0] + '~.onnx')
                 onnx.save(onnx_model, fixed_path)
+
+        if ext_path:
+            # save onnx_model to temporary alt file
+            basename = os.path.basename(getattr(args, 'input_path'))
+            alt_path = os.path.join(tmpdir, os.path.splitext(basename)[0] + '-alt.onnx')
+            onnx.save(onnx_model, alt_path)
+
+            # call extension with options
+            ext_cmd = [ext_path]
+            if oneutils.is_valid_attr(args, 'unroll_rnn'):
+                ext_cmd.append('--unroll_rnn')
+            if oneutils.is_valid_attr(args, 'unroll_lstm'):
+                ext_cmd.append('--unroll_lstm')
+            if oneutils.is_valid_attr(args, 'experimental_disable_batchmatmul_unfold'):
+                ext_cmd.append('--experimental_disable_batchmatmul_unfold')
+            if oneutils.is_valid_attr(args, 'save_intermediate'):
+                ext_cmd.append('--save_intermediate')
+            if oneutils.is_valid_attr(args, 'keep_io_order'):
+                ext_cmd.append('--keep_io_order')
+            ext_cmd.append(alt_path)
+            ext_cmd.append(getattr(args, 'output_path'))
+            oneutils.run(ext_cmd, logfile=f)
+            return
+
         tf_savedmodel = onnx_tf.backend.prepare(onnx_model)
 
         savedmodel_name = os.path.splitext(os.path.basename(