#include <gtest/gtest.h>
+#include <test/cpp/jit/test_utils.h>
#include <torch/csrc/jit/ir/irparser.h>
#include <torch/csrc/jit/passes/concat_opt.h>
+#include <torch/csrc/jit/passes/variadic_ops.h>
#include <torch/csrc/jit/runtime/interpreter.h>
#include <torch/csrc/jit/testing/file_check.h>
namespace torch {
namespace jit {
-namespace {
-
-void checkOutputs(
- const std::vector<at::Tensor>& out1,
- const std::vector<at::Tensor>& out2) {
- ASSERT_EQ(out1.size(), out2.size());
- for (size_t i = 0; i < out1.size(); ++i) {
- ASSERT_EQ(out1[i].sizes(), out2[i].sizes());
- float max_diff = (out1[i] - out2[i]).abs().max().item<double>();
- ASSERT_EQ(max_diff, 0);
- }
-}
-
-std::vector<at::Tensor> runGraph(
- std::shared_ptr<Graph> graph,
- const std::vector<at::Tensor> inputs) {
- std::vector<IValue> stack = fmap<IValue>(inputs);
- Code code(graph, "test");
- InterpreterState(code).run(stack);
- TORCH_INTERNAL_ASSERT(!stack.empty());
- // Graph outputs that are handled below:
- // * A list of Tensors.
- // * 1 Tensor.
- if (stack.front().isTensorList()) {
- return stack.front().toTensorVector();
- }
- TORCH_INTERNAL_ASSERT(stack.front().isTensor());
- return {stack.front().toTensor()};
-}
-
-} // namespace
-
TEST(ConcatOptTest, SimpleCommonInputsEliminationPrefix) {
auto graph = std::make_shared<Graph>();
ASSERT_TRUE(EliminateConcatCommonInputs(graph));
graph->lint();
auto opt_outputs = runGraph(graph, inputs);
- checkOutputs(orig_outputs, opt_outputs);
+ ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
// Graph after EliminateConcatCommonInputs:
// graph(%0 : ...,
ASSERT_TRUE(EliminateConcatCommonInputs(graph));
graph->lint();
auto opt_outputs = runGraph(graph, inputs);
- checkOutputs(orig_outputs, opt_outputs);
+ ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
// Graph after EliminateConcatCommonInputs:
// graph(%0 : ...,
graph->lint();
auto opt_outputs = runGraph(graph, inputs);
- checkOutputs(orig_outputs, opt_outputs);
+ ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
// No optimizations should have happened in this case since the inputs
// to the `cat` are in different order.
ASSERT_TRUE(EliminateConcatCommonInputs(graph));
graph->lint();
auto opt_outputs = runGraph(graph, inputs);
- checkOutputs(orig_outputs, opt_outputs);
+ ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
testing::FileCheck()
.check_count("= prim::VarConcat(%0, %1, %5)", 1, /*exactly*/ true)
graph->lint();
auto opt_outputs = runGraph(graph, inputs);
- checkOutputs(orig_outputs, opt_outputs);
+ ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
// After full concat optimization we should have the following graph:
//
graph->lint();
auto opt_outputs = runGraph(graph, inputs);
- checkOutputs(orig_outputs, opt_outputs);
+ ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
// No optimizations should have happened in this case since the output
// shape of `aten::cat` is not known.
graph->lint();
auto opt_outputs = runGraph(graph, inputs);
- checkOutputs(orig_outputs, opt_outputs);
+ ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
// No optimizations should have happened in this case since the shape of %5,
// which is an input to `aten::cat`, is not known.
graph->lint();
auto opt_outputs = runGraph(graph, inputs);
- checkOutputs(orig_outputs, opt_outputs);
+ ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
// After replacing `aten::cat` with `prim::VarConcat` we should have the
// following graph:
graph->lint();
auto opt_outputs = runGraph(graph, inputs);
- checkOutputs(orig_outputs, opt_outputs);
+ ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
// After full concat optimization we should have the following graph:
//
graph->lint();
auto opt_outputs = runGraph(graph, inputs);
- checkOutputs(orig_outputs, opt_outputs);
+ ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
// After replacing `aten::cat` with `prim::VarConcat` we should have the
// following graph:
ASSERT_TRUE(UseVariadicCat(graph));
graph->lint();
auto opt_outputs = runGraph(graph, inputs);
- checkOutputs(orig_outputs, opt_outputs);
+ ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
// The input list to `aten::cat` is mutated only after `aten::cat` op. So,
// it should have been replaced with `prim::VarConcat`. The transformed graph
ASSERT_FALSE(UseVariadicCat(graph));
graph->lint();
auto opt_outputs = runGraph(graph, inputs);
- checkOutputs(orig_outputs, opt_outputs);
+ ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
// No transformation should have happened since the `prim::ListConstruct` is
// mutated before `aten::cat`.
ASSERT_TRUE(RemoveListMutationAndUseVariadicCat(graph));
graph->lint();
auto opt_outputs = runGraph(graph, inputs);
- checkOutputs(orig_outputs, opt_outputs);
+ ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
// The mutation of the list must be removed and the `aten::cat` op must
// be replaced with the `prim::VarConcat` op in the graph. The transformed
ASSERT_TRUE(RemoveListMutationAndUseVariadicCat(graph));
graph->lint();
auto opt_outputs = runGraph(graph, inputs);
- checkOutputs(orig_outputs, opt_outputs);
+ ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
// All the mutations of the list must be removed and the `aten::cat` ops must
// be replaced with `prim::VarConcat` ops in the graph. The transformed graph
ASSERT_TRUE(EliminateConcatCommonInputs(graph));
graph->lint();
auto opt_outputs = runGraph(graph, inputs);
- checkOutputs(orig_outputs, opt_outputs);
+ ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
// After performing:
// * Remove list mutation
}
return diff.abs().max().item<float>() < 2e-6 * maxValue;
}
+
bool almostEqual(const at::Tensor& a, const at::Tensor& b) {
return checkRtol(a - b, {a, b});
}
return (a - b).abs().max().item<float>() == 0.f;
}
+bool exactlyEqual(
+ const std::vector<at::Tensor>& a,
+ const std::vector<at::Tensor>& b) {
+ if (a.size() != b.size()) {
+ return false;
+ }
+ for (size_t i = 0; i < a.size(); ++i) {
+ if (!exactlyEqual(a[i], b[i])) {
+ return false;
+ }
+ }
+ return true;
+}
+
std::pair<at::Tensor, at::Tensor> lstm(
at::Tensor input,
at::Tensor hx,
});
} // namespace
+std::vector<at::Tensor> runGraph(
+ std::shared_ptr<Graph> graph,
+ const std::vector<at::Tensor>& inputs) {
+ std::vector<IValue> stack = fmap<IValue>(inputs);
+ Code code(graph, "test");
+ InterpreterState(code).run(stack);
+ TORCH_INTERNAL_ASSERT(!stack.empty());
+ // Graph outputs that are handled below:
+ // * A list of Tensors.
+ // * 1 Tensor.
+ if (stack.front().isTensorList()) {
+ return stack.front().toTensorVector();
+ }
+ TORCH_INTERNAL_ASSERT(stack.front().isTensor());
+ return {stack.front().toTensor()};
+}
+
} // namespace jit
} // namespace torch
bool almostEqual(const at::Tensor& a, const at::Tensor& b);
bool exactlyEqual(const at::Tensor& a, const at::Tensor& b);
+bool exactlyEqual(
+ const std::vector<at::Tensor>& a,
+ const std::vector<at::Tensor>& b);
+
+std::vector<at::Tensor> runGraph(
+ std::shared_ptr<Graph> graph,
+ const std::vector<at::Tensor>& inputs);
std::pair<at::Tensor, at::Tensor> lstm(
at::Tensor input,
"torch/csrc/jit/passes/symbolic_shape_analysis.cpp",
"torch/csrc/jit/passes/specialize_autogradzero.cpp",
"torch/csrc/jit/passes/update_differentiable_graph_requires_grad.cpp",
+ "torch/csrc/jit/passes/variadic_ops.cpp",
"torch/csrc/jit/passes/subgraph_rewrite.cpp",
"torch/csrc/jit/passes/tensorexpr_fuser.cpp",
"torch/csrc/jit/passes/utils/memory_dag.cpp",
GRAPH_DUMP("After expanding Concat and eliminating redundancy", graph);
}
-namespace {
-
-class VariadicCatUpdater {
- public:
- explicit VariadicCatUpdater(std::shared_ptr<Graph> graph)
- : graph_(std::move(graph)) {}
-
- bool run() {
- collectCatNodes(graph_->block());
- bool changed = false;
- for (auto c : cat_nodes_) {
- changed = replaceWithVariadicCat(c) || changed;
- }
- return changed;
- }
-
- private:
- void collectCatNodes(Block* block) {
- for (auto node : block->nodes()) {
- if (node->kind() == aten::cat) {
- cat_nodes_.push_back(node);
- }
- for (Block* b : node->blocks()) {
- collectCatNodes(b);
- }
- }
- }
-
- bool replaceWithVariadicCat(Node* cat) {
- if (cat->input(0)->node()->kind() != prim::ListConstruct) {
- return false;
- }
- auto list = cat->input(0)->node();
- // We do not transform cat ops whose list input can not be moved to the
- // position before cat. This in turn implies that there is some mutation
- // of the input list before cat.
- if (!getOrCreateAliasDb()->couldMoveBeforeTopologically(list, cat)) {
- return false;
- }
- std::vector<Value*> inputs = list->inputs().vec();
- inputs.push_back(cat->input(1));
- auto var_cat = cat->owningGraph()->create(prim::VarConcat, inputs);
- GRAPH_UPDATE("Adding\n", *var_cat);
- var_cat->insertBefore(cat);
- GRAPH_UPDATE("Replacing\n", *cat, "with\n", *var_cat);
- cat->output()->replaceAllUsesWith(var_cat->output());
- GRAPH_UPDATE("Deleting\n", *cat);
- cat->destroy();
- if (!list->hasUses()) {
- GRAPH_UPDATE("Deleting\n", *list);
- list->destroy();
- }
- return true;
- }
-
- AliasDb* getOrCreateAliasDb() {
- if (!aliasDb_) {
- aliasDb_ = std::make_unique<AliasDb>(graph_);
- }
- return aliasDb_.get();
- }
-
- std::shared_ptr<Graph> graph_;
- std::unique_ptr<AliasDb> aliasDb_ = nullptr;
-
- std::vector<Node*> cat_nodes_;
-};
-
-} // namespace
-
-bool UseVariadicCat(const std::shared_ptr<Graph>& graph) {
- GRAPH_DUMP("Before VariadicCat", graph);
- bool changed = VariadicCatUpdater(graph).run();
- if (changed) {
- GRAPH_DUMP("After VariadicCat", graph);
- }
- return changed;
-}
-
-bool RemoveListMutationAndUseVariadicCat(const std::shared_ptr<Graph>& graph) {
- bool changed_in_last_iter = true;
- bool changed = false;
- while (changed_in_last_iter) {
- changed_in_last_iter = RemoveListMutation(graph);
- changed_in_last_iter = changed_in_last_iter || UseVariadicCat(graph);
- changed = changed || changed_in_last_iter;
- }
- return changed;
-}
-
} // namespace jit
} // namespace torch
TORCH_API void ExpandConcatAndEliminateRedundancy(
const std::shared_ptr<Graph>& graph);
-// Replaces the `aten::cat` ops in the given graph with variadic cat ops.
-// Returns true if the graph is modified.
-TORCH_API bool UseVariadicCat(const std::shared_ptr<Graph>& graph);
-
-TORCH_API bool RemoveListMutationAndUseVariadicCat(
- const std::shared_ptr<Graph>& graph);
-
} // namespace jit
} // namespace torch
--- /dev/null
+#include <torch/csrc/jit/passes/variadic_ops.h>
+
+#include <torch/csrc/jit/ir/alias_analysis.h>
+#include <torch/csrc/jit/jit_log.h>
+#include <torch/csrc/jit/passes/remove_mutation.h>
+
+namespace torch {
+namespace jit {
+
+namespace {
+
+class VariadicUpdater {
+ public:
+ explicit VariadicUpdater(
+ std::shared_ptr<Graph> graph,
+ NodeKind op,
+ NodeKind variadic_op)
+ : graph_(std::move(graph)), op_(op), variadic_op_(variadic_op) {}
+
+ bool run() {
+ collectOpNodes(graph_->block());
+ bool changed = false;
+ for (auto n : op_nodes_) {
+ changed |= replaceWithVariadicOp(n);
+ }
+ return changed;
+ }
+
+ private:
+ void collectOpNodes(Block* block) {
+ for (auto node : block->nodes()) {
+ if (node->kind() == op_) {
+ op_nodes_.push_back(node);
+ }
+ for (Block* b : node->blocks()) {
+ collectOpNodes(b);
+ }
+ }
+ }
+
+ bool replaceWithVariadicOp(Node* op_node) {
+ if (op_node->input(0)->node()->kind() != prim::ListConstruct) {
+ return false;
+ }
+ auto list = op_node->input(0)->node();
+ // We do not transform ops whose list input can not be moved to the
+ // position before op. This in turn implies that there is some mutation
+ // of the input list before op.
+ if (!getOrCreateAliasDb()->couldMoveBeforeTopologically(list, op_node)) {
+ return false;
+ }
+ std::vector<Value*> inputs = list->inputs().vec();
+ // Add non-list inputs
+ for (size_t i = 1; i < op_node->inputs().size(); ++i) {
+ inputs.push_back(op_node->input(i));
+ }
+ auto var_op_node = op_node->owningGraph()->create(variadic_op_, inputs);
+ GRAPH_UPDATE("Adding\n", *var_op_node);
+ var_op_node->insertBefore(op_node);
+ GRAPH_UPDATE("Replacing\n", *op_node, "with\n", *var_op_node);
+ op_node->output()->replaceAllUsesWith(var_op_node->output());
+ GRAPH_UPDATE("Deleting\n", *op_node);
+ op_node->destroy();
+ if (!list->hasUses()) {
+ GRAPH_UPDATE("Deleting\n", *list);
+ list->destroy();
+ }
+ return true;
+ }
+
+ AliasDb* getOrCreateAliasDb() {
+ if (!aliasDb_) {
+ aliasDb_ = std::make_unique<AliasDb>(graph_);
+ }
+ return aliasDb_.get();
+ }
+
+ std::shared_ptr<Graph> graph_;
+ std::unique_ptr<AliasDb> aliasDb_ = nullptr;
+
+ std::vector<Node*> op_nodes_;
+
+ NodeKind op_;
+ NodeKind variadic_op_;
+};
+
+} // namespace
+
+bool UseVariadicOp(
+ const std::shared_ptr<Graph>& graph,
+ NodeKind op,
+ NodeKind variadic_op) {
+ const std::string pass_name = std::string("variadic ") + op.toQualString();
+ GRAPH_DUMP("Before " + pass_name, graph);
+ bool changed = VariadicUpdater(graph, op, variadic_op).run();
+ if (changed) {
+ GRAPH_DUMP("After " + pass_name, graph);
+ }
+ return changed;
+}
+
+bool RemoveListMutationAndUseVariadicOp(
+ const std::shared_ptr<Graph>& graph,
+ NodeKind op,
+ NodeKind variadic_op) {
+ bool changed_in_last_iter = true;
+ bool changed = false;
+ while (changed_in_last_iter) {
+ changed_in_last_iter = RemoveListMutation(graph);
+ changed_in_last_iter =
+ UseVariadicOp(graph, op, variadic_op) || changed_in_last_iter;
+ changed = changed || changed_in_last_iter;
+ }
+ return changed;
+}
+
+bool UseVariadicCat(const std::shared_ptr<Graph>& graph) {
+ return UseVariadicOp(graph, aten::cat, prim::VarConcat);
+}
+
+bool RemoveListMutationAndUseVariadicCat(const std::shared_ptr<Graph>& graph) {
+ return RemoveListMutationAndUseVariadicOp(graph, aten::cat, prim::VarConcat);
+}
+
+} // namespace jit
+} // namespace torch
--- /dev/null
+#pragma once
+
+#include <torch/csrc/jit/ir/ir.h>
+
+namespace torch {
+namespace jit {
+
+// Replaces the `aten::cat` ops in the given graph with variadic cat ops.
+// Returns true if the graph is modified.
+TORCH_API bool UseVariadicCat(const std::shared_ptr<Graph>& graph);
+
+TORCH_API bool RemoveListMutationAndUseVariadicCat(
+ const std::shared_ptr<Graph>& graph);
+
+} // namespace jit
+} // namespace torch
#include <caffe2/core/timer.h>
#include <torch/csrc/jit/ir/alias_analysis.h>
#include <torch/csrc/jit/passes/canonicalize.h>
-#include <torch/csrc/jit/passes/concat_opt.h>
#include <torch/csrc/jit/passes/dead_code_elimination.h>
#include <torch/csrc/jit/passes/freeze_module.h>
#include <torch/csrc/jit/passes/remove_mutation.h>
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
+#include <torch/csrc/jit/passes/variadic_ops.h>
#include <torch/csrc/jit/runtime/static/ops.h>
#include <torch/csrc/jit/runtime/static/passes.h>
#include <torch/csrc/jit/runtime/vararg_functions.h>