# the same group; they should each be a separate DiffGraph
self.assertGraphContainsExactly(graph, 'prim::DifferentiableGraph', 2)
+ def test_mutation_subgraph_inlining(self):
+ # cannot move a node which has writers into a differentiable subgraph,
+ # bc CSE might lose context that it has writers
+
+ def fn(x):
+ a = x.t()
+ a = a + 1
+ c = x.t()
+ c = c + 1
+ e = a + c
+ b = a.add_(x)
+ d = c.add_(x)
+ return e, b, d
+
+ fn_script = torch.jit.script(fn)
+ outs1 = fn_script(torch.tensor(0.5, requires_grad=True))
+ outs2 = fn(torch.tensor(0.5, requires_grad=True))
+ for i in range(len(outs1)):
+ self.assertEqual(outs1[i], outs2[i])
+ graph = fn_script.graph_for(torch.tensor(0.5, requires_grad=True))
+ FileCheck().check_not("DifferentiableGraph").run(graph)
+
class TestCustomOperators(JitTestCase):
return result;
}
- bool shouldConsiderForMerge(Node* node) {
+ bool shouldConsiderForMerge(Node* node, const AliasDb& aliasDb) {
// if we're already in the process of merging
if (node->kind() == prim::DifferentiableGraph) {
return true;
if (node->kind() == prim::Constant) {
return false;
}
+ // when a node which has writers is moved into a subgraph it may lose
+ // context and CSE could merge it with another node that has writers
+ // TODO: @eellison Fix problem more generally in CSE, land PR #18500
+ if (aliasDb.hasWriters(node)) {
+ return false;
+ }
return isDifferentiable(node);
}
std::pair<graph_node_list::iterator, bool> scanNode(
Node* consumer,
AliasDb& aliasDb) {
- if (shouldConsiderForMerge(consumer)) {
+ if (shouldConsiderForMerge(consumer, aliasDb)) {
if (consumer->kind() != prim::DifferentiableGraph) {
consumer = SubgraphUtils::createSingletonSubgraph(
consumer, prim::DifferentiableGraph);
Node* producer,
AliasDb& aliasDb) {
AT_ASSERT(consumer->kind() == prim::DifferentiableGraph);
- bool canMerge = shouldConsiderForMerge(producer) &&
+ bool canMerge = shouldConsiderForMerge(producer, aliasDb) &&
aliasDb.moveBeforeTopologicallyValid(producer, consumer);
if (!canMerge) {
SubgraphSlicer(graph->block(), graph, threshold).run(diff_nodes);
return diff_nodes;
}
-
} // namespace jit
} // namespace torch