From 172e5c76ab05f1a137eb065b7f221a20eaef514a Mon Sep 17 00:00:00 2001 From: Bert Maher Date: Mon, 23 Aug 2021 17:28:33 -0700 Subject: [PATCH] Fix some memory bugs in onnx passes (#63754) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63754 Running onnx tests with ASAN uncovers several memory errors. These two are caused by: (1) iterating the uses list of a node after mutation, and (2) accessing the `blocks` attribute of a possibly deleted node. To reproduce (this is on a CentOS 7 box): ``` DEBUG=1 CFLAGS="-fsanitize=address" CXXFLAGS="-fsanitize=address" USE_LLVM=$(realpath ../llvm-project/install) CMAKE_PREFIX_PATH=$CONDA_PREFIX python setup.py install LD_PRELOAD=$(realpath /lib64/libasan.so.5) numactl -C3 pytest -v --cov --cov-report xml:test/coverage.xml --cov-append onnx/test_pytorch_onnx_onnxruntime.py::TestONNXRuntime_opset11 -s ``` Test Plan: Imported from OSS Reviewed By: ZolotukhinM Differential Revision: D30493939 Pulled By: bertmaher fbshipit-source-id: e16e19dc9b4c9896e102ca8bf04c8bedfdde87af --- .../csrc/jit/passes/onnx/list_model_parameters.cpp | 6 ++- .../jit/passes/onnx/pattern_conversion/common.cpp | 4 +- .../passes/onnx/remove_inplace_ops_for_onnx.cpp | 45 +++++++++++++--------- 3 files changed, 33 insertions(+), 22 deletions(-) diff --git a/torch/csrc/jit/passes/onnx/list_model_parameters.cpp b/torch/csrc/jit/passes/onnx/list_model_parameters.cpp index ccadf53..9c751bb 100644 --- a/torch/csrc/jit/passes/onnx/list_model_parameters.cpp +++ b/torch/csrc/jit/passes/onnx/list_model_parameters.cpp @@ -76,6 +76,7 @@ std::vector getParamAttributes( WithInsertPoint guard(m); std::vector parameterIValues = {}; + std::unordered_set nodesToDestroy; for (auto it = block->nodes().begin(); it != block->nodes().end();) { Node* n = *it; it++; // node n can be destroyed @@ -142,7 +143,7 @@ std::vector getParamAttributes( // This attr is constant for ONNX. auto attrVal = tryInsertConstant(*graph, attr); n->output()->replaceAllUsesWith(*attrVal); - n->destroy(); + nodesToDestroy.emplace(n); } } } @@ -156,6 +157,9 @@ std::vector getParamAttributes( std::end(nextParameterIValues)); } } + for (auto n : nodesToDestroy) { + n->destroy(); + } return parameterIValues; } diff --git a/torch/csrc/jit/passes/onnx/pattern_conversion/common.cpp b/torch/csrc/jit/passes/onnx/pattern_conversion/common.cpp index 2854c3a..bc64630 100644 --- a/torch/csrc/jit/passes/onnx/pattern_conversion/common.cpp +++ b/torch/csrc/jit/passes/onnx/pattern_conversion/common.cpp @@ -4,8 +4,8 @@ namespace torch { namespace jit { bool IndexingPatternFinder::IsSameSource(const Node* n, const Node* m) { - const auto& source_n = n->sourceRange().source(); - const auto& source_m = m->sourceRange().source(); + const auto source_n = n->sourceRange().source(); + const auto source_m = m->sourceRange().source(); return ( (source_n->text() == source_m->text()) && (source_n->starting_line_no() == source_m->starting_line_no())); diff --git a/torch/csrc/jit/passes/onnx/remove_inplace_ops_for_onnx.cpp b/torch/csrc/jit/passes/onnx/remove_inplace_ops_for_onnx.cpp index 913f4dc..2cef76a 100644 --- a/torch/csrc/jit/passes/onnx/remove_inplace_ops_for_onnx.cpp +++ b/torch/csrc/jit/passes/onnx/remove_inplace_ops_for_onnx.cpp @@ -317,26 +317,33 @@ static void PrepareForRemoveMutations(MutationRemover& mr, Block* b) { } for (auto input : b->inputs()) { - for (auto use : input->uses()) { - Node* node = use.user; - if (!mr.inplaceOpVariant(node)) { - continue; - } - auto it = std::find(node->inputs().begin(), node->inputs().end(), input); - if (it != node->inputs().end()) { - int index = std::distance(node->inputs().begin(), it); - std::cerr << "Warning: ONNX Preprocess - Removing mutation from node " - << node->kind().toQualString() << " on block input: '" - << (*it)->debugName() << "'. This changes graph semantics." - << std::endl; - - Node* newNode = - addDummyClone(b->owningGraph(), input, false, b->return_node()); - TORCH_INTERNAL_ASSERT(nullptr != newNode); - node->replaceInput(index, newNode->output()); - input->replaceAllUsesAfterNodeWith(node, newNode->output()); + bool needsRestart = false; + do { + needsRestart = false; + for (auto use : input->uses()) { + Node* node = use.user; + if (!mr.inplaceOpVariant(node)) { + continue; + } + auto it = + std::find(node->inputs().begin(), node->inputs().end(), input); + if (it != node->inputs().end()) { + int index = std::distance(node->inputs().begin(), it); + std::cerr << "Warning: ONNX Preprocess - Removing mutation from node " + << node->kind().toQualString() << " on block input: '" + << (*it)->debugName() << "'. This changes graph semantics." + << std::endl; + + Node* newNode = + addDummyClone(b->owningGraph(), input, false, b->return_node()); + TORCH_INTERNAL_ASSERT(nullptr != newNode); + node->replaceInput(index, newNode->output()); + input->replaceAllUsesAfterNodeWith(node, newNode->output()); + needsRestart = true; + break; + } } - } + } while (needsRestart); } } -- 2.7.4