From 9983c24cfcad8e091f55fc728999ca7d5a2d7c65 Mon Sep 17 00:00:00 2001 From: Lara Haidar-Ahmad Date: Thu, 18 Apr 2019 22:25:04 -0700 Subject: [PATCH] Strip doc_string from exported ONNX models (#18882) Summary: Strip the doc_string by default from the exported ONNX models (this string has the stack trace and information about the local repos and folders, which can be confidential). The users can still generate the doc_string by specifying add_doc_string=True in torch.onnx.export(). Pull Request resolved: https://github.com/pytorch/pytorch/pull/18882 Differential Revision: D14889684 Pulled By: houseroad fbshipit-source-id: 26d2c23c8dc3f484544aa854b507ada429adb9b8 --- test/onnx/test_utility_funs.py | 24 ++++++++++++++++++++++++ torch/csrc/jit/export.cpp | 5 +++-- torch/csrc/jit/export.h | 3 ++- torch/csrc/jit/python_ir.cpp | 9 ++++++--- torch/onnx/utils.py | 16 +++++++++++----- 5 files changed, 46 insertions(+), 11 deletions(-) diff --git a/test/onnx/test_utility_funs.py b/test/onnx/test_utility_funs.py index 0e87e72..fc3a61f 100644 --- a/test/onnx/test_utility_funs.py +++ b/test/onnx/test_utility_funs.py @@ -6,7 +6,10 @@ import torch.onnx from torch.onnx import utils from torch.onnx.symbolic import _set_opset_version +import onnx + import io +import copy class TestUtilityFuns(TestCase): @@ -137,5 +140,26 @@ class TestUtilityFuns(TestCase): assert node.kind() != "onnx::Transpose" assert len(list(graph.nodes())) == 1 + def test_strip_doc_string(self): + class MyModule(torch.nn.Module): + def forward(self, input): + return torch.exp(input) + x = torch.randn(3, 4) + + def is_model_stripped(f, strip_doc_string=None): + if strip_doc_string is None: + torch.onnx.export(MyModule(), x, f) + else: + torch.onnx.export(MyModule(), x, f, strip_doc_string=strip_doc_string) + model = onnx.load(io.BytesIO(f.getvalue())) + model_strip = copy.copy(model) + onnx.helper.strip_doc_string(model_strip) + return model == model_strip + + # test strip_doc_string=True (default) + self.assertTrue(is_model_stripped(io.BytesIO())) + # test strip_doc_string=False + self.assertFalse(is_model_stripped(io.BytesIO(), False)) + if __name__ == '__main__': run_tests() diff --git a/torch/csrc/jit/export.cpp b/torch/csrc/jit/export.cpp index d6133bf..6e5a6df 100644 --- a/torch/csrc/jit/export.cpp +++ b/torch/csrc/jit/export.cpp @@ -1003,14 +1003,15 @@ std::tuple export_onnx( const std::map& initializers, int64_t onnx_opset_version, bool defer_weight_export, - ::torch::onnx::OperatorExportTypes operator_export_type) { + ::torch::onnx::OperatorExportTypes operator_export_type, + bool strip_doc_string) { auto graph_encoder = GraphEncoder( graph, onnx_opset_version, operator_export_type, initializers, defer_weight_export, - false); + strip_doc_string); return std::make_tuple( graph_encoder.get_model_proto().SerializeAsString(), graph_encoder.get_raw_data_export_map()); diff --git a/torch/csrc/jit/export.h b/torch/csrc/jit/export.h index 1904723..636fa2e 100644 --- a/torch/csrc/jit/export.h +++ b/torch/csrc/jit/export.h @@ -25,7 +25,8 @@ TORCH_API std::tuple export_onnx( int64_t onnx_opset_version, bool defer_weight_export = false, ::torch::onnx::OperatorExportTypes operator_export_type = - ::torch::onnx::OperatorExportTypes::ONNX); + ::torch::onnx::OperatorExportTypes::ONNX, + bool strip_doc_string = true); // For testing purposes TORCH_API std::string pretty_print_onnx( diff --git a/torch/csrc/jit/python_ir.cpp b/torch/csrc/jit/python_ir.cpp index 1a40aa4..655a64a 100644 --- a/torch/csrc/jit/python_ir.cpp +++ b/torch/csrc/jit/python_ir.cpp @@ -228,7 +228,8 @@ void initPythonIRBindings(PyObject* module_) { const std::map& initializers, int64_t onnx_opset_version, bool defer_weight_export, - ::torch::onnx::OperatorExportTypes operator_export_type) { + ::torch::onnx::OperatorExportTypes operator_export_type, + bool strip_doc_string) { std::string graph; RawDataExportMap export_map; std::tie(graph, export_map) = export_onnx( @@ -236,7 +237,8 @@ void initPythonIRBindings(PyObject* module_) { initializers, onnx_opset_version, defer_weight_export, - operator_export_type); + operator_export_type, + strip_doc_string); std::unordered_map python_serialized_export_map; for (auto& kv : export_map) { @@ -255,7 +257,8 @@ void initPythonIRBindings(PyObject* module_) { py::arg("onnx_opset_version") = 0, py::arg("defer_weight_export") = false, py::arg("operator_export_type") = - ::torch::onnx::OperatorExportTypes::ONNX) + ::torch::onnx::OperatorExportTypes::ONNX, + py::arg("strip_doc_string") = true) .def( "_pretty_print_onnx", [](const std::shared_ptr g, diff --git a/torch/onnx/utils.py b/torch/onnx/utils.py index 9992724..1189aab 100644 --- a/torch/onnx/utils.py +++ b/torch/onnx/utils.py @@ -56,7 +56,7 @@ def set_training(model, mode): def export(model, args, f, export_params=True, verbose=False, training=False, input_names=None, output_names=None, aten=False, export_raw_ir=False, operator_export_type=None, opset_version=None, _retain_param_name=True, - do_constant_folding=False): + do_constant_folding=False, strip_doc_string=True): r""" Export a model into ONNX format. This exporter runs your model once in order to get a trace of its execution to be exported; @@ -112,6 +112,9 @@ def export(model, args, f, export_params=True, verbose=False, training=False, optimization is applied to the model during export. Constant-folding optimization will replace some of the ops that have all constant inputs, with pre-computed constant nodes. + strip_doc_string (bool, default True): if True, strips the field + "doc_string" from the exported model, which information about the stack + trace. """ if aten or export_raw_ir: assert operator_export_type is None @@ -124,7 +127,8 @@ def export(model, args, f, export_params=True, verbose=False, training=False, operator_export_type = OperatorExportTypes.ONNX _export(model, args, f, export_params, verbose, training, input_names, output_names, operator_export_type=operator_export_type, opset_version=opset_version, - _retain_param_name=_retain_param_name, do_constant_folding=do_constant_folding) + _retain_param_name=_retain_param_name, do_constant_folding=do_constant_folding, + strip_doc_string=strip_doc_string) # ONNX can't handle constants that are lists of tensors, which can @@ -337,7 +341,8 @@ def _export_to_pretty_string(model, args, f, export_params=True, verbose=False, def _export(model, args, f, export_params=True, verbose=False, training=False, input_names=None, output_names=None, operator_export_type=OperatorExportTypes.ONNX, export_type=ExportTypes.PROTOBUF_FILE, example_outputs=None, propagate=False, - opset_version=None, _retain_param_name=False, do_constant_folding=False): + opset_version=None, _retain_param_name=False, do_constant_folding=False, + strip_doc_string=True): global __IN_ONNX_EXPORT assert __IN_ONNX_EXPORT is False __IN_ONNX_EXPORT = True @@ -355,9 +360,10 @@ def _export(model, args, f, export_params=True, verbose=False, training=False, # TODO: Don't allocate a in-memory string for the protobuf defer_weight_export = export_type is not ExportTypes.PROTOBUF_FILE if export_params: - proto, export_map = graph._export_onnx(params_dict, opset_version, defer_weight_export, operator_export_type) + proto, export_map = graph._export_onnx(params_dict, opset_version, defer_weight_export, operator_export_type, + strip_doc_string) else: - proto, export_map = graph._export_onnx({}, opset_version, False, operator_export_type) + proto, export_map = graph._export_onnx({}, opset_version, False, operator_export_type, strip_doc_string) if export_type == ExportTypes.PROTOBUF_FILE: assert(len(export_map) == 0) -- 2.7.4