--- /dev/null
+import torch
+from torch.testing import FileCheck
+from torch.testing._internal.jit_utils import JitTestCase
+
+if __name__ == "__main__":
+ raise RuntimeError(
+ "This test file is not meant to be run directly, use:\n\n"
+ "\tpython test/test_jit.py TESTNAME\n\n"
+ "instead."
+ )
+
+
+class TestBatchMM(JitTestCase):
+ @staticmethod
+ def _get_test_tensors(n: int):
+ return [
+ torch.tensor([[1 + x, 2 + x, 3 + x], [4 + x, 5 + x, 6 + x]])
+ if x % 2 == 0
+ else torch.tensor([[1 + x, 2 + x], [3 + x, 4 + x], [5 + x, 6 + x]])
+ for x in range(n)
+ ]
+
+ def test_batch_mm_no_mutation(self):
+ def test_batch_mm(
+ T1: torch.Tensor,
+ T2: torch.Tensor,
+ T3: torch.Tensor,
+ T4: torch.Tensor,
+ T5: torch.Tensor,
+ T6: torch.Tensor,
+ T7: torch.Tensor,
+ T8: torch.Tensor,
+ ):
+ return (
+ torch.mm(T1, T2)
+ + torch.mm(T3, T4)
+ + torch.mm(T5, T6)
+ + torch.mm(T7, T8)
+ )
+
+ test_batch_mm_scripted = torch.jit.script(test_batch_mm)
+
+ tensors = TestBatchMM._get_test_tensors(8)
+ expected = test_batch_mm(*tensors)
+
+ FileCheck().check_count("aten::mm", 4, exactly=True).run(
+ test_batch_mm_scripted.graph
+ )
+ self.run_pass("batch_mm", test_batch_mm_scripted.graph)
+ FileCheck().check_count("prim::MMTreeReduce", 1, exactly=True).run(
+ test_batch_mm_scripted.graph
+ )
+
+ actual = test_batch_mm_scripted(*tensors)
+ self.assertEqual(expected, actual, atol=1e-9, rtol=1e-9)
+
+
+ def test_batch_mm_permitted_mutation(self):
+ def test_batch_mm(
+ T1: torch.Tensor,
+ T2: torch.Tensor,
+ T3: torch.Tensor,
+ T4: torch.Tensor,
+ T5: torch.Tensor,
+ T6: torch.Tensor,
+ T7: torch.Tensor,
+ T8: torch.Tensor,
+ ):
+ result = {}
+ result["product"] = (
+ torch.mm(T1, T2)
+ + torch.mm(T3, T4)
+ + torch.mm(T5, T6)
+ + torch.mm(T7, T8)
+ )
+ result["constant"] = torch.tensor([42.0])
+ return result
+
+ test_batch_mm_scripted = torch.jit.script(test_batch_mm)
+
+ tensors = TestBatchMM._get_test_tensors(8)
+ expected = test_batch_mm(*tensors)
+
+ FileCheck().check_count("aten::mm", 4, exactly=True).run(
+ test_batch_mm_scripted.graph
+ )
+ self.run_pass("batch_mm", test_batch_mm_scripted.graph)
+ FileCheck().check_count("prim::MMTreeReduce", 1, exactly=True).run(
+ test_batch_mm_scripted.graph
+ )
+
+ actual = test_batch_mm_scripted(*tensors)
+ self.assertEqual(expected, actual, atol=1e-9, rtol=1e-9)
+
+ def test_batch_mm_prohibited_mutation(self):
+ @torch.jit.script
+ def test_batch_mm(n: int):
+ T1 = torch.zeros((n, n))
+ T2 = torch.zeros((n, n))
+ T3 = torch.zeros((n, n))
+ T4 = torch.zeros((n, n))
+ T5 = torch.zeros((n, n))
+ T6 = torch.zeros((n, n))
+ T7 = torch.zeros((n, n))
+ T8 = torch.zeros((n, n))
+ torch.relu_(T1)
+ result = (
+ torch.mm(T1, T2)
+ + torch.mm(T3, T4)
+ + torch.mm(T5, T6)
+ + torch.mm(T7, T8)
+ )
+ return result
+
+ FileCheck().check_count("aten::mm", 4, exactly=True).run(test_batch_mm.graph)
+ self.run_pass("batch_mm", test_batch_mm.graph)
+ FileCheck().check_count("aten::mm", 4, exactly=True).check_not(
+ "prim::MMTreeReduce"
+ ).run(test_batch_mm.graph)
+
+ def test_batch_mm_prohibited_mutation_multiple_adds(self):
+ @torch.jit.script
+ def test_batch_mm(n: int):
+ T1 = torch.zeros((n, n))
+ T2 = torch.zeros((n, n))
+ T3 = torch.zeros((n, n))
+ T4 = torch.zeros((n, n))
+ T5 = torch.zeros((n, n))
+ T6 = torch.zeros((n, n))
+ T7 = torch.zeros((n, n))
+ T8 = torch.zeros((n, n))
+ T9 = torch.zeros((n, n))
+ T10 = torch.zeros((n, n))
+ torch.relu_(T1)
+ result = {}
+ result["no_mutated_parameters"] = (
+ torch.mm(T2, T3)
+ + torch.mm(T4, T5)
+ + torch.mm(T6, T7)
+ + torch.mm(T8, T9)
+ )
+ result["all_parameters"] = (
+ torch.mm(T1, T2)
+ + torch.mm(T3, T4)
+ + torch.mm(T5, T6)
+ + torch.mm(T7, T8)
+ + torch.mm(T9, T10)
+ )
+ return result
+
+ self.run_pass("batch_mm", test_batch_mm.graph)
+ FileCheck().check_count("prim::MMTreeReduce", 1, exactly=True).check_count(
+ "aten::mm", 5, exactly=True
+ ).run(test_batch_mm.graph)
+
+ def test_batch_mm_prohibited_mutation_if_node(self):
+ @torch.jit.script
+ def test_batch_mm(n: int, use_t1: bool):
+ T1 = torch.zeros((n, n))
+ T2 = torch.zeros((n, n))
+ T3 = torch.zeros((n, n))
+ T4 = torch.zeros((n, n))
+ T5 = torch.zeros((n, n))
+ T6 = torch.zeros((n, n))
+ T7 = torch.zeros((n, n))
+ T8 = torch.zeros((n, n))
+ T9 = torch.zeros((n, n))
+ T10 = torch.zeros((n, n))
+ if use_t1:
+ torch.relu_(T1)
+ return (
+ torch.mm(T1, T2)
+ + torch.mm(T3, T4)
+ + torch.mm(T5, T6)
+ + torch.mm(T7, T8)
+ + torch.mm(T9, T10)
+ )
+ else:
+ return (
+ torch.mm(T2, T3)
+ + torch.mm(T4, T5)
+ + torch.mm(T6, T7)
+ + torch.mm(T8, T9)
+ )
+
+ self.run_pass("batch_mm", test_batch_mm.graph)
+ FileCheck().check_count("aten::mm", 5, exactly=True).check_count(
+ "prim::MMTreeReduce", 1, exactly=True
+ ).run(test_batch_mm.graph)
+
+ def test_batch_mm_side_permitted_mutation(self):
+ @torch.jit.script
+ def test_batch_mm(n: int):
+ result = {}
+ A = torch.zeros((n, n))
+ T1 = torch.zeros((n, n))
+ T2 = torch.zeros((n, n))
+ T3 = torch.zeros((n, n))
+ T4 = torch.zeros((n, n))
+ T5 = torch.zeros((n, n))
+ T6 = torch.zeros((n, n))
+ T7 = torch.zeros((n, n))
+ T8 = torch.zeros((n, n))
+ result["T1"] = torch.mm(A, T1)
+ result["T2"] = torch.mm(A, T2)
+ result["T3"] = torch.mm(A, T3)
+ result["T4"] = torch.mm(A, T4)
+ result["T5"] = torch.mm(A, T5)
+ result["T6"] = torch.mm(A, T6)
+ result["T7"] = torch.mm(A, T7)
+ result["T8"] = torch.mm(A, T8)
+ return result
+
+ FileCheck().check_count("aten::mm", 8, exactly=True).run(test_batch_mm.graph)
+ self.run_pass("batch_mm", test_batch_mm.graph)
+ FileCheck().check_count("prim::MMBatchSide", 1, exactly=True).check_not(
+ "aten::mm"
+ ).run(test_batch_mm.graph)
+
+ def test_batch_mm_side_prohibited_mutation_uncommon_side(self):
+ @torch.jit.script
+ def test_batch_mm(n: int):
+ A = torch.zeros((n, n))
+ T1 = torch.zeros((n, n))
+ T2 = torch.zeros((n, n))
+ T3 = torch.zeros((n, n))
+ T4 = torch.zeros((n, n))
+ T5 = torch.zeros((n, n))
+ T6 = torch.zeros((n, n))
+ T7 = torch.zeros((n, n))
+ T8 = torch.zeros((n, n))
+ T9 = torch.zeros((n, n))
+ T10 = torch.zeros((n, n))
+ torch.relu_(T1)
+ result = {}
+ result["T1"] = torch.mm(A, T1)
+ result["T2"] = torch.mm(A, T2)
+ result["T3"] = torch.mm(A, T3)
+ result["T4"] = torch.mm(A, T4)
+ result["T5"] = torch.mm(A, T5)
+ result["T6"] = torch.mm(A, T6)
+ result["T7"] = torch.mm(A, T7)
+ result["T8"] = torch.mm(A, T8)
+ result["T9"] = torch.mm(A, T9)
+ result["T10"] = torch.mm(A, T10)
+ return result
+
+ FileCheck().check_count("aten::mm", 10, exactly=True).run(test_batch_mm.graph)
+ self.run_pass("batch_mm", test_batch_mm.graph)
+
+ FileCheck().check_count("aten::mm", 1, exactly=True).run(test_batch_mm.graph)
+ FileCheck().check_count("prim::MMBatchSide", 1, exactly=True).run(
+ test_batch_mm.graph
+ )
+
+ def test_batch_mm_side_prohibited_mutation_common_side(self):
+ @torch.jit.script
+ def test_batch_mm(n: int):
+ A = torch.zeros((n, n))
+ T1 = torch.zeros((n, n))
+ T2 = torch.zeros((n, n))
+ T3 = torch.zeros((n, n))
+ T4 = torch.zeros((n, n))
+ T5 = torch.zeros((n, n))
+ T6 = torch.zeros((n, n))
+ T7 = torch.zeros((n, n))
+ T8 = torch.zeros((n, n))
+ T9 = torch.zeros((n, n))
+ T10 = torch.zeros((n, n))
+ torch.relu_(A)
+ result = {}
+ result["T1"] = torch.mm(A, T1)
+ result["T2"] = torch.mm(A, T2)
+ result["T3"] = torch.mm(A, T3)
+ result["T4"] = torch.mm(A, T4)
+ result["T5"] = torch.mm(A, T5)
+ result["T6"] = torch.mm(A, T6)
+ result["T7"] = torch.mm(A, T7)
+ result["T8"] = torch.mm(A, T8)
+ result["T9"] = torch.mm(A, T9)
+ result["T10"] = torch.mm(A, T10)
+ return result
+
+ FileCheck().check_count("aten::mm", 10, exactly=True).run(test_batch_mm.graph)
+ self.run_pass("batch_mm", test_batch_mm.graph)
+ FileCheck().check_count("aten::mm", 10, exactly=True).check_not(
+ "prim::MMBatchSide"
+ ).run(test_batch_mm.graph)
#include <torch/csrc/jit/passes/dead_code_elimination.h>
#include <torch/csrc/jit/passes/peephole.h>
#include <torch/csrc/jit/runtime/custom_operator.h>
+#include <torch/csrc/jit/runtime/graph_iterator.h>
#include <ATen/ATen.h>
#include <algorithm>
enum class Side { LHS, RHS };
-void BatchMMTreeReduce(Block* block) {
+void BatchMMTreeReduce(Block* block, AliasDb& alias_db) {
auto graph = block->owningGraph();
// Look for trees in the block
std::unordered_map<Node*, TreeToken> tokens;
for (auto node : block->nodes()) {
- if (node->matches("aten::mm(Tensor self, Tensor mat2) -> Tensor")) {
+ if (node->matches("aten::mm(Tensor self, Tensor mat2) -> Tensor") &&
+ !alias_db.hasWriters(node)) {
tokens[node] = TreeToken::mm(node);
- } else if (node->matches("aten::t(Tensor self) -> Tensor")) {
+ } else if (
+ node->matches("aten::t(Tensor self) -> Tensor") &&
+ !alias_db.hasWriters(node)) {
auto input_it = tokens.find(node->input()->node());
if (input_it != tokens.end()) {
tokens[node] = TreeToken::transpose(node, input_it->second);
}
} else if (
node->matches(
- "aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor")) {
+ "aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor") &&
+ !alias_db.hasWriters(node)) {
Node* lhs = node->inputs()[0]->node();
Node* rhs = node->inputs()[1]->node();
auto lhs_it = tokens.find(lhs);
}
} else {
for (auto block : node->blocks()) {
- BatchMMTreeReduce(block);
+ BatchMMTreeReduce(block, alias_db);
}
}
}
std::vector<Node*> rhses; // Like above, but rhs
for (Use u : value->uses()) {
if (u.user->owningBlock() == block &&
- u.user->matches("aten::mm(Tensor self, Tensor mat2) -> Tensor")) {
+ u.user->matches("aten::mm(Tensor self, Tensor mat2) -> Tensor") &&
+ !alias_db.hasWriters(u.user)) {
if (u.offset == 0 && u.user->inputs()[1] != value) {
lhses.push_back(u.user);
} else if (u.offset == 1 && u.user->inputs()[0] != value) {
std::unordered_set<Value*> considered_values;
for (Node* node : block->nodes()) {
- if (node->matches("aten::mm(Tensor self, Tensor mat2) -> Tensor")) {
+ if (node->matches("aten::mm(Tensor self, Tensor mat2) -> Tensor") &&
+ !alias_db.hasWriters(node)) {
for (Value* input : node->inputs()) {
if (/*bool not_inserted = */ !considered_values.emplace(input).second) {
continue;
return false;
}
+bool hasMMOperators(std::shared_ptr<Graph>& graph) {
+ DepthFirstGraphNodeIterator it(graph);
+ Node* n = nullptr;
+ while ((n = it.next()) != nullptr) {
+ if (n->matches("aten::mm(Tensor self, Tensor mat2) -> Tensor")) {
+ return true;
+ }
+ }
+ return false;
+}
+
void BatchMM(std::shared_ptr<Graph>& graph) {
- if (hasMutableOperators(graph->block())) {
- // TODO(suo): make BatchMM mutability-safe
+ if (!hasMMOperators(graph)) {
return;
}
AliasDb alias_db(graph);
- BatchMMTreeReduce(graph->block());
+ BatchMMTreeReduce(graph->block(), alias_db);
BatchMMSide(graph->block(), alias_db);
EliminateDeadCode(graph);
// It's possible that transpose rearrangements have created sequences of