// c10::AliasAnalysisKind::PURE_FUNCTION to make alias analysis work.
void FuseListUnpack(std::shared_ptr<torch::jit::Graph>& graph) {
auto nodes = graph->nodes();
+ std::vector<Node*> equally_splits_to_remove;
for (auto it = nodes.begin(); it != nodes.end(); ++it) {
Node* node = *it;
const char* node_qual_string = node->kind().toQualString();
it_next.destroyCurrent(); // remove list_unpack
node->eraseOutput(0);
+
+ if (strcmp(node_qual_string, "fb::equally_split") == 0 &&
+ node->outputs().size() == 1) {
+ // This captures a case of `y = fb::equally_split(x, 1, _)` where y
+ // becomes just an alias of x.
+ // If this case is found, replace y with x to avoid executing this op.
+ equally_splits_to_remove.push_back(node);
+ }
}
}
+
+ for (Node* node : equally_splits_to_remove) {
+ node->output(0)->replaceAllUsesWith(node->input(0));
+ node->destroy();
+ }
+
#ifndef NDEBUG
graph->lint();
AliasDb db2(graph);