[JIT] Move UseVariadicCat internals (#63577)
authorMike Iovine <mikeiovine@fb.com>
Tue, 24 Aug 2021 00:26:27 +0000 (17:26 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Tue, 24 Aug 2021 00:30:36 +0000 (17:30 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63577

Since other variadic ops will have an almost identical implementation, we can generalize the `UseVariadicCat` implementation and put it in a common folder.

Also moved some test utilities that other variadic op tests will likely need.

Test Plan: `buck test caffe2/test/cpp/jit:jit -- ConcatOptTest`

Reviewed By: navahgar

Differential Revision: D30409937

fbshipit-source-id: 925c11c27b58ce98cb8368d2a205e26ba66d3db9

test/cpp/jit/test_concat_opt.cpp
test/cpp/jit/test_utils.cpp
test/cpp/jit/test_utils.h
tools/build_variables.bzl
torch/csrc/jit/passes/concat_opt.cpp
torch/csrc/jit/passes/concat_opt.h
torch/csrc/jit/passes/variadic_ops.cpp [new file with mode: 0644]
torch/csrc/jit/passes/variadic_ops.h [new file with mode: 0644]
torch/csrc/jit/runtime/static/impl.cpp

index 03c0ce6..5cb73d2 100644 (file)
@@ -1,45 +1,15 @@
 #include <gtest/gtest.h>
 
+#include <test/cpp/jit/test_utils.h>
 #include <torch/csrc/jit/ir/irparser.h>
 #include <torch/csrc/jit/passes/concat_opt.h>
+#include <torch/csrc/jit/passes/variadic_ops.h>
 #include <torch/csrc/jit/runtime/interpreter.h>
 #include <torch/csrc/jit/testing/file_check.h>
 
 namespace torch {
 namespace jit {
 
-namespace {
-
-void checkOutputs(
-    const std::vector<at::Tensor>& out1,
-    const std::vector<at::Tensor>& out2) {
-  ASSERT_EQ(out1.size(), out2.size());
-  for (size_t i = 0; i < out1.size(); ++i) {
-    ASSERT_EQ(out1[i].sizes(), out2[i].sizes());
-    float max_diff = (out1[i] - out2[i]).abs().max().item<double>();
-    ASSERT_EQ(max_diff, 0);
-  }
-}
-
-std::vector<at::Tensor> runGraph(
-    std::shared_ptr<Graph> graph,
-    const std::vector<at::Tensor> inputs) {
-  std::vector<IValue> stack = fmap<IValue>(inputs);
-  Code code(graph, "test");
-  InterpreterState(code).run(stack);
-  TORCH_INTERNAL_ASSERT(!stack.empty());
-  // Graph outputs that are handled below:
-  //   * A list of Tensors.
-  //   * 1 Tensor.
-  if (stack.front().isTensorList()) {
-    return stack.front().toTensorVector();
-  }
-  TORCH_INTERNAL_ASSERT(stack.front().isTensor());
-  return {stack.front().toTensor()};
-}
-
-} // namespace
-
 TEST(ConcatOptTest, SimpleCommonInputsEliminationPrefix) {
   auto graph = std::make_shared<Graph>();
 
@@ -64,7 +34,7 @@ TEST(ConcatOptTest, SimpleCommonInputsEliminationPrefix) {
   ASSERT_TRUE(EliminateConcatCommonInputs(graph));
   graph->lint();
   auto opt_outputs = runGraph(graph, inputs);
-  checkOutputs(orig_outputs, opt_outputs);
+  ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
 
   // Graph after EliminateConcatCommonInputs:
   //  graph(%0 : ...,
@@ -109,7 +79,7 @@ TEST(ConcatOptTest, SimpleCommonInputsEliminationSuffix) {
   ASSERT_TRUE(EliminateConcatCommonInputs(graph));
   graph->lint();
   auto opt_outputs = runGraph(graph, inputs);
-  checkOutputs(orig_outputs, opt_outputs);
+  ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
 
   // Graph after EliminateConcatCommonInputs:
   //  graph(%0 : ...,
@@ -161,7 +131,7 @@ TEST(ConcatOptTest, CommonInputsEliminationWithDifferentOrderInputs) {
   graph->lint();
   auto opt_outputs = runGraph(graph, inputs);
 
-  checkOutputs(orig_outputs, opt_outputs);
+  ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
 
   // No optimizations should have happened in this case since the inputs
   // to the `cat` are in different order.
@@ -198,7 +168,7 @@ TEST(ConcatOptTest, MoreCommonInputsElimination) {
   ASSERT_TRUE(EliminateConcatCommonInputs(graph));
   graph->lint();
   auto opt_outputs = runGraph(graph, inputs);
-  checkOutputs(orig_outputs, opt_outputs);
+  ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
 
   testing::FileCheck()
       .check_count("= prim::VarConcat(%0, %1, %5)", 1, /*exactly*/ true)
@@ -233,7 +203,7 @@ TEST(ConcatOptTest, ExpandConcat) {
   graph->lint();
   auto opt_outputs = runGraph(graph, inputs);
 
-  checkOutputs(orig_outputs, opt_outputs);
+  ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
 
   // After full concat optimization we should have the following graph:
   //
@@ -289,7 +259,7 @@ TEST(ConcatOptTest, ConcatWithoutResultShape) {
   graph->lint();
   auto opt_outputs = runGraph(graph, inputs);
 
-  checkOutputs(orig_outputs, opt_outputs);
+  ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
 
   // No optimizations should have happened in this case since the output
   // shape of `aten::cat` is not known.
@@ -324,7 +294,7 @@ TEST(ConcatOptTest, ConcatWithoutInputShape) {
   graph->lint();
   auto opt_outputs = runGraph(graph, inputs);
 
-  checkOutputs(orig_outputs, opt_outputs);
+  ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
 
   // No optimizations should have happened in this case since the shape of %5,
   // which is an input to `aten::cat`, is not known.
@@ -361,7 +331,7 @@ TEST(ConcatOptTest, UseVariadicCat) {
   graph->lint();
   auto opt_outputs = runGraph(graph, inputs);
 
-  checkOutputs(orig_outputs, opt_outputs);
+  ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
 
   // After replacing `aten::cat` with `prim::VarConcat` we should have the
   // following graph:
@@ -406,7 +376,7 @@ TEST(OptimizeConcatTest, UseVariadicCatReplaceMultiple) {
   graph->lint();
   auto opt_outputs = runGraph(graph, inputs);
 
-  checkOutputs(orig_outputs, opt_outputs);
+  ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
 
   // After full concat optimization we should have the following graph:
   //
@@ -446,7 +416,7 @@ TEST(ConcatOptTest, UseVariadicCatWithMultipleListUses) {
   graph->lint();
   auto opt_outputs = runGraph(graph, inputs);
 
-  checkOutputs(orig_outputs, opt_outputs);
+  ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
 
   // After replacing `aten::cat` with `prim::VarConcat` we should have the
   // following graph:
@@ -488,7 +458,7 @@ TEST(ConcatOptTest, UseVariadicCatWithListMutationAfterCat) {
   ASSERT_TRUE(UseVariadicCat(graph));
   graph->lint();
   auto opt_outputs = runGraph(graph, inputs);
-  checkOutputs(orig_outputs, opt_outputs);
+  ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
 
   // The input list to `aten::cat` is mutated only after `aten::cat` op. So,
   // it should have been replaced with `prim::VarConcat`. The transformed graph
@@ -534,7 +504,7 @@ TEST(ConcatOptTest, UseVariadicCatWithListMutationBeforeCat) {
     ASSERT_FALSE(UseVariadicCat(graph));
     graph->lint();
     auto opt_outputs = runGraph(graph, inputs);
-    checkOutputs(orig_outputs, opt_outputs);
+    ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
 
     // No transformation should have happened since the `prim::ListConstruct` is
     // mutated before `aten::cat`.
@@ -549,7 +519,7 @@ TEST(ConcatOptTest, UseVariadicCatWithListMutationBeforeCat) {
     ASSERT_TRUE(RemoveListMutationAndUseVariadicCat(graph));
     graph->lint();
     auto opt_outputs = runGraph(graph, inputs);
-    checkOutputs(orig_outputs, opt_outputs);
+    ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
 
     // The mutation of the list must be removed and the `aten::cat` op must
     // be replaced with the `prim::VarConcat` op in the graph. The transformed
@@ -602,7 +572,7 @@ TEST(ConcatOptTest, UseVariadicCatWithMultipleListMutations) {
   ASSERT_TRUE(RemoveListMutationAndUseVariadicCat(graph));
   graph->lint();
   auto opt_outputs = runGraph(graph, inputs);
-  checkOutputs(orig_outputs, opt_outputs);
+  ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
 
   // All the mutations of the list must be removed and the `aten::cat` ops must
   // be replaced with `prim::VarConcat` ops in the graph. The transformed graph
@@ -659,7 +629,7 @@ TEST(
   ASSERT_TRUE(EliminateConcatCommonInputs(graph));
   graph->lint();
   auto opt_outputs = runGraph(graph, inputs);
-  checkOutputs(orig_outputs, opt_outputs);
+  ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
 
   // After performing:
   //     * Remove list mutation
index 7750ba8..27667f0 100644 (file)
@@ -198,6 +198,7 @@ bool checkRtol(const at::Tensor& diff, const std::vector<at::Tensor> inputs) {
   }
   return diff.abs().max().item<float>() < 2e-6 * maxValue;
 }
+
 bool almostEqual(const at::Tensor& a, const at::Tensor& b) {
   return checkRtol(a - b, {a, b});
 }
@@ -206,6 +207,20 @@ bool exactlyEqual(const at::Tensor& a, const at::Tensor& b) {
   return (a - b).abs().max().item<float>() == 0.f;
 }
 
+bool exactlyEqual(
+    const std::vector<at::Tensor>& a,
+    const std::vector<at::Tensor>& b) {
+  if (a.size() != b.size()) {
+    return false;
+  }
+  for (size_t i = 0; i < a.size(); ++i) {
+    if (!exactlyEqual(a[i], b[i])) {
+      return false;
+    }
+  }
+  return true;
+}
+
 std::pair<at::Tensor, at::Tensor> lstm(
     at::Tensor input,
     at::Tensor hx,
@@ -248,5 +263,22 @@ RegisterOperators reg({
 });
 } // namespace
 
+std::vector<at::Tensor> runGraph(
+    std::shared_ptr<Graph> graph,
+    const std::vector<at::Tensor>& inputs) {
+  std::vector<IValue> stack = fmap<IValue>(inputs);
+  Code code(graph, "test");
+  InterpreterState(code).run(stack);
+  TORCH_INTERNAL_ASSERT(!stack.empty());
+  // Graph outputs that are handled below:
+  //   * A list of Tensors.
+  //   * 1 Tensor.
+  if (stack.front().isTensorList()) {
+    return stack.front().toTensorVector();
+  }
+  TORCH_INTERNAL_ASSERT(stack.front().isTensor());
+  return {stack.front().toTensor()};
+}
+
 } // namespace jit
 } // namespace torch
index 676759d..5e640ae 100644 (file)
@@ -88,6 +88,13 @@ bool checkRtol(const at::Tensor& diff, const std::vector<at::Tensor> inputs);
 bool almostEqual(const at::Tensor& a, const at::Tensor& b);
 
 bool exactlyEqual(const at::Tensor& a, const at::Tensor& b);
+bool exactlyEqual(
+    const std::vector<at::Tensor>& a,
+    const std::vector<at::Tensor>& b);
+
+std::vector<at::Tensor> runGraph(
+    std::shared_ptr<Graph> graph,
+    const std::vector<at::Tensor>& inputs);
 
 std::pair<at::Tensor, at::Tensor> lstm(
     at::Tensor input,
index e20d973..2eabbd0 100644 (file)
@@ -244,6 +244,7 @@ core_sources_full_mobile_no_backend_interface = [
     "torch/csrc/jit/passes/symbolic_shape_analysis.cpp",
     "torch/csrc/jit/passes/specialize_autogradzero.cpp",
     "torch/csrc/jit/passes/update_differentiable_graph_requires_grad.cpp",
+    "torch/csrc/jit/passes/variadic_ops.cpp",
     "torch/csrc/jit/passes/subgraph_rewrite.cpp",
     "torch/csrc/jit/passes/tensorexpr_fuser.cpp",
     "torch/csrc/jit/passes/utils/memory_dag.cpp",
index aa2573e..81c8a67 100644 (file)
@@ -497,95 +497,5 @@ void ExpandConcatAndEliminateRedundancy(const std::shared_ptr<Graph>& graph) {
   GRAPH_DUMP("After expanding Concat and eliminating redundancy", graph);
 }
 
-namespace {
-
-class VariadicCatUpdater {
- public:
-  explicit VariadicCatUpdater(std::shared_ptr<Graph> graph)
-      : graph_(std::move(graph)) {}
-
-  bool run() {
-    collectCatNodes(graph_->block());
-    bool changed = false;
-    for (auto c : cat_nodes_) {
-      changed = replaceWithVariadicCat(c) || changed;
-    }
-    return changed;
-  }
-
- private:
-  void collectCatNodes(Block* block) {
-    for (auto node : block->nodes()) {
-      if (node->kind() == aten::cat) {
-        cat_nodes_.push_back(node);
-      }
-      for (Block* b : node->blocks()) {
-        collectCatNodes(b);
-      }
-    }
-  }
-
-  bool replaceWithVariadicCat(Node* cat) {
-    if (cat->input(0)->node()->kind() != prim::ListConstruct) {
-      return false;
-    }
-    auto list = cat->input(0)->node();
-    // We do not transform cat ops whose list input can not be moved to the
-    // position before cat. This in turn implies that there is some mutation
-    // of the input list before cat.
-    if (!getOrCreateAliasDb()->couldMoveBeforeTopologically(list, cat)) {
-      return false;
-    }
-    std::vector<Value*> inputs = list->inputs().vec();
-    inputs.push_back(cat->input(1));
-    auto var_cat = cat->owningGraph()->create(prim::VarConcat, inputs);
-    GRAPH_UPDATE("Adding\n", *var_cat);
-    var_cat->insertBefore(cat);
-    GRAPH_UPDATE("Replacing\n", *cat, "with\n", *var_cat);
-    cat->output()->replaceAllUsesWith(var_cat->output());
-    GRAPH_UPDATE("Deleting\n", *cat);
-    cat->destroy();
-    if (!list->hasUses()) {
-      GRAPH_UPDATE("Deleting\n", *list);
-      list->destroy();
-    }
-    return true;
-  }
-
-  AliasDb* getOrCreateAliasDb() {
-    if (!aliasDb_) {
-      aliasDb_ = std::make_unique<AliasDb>(graph_);
-    }
-    return aliasDb_.get();
-  }
-
-  std::shared_ptr<Graph> graph_;
-  std::unique_ptr<AliasDb> aliasDb_ = nullptr;
-
-  std::vector<Node*> cat_nodes_;
-};
-
-} // namespace
-
-bool UseVariadicCat(const std::shared_ptr<Graph>& graph) {
-  GRAPH_DUMP("Before VariadicCat", graph);
-  bool changed = VariadicCatUpdater(graph).run();
-  if (changed) {
-    GRAPH_DUMP("After VariadicCat", graph);
-  }
-  return changed;
-}
-
-bool RemoveListMutationAndUseVariadicCat(const std::shared_ptr<Graph>& graph) {
-  bool changed_in_last_iter = true;
-  bool changed = false;
-  while (changed_in_last_iter) {
-    changed_in_last_iter = RemoveListMutation(graph);
-    changed_in_last_iter = changed_in_last_iter || UseVariadicCat(graph);
-    changed = changed || changed_in_last_iter;
-  }
-  return changed;
-}
-
 } // namespace jit
 } // namespace torch
index b82dc25..ef4d943 100644 (file)
@@ -13,12 +13,5 @@ TORCH_API bool EliminateConcatCommonInputs(const std::shared_ptr<Graph>& graph);
 TORCH_API void ExpandConcatAndEliminateRedundancy(
     const std::shared_ptr<Graph>& graph);
 
-// Replaces the `aten::cat` ops in the given graph with variadic cat ops.
-// Returns true if the graph is modified.
-TORCH_API bool UseVariadicCat(const std::shared_ptr<Graph>& graph);
-
-TORCH_API bool RemoveListMutationAndUseVariadicCat(
-    const std::shared_ptr<Graph>& graph);
-
 } // namespace jit
 } // namespace torch
diff --git a/torch/csrc/jit/passes/variadic_ops.cpp b/torch/csrc/jit/passes/variadic_ops.cpp
new file mode 100644 (file)
index 0000000..aeb7074
--- /dev/null
@@ -0,0 +1,126 @@
+#include <torch/csrc/jit/passes/variadic_ops.h>
+
+#include <torch/csrc/jit/ir/alias_analysis.h>
+#include <torch/csrc/jit/jit_log.h>
+#include <torch/csrc/jit/passes/remove_mutation.h>
+
+namespace torch {
+namespace jit {
+
+namespace {
+
+class VariadicUpdater {
+ public:
+  explicit VariadicUpdater(
+      std::shared_ptr<Graph> graph,
+      NodeKind op,
+      NodeKind variadic_op)
+      : graph_(std::move(graph)), op_(op), variadic_op_(variadic_op) {}
+
+  bool run() {
+    collectOpNodes(graph_->block());
+    bool changed = false;
+    for (auto n : op_nodes_) {
+      changed |= replaceWithVariadicOp(n);
+    }
+    return changed;
+  }
+
+ private:
+  void collectOpNodes(Block* block) {
+    for (auto node : block->nodes()) {
+      if (node->kind() == op_) {
+        op_nodes_.push_back(node);
+      }
+      for (Block* b : node->blocks()) {
+        collectOpNodes(b);
+      }
+    }
+  }
+
+  bool replaceWithVariadicOp(Node* op_node) {
+    if (op_node->input(0)->node()->kind() != prim::ListConstruct) {
+      return false;
+    }
+    auto list = op_node->input(0)->node();
+    // We do not transform ops whose list input can not be moved to the
+    // position before op. This in turn implies that there is some mutation
+    // of the input list before op.
+    if (!getOrCreateAliasDb()->couldMoveBeforeTopologically(list, op_node)) {
+      return false;
+    }
+    std::vector<Value*> inputs = list->inputs().vec();
+    // Add non-list inputs
+    for (size_t i = 1; i < op_node->inputs().size(); ++i) {
+      inputs.push_back(op_node->input(i));
+    }
+    auto var_op_node = op_node->owningGraph()->create(variadic_op_, inputs);
+    GRAPH_UPDATE("Adding\n", *var_op_node);
+    var_op_node->insertBefore(op_node);
+    GRAPH_UPDATE("Replacing\n", *op_node, "with\n", *var_op_node);
+    op_node->output()->replaceAllUsesWith(var_op_node->output());
+    GRAPH_UPDATE("Deleting\n", *op_node);
+    op_node->destroy();
+    if (!list->hasUses()) {
+      GRAPH_UPDATE("Deleting\n", *list);
+      list->destroy();
+    }
+    return true;
+  }
+
+  AliasDb* getOrCreateAliasDb() {
+    if (!aliasDb_) {
+      aliasDb_ = std::make_unique<AliasDb>(graph_);
+    }
+    return aliasDb_.get();
+  }
+
+  std::shared_ptr<Graph> graph_;
+  std::unique_ptr<AliasDb> aliasDb_ = nullptr;
+
+  std::vector<Node*> op_nodes_;
+
+  NodeKind op_;
+  NodeKind variadic_op_;
+};
+
+} // namespace
+
+bool UseVariadicOp(
+    const std::shared_ptr<Graph>& graph,
+    NodeKind op,
+    NodeKind variadic_op) {
+  const std::string pass_name = std::string("variadic ") + op.toQualString();
+  GRAPH_DUMP("Before " + pass_name, graph);
+  bool changed = VariadicUpdater(graph, op, variadic_op).run();
+  if (changed) {
+    GRAPH_DUMP("After " + pass_name, graph);
+  }
+  return changed;
+}
+
+bool RemoveListMutationAndUseVariadicOp(
+    const std::shared_ptr<Graph>& graph,
+    NodeKind op,
+    NodeKind variadic_op) {
+  bool changed_in_last_iter = true;
+  bool changed = false;
+  while (changed_in_last_iter) {
+    changed_in_last_iter = RemoveListMutation(graph);
+    changed_in_last_iter =
+        UseVariadicOp(graph, op, variadic_op) || changed_in_last_iter;
+    changed = changed || changed_in_last_iter;
+  }
+  return changed;
+}
+
+bool UseVariadicCat(const std::shared_ptr<Graph>& graph) {
+  return UseVariadicOp(graph, aten::cat, prim::VarConcat);
+}
+
+bool RemoveListMutationAndUseVariadicCat(const std::shared_ptr<Graph>& graph) {
+  return RemoveListMutationAndUseVariadicOp(graph, aten::cat, prim::VarConcat);
+}
+
+} // namespace jit
+} // namespace torch
diff --git a/torch/csrc/jit/passes/variadic_ops.h b/torch/csrc/jit/passes/variadic_ops.h
new file mode 100644 (file)
index 0000000..1c52e95
--- /dev/null
@@ -0,0 +1,16 @@
+#pragma once
+
+#include <torch/csrc/jit/ir/ir.h>
+
+namespace torch {
+namespace jit {
+
+// Replaces the `aten::cat` ops in the given graph with variadic cat ops.
+// Returns true if the graph is modified.
+TORCH_API bool UseVariadicCat(const std::shared_ptr<Graph>& graph);
+
+TORCH_API bool RemoveListMutationAndUseVariadicCat(
+    const std::shared_ptr<Graph>& graph);
+
+} // namespace jit
+} // namespace torch
index 1ee69a6..4219be5 100644 (file)
@@ -9,11 +9,11 @@
 #include <caffe2/core/timer.h>
 #include <torch/csrc/jit/ir/alias_analysis.h>
 #include <torch/csrc/jit/passes/canonicalize.h>
-#include <torch/csrc/jit/passes/concat_opt.h>
 #include <torch/csrc/jit/passes/dead_code_elimination.h>
 #include <torch/csrc/jit/passes/freeze_module.h>
 #include <torch/csrc/jit/passes/remove_mutation.h>
 #include <torch/csrc/jit/passes/subgraph_rewrite.h>
+#include <torch/csrc/jit/passes/variadic_ops.h>
 #include <torch/csrc/jit/runtime/static/ops.h>
 #include <torch/csrc/jit/runtime/static/passes.h>
 #include <torch/csrc/jit/runtime/vararg_functions.h>