From aeec177833cb20e8c6177ef8dbcf02ddc37c8a32 Mon Sep 17 00:00:00 2001 From: Mike Iovine Date: Thu, 26 Aug 2021 14:09:10 -0700 Subject: [PATCH] [JIT] UseVariadicOp takes list_idx parameter (#63915) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63915 Previously, this function only worked for variadic op substitutions of the form `op(list, args) -> variadic_op(list_1, ..., list_n, args)`. This change allows for transformations of the form `op(args_0, list, args_1) -> variadic_op(args_0, list_1, ..., list_n, args_1)`. Test Plan: `buck test caffe2/test/cpp/jit:jit -- Stack Concat` (tests exercising `list_idx != 0` will be added further up in this diff stack) Reviewed By: navahgar Differential Revision: D30529729 fbshipit-source-id: 568080679c3b40bdaedee56bef2e8a5ce7985d2f --- torch/csrc/jit/passes/variadic_ops.cpp | 47 ++++++++++++++++++++++++---------- torch/csrc/jit/passes/variadic_ops.h | 12 +++++++++ 2 files changed, 46 insertions(+), 13 deletions(-) diff --git a/torch/csrc/jit/passes/variadic_ops.cpp b/torch/csrc/jit/passes/variadic_ops.cpp index 6f4d23c..a827d3a 100644 --- a/torch/csrc/jit/passes/variadic_ops.cpp +++ b/torch/csrc/jit/passes/variadic_ops.cpp @@ -14,8 +14,12 @@ class VariadicUpdater { explicit VariadicUpdater( std::shared_ptr graph, NodeKind op, - NodeKind variadic_op) - : graph_(std::move(graph)), op_(op), variadic_op_(variadic_op) {} + NodeKind variadic_op, + size_t list_idx = 0) + : graph_(std::move(graph)), + op_(op), + variadic_op_(variadic_op), + list_idx_(list_idx) {} bool run() { collectOpNodes(graph_->block()); @@ -39,21 +43,34 @@ class VariadicUpdater { } bool replaceWithVariadicOp(Node* op_node) { - if (op_node->input(0)->node()->kind() != prim::ListConstruct) { + const size_t num_inputs = op_node->inputs().size(); + TORCH_CHECK(list_idx_ < num_inputs); + if (op_node->input(list_idx_)->node()->kind() != prim::ListConstruct) { return false; } - auto list = op_node->input(0)->node(); + auto list = op_node->input(list_idx_)->node(); + const size_t list_len = list->inputs().size(); + // We do not transform ops whose list input can not be moved to the // position before op. This in turn implies that there is some mutation // of the input list before op. if (!getOrCreateAliasDb()->couldMoveBeforeTopologically(list, op_node)) { return false; } - std::vector inputs = list->inputs().vec(); - // Add non-list inputs - for (size_t i = 1; i < op_node->inputs().size(); ++i) { - inputs.push_back(op_node->input(i)); - } + + // Construct new inputs + std::vector inputs; + inputs.reserve(num_inputs + list_len - 1); + inputs.insert( + inputs.end(), + op_node->inputs().begin(), + op_node->inputs().begin() + list_idx_); + inputs.insert(inputs.end(), list->inputs().begin(), list->inputs().end()); + inputs.insert( + inputs.end(), + op_node->inputs().begin() + list_idx_ + 1, + op_node->inputs().end()); + auto var_op_node = op_node->owningGraph()->create(variadic_op_, inputs); GRAPH_UPDATE("Adding\n", *var_op_node); var_op_node->insertBefore(op_node); @@ -82,6 +99,8 @@ class VariadicUpdater { NodeKind op_; NodeKind variadic_op_; + + size_t list_idx_; }; } // namespace @@ -89,10 +108,11 @@ class VariadicUpdater { bool UseVariadicOp( const std::shared_ptr& graph, NodeKind op, - NodeKind variadic_op) { + NodeKind variadic_op, + size_t list_idx) { const std::string pass_name = std::string("variadic ") + op.toQualString(); GRAPH_DUMP("Before " + pass_name, graph); - bool changed = VariadicUpdater(graph, op, variadic_op).run(); + bool changed = VariadicUpdater(graph, op, variadic_op, list_idx).run(); if (changed) { GRAPH_DUMP("After " + pass_name, graph); } @@ -102,13 +122,14 @@ bool UseVariadicOp( bool RemoveListMutationAndUseVariadicOp( const std::shared_ptr& graph, NodeKind op, - NodeKind variadic_op) { + NodeKind variadic_op, + size_t list_idx) { bool changed_in_last_iter = true; bool changed = false; while (changed_in_last_iter) { changed_in_last_iter = RemoveListMutation(graph); changed_in_last_iter = - UseVariadicOp(graph, op, variadic_op) || changed_in_last_iter; + UseVariadicOp(graph, op, variadic_op, list_idx) || changed_in_last_iter; changed = changed || changed_in_last_iter; } return changed; diff --git a/torch/csrc/jit/passes/variadic_ops.h b/torch/csrc/jit/passes/variadic_ops.h index 20cc664..e5f6a68 100644 --- a/torch/csrc/jit/passes/variadic_ops.h +++ b/torch/csrc/jit/passes/variadic_ops.h @@ -19,5 +19,17 @@ TORCH_API bool UseVariadicStack(const std::shared_ptr& graph); TORCH_API bool RemoveListMutationAndUseVariadicStack( const std::shared_ptr& graph); +TORCH_API bool UseVariadicOp( + const std::shared_ptr& graph, + NodeKind op, + NodeKind variadic_op, + size_t list_idx = 0); + +TORCH_API bool RemoveListMutationAndUseVariadicOp( + const std::shared_ptr& graph, + NodeKind op, + NodeKind variadic_op, + size_t list_idx = 0); + } // namespace jit } // namespace torch -- 2.7.4