From 0bd8d0951dcb4063c0f7552a7404bd7f0e7b6e6f Mon Sep 17 00:00:00 2001 From: Don Jang Date: Thu, 26 Aug 2021 16:28:35 -0700 Subject: [PATCH] [Static Runtime] Remove unnecessary fb::equally_split nodes (#64022) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/64022 Test Plan: - Added unittest `StaticRuntime.RemoveEquallySplitListUnpack`. Reviewed By: hlu1 Differential Revision: D30472189 fbshipit-source-id: 36040b0146f4be9d0d0fda293f7205f43aad0b87 --- torch/csrc/jit/runtime/static/passes.cpp | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/torch/csrc/jit/runtime/static/passes.cpp b/torch/csrc/jit/runtime/static/passes.cpp index c8e1107..1133e39 100644 --- a/torch/csrc/jit/runtime/static/passes.cpp +++ b/torch/csrc/jit/runtime/static/passes.cpp @@ -412,6 +412,7 @@ void ReplaceWithCopy( // c10::AliasAnalysisKind::PURE_FUNCTION to make alias analysis work. void FuseListUnpack(std::shared_ptr& graph) { auto nodes = graph->nodes(); + std::vector equally_splits_to_remove; for (auto it = nodes.begin(); it != nodes.end(); ++it) { Node* node = *it; const char* node_qual_string = node->kind().toQualString(); @@ -445,8 +446,22 @@ void FuseListUnpack(std::shared_ptr& graph) { it_next.destroyCurrent(); // remove list_unpack node->eraseOutput(0); + + if (strcmp(node_qual_string, "fb::equally_split") == 0 && + node->outputs().size() == 1) { + // This captures a case of `y = fb::equally_split(x, 1, _)` where y + // becomes just an alias of x. + // If this case is found, replace y with x to avoid executing this op. + equally_splits_to_remove.push_back(node); + } } } + + for (Node* node : equally_splits_to_remove) { + node->output(0)->replaceAllUsesWith(node->input(0)); + node->destroy(); + } + #ifndef NDEBUG graph->lint(); AliasDb db2(graph); -- 2.7.4