[JIT] Improve BatchMM mutability handling (#65097)
authorDavid Berard <dberard@fb.com>
Thu, 16 Sep 2021 17:44:33 +0000 (10:44 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Thu, 16 Sep 2021 17:46:14 +0000 (10:46 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/65097

Previously, BatchMM would skip any block containing any mutable
operators. Now it will avoid batching any operation whose inputs or
outputs are ever mutated. Specifically: consider a tree of ADD, T,
and MM nodes rooted at an ADD node.  If any input or output to any
node in the tree is ever mutated, then the entire tree will be ignored
by BatchMM.

Test Plan: python test/test_jit.py TestBatchMM

Reviewed By: eellison

Differential Revision: D30973515

Pulled By: davidberard98

fbshipit-source-id: 9d836faa1ef0c9e3fefe0ffc0bd265f275471f48

test/jit/test_batch_mm.py [new file with mode: 0644]
test/test_jit.py
torch/csrc/jit/passes/batch_mm.cpp
torch/csrc/jit/python/init.cpp

diff --git a/test/jit/test_batch_mm.py b/test/jit/test_batch_mm.py
new file mode 100644 (file)
index 0000000..f04eef8
--- /dev/null
@@ -0,0 +1,288 @@
+import torch
+from torch.testing import FileCheck
+from torch.testing._internal.jit_utils import JitTestCase
+
+if __name__ == "__main__":
+    raise RuntimeError(
+        "This test file is not meant to be run directly, use:\n\n"
+        "\tpython test/test_jit.py TESTNAME\n\n"
+        "instead."
+    )
+
+
+class TestBatchMM(JitTestCase):
+    @staticmethod
+    def _get_test_tensors(n: int):
+        return [
+            torch.tensor([[1 + x, 2 + x, 3 + x], [4 + x, 5 + x, 6 + x]])
+            if x % 2 == 0
+            else torch.tensor([[1 + x, 2 + x], [3 + x, 4 + x], [5 + x, 6 + x]])
+            for x in range(n)
+        ]
+
+    def test_batch_mm_no_mutation(self):
+        def test_batch_mm(
+            T1: torch.Tensor,
+            T2: torch.Tensor,
+            T3: torch.Tensor,
+            T4: torch.Tensor,
+            T5: torch.Tensor,
+            T6: torch.Tensor,
+            T7: torch.Tensor,
+            T8: torch.Tensor,
+        ):
+            return (
+                torch.mm(T1, T2)
+                + torch.mm(T3, T4)
+                + torch.mm(T5, T6)
+                + torch.mm(T7, T8)
+            )
+
+        test_batch_mm_scripted = torch.jit.script(test_batch_mm)
+
+        tensors = TestBatchMM._get_test_tensors(8)
+        expected = test_batch_mm(*tensors)
+
+        FileCheck().check_count("aten::mm", 4, exactly=True).run(
+            test_batch_mm_scripted.graph
+        )
+        self.run_pass("batch_mm", test_batch_mm_scripted.graph)
+        FileCheck().check_count("prim::MMTreeReduce", 1, exactly=True).run(
+            test_batch_mm_scripted.graph
+        )
+
+        actual = test_batch_mm_scripted(*tensors)
+        self.assertEqual(expected, actual, atol=1e-9, rtol=1e-9)
+
+
+    def test_batch_mm_permitted_mutation(self):
+        def test_batch_mm(
+            T1: torch.Tensor,
+            T2: torch.Tensor,
+            T3: torch.Tensor,
+            T4: torch.Tensor,
+            T5: torch.Tensor,
+            T6: torch.Tensor,
+            T7: torch.Tensor,
+            T8: torch.Tensor,
+        ):
+            result = {}
+            result["product"] = (
+                torch.mm(T1, T2)
+                + torch.mm(T3, T4)
+                + torch.mm(T5, T6)
+                + torch.mm(T7, T8)
+            )
+            result["constant"] = torch.tensor([42.0])
+            return result
+
+        test_batch_mm_scripted = torch.jit.script(test_batch_mm)
+
+        tensors = TestBatchMM._get_test_tensors(8)
+        expected = test_batch_mm(*tensors)
+
+        FileCheck().check_count("aten::mm", 4, exactly=True).run(
+            test_batch_mm_scripted.graph
+        )
+        self.run_pass("batch_mm", test_batch_mm_scripted.graph)
+        FileCheck().check_count("prim::MMTreeReduce", 1, exactly=True).run(
+            test_batch_mm_scripted.graph
+        )
+
+        actual = test_batch_mm_scripted(*tensors)
+        self.assertEqual(expected, actual, atol=1e-9, rtol=1e-9)
+
+    def test_batch_mm_prohibited_mutation(self):
+        @torch.jit.script
+        def test_batch_mm(n: int):
+            T1 = torch.zeros((n, n))
+            T2 = torch.zeros((n, n))
+            T3 = torch.zeros((n, n))
+            T4 = torch.zeros((n, n))
+            T5 = torch.zeros((n, n))
+            T6 = torch.zeros((n, n))
+            T7 = torch.zeros((n, n))
+            T8 = torch.zeros((n, n))
+            torch.relu_(T1)
+            result = (
+                torch.mm(T1, T2)
+                + torch.mm(T3, T4)
+                + torch.mm(T5, T6)
+                + torch.mm(T7, T8)
+            )
+            return result
+
+        FileCheck().check_count("aten::mm", 4, exactly=True).run(test_batch_mm.graph)
+        self.run_pass("batch_mm", test_batch_mm.graph)
+        FileCheck().check_count("aten::mm", 4, exactly=True).check_not(
+            "prim::MMTreeReduce"
+        ).run(test_batch_mm.graph)
+
+    def test_batch_mm_prohibited_mutation_multiple_adds(self):
+        @torch.jit.script
+        def test_batch_mm(n: int):
+            T1 = torch.zeros((n, n))
+            T2 = torch.zeros((n, n))
+            T3 = torch.zeros((n, n))
+            T4 = torch.zeros((n, n))
+            T5 = torch.zeros((n, n))
+            T6 = torch.zeros((n, n))
+            T7 = torch.zeros((n, n))
+            T8 = torch.zeros((n, n))
+            T9 = torch.zeros((n, n))
+            T10 = torch.zeros((n, n))
+            torch.relu_(T1)
+            result = {}
+            result["no_mutated_parameters"] = (
+                torch.mm(T2, T3)
+                + torch.mm(T4, T5)
+                + torch.mm(T6, T7)
+                + torch.mm(T8, T9)
+            )
+            result["all_parameters"] = (
+                torch.mm(T1, T2)
+                + torch.mm(T3, T4)
+                + torch.mm(T5, T6)
+                + torch.mm(T7, T8)
+                + torch.mm(T9, T10)
+            )
+            return result
+
+        self.run_pass("batch_mm", test_batch_mm.graph)
+        FileCheck().check_count("prim::MMTreeReduce", 1, exactly=True).check_count(
+            "aten::mm", 5, exactly=True
+        ).run(test_batch_mm.graph)
+
+    def test_batch_mm_prohibited_mutation_if_node(self):
+        @torch.jit.script
+        def test_batch_mm(n: int, use_t1: bool):
+            T1 = torch.zeros((n, n))
+            T2 = torch.zeros((n, n))
+            T3 = torch.zeros((n, n))
+            T4 = torch.zeros((n, n))
+            T5 = torch.zeros((n, n))
+            T6 = torch.zeros((n, n))
+            T7 = torch.zeros((n, n))
+            T8 = torch.zeros((n, n))
+            T9 = torch.zeros((n, n))
+            T10 = torch.zeros((n, n))
+            if use_t1:
+                torch.relu_(T1)
+                return (
+                    torch.mm(T1, T2)
+                    + torch.mm(T3, T4)
+                    + torch.mm(T5, T6)
+                    + torch.mm(T7, T8)
+                    + torch.mm(T9, T10)
+                )
+            else:
+                return (
+                    torch.mm(T2, T3)
+                    + torch.mm(T4, T5)
+                    + torch.mm(T6, T7)
+                    + torch.mm(T8, T9)
+                )
+
+        self.run_pass("batch_mm", test_batch_mm.graph)
+        FileCheck().check_count("aten::mm", 5, exactly=True).check_count(
+            "prim::MMTreeReduce", 1, exactly=True
+        ).run(test_batch_mm.graph)
+
+    def test_batch_mm_side_permitted_mutation(self):
+        @torch.jit.script
+        def test_batch_mm(n: int):
+            result = {}
+            A = torch.zeros((n, n))
+            T1 = torch.zeros((n, n))
+            T2 = torch.zeros((n, n))
+            T3 = torch.zeros((n, n))
+            T4 = torch.zeros((n, n))
+            T5 = torch.zeros((n, n))
+            T6 = torch.zeros((n, n))
+            T7 = torch.zeros((n, n))
+            T8 = torch.zeros((n, n))
+            result["T1"] = torch.mm(A, T1)
+            result["T2"] = torch.mm(A, T2)
+            result["T3"] = torch.mm(A, T3)
+            result["T4"] = torch.mm(A, T4)
+            result["T5"] = torch.mm(A, T5)
+            result["T6"] = torch.mm(A, T6)
+            result["T7"] = torch.mm(A, T7)
+            result["T8"] = torch.mm(A, T8)
+            return result
+
+        FileCheck().check_count("aten::mm", 8, exactly=True).run(test_batch_mm.graph)
+        self.run_pass("batch_mm", test_batch_mm.graph)
+        FileCheck().check_count("prim::MMBatchSide", 1, exactly=True).check_not(
+            "aten::mm"
+        ).run(test_batch_mm.graph)
+
+    def test_batch_mm_side_prohibited_mutation_uncommon_side(self):
+        @torch.jit.script
+        def test_batch_mm(n: int):
+            A = torch.zeros((n, n))
+            T1 = torch.zeros((n, n))
+            T2 = torch.zeros((n, n))
+            T3 = torch.zeros((n, n))
+            T4 = torch.zeros((n, n))
+            T5 = torch.zeros((n, n))
+            T6 = torch.zeros((n, n))
+            T7 = torch.zeros((n, n))
+            T8 = torch.zeros((n, n))
+            T9 = torch.zeros((n, n))
+            T10 = torch.zeros((n, n))
+            torch.relu_(T1)
+            result = {}
+            result["T1"] = torch.mm(A, T1)
+            result["T2"] = torch.mm(A, T2)
+            result["T3"] = torch.mm(A, T3)
+            result["T4"] = torch.mm(A, T4)
+            result["T5"] = torch.mm(A, T5)
+            result["T6"] = torch.mm(A, T6)
+            result["T7"] = torch.mm(A, T7)
+            result["T8"] = torch.mm(A, T8)
+            result["T9"] = torch.mm(A, T9)
+            result["T10"] = torch.mm(A, T10)
+            return result
+
+        FileCheck().check_count("aten::mm", 10, exactly=True).run(test_batch_mm.graph)
+        self.run_pass("batch_mm", test_batch_mm.graph)
+
+        FileCheck().check_count("aten::mm", 1, exactly=True).run(test_batch_mm.graph)
+        FileCheck().check_count("prim::MMBatchSide", 1, exactly=True).run(
+            test_batch_mm.graph
+        )
+
+    def test_batch_mm_side_prohibited_mutation_common_side(self):
+        @torch.jit.script
+        def test_batch_mm(n: int):
+            A = torch.zeros((n, n))
+            T1 = torch.zeros((n, n))
+            T2 = torch.zeros((n, n))
+            T3 = torch.zeros((n, n))
+            T4 = torch.zeros((n, n))
+            T5 = torch.zeros((n, n))
+            T6 = torch.zeros((n, n))
+            T7 = torch.zeros((n, n))
+            T8 = torch.zeros((n, n))
+            T9 = torch.zeros((n, n))
+            T10 = torch.zeros((n, n))
+            torch.relu_(A)
+            result = {}
+            result["T1"] = torch.mm(A, T1)
+            result["T2"] = torch.mm(A, T2)
+            result["T3"] = torch.mm(A, T3)
+            result["T4"] = torch.mm(A, T4)
+            result["T5"] = torch.mm(A, T5)
+            result["T6"] = torch.mm(A, T6)
+            result["T7"] = torch.mm(A, T7)
+            result["T8"] = torch.mm(A, T8)
+            result["T9"] = torch.mm(A, T9)
+            result["T10"] = torch.mm(A, T10)
+            return result
+
+        FileCheck().check_count("aten::mm", 10, exactly=True).run(test_batch_mm.graph)
+        self.run_pass("batch_mm", test_batch_mm.graph)
+        FileCheck().check_count("aten::mm", 10, exactly=True).check_not(
+            "prim::MMBatchSide"
+        ).run(test_batch_mm.graph)
index a6589ad..cac6618 100644 (file)
@@ -64,6 +64,7 @@ from jit.test_aten_pow import TestAtenPow  # noqa: F401
 from jit.test_optimize_for_mobile_preserve_debug_info import TestOptimizeForMobilePreserveDebugInfo  # noqa: F401
 from jit.test_union import TestUnion  # noqa: F401
 from jit.test_models import MnistNet
+from jit.test_batch_mm import TestBatchMM  # noqa: F401
 
 # Torch
 from torch import Tensor
index 944e278..5e9ebea 100644 (file)
@@ -9,6 +9,7 @@
 #include <torch/csrc/jit/passes/dead_code_elimination.h>
 #include <torch/csrc/jit/passes/peephole.h>
 #include <torch/csrc/jit/runtime/custom_operator.h>
+#include <torch/csrc/jit/runtime/graph_iterator.h>
 
 #include <ATen/ATen.h>
 #include <algorithm>
@@ -249,22 +250,26 @@ struct TreeToken {
 
 enum class Side { LHS, RHS };
 
-void BatchMMTreeReduce(Block* block) {
+void BatchMMTreeReduce(Block* block, AliasDb& alias_db) {
   auto graph = block->owningGraph();
 
   // Look for trees in the block
   std::unordered_map<Node*, TreeToken> tokens;
   for (auto node : block->nodes()) {
-    if (node->matches("aten::mm(Tensor self, Tensor mat2) -> Tensor")) {
+    if (node->matches("aten::mm(Tensor self, Tensor mat2) -> Tensor") &&
+        !alias_db.hasWriters(node)) {
       tokens[node] = TreeToken::mm(node);
-    } else if (node->matches("aten::t(Tensor self) -> Tensor")) {
+    } else if (
+        node->matches("aten::t(Tensor self) -> Tensor") &&
+        !alias_db.hasWriters(node)) {
       auto input_it = tokens.find(node->input()->node());
       if (input_it != tokens.end()) {
         tokens[node] = TreeToken::transpose(node, input_it->second);
       }
     } else if (
         node->matches(
-            "aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor")) {
+            "aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor") &&
+        !alias_db.hasWriters(node)) {
       Node* lhs = node->inputs()[0]->node();
       Node* rhs = node->inputs()[1]->node();
       auto lhs_it = tokens.find(lhs);
@@ -285,7 +290,7 @@ void BatchMMTreeReduce(Block* block) {
       }
     } else {
       for (auto block : node->blocks()) {
-        BatchMMTreeReduce(block);
+        BatchMMTreeReduce(block, alias_db);
       }
     }
   }
@@ -394,7 +399,8 @@ std::pair<std::vector<Node*>, std::vector<Node*>> gatherIndependentMMUses(
   std::vector<Node*> rhses; // Like above, but rhs
   for (Use u : value->uses()) {
     if (u.user->owningBlock() == block &&
-        u.user->matches("aten::mm(Tensor self, Tensor mat2) -> Tensor")) {
+        u.user->matches("aten::mm(Tensor self, Tensor mat2) -> Tensor") &&
+        !alias_db.hasWriters(u.user)) {
       if (u.offset == 0 && u.user->inputs()[1] != value) {
         lhses.push_back(u.user);
       } else if (u.offset == 1 && u.user->inputs()[0] != value) {
@@ -432,7 +438,8 @@ void BatchMMSide(Block* block, AliasDb& alias_db) {
 
   std::unordered_set<Value*> considered_values;
   for (Node* node : block->nodes()) {
-    if (node->matches("aten::mm(Tensor self, Tensor mat2) -> Tensor")) {
+    if (node->matches("aten::mm(Tensor self, Tensor mat2) -> Tensor") &&
+        !alias_db.hasWriters(node)) {
       for (Value* input : node->inputs()) {
         if (/*bool not_inserted = */ !considered_values.emplace(input).second) {
           continue;
@@ -465,13 +472,23 @@ bool hasMutableOperators(Block* block) {
   return false;
 }
 
+bool hasMMOperators(std::shared_ptr<Graph>& graph) {
+  DepthFirstGraphNodeIterator it(graph);
+  Node* n = nullptr;
+  while ((n = it.next()) != nullptr) {
+    if (n->matches("aten::mm(Tensor self, Tensor mat2) -> Tensor")) {
+      return true;
+    }
+  }
+  return false;
+}
+
 void BatchMM(std::shared_ptr<Graph>& graph) {
-  if (hasMutableOperators(graph->block())) {
-    // TODO(suo): make BatchMM mutability-safe
+  if (!hasMMOperators(graph)) {
     return;
   }
   AliasDb alias_db(graph);
-  BatchMMTreeReduce(graph->block());
+  BatchMMTreeReduce(graph->block(), alias_db);
   BatchMMSide(graph->block(), alias_db);
   EliminateDeadCode(graph);
   // It's possible that transpose rearrangements have created sequences of
index 35197e4..c4120d1 100644 (file)
@@ -10,6 +10,7 @@
 #include <torch/csrc/jit/frontend/tracer.h>
 #include <torch/csrc/jit/ir/irparser.h>
 #include <torch/csrc/jit/jit_log.h>
+#include <torch/csrc/jit/passes/batch_mm.h>
 #include <torch/csrc/jit/passes/canonicalize.h>
 #include <torch/csrc/jit/passes/canonicalize_graph_fuser_ops.h>
 #include <torch/csrc/jit/passes/common_subexpression_elimination.h>
@@ -893,6 +894,7 @@ void initJITBindings(PyObject* module) {
             }
             return retval;
           })
+      .def("_jit_pass_batch_mm", BatchMM)
       .def("_jit_decay_packed_param_input_types", [](Graph& g) {
         for (Value* i : g.inputs()) {
           if (i->type() ==