add an utility function to check whether it's in the middle of onnx export or not
authorLu Fang <lufang@fb.com>
Tue, 9 Apr 2019 17:01:48 +0000 (10:01 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Tue, 9 Apr 2019 17:07:08 +0000 (10:07 -0700)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/19050

Reviewed By: yinghai

Differential Revision: D14849878

Pulled By: houseroad

fbshipit-source-id: a0a4a57f5f9f315ba1334edfccc9284a8099d17f

test/onnx/test_utility_funs.py [new file with mode: 0644]
torch/onnx/__init__.py
torch/onnx/utils.py

diff --git a/test/onnx/test_utility_funs.py b/test/onnx/test_utility_funs.py
new file mode 100644 (file)
index 0000000..fb9b99b
--- /dev/null
@@ -0,0 +1,30 @@
+from __future__ import absolute_import, division, print_function, unicode_literals
+from test_pytorch_common import TestCase, run_tests
+
+import torch
+import torch.onnx
+
+import io
+
+
+class TestOperators(TestCase):
+
+    def test_master_opset(self):
+        test_self = self
+
+        class MyModule(torch.nn.Module):
+            def forward(self, x):
+                test_self.assertTrue(torch.onnx.is_in_onnx_export())
+                raise ValueError
+                return x + 1
+
+        x = torch.randn(3, 4)
+        f = io.BytesIO()
+        try:
+            torch.onnx.export(MyModule(), x, f)
+        except ValueError:
+            self.assertFalse(torch.onnx.is_in_onnx_export())
+
+
+if __name__ == '__main__':
+    run_tests()
index 5a9451e..519a10d 100644 (file)
@@ -16,7 +16,8 @@ class ExportTypes:
 
 def _export(*args, **kwargs):
     from torch.onnx import utils
-    return utils._export(*args, **kwargs)
+    result = utils._export(*args, **kwargs)
+    return result
 
 
 def export(*args, **kwargs):
@@ -52,3 +53,8 @@ def _run_symbolic_function(*args, **kwargs):
 def _run_symbolic_method(*args, **kwargs):
     from torch.onnx import utils
     return utils._run_symbolic_method(*args, **kwargs)
+
+
+def is_in_onnx_export():
+    from torch.onnx import utils
+    return utils.is_in_onnx_export()
index 6338ede..1910a3c 100644 (file)
@@ -19,6 +19,20 @@ from torch.onnx import ONNX_ARCHIVE_MODEL_PROTO_NAME, ExportTypes, OperatorExpor
 from torch._C import ListType
 
 
+# the flag to tell the user whether it's in the middle of ONNX export or not
+__IN_ONNX_EXPORT = False
+
+
+def is_in_onnx_export():
+    r"""
+    Check whether it's in the middle of the ONNX export.
+    This function returns True in the middle of torch.onnx.export().
+    torch.onnx.export should be executed with single thread.
+    """
+    global __IN_ONNX_EXPORT
+    return __IN_ONNX_EXPORT
+
+
 @contextlib.contextmanager
 def set_training(model, mode):
     r"""
@@ -308,52 +322,59 @@ def _export(model, args, f, export_params=True, verbose=False, training=False,
             input_names=None, output_names=None, operator_export_type=OperatorExportTypes.ONNX,
             export_type=ExportTypes.PROTOBUF_FILE, example_outputs=None, propagate=False,
             opset_version=None, _retain_param_name=False):
-    from torch.onnx.symbolic import _default_onnx_opset_version, _set_opset_version
-    if opset_version is None:
-        opset_version = _default_onnx_opset_version
-    _set_opset_version(opset_version)
-    graph, params_dict, torch_out = _model_to_graph(model, args, f, verbose,
-                                                    training, input_names,
-                                                    output_names, operator_export_type,
-                                                    example_outputs, propagate,
-                                                    _retain_param_name)
-
-    # TODO: Don't allocate a in-memory string for the protobuf
-    defer_weight_export = export_type is not ExportTypes.PROTOBUF_FILE
-    if export_params:
-        proto, export_map = graph._export_onnx(params_dict, opset_version, defer_weight_export, operator_export_type)
-    else:
-        proto, export_map = graph._export_onnx({}, opset_version, False, operator_export_type)
-
-    if export_type == ExportTypes.PROTOBUF_FILE:
-        assert(len(export_map) == 0)
-        torch.serialization._with_file_like(f, "wb", lambda f: f.write(proto))
-    elif export_type in [ExportTypes.ZIP_ARCHIVE, ExportTypes.COMPRESSED_ZIP_ARCHIVE]:
-        import zipfile
-        compression = zipfile.ZIP_DEFLATED \
-            if export_type == ExportTypes.COMPRESSED_ZIP_ARCHIVE \
-            else zipfile.ZIP_STORED
-        with zipfile.ZipFile(f, 'w', compression=compression) as z:
-            z.writestr(ONNX_ARCHIVE_MODEL_PROTO_NAME, proto)
-            for k, v in export_map.items():
-                z.writestr(k, v)
-    elif export_type == ExportTypes.DIRECTORY:
-        import os
-        if os.path.exists(f):
-            assert(os.path.isdir(f))
+    global __IN_ONNX_EXPORT
+    assert __IN_ONNX_EXPORT is False
+    __IN_ONNX_EXPORT = True
+    try:
+        from torch.onnx.symbolic import _default_onnx_opset_version, _set_opset_version
+        if opset_version is None:
+            opset_version = _default_onnx_opset_version
+        _set_opset_version(opset_version)
+        graph, params_dict, torch_out = _model_to_graph(model, args, f, verbose,
+                                                        training, input_names,
+                                                        output_names, operator_export_type,
+                                                        example_outputs, propagate,
+                                                        _retain_param_name)
+
+        # TODO: Don't allocate a in-memory string for the protobuf
+        defer_weight_export = export_type is not ExportTypes.PROTOBUF_FILE
+        if export_params:
+            proto, export_map = graph._export_onnx(params_dict, opset_version, defer_weight_export, operator_export_type)
         else:
-            os.makedirs(f)
-
-        model_proto_file = os.path.join(f, ONNX_ARCHIVE_MODEL_PROTO_NAME)
-        torch.serialization._with_file_like(
-            model_proto_file, "wb", lambda f: f.write(proto))
+            proto, export_map = graph._export_onnx({}, opset_version, False, operator_export_type)
+
+        if export_type == ExportTypes.PROTOBUF_FILE:
+            assert(len(export_map) == 0)
+            torch.serialization._with_file_like(f, "wb", lambda f: f.write(proto))
+        elif export_type in [ExportTypes.ZIP_ARCHIVE, ExportTypes.COMPRESSED_ZIP_ARCHIVE]:
+            import zipfile
+            compression = zipfile.ZIP_DEFLATED \
+                if export_type == ExportTypes.COMPRESSED_ZIP_ARCHIVE \
+                else zipfile.ZIP_STORED
+            with zipfile.ZipFile(f, 'w', compression=compression) as z:
+                z.writestr(ONNX_ARCHIVE_MODEL_PROTO_NAME, proto)
+                for k, v in export_map.items():
+                    z.writestr(k, v)
+        elif export_type == ExportTypes.DIRECTORY:
+            import os
+            if os.path.exists(f):
+                assert(os.path.isdir(f))
+            else:
+                os.makedirs(f)
 
-        for k, v in export_map.items():
-            weight_proto_file = os.path.join(f, k)
+            model_proto_file = os.path.join(f, ONNX_ARCHIVE_MODEL_PROTO_NAME)
             torch.serialization._with_file_like(
-                weight_proto_file, "wb", lambda f: f.write(v))
-    else:
-        raise RuntimeError('Unknown export type')
+                model_proto_file, "wb", lambda f: f.write(proto))
+
+            for k, v in export_map.items():
+                weight_proto_file = os.path.join(f, k)
+                torch.serialization._with_file_like(
+                    weight_proto_file, "wb", lambda f: f.write(v))
+        else:
+            raise RuntimeError('Unknown export type')
+    finally:
+        assert __IN_ONNX_EXPORT
+        __IN_ONNX_EXPORT = False
     return torch_out