JIT_TEST(SchemaParser)
JIT_TEST(TopologicalIndex)
JIT_TEST(TopologicalMove)
+JIT_TEST(SubgraphUtils)
#define JIT_TEST_CUDA(name) \
TEST(JitTest, name##_CUDA) { \
testSchemaParser();
testTopologicalIndex();
testTopologicalMove();
+ testSubgraphUtils();
return out.str();
}
#include "torch/csrc/jit/interpreter.h"
#include "torch/csrc/jit/ir.h"
#include "torch/csrc/jit/operator.h"
+#include "torch/csrc/jit/passes/common_subexpression_elimination.h"
#include "torch/csrc/jit/passes/constant_propagation.h"
#include "torch/csrc/jit/passes/create_autodiff_subgraphs.h"
#include "torch/csrc/jit/passes/dead_code_elimination.h"
#include "torch/csrc/jit/passes/lower_grad_of.h"
#include "torch/csrc/jit/passes/requires_grad_analysis.h"
#include "torch/csrc/jit/passes/shape_analysis.h"
+#include "torch/csrc/jit/passes/utils/subgraph_utils.h"
#include "torch/csrc/jit/symbolic_variable.h"
#include "torch/csrc/jit/tracer.h"
#include "torch/csrc/utils/hash.h"
void testCreateAutodiffSubgraphs(std::ostream& out = std::cout) {
auto graph = build_lstm();
- CreateAutodiffSubgraphs(*graph, /*threshold=*/2);
+ CreateAutodiffSubgraphs(graph, /*threshold=*/2);
out << "testCreateAutodiffSubgraphs\n";
out << *graph << "\n";
}
+void testSubgraphUtils() {
+ auto graph = build_lstm();
+ EliminateCommonSubexpression(graph);
+
+ std::vector<Node*> originalNodes(
+ graph->nodes().begin(), graph->nodes().end());
+
+ // Merge everything into a single subgraph
+ bool first = true;
+ Node* subgraph;
+ for (auto it = graph->nodes().rbegin(); it != graph->nodes().rend();) {
+ if (first) {
+ subgraph = SubgraphUtils::createSingletonSubgraph(
+ *it, prim::DifferentiableGraph);
+ it = ++subgraph->reverseIterator();
+ first = false;
+ }
+
+ SubgraphUtils::mergeNodeIntoSubgraph(*it, subgraph);
+ it = ++subgraph->reverseIterator();
+ }
+
+ // Unmerge and compare with original node listing
+ SubgraphUtils::unmergeSubgraph(subgraph);
+ EliminateCommonSubexpression(graph);
+
+ std::vector<Node*> newNodes(graph->nodes().begin(), graph->nodes().end());
+ ASSERT_EQ(originalNodes.size(), newNodes.size());
+}
+
autograd::Variable var(at::Type& t, at::IntList sizes, bool requires_grad) {
return autograd::make_variable(at::rand(sizes, t.options()), requires_grad);
}
%3 : Dynamic
%4 : Dynamic) {
%7 : int = prim::Constant[value=1]()
- %19 : int = prim::Constant[value=1]()
- %23 : Dynamic, %24 : Dynamic = prim::DifferentiableGraph_0(%0, %3, %1, %4, %2)
- return (%24, %23);
+ %23 : Dynamic, %24 : Dynamic = prim::DifferentiableGraph_0(%2, %1, %4, %0, %3)
+ return (%23, %24);
}
-with prim::DifferentiableGraph_0 = graph(%1 : Dynamic
- %2 : Dynamic
- %4 : Dynamic
- %5 : Dynamic
- %17 : Dynamic) {
- %0 : Dynamic = aten::mm(%1, %2)
- %3 : Dynamic = aten::mm(%4, %5)
- %7 : int = prim::Constant[value=1]()
- %6 : Dynamic = aten::add(%0, %3, %7)
- %8 : Dynamic, %9 : Dynamic, %10 : Dynamic, %11 : Dynamic = prim::ConstantChunk[chunks=4, dim=1](%6)
- %12 : Dynamic = aten::sigmoid(%8)
- %13 : Dynamic = aten::sigmoid(%11)
- %14 : Dynamic = aten::tanh(%10)
- %15 : Dynamic = aten::sigmoid(%9)
- %16 : Dynamic = aten::mul(%15, %17)
- %18 : Dynamic = aten::mul(%12, %14)
- %20 : int = prim::Constant[value=1]()
- %19 : Dynamic = aten::add(%16, %18, %20)
- %21 : Dynamic = aten::tanh(%19)
- %22 : Dynamic = aten::mul(%13, %21)
- return (%19, %22);
+with prim::DifferentiableGraph_0 = graph(%13 : Dynamic
+ %32 : Dynamic
+ %33 : Dynamic
+ %35 : Dynamic
+ %36 : Dynamic) {
+ %37 : Dynamic = aten::mm(%35, %36)
+ %34 : Dynamic = aten::mm(%32, %33)
+ %30 : int = prim::Constant[value=1]()
+ %31 : Dynamic = aten::add(%37, %34, %30)
+ %24 : Dynamic, %25 : Dynamic, %26 : Dynamic, %27 : Dynamic = prim::ConstantChunk[chunks=4, dim=1](%31)
+ %22 : Dynamic = aten::sigmoid(%24)
+ %20 : Dynamic = aten::sigmoid(%27)
+ %18 : Dynamic = aten::tanh(%26)
+ %16 : Dynamic = aten::sigmoid(%25)
+ %14 : Dynamic = aten::mul(%16, %13)
+ %11 : Dynamic = aten::mul(%22, %18)
+ %8 : Dynamic = aten::add(%14, %11, %30)
+ %4 : Dynamic = aten::tanh(%8)
+ %2 : Dynamic = aten::mul(%20, %4)
+ return (%2, %8);
}
testDifferentiate
%cellgate : Float(*, *)
%outgate : Float(*, *)
%18 : Float(*, *)) {
- %19 : Float(*, *), %20 : Float(*, *) = prim::FusionGroup_0(%forgetgate, %ingate, %cellgate, %outgate, %11, %0, %18, %1)
+ %19 : Float(*, *), %20 : Float(*, *) = prim::FusionGroup_0(%forgetgate, %ingate, %cellgate, %outgate, %9, %1, %18, %0)
%21 : Float(*, *) = aten::t(%13)
%22 : Float(*, *) = aten::mm(%20, %21)
%23 : Float(*, *) = aten::t(%10)
%25 : Float(*, *) = aten::t(%24)
%26 : Float(*, *) = aten::t(%12)
%27 : Float(*, *) = aten::mm(%20, %26)
- %28 : Float(*, *) = aten::t(%9)
+ %28 : Float(*, *) = aten::t(%11)
%29 : Float(*, *) = aten::mm(%28, %20)
%30 : Float(*, *) = aten::t(%29)
- return (%30, %27, %25, %22, %20, %20, %19);
+ return (%19, %20, %20, %22, %25, %27, %30);
}
with prim::FusionGroup_0 = graph(%0 : Float(*, *)
%1 : Float(*, *)
%w_hh : Float(*, *)
%b_ih : Float(*)
%b_hh : Float(*)) {
- %7 : Float(*, *), %8 : Float(*, *) = prim::DifferentiableGraph_0(%w_ih, %x, %w_hh, %hx, %b_ih, %b_hh, %cx)
- return (%8, %7);
+ %hy : Float(*, *), %cy : Float(*, *) = prim::DifferentiableGraph_0(%cx, %b_hh, %b_ih, %hx, %w_hh, %x, %w_ih)
+ return (%hy, %cy);
}
with prim::DifferentiableGraph_0 = graph(%0 : Float(*, *)
- %1 : Float(*, *)
- %2 : Float(*, *)
+ %1 : Float(*)
+ %2 : Float(*)
%3 : Float(*, *)
- %4 : Float(*)
- %5 : Float(*)
+ %4 : Float(*, *)
+ %5 : Float(*, *)
%6 : Float(*, *)) {
- %7 : Float(*, *) = aten::t(%0)
- %8 : Float(*, *) = aten::mm(%1, %7)
- %9 : Float(*, *) = aten::t(%2)
+ %7 : Float(*, *) = aten::t(%6)
+ %8 : Float(*, *) = aten::mm(%5, %7)
+ %9 : Float(*, *) = aten::t(%4)
%10 : Float(*, *) = aten::mm(%3, %9)
%11 : int = prim::Constant[value=1]()
- %12 : Float(*, *) = prim::FusionGroup_0(%4, %8, %10)
- %13 : Dynamic[] = prim::ListConstruct(%12, %5)
+ %12 : Float(*, *) = prim::FusionGroup_0(%2, %8, %10)
+ %13 : Dynamic[] = prim::ListConstruct(%12, %1)
%14 : Dynamic[] = aten::broadcast_tensors(%13)
%15 : Dynamic, %16 : Dynamic = prim::ListUnpack(%14)
- %hy : Float(*, *), %18 : Float(*, *), %cy : Float(*, *), %outgate.1 : Float(*, *), %cellgate.1 : Float(*, *), %forgetgate.1 : Float(*, *), %ingate.1 : Float(*, *) = prim::FusionGroup_1(%6, %16, %15)
- return (%cy, %hy, %7, %9, %ingate.1, %forgetgate.1, %cellgate.1, %outgate.1, %18);
+ %hy : Float(*, *), %18 : Float(*, *), %cy : Float(*, *), %outgate.1 : Float(*, *), %cellgate.1 : Float(*, *), %forgetgate.1 : Float(*, *), %ingate.1 : Float(*, *) = prim::FusionGroup_1(%0, %16, %15)
+ return (%hy, %cy, %7, %9, %ingate.1, %forgetgate.1, %cellgate.1, %outgate.1, %18);
}
with prim::FusionGroup_0 = graph(%0 : Float(*)
%1 : Float(*, *)
%10 : Undefined
%11 : Undefined
%12 : Float(*, *)
- %13 : Float(*, *)
+ %13 : Float(*)
%14 : Float(*)
%15 : Float(*)
- %16 : Float(*)
+ %16 : Float(*, *)
%17 : Float(*, *)
%18 : Float(*, *)
%Wx : Float(*, *)
%cellgate : Float(*, *)
%outgate : Float(*, *)
%27 : Float(*, *)) {
- %28 : Float(*, *) = prim::FusionGroup_0(%ingate, %forgetgate, %cellgate, %outgate, %17, %0, %27, %1)
+ %28 : Float(*, *) = prim::FusionGroup_0(%ingate, %forgetgate, %cellgate, %outgate, %12, %1, %27, %0)
%29 : Float(*, *) = aten::mul(%28, %Uz)
%30 : Float(*, *) = aten::mul(%28, %Wx)
- %31 : Float(*, *) = prim::FusionGroup_1(%28, %22, %16)
- %32 : Float(*, *), %33 : Float(*, *) = prim::FusionGroup_2(%Wx, %14, %28, %Uz, %15)
- %34 : Float(*, *) = aten::t(%13)
+ %31 : Float(*, *) = prim::FusionGroup_1(%28, %22, %13)
+ %32 : Float(*, *), %33 : Float(*, *) = prim::FusionGroup_2(%Wx, %15, %28, %Uz, %14)
+ %34 : Float(*, *) = aten::t(%16)
%35 : Float(*, *) = aten::mm(%34, %31)
%36 : Float(*, *) = aten::t(%35)
- %37 : Float(*, *) = aten::t(%12)
+ %37 : Float(*, *) = aten::t(%17)
%38 : Float(*, *) = aten::mm(%37, %33)
%39 : Float(*, *) = aten::t(%38)
- return (%39, %36, %32, %30, %29, %28);
+ return (%28, %29, %30, %32, %36, %39);
}
with prim::FusionGroup_0 = graph(%0 : Float(*, *)
%1 : Float(*, *)
%beta_i : Float(*)
%beta_h : Float(*)
%bias : Float(*)) {
- %9 : Float(*, *), %10 : Float(*, *) = prim::DifferentiableGraph_0(%w_ih, %x, %w_hh, %hx, %alpha, %beta_i, %beta_h, %bias, %cx)
- return (%10, %9);
+ %hy : Float(*, *), %cy : Float(*, *) = prim::DifferentiableGraph_0(%cx, %bias, %beta_h, %beta_i, %alpha, %hx, %w_hh, %x, %w_ih)
+ return (%hy, %cy);
}
with prim::DifferentiableGraph_0 = graph(%0 : Float(*, *)
- %1 : Float(*, *)
- %2 : Float(*, *)
- %3 : Float(*, *)
+ %1 : Float(*)
+ %2 : Float(*)
+ %3 : Float(*)
%4 : Float(*)
- %5 : Float(*)
- %6 : Float(*)
- %7 : Float(*)
+ %5 : Float(*, *)
+ %6 : Float(*, *)
+ %7 : Float(*, *)
%8 : Float(*, *)) {
- %9 : Float(*, *) = aten::t(%0)
- %Wx.1 : Float(*, *) = aten::mm(%1, %9)
- %11 : Float(*, *) = aten::t(%2)
- %Uz.1 : Float(*, *) = aten::mm(%3, %11)
+ %9 : Float(*, *) = aten::t(%8)
+ %Wx.1 : Float(*, *) = aten::mm(%7, %9)
+ %11 : Float(*, *) = aten::t(%6)
+ %Uz.1 : Float(*, *) = aten::mm(%5, %11)
%13 : int = prim::Constant[value=1]()
- %14 : Float(*, *), %15 : Float(*, *) = prim::FusionGroup_0(%6, %Uz.1, %5, %Wx.1, %4)
- %16 : Dynamic[] = prim::ListConstruct(%14, %7)
+ %14 : Float(*, *), %15 : Float(*, *) = prim::FusionGroup_0(%2, %Uz.1, %3, %Wx.1, %4)
+ %16 : Dynamic[] = prim::ListConstruct(%14, %1)
%17 : Dynamic[] = aten::broadcast_tensors(%16)
%18 : Dynamic, %19 : Dynamic = prim::ListUnpack(%17)
- %hy : Float(*, *), %21 : Float(*, *), %cy : Float(*, *), %outgate.1 : Float(*, *), %cellgate.1 : Float(*, *), %forgetgate.1 : Float(*, *), %ingate.1 : Float(*, *) = prim::FusionGroup_1(%8, %19, %18)
- return (%cy, %hy, %9, %Wx.1, %11, %Uz.1, %15, %ingate.1, %forgetgate.1, %cellgate.1, %outgate.1, %21);
+ %hy : Float(*, *), %21 : Float(*, *), %cy : Float(*, *), %outgate.1 : Float(*, *), %cellgate.1 : Float(*, *), %forgetgate.1 : Float(*, *), %ingate.1 : Float(*, *) = prim::FusionGroup_1(%0, %19, %18)
+ return (%hy, %cy, %9, %Wx.1, %11, %Uz.1, %15, %ingate.1, %forgetgate.1, %cellgate.1, %outgate.1, %21);
}
with prim::FusionGroup_0 = graph(%0 : Float(*)
%1 : Float(*, *)
${TORCH_SRC_DIR}/csrc/jit/passes/requires_grad_analysis.cpp
${TORCH_SRC_DIR}/csrc/jit/passes/specialize_undef.cpp
${TORCH_SRC_DIR}/csrc/jit/passes/python_print.cpp
+ ${TORCH_SRC_DIR}/csrc/jit/passes/utils/subgraph_utils.cpp
${TORCH_SRC_DIR}/csrc/jit/fuser/interface.cpp
${TORCH_SRC_DIR}/csrc/jit/register_prim_ops.cpp
${TORCH_SRC_DIR}/csrc/jit/register_special_ops.cpp
// Phase 5. Apply non-differentiable optimizations to the graphs we've found
// (or the whole grpah if we know we won't need its derivative).
if (needsGradient(opt_graph)) {
- auto diff_nodes = CreateAutodiffSubgraphs(*opt_graph, autodiffSubgraphNodeThreshold);
+ auto diff_nodes = CreateAutodiffSubgraphs(opt_graph, autodiffSubgraphNodeThreshold);
for (Node * dnode : diff_nodes) {
auto diff_graph = std::move(dnode->g(attr::Subgraph));
Gradient gradient = differentiate(diff_graph);
return ConstantPropagation(g);
})
.def("_jit_pass_erase_shape_information", EraseShapeInformation)
- .def("_jit_pass_create_autodiff_subgraphs", [](Graph& graph) {
+ .def("_jit_pass_create_autodiff_subgraphs", [](std::shared_ptr<Graph> graph) {
CreateAutodiffSubgraphs(graph);
})
.def("_jit_run_cpp_tests", [] {
TORCH_API Node* createUndefined();
TORCH_API Node* createNoneGenerator();
TORCH_API Node* createFusionGroup();
+ TORCH_API Node* createDifferentiableSubgraph();
TORCH_API Node* createTuple(at::ArrayRef<Value*> values);
TORCH_API Node* createTupleUnpack(Value * v);
TORCH_API Node* createTupleIndex(Value * tup, int64_t index);
#include "torch/csrc/jit/passes/create_autodiff_subgraphs.h"
-#include "torch/csrc/jit/ir.h"
-#include "torch/csrc/jit/autodiff.h"
#include "torch/csrc/jit/assertions.h"
-#include "torch/csrc/jit/dynamic_dag.h"
-
-#include <cstddef>
-#include <limits>
-
-namespace torch { namespace jit {
+#include "torch/csrc/jit/autodiff.h"
+#include "torch/csrc/jit/ir.h"
+#include "torch/csrc/jit/passes/common_subexpression_elimination.h"
+#include "torch/csrc/jit/passes/utils/subgraph_utils.h"
-struct Graph;
+namespace torch {
+namespace jit {
namespace {
-// Move nodes that exist in graph g into a 'group_node_kind' node.
-// All inputs shared by the nodes become inputs to the new node.
-// Outputs from 'nodes' are redirected to outputs of the new node,
-// and the original nodes are removed.
-// prereq: it is topologically valid to place the new node
-// right before nodes[0] (i.e. it will not create cycles and all uses of
-// new node will be after this position).
-// prereq: nodes are in topological order
-Node* mergeNodes(Block * block, Symbol group_node_kind, ArrayRef<Node*> nodes) {
- JIT_ASSERT(nodes.size() > 0);
- std::unordered_map<Value*, Value*> value_map;
- Graph * graph = block->owningGraph();
-
- auto new_graph = std::make_shared<Graph>();
- Node * group_node = graph->create(group_node_kind, 0);
- group_node->g_(attr::Subgraph, new_graph);
-
- auto getOrCreateInput = [&](Value * v) {
- if(value_map.count(v) > 0) {
- return value_map[v];
- }
- if (auto value = toIValue(v)) {
- Value * nv = new_graph->insertConstant(*value);
- value_map[v] = nv;
- return nv;
- }
- Value * nv = new_graph->addInput()->setType(v->type());
- group_node->addInput(v);
- value_map[v] = nv;
- return nv;
- };
- std::unordered_set<Node*> group_set(nodes.begin(), nodes.end());
- for(auto n : nodes) {
- auto nn = new_graph->appendNode(new_graph->createClone(n, getOrCreateInput));
- for(size_t i = 0; i < nn->outputs().size(); ++i) {
- auto old_output = n->outputs()[i];
- auto new_output = nn->outputs()[i];
- value_map[old_output] = new_output;
- std::vector<Use> to_replace;
- for(auto u : old_output->uses()) {
- // Uses within the set do not need to be made outputs
- if(group_set.count(u.user) > 0)
- continue;
- // Other uses do, but we
- // cannot replace them here or we invalid the uses list iterator
- to_replace.push_back(u);
- }
- if(to_replace.size() > 0) {
- new_graph->registerOutput(new_output);
- Value * external_output = group_node->addOutput()->setType(old_output->type());
- for(auto u : to_replace) {
- u.user->replaceInput(u.offset, external_output);
- }
+class SubgraphSlicer {
+ public:
+ SubgraphSlicer(Block* block, size_t minSubgraphSize)
+ : block_(block), minSubgraphSize_(minSubgraphSize) {}
+
+ void run(std::vector<Node*>& diffGraphs) {
+ // We need to run the slicer multiple times in order to get all merge
+ // opportunities. This is because moveBeforeTopologicalValid may reorder
+ // nodes to be AFTER the current iteration point. In order to properly
+ // consider those nodes for merging, we need run the pass until no changes
+ // have been made.
+ //
+ // Example:
+ // c = f(a, b)
+ // d = f(c)
+ // e = f(d) <- iter is here, moving upward
+ // After c.moveBeforeTopologicallyValid(e), we have:
+ // c = f(a, b)
+ // e = f(d) <- iter still here
+ // d = f(c) <- this was node moved on the other side.
+ bool any_changed = true;
+ while (any_changed) {
+ any_changed = false;
+ for (auto it = block_->nodes().rbegin(); it != block_->nodes().rend();) {
+ bool changed;
+ std::tie(it, changed) = scanNode(*it);
+ any_changed |= changed;
}
}
- }
- group_node->insertBefore(nodes[0]);
- // delete backward, so that nodes are use-free before deletion
- for(size_t i = nodes.size(); i > 0; --i) {
- nodes[i - 1]->destroy();
- }
- JIT_ASSERT(isDifferentiable(*new_graph));
- return group_node;
-}
-bool shouldConsiderForMerge(detail::Vertex<Node*>* v) {
- if (v->data.size() >= 2) {
- return true;
- }
- JIT_ASSERT(v->data.size() == 1);
- auto * node = *v->data.begin();
- if (node->kind() == prim::Constant) {
- return false;
- }
- return isDifferentiable(node);
-}
+ // Done constructing subgraphs. Do some post-processing cleanup:
+ // 1. Run CSE to delete redundanet constant nodes.
+ // 2. We may need to re-inline ones that are too small.
+ auto curNode = *block_->nodes().rbegin();
+ while (curNode != *block_->nodes().rend()) {
+ for (auto subBlock : curNode->blocks()) {
+ SubgraphSlicer(subBlock, minSubgraphSize_).run(diffGraphs);
+ }
-static detail::DynamicDAG<Node*> make_dependency_graph(Block * block) {
- detail::DynamicDAG<Node*> dag;
- std::unordered_map<Node*,detail::Vertex<Node*>*> node_to_vertex;
- // NB: the block's param and return nodes are not in the dependency graph.
- for (Node * node : block->nodes()) {
- node_to_vertex[node] = dag.newVertex(node);
- }
- for (auto * node : block->nodes()) {
- for (auto * v : node->outputs()) {
- for (auto & use : v->uses()) {
- // [Determine data dependencies]
- // Consider the following code:
- // y = f(x)
- // if k:
- // w += y
- // z = g(y)
- // This produces a dependency graph with 3 vertices:
- // (0: f) (1: if k ...) (2: g)
- // We need to peek into the if Node* to determine its data dependencies
- // (the body depends on the output of f, so Vertex 1 depends on Vertex 0).
- // For each Use of y, we find an owning node of y that is a part of the
- // dependency graph (in this case, the Vertex containing the if Node*)
- // and then record the dependency.
- auto * owning_node = use.user;
- if (owning_node == block->return_node()) {
- // The return node is not in the dag. Carry on.
- continue;
- }
- while (true) {
- auto search = node_to_vertex.find(owning_node);
- if (search == node_to_vertex.end()) {
- owning_node = owning_node->owningBlock()->owningNode();
- JIT_ASSERT(owning_node != nullptr);
- continue;
- }
- // NB: DynamicDAG is a simple graph (no multi-edges).
- // addEdge is a no-op if the edge already exists.
- dag.addEdge(node_to_vertex[node], search->second);
- break;
+ // Save the previous node, since we might delete `curNode` in next block
+ auto prevNode = curNode->prev();
+ if (curNode->kind() == prim::DifferentiableGraph) {
+ // Inlining nodes may cause some subexpression to come back in the
+ // subgraphs (for example, copying constants in repeatedly will generate
+ // redundant prim::Constants). Run CSE to clean them up.
+ EliminateCommonSubexpression(curNode->g(attr::Subgraph));
+
+ if (!inlineIfTooSmall(curNode)) {
+ diffGraphs.push_back(curNode);
}
}
+ curNode = prevNode;
}
+ // Run CSE one more time to eliminate duplicates that may have occured
+ // while re-inlining subgraphs.
+ EliminateCommonSubexpression(block_);
}
- return dag;
-}
-static void find_differentiable_groups(
- detail::DynamicDAG<Node*>& dep_graph,
- size_t distance_threshold=256,
- size_t producer_edge_threshold=16) {
- // A Vertex contains a Node* or a differentiable group of Node*.
- // Perform graph contraction on dep_graph: contract two vertices(x, y) if
- // the following conditions hold:
- // - x, y can be merged to form a differentiable group
- // - the contraction would not invalidate the dag (it creates no cycles).
+ private:
+ // Inline this node's group subgraph into the outer graph if it's smaller
+ // than the specified minimum size.
//
- // This performs a greedy algorithm. This greedy algorithm considers
- // dep_graph vertices in reverse topological order by reverse iterating through
- // ord indices. For a certain ord, we attempt to merge the vertex at that ord
- // with each of its parents. If the vertex at the ord cannot be merged with any
- // of its parents, then we move on to a smaller ord and repeat.
- //
- // Each contractEdge call is effectively constant because we limit the size
- // of the affected region (via the distance_threshold) and the fan in/fan out
- // via producer_edge_threshold.
- // In addition, each sort of in_edges is bounded by producer_edge threshold.
- // This makes the complexity of find_differential_groups effectively O(V + E).
-
- // Iterate in reverse topological order
- int64_t ord = dep_graph.max_size() - 1;
- for (int64_t ord = dep_graph.max_size() - 1; ord >= 0; --ord) {
- if (!dep_graph.at(ord)) continue;
-
- auto* consumer = dep_graph.at(ord).value();
- if (!shouldConsiderForMerge(consumer)) continue;
-
- // To bound the complexity of the sort. Makes the algorithm less optimal.
- if (consumer->in_edges().size() > producer_edge_threshold) continue;
-
- // Iterate through consumer->in_edges() in reverse topological order.
- // sort is performed once per ord in dep_graph and once per contraction.
- // There can be at most dep_graph.max_size() contractions, so
- // we do at most 2 * dep_graph.max_size() sorts.
- consumer->in_edges().sort();
+ // Returns true if an inlining has occured, false otherwise.
+ bool inlineIfTooSmall(Node* n) {
+ JIT_ASSERT(n->kind() == prim::DifferentiableGraph);
+ auto subgraph = SubgraphUtils::getSubgraph(n);
+ size_t i = 0;
+ for (auto it = subgraph->nodes().begin(); it != subgraph->nodes().end();
+ ++it) {
+ if (++i >= minSubgraphSize_) {
+ return false;
+ }
+ }
- for (auto it = consumer->in_edges().rbegin(); it != consumer->in_edges().rend(); ++it) {
- auto * producer = *it;
- // The distance threshold makes this algorithm "not optimal": it will miss
- // some possible contraction opportunities, but it hopefully lets us:
- // 1) preserve locality of tensors. We don't want to keep them alive for too long.
- // 2) Help bound the computation complexity for contractEdge
- if (consumer->ord - producer->ord > distance_threshold) continue;
- if (!shouldConsiderForMerge(producer)) continue;
+ SubgraphUtils::unmergeSubgraph(n);
+ return true;
+ }
- // If the edge contraction is successful, dep_graph.at(ord) may have changed
- // as well as consumer->in_edges() so we break out of this loop
- if (dep_graph.contractEdge(producer, consumer)) {
- // Stay at the current ord until we are done considering the vertex
- // at this ord for contraction
- ++ord;
- break;
+ value_list sortReverseTopological(ArrayRef<Value*> inputs) {
+ value_list result;
+ for (auto i : inputs) {
+ if (i->node()->owningBlock() == block_) {
+ result.push_back(i);
}
}
+ // Sort in reverse topological order
+ std::sort(result.begin(), result.end(), [&](Value* a, Value* b) {
+ return a->node()->isAfter(b->node());
+ });
+ return result;
}
-}
-
-static void reorder_according_to_dag(Block * block, const detail::DynamicDAG<Node*>& dep_graph) {
- for (size_t ord = 0; ord < dep_graph.max_size(); ++ord) {
- const auto& vertex = dep_graph.at(ord);
- if (!vertex.has_value()) continue;
- auto& nodes = vertex.value()->data;
- for (Node* node : nodes) {
- // Move all nodes according to the topological order in dep_graph. A lot
- // of the moves are unnecessary but this is a quick & easy solution.
- node->moveBefore(block->return_node());
+ bool shouldConsiderForMerge(Node* node) {
+ // if we're already in the process of merging
+ if (node->kind() == prim::DifferentiableGraph) {
+ return true;
+ }
+ if (node->kind() == prim::Constant) {
+ return false;
}
+ return isDifferentiable(node);
}
-}
-
-static void merge_differentiable_groups(
- Block * block,
- const detail::DynamicDAG<Node*>& dep_graph,
- size_t size_threshold,
- std::vector<Node*>& diff_graphs) {
- for (size_t ord = 0; ord < dep_graph.max_size(); ++ord) {
- const auto& vertex = dep_graph.at(ord);
- if (!vertex) continue;
- if (!shouldConsiderForMerge(vertex.value())) continue;
- auto& nodes = vertex.value()->data;
- if (nodes.size() < size_threshold) continue;
+ std::pair<graph_node_list::iterator, bool> scanNode(Node* consumer) {
+ if (shouldConsiderForMerge(consumer)) {
+ if (consumer->kind() != prim::DifferentiableGraph) {
+ consumer = SubgraphUtils::createSingletonSubgraph(
+ consumer, prim::DifferentiableGraph);
+ }
+ auto inputs = sortReverseTopological(consumer->inputs());
+ for (auto input : inputs) {
+ if (auto group = tryMerge(consumer, input->node())) {
+ // we successfully merged, so the new group's `inputs` may have
+ // changed. So rescan the new group for more merging opportunities.
+ return std::make_pair(group.value()->reverseIterator(), true);
+ }
+ }
+ }
- diff_graphs.push_back(mergeNodes(block, prim::DifferentiableGraph, nodes));
+ return std::make_pair(++consumer->reverseIterator(), false);
}
-}
-void CreateAutodiffSubgraphsPK(
- Block * block,
- size_t size_threshold,
- std::vector<Node*>& diff_graphs) {
- for (auto * node : block->nodes()) {
- // Find subgraphs to run this on recursively.
- if (isDifferentiable(node)) continue;
- for (auto * sub_block : node->blocks()) {
- CreateAutodiffSubgraphsPK(sub_block, size_threshold, diff_graphs);
+ // Try to merge `producer` into `consumer`. If successful, this destroys
+ // `producer` and returns the `consumer` group.
+ c10::optional<Node*> tryMerge(Node* consumer, Node* producer) {
+ JIT_ASSERT(consumer->kind() == prim::DifferentiableGraph);
+ bool canMerge = shouldConsiderForMerge(producer) &&
+ producer->moveBeforeTopologicallyValid(consumer);
+
+ if (!canMerge) {
+ return c10::nullopt;
}
- }
- auto dep_graph = make_dependency_graph(block);
- find_differentiable_groups(dep_graph);
- reorder_according_to_dag(block, dep_graph);
- merge_differentiable_groups(block, dep_graph, size_threshold, diff_graphs);
-}
+ SubgraphUtils::mergeNodeIntoSubgraph(producer, consumer);
+
+ return consumer;
+ }
+ Block* block_;
+ size_t minSubgraphSize_;
+};
} // anonymous namespace
-std::vector<Node*> CreateAutodiffSubgraphs(Graph & graph, size_t threshold) {
+std::vector<Node*> CreateAutodiffSubgraphs(
+ std::shared_ptr<Graph> graph,
+ size_t threshold) {
std::vector<Node*> diff_nodes;
- CreateAutodiffSubgraphsPK(graph.block(), threshold, diff_nodes);
+ SubgraphSlicer(graph->block(), threshold).run(diff_nodes);
return diff_nodes;
}
-}}
+} // namespace jit
+} // namespace torch
// subgraphs that are differentiable by the jit's autodiff passes
// threshold - minimum number of nodes that will appear in a block
// returns all differentiable blocks that have been found
-TORCH_API std::vector<Node*> CreateAutodiffSubgraphs(Graph & graph, size_t threshold = 2);
-
+TORCH_API std::vector<Node*> CreateAutodiffSubgraphs(
+ std::shared_ptr<Graph> graph,
+ size_t threshold = 2);
}}
--- /dev/null
+#include "subgraph_utils.h"
+
+namespace torch {
+namespace jit {
+namespace SubgraphUtils {
+namespace {
+bool isSubgraphNodeKind(Symbol s) {
+ return s == prim::DifferentiableGraph || s == prim::FusionGroup;
+}
+
+bool isSubgraphNodeKind(Node* n) {
+ return isSubgraphNodeKind(n->kind());
+}
+
+// Combine the nodes in two subgraph together. The nodes will end up in
+// `mergeTo`, and `mergeFrom` is destroyed.
+void mergeSubgraph(Node* mergeTo, Node* mergeFrom) {
+ const auto nodes = unmergeSubgraph(mergeFrom);
+ for (auto it = nodes.rbegin(); it != nodes.rend(); ++it) {
+ mergeNodeIntoSubgraph(*it, mergeTo);
+ }
+}
+} // namespace
+
+std::shared_ptr<Graph> getSubgraph(Node* n) {
+ JIT_ASSERT(isSubgraphNodeKind(n));
+ return n->g(attr::Subgraph);
+}
+
+std::vector<Node*> unmergeSubgraph(Node* subgraphNode) {
+ JIT_ASSERT(subgraphNode->kind() == prim::DifferentiableGraph);
+ auto outerGraph = subgraphNode->owningGraph();
+
+ std::vector<Node*> temporary_nodes;
+ auto subgraph = getSubgraph(subgraphNode);
+
+ // Initialize a map of inner graph values to outer graph values
+ std::unordered_map<const Value*, Value*> innerToOuter;
+ const auto innerInputs = subgraph->inputs();
+ const auto outerInputs = subgraphNode->inputs();
+ for (size_t i = 0; i < innerInputs.size(); ++i) {
+ innerToOuter[innerInputs[i]] = outerInputs[i];
+ }
+
+ // Clone all nodes
+ for (auto inner : subgraph->nodes()) {
+ Node* outer = outerGraph->createClone(
+ inner, [&](Value* k) -> Value* { return innerToOuter.at(k); });
+ outer->insertBefore(subgraphNode);
+ temporary_nodes.emplace_back(outer);
+ const auto innerOutputs = inner->outputs();
+ const auto outerOutputs = outer->outputs();
+ for (size_t i = 0; i < innerOutputs.size(); ++i) {
+ innerToOuter[innerOutputs[i]] = outerOutputs[i];
+ }
+ }
+
+ // Replace uses of group outputs and destroy the group
+ const auto subgraphOutputs = subgraph->outputs();
+ for (size_t i = 0; i < subgraphOutputs.size(); ++i) {
+ const auto outerOutput = innerToOuter.at(subgraphOutputs[i]);
+ subgraphNode->outputs()[i]->replaceAllUsesWith(outerOutput);
+ }
+ subgraphNode->destroy();
+
+ return temporary_nodes;
+}
+
+void mergeNodeIntoSubgraph(Node* toMerge, Node* subgraphNode) {
+ JIT_ASSERT(isSubgraphNodeKind(subgraphNode));
+ if (isSubgraphNodeKind(toMerge)) {
+ return mergeSubgraph(subgraphNode, toMerge);
+ }
+
+ auto subgraph = getSubgraph(subgraphNode);
+
+ // Map from values in the surrounding graph to inputs in the subgraph
+ std::unordered_map<Value*, Value*> inputsMap;
+
+ JIT_ASSERT(subgraphNode->inputs().size() == subgraph->inputs().size());
+ size_t idx = 0;
+ for (auto input : subgraphNode->inputs()) {
+ inputsMap[input] = subgraph->inputs()[idx];
+ idx++;
+ }
+
+ // Add n's inputs to the group's input list if we don't already have them
+ WithInsertPoint guard(*subgraph->nodes().begin());
+ for (auto input : toMerge->inputs()) {
+ if (inputsMap.count(input) == 0) {
+ // Clone constants inside the subgraph instead of referencing them, to
+ // enable more optimizations
+ if (auto value = toIValue(input)) {
+ auto nv = subgraph->insertConstant(*value);
+ inputsMap[input] = nv;
+ } else {
+ // The common case: this is a regular input, so just register it with
+ // the group node and inner subgraph
+ subgraphNode->addInput(input);
+ auto inputToGraph = subgraph->addInput();
+ inputToGraph->setType(input->type());
+ inputsMap[input] = inputToGraph;
+ }
+ }
+ }
+
+ // Merge the node into the graph
+ auto mergedNode = subgraph->insertNode(
+ subgraph->createClone(toMerge, [&](Value* v) { return inputsMap[v]; }));
+
+ // If n's outputs were inputs to `group`, remove them since we just merged
+ // n in.
+ //
+ // i.e.,
+ // x = f(w); group(x, y, z) becomes group(w, y, z).
+ // x, y, z = f(w); group(x, y, z) becomes group(w).
+ auto inputs = subgraphNode->inputs();
+ for (size_t i = 0; i < toMerge->outputs().size(); ++i) {
+ auto it = std::find(inputs.begin(), inputs.end(), toMerge->outputs()[i]);
+ if (it != inputs.end()) {
+ size_t p = it - inputs.begin();
+ subgraphNode->removeInput(p);
+ subgraph->inputs()[p]->replaceAllUsesWith(mergedNode->outputs()[i]);
+ subgraph->eraseInput(p);
+ }
+ }
+
+ // Add n's outputs to the group node and inner subgraph outputs.
+ for (size_t i = 0; i < toMerge->outputs().size(); i++) {
+ auto oldOutput = toMerge->outputs()[i];
+
+ // Only register the output in the group node if it's actually used
+ // outside the subgraph.
+ const auto hasUsesOutsideSubgraph = std::any_of(
+ oldOutput->uses().cbegin(),
+ oldOutput->uses().cend(),
+ [&](const Use& use) { return use.user->isAfter(subgraphNode); });
+
+ if (hasUsesOutsideSubgraph) {
+ auto newOutput = mergedNode->outputs()[i];
+ subgraph->registerOutput(newOutput);
+ auto groupOutput = subgraphNode->addOutput();
+ groupOutput->copyMetadata(oldOutput);
+ oldOutput->replaceAllUsesWith(groupOutput);
+ }
+ }
+
+ // Remove the original node now that the merge is complete
+ toMerge->destroy();
+}
+
+Node* createSingletonSubgraph(Node* n, Symbol subgraphKind) {
+ JIT_ASSERT(isSubgraphNodeKind(subgraphKind));
+ auto graph = n->owningGraph();
+ auto subgraph = graph->create(subgraphKind, 0);
+ subgraph->g_(attr::Subgraph, std::make_shared<Graph>(graph->current_scope()));
+ subgraph->insertBefore(n);
+ mergeNodeIntoSubgraph(n, subgraph);
+ return subgraph;
+}
+} // namespace SubgraphUtils
+} // namespace jit
+} // namespace torch
--- /dev/null
+#pragma once
+
+#include "torch/csrc/jit/ir.h"
+
+namespace torch {
+namespace jit {
+
+// Utilities for dealing with nodes that contain subgraphs.
+//
+// They handle the complexity of editing inputs/outputs as you merge nodes in
+// and out of subgraphs.
+namespace SubgraphUtils {
+
+// Create a new subgraph node that contains only `n`. The new subgraph will have
+// `subgraphKind` as its type.
+//
+// `n` is destroyed.
+//
+// Returns the new subgraph node.
+Node* createSingletonSubgraph(Node* n, Symbol subgraphKind);
+
+// Merge a node into a subgraph node. If `toMerge` is also a subgraph, the
+// subgraphs are merged.
+// `toMerge` is destroyed.
+void mergeNodeIntoSubgraph(Node* toMerge, Node* subgraphNode);
+
+// Move nodes from a subgraph node to the outer graph.
+// `subgraphNode` is destroyed.
+std::vector<Node*> unmergeSubgraph(Node* subgraphNode);
+
+// Convenience function
+std::shared_ptr<Graph> getSubgraph(Node* n);
+
+} // namespace SubgraphUtils
+} // namespace jit
+} // namespace torch