[Static Runtime] Remove unnecessary fb::equally_split nodes (#64022)
authorDon Jang <djang@fb.com>
Thu, 26 Aug 2021 23:28:35 +0000 (16:28 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Thu, 26 Aug 2021 23:29:43 +0000 (16:29 -0700)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/64022

Test Plan: - Added unittest `StaticRuntime.RemoveEquallySplitListUnpack`.

Reviewed By: hlu1

Differential Revision: D30472189

fbshipit-source-id: 36040b0146f4be9d0d0fda293f7205f43aad0b87

torch/csrc/jit/runtime/static/passes.cpp

index c8e1107..1133e39 100644 (file)
@@ -412,6 +412,7 @@ void ReplaceWithCopy(
 // c10::AliasAnalysisKind::PURE_FUNCTION to make alias analysis work.
 void FuseListUnpack(std::shared_ptr<torch::jit::Graph>& graph) {
   auto nodes = graph->nodes();
+  std::vector<Node*> equally_splits_to_remove;
   for (auto it = nodes.begin(); it != nodes.end(); ++it) {
     Node* node = *it;
     const char* node_qual_string = node->kind().toQualString();
@@ -445,8 +446,22 @@ void FuseListUnpack(std::shared_ptr<torch::jit::Graph>& graph) {
       it_next.destroyCurrent(); // remove list_unpack
 
       node->eraseOutput(0);
+
+      if (strcmp(node_qual_string, "fb::equally_split") == 0 &&
+          node->outputs().size() == 1) {
+        // This captures a case of `y = fb::equally_split(x, 1, _)` where y
+        // becomes just an alias of x.
+        // If this case is found, replace y with x to avoid executing this op.
+        equally_splits_to_remove.push_back(node);
+      }
     }
   }
+
+  for (Node* node : equally_splits_to_remove) {
+    node->output(0)->replaceAllUsesWith(node->input(0));
+    node->destroy();
+  }
+
 #ifndef NDEBUG
   graph->lint();
   AliasDb db2(graph);