From d902774cadd085c89bd27391d1a3c5a8488235de Mon Sep 17 00:00:00 2001 From: eellison Date: Tue, 23 Apr 2019 20:31:36 -0700 Subject: [PATCH] Dont introduce aliasing in CSE or Constant Pooling (#19576) Summary: We can't introduce aliasing to a graph output, since they may be mutated after. Pull Request resolved: https://github.com/pytorch/pytorch/pull/19576 Differential Revision: D15057734 Pulled By: eellison fbshipit-source-id: 33594c05d985a0c58edebd6252e1ee2c0efb6f0e --- test/cpp/jit/test_alias_analysis.h | 16 +++++-- test/cpp/jit/test_constant_pooling.h | 33 ++++++++++---- test/test_jit.py | 50 ++++++++++----------- torch/csrc/jit/autodiff.cpp | 10 ++++- torch/csrc/jit/passes/alias_analysis.cpp | 52 +++++++++++++++++++--- torch/csrc/jit/passes/alias_analysis.h | 9 +++- .../passes/common_subexpression_elimination.cpp | 17 +++++++ torch/csrc/jit/passes/constant_pooling.cpp | 23 +++++++--- .../csrc/jit/passes/create_autodiff_subgraphs.cpp | 12 ++--- torch/csrc/jit/passes/utils/memory_dag.cpp | 42 ++++++++++++----- torch/csrc/jit/passes/utils/memory_dag.h | 8 ++++ 11 files changed, 198 insertions(+), 74 deletions(-) diff --git a/test/cpp/jit/test_alias_analysis.h b/test/cpp/jit/test_alias_analysis.h index 87bdcec..9d121c4 100644 --- a/test/cpp/jit/test_alias_analysis.h +++ b/test/cpp/jit/test_alias_analysis.h @@ -507,7 +507,7 @@ void testContainerAliasing() { &*graph); auto node_iter = graph->block()->nodes().begin(); - node_iter++; // string + auto str_node = node_iter++; // string Node* ten_node = *node_iter++; AliasDb aliasDb(graph); @@ -515,6 +515,8 @@ void testContainerAliasing() { for (auto out : graph->outputs()) { AT_ASSERT(aliasDb.mayContainAlias(ten_node->output(), out)); } + AT_ASSERT(aliasDb.mayContainAlias({ten_node->output()}, graph->outputs())); + AT_ASSERT(!aliasDb.mayContainAlias(str_node->output(), graph->outputs())); } { @@ -533,13 +535,13 @@ void testContainerAliasing() { auto node_iter = graph->block()->nodes().begin(); node_iter++; // string - Node* ten_node = *node_iter++; + Node* int_node = *node_iter++; AliasDb aliasDb(graph); AT_ASSERT(graph->outputs().size() == 3); // primitive values don't need to alias container for (auto out : graph->outputs()) { - AT_ASSERT(!aliasDb.mayContainAlias(ten_node->output(), out)); + AT_ASSERT(!aliasDb.mayContainAlias(int_node->output(), out)); } } @@ -561,6 +563,7 @@ void testContainerAliasing() { for (auto input : graph->inputs()) { AT_ASSERT(aliasDb.mayContainAlias(input, tuple_node->output())); } + AT_ASSERT(aliasDb.mayContainAlias(graph->inputs(), graph->outputs())); } // Test tuple that doesn't come from construct @@ -648,6 +651,13 @@ graph(): AT_ASSERT(aliasDb.mayContainAlias(first_ten->output(), tup_node->output())); AT_ASSERT( !aliasDb.mayContainAlias(second_ten->output(), tup_node->output())); + + std::vector first_st = {first_ten->output()}; + std::vector second_st = {second_ten->output()}; + std::vector tup_st = {tup_node->output()}; + AT_ASSERT(aliasDb.mayContainAlias(first_st, tup_st)); + AT_ASSERT(!aliasDb.mayContainAlias(first_st, second_st)); + AT_ASSERT(!aliasDb.mayContainAlias(second_st, tup_st)); } } diff --git a/test/cpp/jit/test_constant_pooling.h b/test/cpp/jit/test_constant_pooling.h index 9a566bb..e8d0da2 100644 --- a/test/cpp/jit/test_constant_pooling.h +++ b/test/cpp/jit/test_constant_pooling.h @@ -34,16 +34,16 @@ graph(): script::parseIR( R"IR( graph(%cond : Tensor): - %a : string = prim::Constant[value="bcd"]() + %a : str = prim::Constant[value="bcd"]() %3 : bool = prim::Bool(%cond) - %b : string = prim::If(%3) + %b : str = prim::If(%3) block0(): - %b.1 : string = prim::Constant[value="abc"]() + %b.1 : str = prim::Constant[value="abc"]() -> (%b.1) block1(): - %b.2 : string = prim::Constant[value="abc"]() + %b.2 : str = prim::Constant[value="abc"]() -> (%b.2) - %7 : (string, string) = prim::TupleConstruct(%a, %b) + %7 : (str, str) = prim::TupleConstruct(%a, %b) return (%7) )IR", &*graph); @@ -69,8 +69,8 @@ graph(): %y : Tensor = aten::tensor(%3, %10, %7, %15) %9 : int[] = prim::ListConstruct(%1, %2) %z : Tensor = aten::tensor(%9, %10, %7, %15) - %14 : (Tensor, Tensor) = prim::TupleConstruct(%x, %y) - return (%14) + %f = prim::Print(%x, %y, %z) + return (%1) )IR", &*graph); // three tensors created - two different devices among the three @@ -82,7 +82,24 @@ graph(): ->check_count("Long(2) = prim::Constant", 1, /*exactly*/ true) ->run(*graph); } + // don't create aliasing of graph outputs in constant pooling + { + auto graph = std::make_shared(); + script::parseIR( + R"IR( +graph(%cond : Tensor): + %a : Tensor = prim::Constant() + %b : Tensor = prim::Constant() + %c : Tensor = prim::Constant() + %1 = prim::Print(%c) + return (%a, %b) + )IR", + &*graph); + ConstantPooling(graph); + testing::FileCheck() + .check_count("prim::Constant", 2, /*exactly*/ true) + ->run(*graph); + } } - } // namespace jit } // namespace torch diff --git a/test/test_jit.py b/test/test_jit.py index 1cb922b..a246a42 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -1139,22 +1139,40 @@ class TestJit(JitTestCase): self.assertExportImport(trace, (x, y)) + def test_cse_not_introduce_aliasing(self): + @torch.jit.script + def tensor_alias_outputs(x): + return x + x, x + x + + self.run_pass('cse', tensor_alias_outputs.graph) + FileCheck().check_count("aten::add", 2).run(tensor_alias_outputs.graph) + + @torch.jit.script + def ints_alias_outputs(x): + # type: (int) -> Tuple[int, int] + return x + x, x + x + + # non-aliasing types can be CSEd + self.run_pass('cse', ints_alias_outputs.graph) + FileCheck().check_count("aten::add", 1, exactly=True).run(ints_alias_outputs.graph) + def test_recursive_cse(self): input_str = """ graph(%x : Tensor, - %y : Tensor): + %y : Tensor, + %20 : int): %2 : int = prim::Constant[value=1]() %3 : Tensor = aten::add(%x, %y, %2) - %4 : Tensor = aten::gt(%3, %x) + %4 : int = aten::add(%2, %20) %5 : bool = prim::Bool(%4) - %z : Tensor = prim::If(%5) + %z : int = prim::If(%5) # CHECK: block block0(): # CHECK-NOT: aten::add - %z.1 : Tensor = aten::add(%x, %y, %2) + %z.1 : int = aten::add(%2, %20) -> (%z.1) block1(): - -> (%x) + -> (%2) return (%z) """ graph = parse_ir(input_str) @@ -12793,28 +12811,6 @@ class TestAutodiffSubgraphSlicing(JitTestCase): # the same group; they should each be a separate DiffGraph self.assertGraphContainsExactly(graph, 'prim::DifferentiableGraph', 2) - def test_mutation_subgraph_inlining(self): - # cannot move a node which has writers into a differentiable subgraph, - # bc CSE might lose context that it has writers - - def fn(x): - a = x.t() - a = a + 1 - c = x.t() - c = c + 1 - e = a + c - b = a.add_(x) - d = c.add_(x) - return e, b, d - - fn_script = torch.jit.script(fn) - outs1 = fn_script(torch.tensor(0.5, requires_grad=True)) - outs2 = fn(torch.tensor(0.5, requires_grad=True)) - for i in range(len(outs1)): - self.assertEqual(outs1[i], outs2[i]) - graph = fn_script.graph_for(torch.tensor(0.5, requires_grad=True)) - FileCheck().check_not("DifferentiableGraph").run(graph) - class TestCustomOperators(JitTestCase): diff --git a/torch/csrc/jit/autodiff.cpp b/torch/csrc/jit/autodiff.cpp index 922681e..5120a04 100644 --- a/torch/csrc/jit/autodiff.cpp +++ b/torch/csrc/jit/autodiff.cpp @@ -804,7 +804,14 @@ static void lambdaLiftReverse(Gradient& grad_desc, ReverseDetails& rev_info) { // we create an incorrect sum that doesn't use prev vjp, replace uses, and // fix the sum. Value* new_vjp = createAutogradAdd(tmp_vjp_in, tmp_vjp_in); - new_vjp->node()->moveAfter(tmp_vjp_prev->node()); + if (tmp_vjp_prev->node()->kind() == prim::Param) { + // can't move a node after a block param node + new_vjp->node()->moveBefore( + *tmp_vjp_prev->node()->owningBlock()->nodes().begin()); + } else { + new_vjp->node()->moveAfter(tmp_vjp_prev->node()); + } + tmp_vjp_prev->replaceAllUsesWith(new_vjp); new_vjp->node()->replaceInput(1, tmp_vjp_prev); grad_desc.df_input_vjps.emplace_back(i); @@ -859,6 +866,5 @@ Gradient differentiate(std::shared_ptr& graph) { ConstantPooling(grad_desc.df); return grad_desc; } - } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/passes/alias_analysis.cpp b/torch/csrc/jit/passes/alias_analysis.cpp index 99581fa..ab4c33f 100644 --- a/torch/csrc/jit/passes/alias_analysis.cpp +++ b/torch/csrc/jit/passes/alias_analysis.cpp @@ -384,10 +384,13 @@ void AliasDb::analyzeImpl(Node* node) { return analyzeWait(node); case prim::TupleConstruct: return analyzeTupleConstruct(node); + case prim::GradOf: + return analyzeGradOf(node); case prim::Constant: case prim::DictConstruct: case prim::ListConstruct: case prim::AutogradZero: + case prim::AutogradAdd: case prim::FusedConcat: case prim::MMTreeReduce: case prim::MMBatchSide: @@ -594,6 +597,12 @@ void AliasDb::analyzeLoop(Node* node) { mapAliases(node->outputs(), blockOutputs); } +void AliasDb::analyzeGradOf(Node* node) { + const auto grad_of_block = node->blocks().at(0); + analyze(grad_of_block); + mapAliases(node->outputs(), grad_of_block->outputs()); +} + void AliasDb::analyzeSubgraph(Node* node) { const auto subgraph = node->g(attr::Subgraph).get(); @@ -704,6 +713,11 @@ void AliasDb::analyzeWait(Node* node) { } void AliasDb::analyzeTupleConstruct(Node* node) { + // Because we currently mark all Tuples as needing annotation + // (even those containing just prmitive types), an element needs to be created + // for TupleConstruct. When that changes we can create an element + // only if it contains elements which need annotation + getOrCreateElement(node->output()); for (const auto& input : node->inputs()) { if (shouldAnnotate(input)) { addToContainedElements(input, node->output()); @@ -831,16 +845,40 @@ bool AliasDb::cannotCheckAliasContainment(const Value* elem) const { return false; } -bool AliasDb::mayContainAlias(const Value* a, const Value* b) const { - if (!shouldAnnotate(a) || !shouldAnnotate(b)) { - return false; +bool AliasDb::mayContainAlias(Value* a, Value* b) const { + const std::vector a_vec = {a}; + const std::vector b_vec = {b}; + + return mayContainAlias(a_vec, b_vec); +} + +bool AliasDb::mayContainAlias( + const at::ArrayRef& a, + const at::ArrayRef& b) const { + std::vector a_elements; + for (const auto& val : a) { + if (cannotCheckAliasContainment(val)) { + return true; + } + if (shouldAnnotate(val)) { + a_elements.push_back(elementMap_.at(val)); + } } - if (cannotCheckAliasContainment(a) || cannotCheckAliasContainment(b)) { - return true; + if (a_elements.size() == 0) { + return false; } - return memoryDAG_->mayContainAlias(elementMap_.at(a), elementMap_.at(b)); + std::vector b_elements; + for (const auto& val : b) { + if (cannotCheckAliasContainment(val)) { + return true; + } + if (shouldAnnotate(val)) { + b_elements.push_back(elementMap_.at(val)); + } + } + return memoryDAG_->mayContainAlias(a_elements, b_elements); } // Make each value in the `from` list point to its partner in the `to` list @@ -1241,6 +1279,7 @@ TORCH_API bool aliasAnalysisHasSpecialCaseFor(Symbol symbol) { prim::TupleConstruct, prim::AutogradZero, prim::FusedConcat, + prim::GradOf, prim::MMTreeReduce, prim::MMBatchSide, prim::BroadcastSizes, @@ -1256,6 +1295,7 @@ TORCH_API bool aliasAnalysisHasSpecialCaseFor(Symbol symbol) { prim::BroadcastingChunk, prim::fork, prim::CreateObject, + prim::AutogradAdd, prim::GetAttr, prim::SetAttr, aten::wait, diff --git a/torch/csrc/jit/passes/alias_analysis.h b/torch/csrc/jit/passes/alias_analysis.h index e173e91..7de15cd 100644 --- a/torch/csrc/jit/passes/alias_analysis.h +++ b/torch/csrc/jit/passes/alias_analysis.h @@ -47,7 +47,13 @@ class AliasDb { // Does `a` and `b` potentially share a memory location or do either // hold in memory any element that exists in the other - bool mayContainAlias(const Value* a, const Value* b) const; + bool mayContainAlias(Value* a, Value* b) const; + + // Do any values in group `a` share a memory location or hold in memory + // any element that exists in group `b` + bool mayContainAlias( + const at::ArrayRef& a, + const at::ArrayRef& b) const; // Do `a` and `b` potentially share a memory location? bool mayAlias(const Value* a, const Value* b) const; @@ -189,6 +195,7 @@ class AliasDb { void analyzeBroadcastingChunk(Node* node); void analyzeFork(Node* node); void analyzeWait(Node* node); + void analyzeGradOf(Node* node); void analyzeSetAttr(Node* node); void analyzeTupleConstruct(Node* node); void analyzeCustomOp(Node* node); diff --git a/torch/csrc/jit/passes/common_subexpression_elimination.cpp b/torch/csrc/jit/passes/common_subexpression_elimination.cpp index 2d082ce..2d6f897 100644 --- a/torch/csrc/jit/passes/common_subexpression_elimination.cpp +++ b/torch/csrc/jit/passes/common_subexpression_elimination.cpp @@ -9,6 +9,7 @@ namespace torch { namespace jit { namespace { + // The function implements common subexpression elimination. // Since the nodes are visited in topological order, one pass is enough. void EliminateCommonSubexpression( @@ -42,7 +43,15 @@ void EliminateCommonSubexpression( // Check for CSE opportunities in the parent block. auto parent_lookup = parent_lookup_fn(node); + auto g_out = node->owningGraph()->outputs(); if (parent_lookup) { + // since the graph outputs may be mutated after they are returned, + // don't introduce new aliasing among graph outputs + if (aliasDb.mayContainAlias(node->outputs(), g_out) && + aliasDb.mayContainAlias(parent_lookup->outputs(), g_out)) { + continue; + } + node->replaceAllUsesWith(parent_lookup); it.destroyCurrent(); continue; @@ -53,6 +62,14 @@ void EliminateCommonSubexpression( if (!subit.second) { // Subexpression exists, replace the uses of node, and destroy it. auto existing = *subit.first; + + // don't introduce new aliasing among graph outputs + if (aliasDb.mayContainAlias( + node->outputs(), node->owningGraph()->outputs()) && + aliasDb.mayContainAlias(existing->outputs(), g_out)) { + continue; + } + node->replaceAllUsesWith(existing); // Destroy the node. it.destroyCurrent(); diff --git a/torch/csrc/jit/passes/constant_pooling.cpp b/torch/csrc/jit/passes/constant_pooling.cpp index 5421c8c..e577b10 100644 --- a/torch/csrc/jit/passes/constant_pooling.cpp +++ b/torch/csrc/jit/passes/constant_pooling.cpp @@ -1,7 +1,8 @@ +#include #include #include #include -#include +#include #include namespace torch { @@ -13,7 +14,8 @@ namespace { // Move all constants to the beginning of the graph, and deduplicate void ConstantPooling( Block* block, - std::unordered_set& constants) { + std::unordered_set& constants, + const AliasDb& aliasDb) { for (auto it = block->nodes().begin(); it != block->nodes().end();) { auto node = *it; // node may be moved to a different block so advance iterator now @@ -21,7 +23,7 @@ void ConstantPooling( if (!node->blocks().empty()) { // Traverse sub-blocks. for (auto block : node->blocks()) { - ConstantPooling(block, constants); + ConstantPooling(block, constants, aliasDb); } continue; } @@ -35,6 +37,16 @@ void ConstantPooling( if (!subit.second) { // constant exists, replace the uses of node, and destroy it. auto existing = *subit.first; + + // since the graph outputs may be mutated after they are returned, + // don't introduce new aliasing among graph outputs + if (aliasDb.mayContainAlias( + node->outputs(), node->owningGraph()->outputs()) && + aliasDb.mayContainAlias( + existing->outputs(), node->owningGraph()->outputs())) { + continue; + } + node->replaceAllUsesWith(existing); node->destroy(); continue; @@ -46,13 +58,12 @@ void ConstantPooling( node->moveBefore(first_node); } } - } // anonymous namespace void ConstantPooling(const std::shared_ptr& graph) { + AliasDb aliasDb(graph); std::unordered_set constants; - ConstantPooling(graph->block(), constants); + ConstantPooling(graph->block(), constants, aliasDb); } - } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/passes/create_autodiff_subgraphs.cpp b/torch/csrc/jit/passes/create_autodiff_subgraphs.cpp index b7a55b8..cd814a3 100644 --- a/torch/csrc/jit/passes/create_autodiff_subgraphs.cpp +++ b/torch/csrc/jit/passes/create_autodiff_subgraphs.cpp @@ -110,7 +110,7 @@ class SubgraphSlicer { return result; } - bool shouldConsiderForMerge(Node* node, const AliasDb& aliasDb) { + bool shouldConsiderForMerge(Node* node) { // if we're already in the process of merging if (node->kind() == prim::DifferentiableGraph) { return true; @@ -118,19 +118,13 @@ class SubgraphSlicer { if (node->kind() == prim::Constant) { return false; } - // when a node which has writers is moved into a subgraph it may lose - // context and CSE could merge it with another node that has writers - // TODO: @eellison Fix problem more generally in CSE, land PR #18500 - if (aliasDb.hasWriters(node)) { - return false; - } return isDifferentiable(node); } std::pair scanNode( Node* consumer, AliasDb& aliasDb) { - if (shouldConsiderForMerge(consumer, aliasDb)) { + if (shouldConsiderForMerge(consumer)) { if (consumer->kind() != prim::DifferentiableGraph) { consumer = SubgraphUtils::createSingletonSubgraph( consumer, prim::DifferentiableGraph); @@ -155,7 +149,7 @@ class SubgraphSlicer { Node* producer, AliasDb& aliasDb) { AT_ASSERT(consumer->kind() == prim::DifferentiableGraph); - bool canMerge = shouldConsiderForMerge(producer, aliasDb) && + bool canMerge = shouldConsiderForMerge(producer) && aliasDb.moveBeforeTopologicallyValid(producer, consumer); if (!canMerge) { diff --git a/torch/csrc/jit/passes/utils/memory_dag.cpp b/torch/csrc/jit/passes/utils/memory_dag.cpp index 6b56588..1dc74c5 100644 --- a/torch/csrc/jit/passes/utils/memory_dag.cpp +++ b/torch/csrc/jit/passes/utils/memory_dag.cpp @@ -2,7 +2,6 @@ #include #include -#include #include namespace torch { @@ -16,10 +15,9 @@ bool MemoryDAG::mayAlias(const Element* a, const Element* b) const { return mayAliasImpl(a, b); } -bool MemoryDAG::mayAliasImpl(const Element* a, const Element* b) const { - const auto aMemLoc = a->getMemoryLocations(); - const auto bMemLoc = b->getMemoryLocations(); - +bool MemoryDAG::memoryLocationOverlap( + const std::unordered_set& aMemLoc, + const std::unordered_set& bMemLoc) const { // XXX: This could be more efficiently done as a bitwise AND on two bitfields // that represent memory location membership. If these comparisons end up // being a bottleneck, consider implementing it that way. @@ -30,9 +28,17 @@ bool MemoryDAG::mayAliasImpl(const Element* a, const Element* b) const { } } } + return false; } +bool MemoryDAG::mayAliasImpl(const Element* a, const Element* b) const { + const auto aMemLoc = a->getMemoryLocations(); + const auto bMemLoc = b->getMemoryLocations(); + + return memoryLocationOverlap(aMemLoc, bMemLoc); +} + bool MemoryDAG::mayContainAlias(const Element* a, const Element* b) const { return mayContainAliasImpl(a, b); } @@ -67,15 +73,27 @@ bool MemoryDAG::mayContainAliasImpl(const Element* a, const Element* b) const { collectAllContainedMemoryLocations(a, all_a_mlocs); collectAllContainedMemoryLocations(b, all_b_mlocs); - for (const auto a_mem : all_a_mlocs) { - for (const auto b_mem : all_b_mlocs) { - if (a_mem == b_mem) { - return true; - } - } + return memoryLocationOverlap(all_a_mlocs, all_b_mlocs); +} + +bool MemoryDAG::mayContainAlias( + const at::ArrayRef& a, + const at::ArrayRef& b) const { + if (a.size() == 0 || b.size() == 0) { + return false; } - return false; + std::unordered_set all_a_mlocs; + for (const auto& elem : a) { + collectAllContainedMemoryLocations(elem, all_a_mlocs); + } + + std::unordered_set all_b_mlocs; + for (const auto& elem : b) { + collectAllContainedMemoryLocations(elem, all_b_mlocs); + } + + return memoryLocationOverlap(all_a_mlocs, all_b_mlocs); } // Make `v` point at `to`. diff --git a/torch/csrc/jit/passes/utils/memory_dag.h b/torch/csrc/jit/passes/utils/memory_dag.h index 76193cc..cffb3eb 100644 --- a/torch/csrc/jit/passes/utils/memory_dag.h +++ b/torch/csrc/jit/passes/utils/memory_dag.h @@ -1,5 +1,6 @@ #pragma once +#include #include #include #include @@ -45,6 +46,10 @@ class MemoryDAG { bool mayContainAlias(const Element* a, const Element* b) const; bool mayContainAlias(Element* a, Element* b) const; + bool mayContainAlias( + const at::ArrayRef& a, + const at::ArrayRef& b) const; + // Do any values in group `a` potentially share a memory location with any // value in group `b`? // @@ -86,6 +91,9 @@ class MemoryDAG { } private: + bool memoryLocationOverlap( + const std::unordered_set& a, + const std::unordered_set& b) const; bool mayAliasImpl(const Element* a, const Element* b) const; bool mayContainAliasImpl(const Element* contained, const Element* container) const; -- 2.7.4