from torch._C import ListType
+# the flag to tell the user whether it's in the middle of ONNX export or not
+__IN_ONNX_EXPORT = False
+
+
+def is_in_onnx_export():
+ r"""
+ Check whether it's in the middle of the ONNX export.
+ This function returns True in the middle of torch.onnx.export().
+ torch.onnx.export should be executed with single thread.
+ """
+ global __IN_ONNX_EXPORT
+ return __IN_ONNX_EXPORT
+
+
@contextlib.contextmanager
def set_training(model, mode):
r"""
input_names=None, output_names=None, operator_export_type=OperatorExportTypes.ONNX,
export_type=ExportTypes.PROTOBUF_FILE, example_outputs=None, propagate=False,
opset_version=None, _retain_param_name=False):
- from torch.onnx.symbolic import _default_onnx_opset_version, _set_opset_version
- if opset_version is None:
- opset_version = _default_onnx_opset_version
- _set_opset_version(opset_version)
- graph, params_dict, torch_out = _model_to_graph(model, args, f, verbose,
- training, input_names,
- output_names, operator_export_type,
- example_outputs, propagate,
- _retain_param_name)
-
- # TODO: Don't allocate a in-memory string for the protobuf
- defer_weight_export = export_type is not ExportTypes.PROTOBUF_FILE
- if export_params:
- proto, export_map = graph._export_onnx(params_dict, opset_version, defer_weight_export, operator_export_type)
- else:
- proto, export_map = graph._export_onnx({}, opset_version, False, operator_export_type)
-
- if export_type == ExportTypes.PROTOBUF_FILE:
- assert(len(export_map) == 0)
- torch.serialization._with_file_like(f, "wb", lambda f: f.write(proto))
- elif export_type in [ExportTypes.ZIP_ARCHIVE, ExportTypes.COMPRESSED_ZIP_ARCHIVE]:
- import zipfile
- compression = zipfile.ZIP_DEFLATED \
- if export_type == ExportTypes.COMPRESSED_ZIP_ARCHIVE \
- else zipfile.ZIP_STORED
- with zipfile.ZipFile(f, 'w', compression=compression) as z:
- z.writestr(ONNX_ARCHIVE_MODEL_PROTO_NAME, proto)
- for k, v in export_map.items():
- z.writestr(k, v)
- elif export_type == ExportTypes.DIRECTORY:
- import os
- if os.path.exists(f):
- assert(os.path.isdir(f))
+ global __IN_ONNX_EXPORT
+ assert __IN_ONNX_EXPORT is False
+ __IN_ONNX_EXPORT = True
+ try:
+ from torch.onnx.symbolic import _default_onnx_opset_version, _set_opset_version
+ if opset_version is None:
+ opset_version = _default_onnx_opset_version
+ _set_opset_version(opset_version)
+ graph, params_dict, torch_out = _model_to_graph(model, args, f, verbose,
+ training, input_names,
+ output_names, operator_export_type,
+ example_outputs, propagate,
+ _retain_param_name)
+
+ # TODO: Don't allocate a in-memory string for the protobuf
+ defer_weight_export = export_type is not ExportTypes.PROTOBUF_FILE
+ if export_params:
+ proto, export_map = graph._export_onnx(params_dict, opset_version, defer_weight_export, operator_export_type)
else:
- os.makedirs(f)
-
- model_proto_file = os.path.join(f, ONNX_ARCHIVE_MODEL_PROTO_NAME)
- torch.serialization._with_file_like(
- model_proto_file, "wb", lambda f: f.write(proto))
+ proto, export_map = graph._export_onnx({}, opset_version, False, operator_export_type)
+
+ if export_type == ExportTypes.PROTOBUF_FILE:
+ assert(len(export_map) == 0)
+ torch.serialization._with_file_like(f, "wb", lambda f: f.write(proto))
+ elif export_type in [ExportTypes.ZIP_ARCHIVE, ExportTypes.COMPRESSED_ZIP_ARCHIVE]:
+ import zipfile
+ compression = zipfile.ZIP_DEFLATED \
+ if export_type == ExportTypes.COMPRESSED_ZIP_ARCHIVE \
+ else zipfile.ZIP_STORED
+ with zipfile.ZipFile(f, 'w', compression=compression) as z:
+ z.writestr(ONNX_ARCHIVE_MODEL_PROTO_NAME, proto)
+ for k, v in export_map.items():
+ z.writestr(k, v)
+ elif export_type == ExportTypes.DIRECTORY:
+ import os
+ if os.path.exists(f):
+ assert(os.path.isdir(f))
+ else:
+ os.makedirs(f)
- for k, v in export_map.items():
- weight_proto_file = os.path.join(f, k)
+ model_proto_file = os.path.join(f, ONNX_ARCHIVE_MODEL_PROTO_NAME)
torch.serialization._with_file_like(
- weight_proto_file, "wb", lambda f: f.write(v))
- else:
- raise RuntimeError('Unknown export type')
+ model_proto_file, "wb", lambda f: f.write(proto))
+
+ for k, v in export_map.items():
+ weight_proto_file = os.path.join(f, k)
+ torch.serialization._with_file_like(
+ weight_proto_file, "wb", lambda f: f.write(v))
+ else:
+ raise RuntimeError('Unknown export type')
+ finally:
+ assert __IN_ONNX_EXPORT
+ __IN_ONNX_EXPORT = False
return torch_out