_(namespaces, namespaces) \
_(prim, Assign) \
_(prim, BroadcastingChunk) \
+ _(prim, BroadcastSizes) \
_(prim, Constant) \
+ _(prim, ChunkSizes) \
_(prim, None) \
_(prim, Drop) \
_(prim, Eval) \
%8 : Float(*, *) = aten::mm(%5, %7)
%9 : Float(*, *) = aten::t(%4)
%10 : Float(*, *) = aten::mm(%3, %9)
- %11 : int = prim::Constant[value=1]()
- %12 : int[] = aten::size(%8)
- %13 : int[] = aten::size(%10)
- %14 : int[] = aten::size(%2)
- %15 : Float(*, *), %16 : Float(*, *) = prim::FusionGroup_0(%2, %8, %10)
- %17 : int[] = aten::size(%16)
- %18 : int[] = aten::size(%15)
- %19 : int[] = aten::size(%1)
- %20 : Tensor[] = prim::ListConstruct(%15, %1)
- %21 : Tensor[] = aten::broadcast_tensors(%20)
- %22 : Tensor, %23 : Tensor = prim::ListUnpack(%21)
- %24 : int[] = aten::size(%0)
- %hy : Float(*, *), %26 : Float(*, *), %cy : Float(*, *), %28 : Float(*, *), %29 : Float(*, *), %outgate.1 : Float(*, *), %cellgate.1 : Float(*, *), %forgetgate.1 : Float(*, *), %ingate.1 : Float(*, *) = prim::FusionGroup_1(%0, %23, %22)
- %34 : int[] = aten::size(%ingate.1)
- %35 : int[] = aten::size(%forgetgate.1)
- %36 : int[] = aten::size(%cellgate.1)
- %37 : int[] = aten::size(%outgate.1)
- %38 : int[] = aten::size(%29)
- %39 : int[] = aten::size(%28)
- %40 : int[] = aten::size(%26)
- return (%hy, %cy, %7, %9, %12, %13, %17, %14, %18, %19, %ingate.1, %forgetgate.1, %cellgate.1, %outgate.1, %35, %24, %34, %36, %38, %39, %26, %37, %40);
+ %11 : int[] = aten::size(%8)
+ %12 : int[] = aten::size(%10)
+ %13 : int[] = aten::size(%2)
+ %14 : int[] = aten::size(%1)
+ %15 : Tensor[] = prim::ListConstruct(%1, %2, %8, %10)
+ %16 : Tensor[] = aten::broadcast_tensors(%15)
+ %17 : Tensor, %18 : Tensor, %19 : Tensor, %20 : Tensor = prim::ListUnpack(%16)
+ %21 : int[] = prim::BroadcastSizes(%11, %12)
+ %22 : int[] = prim::BroadcastSizes(%21, %13)
+ %23 : int[] = aten::size(%0)
+ %hy : Float(*, *), %25 : Float(*, *), %cy : Float(*, *), %outgate.1 : Float(*, *), %cellgate.1 : Float(*, *), %forgetgate.1 : Float(*, *), %ingate.1 : Float(*, *) = prim::FusionGroup_0(%0, %20, %19, %18, %17)
+ %31 : int[] = aten::size(%25)
+ %32 : int[] = aten::size(%outgate.1)
+ %33 : int[] = aten::size(%cellgate.1)
+ %34 : int[] = aten::size(%forgetgate.1)
+ %35 : int[] = aten::size(%ingate.1)
+ %36 : int[] = prim::BroadcastSizes(%34, %23)
+ %37 : int[] = prim::BroadcastSizes(%35, %33)
+ return (%hy, %cy, %7, %9, %11, %12, %21, %13, %22, %14, %ingate.1, %forgetgate.1, %cellgate.1, %outgate.1, %34, %23, %35, %33, %36, %37, %25, %32, %31);
}
-with prim::FusionGroup_0 = graph(%0 : Float(*)
- %1 : Float(*, *)
- %2 : Float(*, *)) {
- %3 : int = prim::Constant[value=1]()
- %4 : Float(*, *) = aten::add(%1, %2, %3)
- %5 : int = prim::Constant[value=1]()
- %6 : Float(*, *) = aten::add(%4, %0, %5)
- return (%6, %4);
-}
-with prim::FusionGroup_1 = graph(%0 : Float(*, *)
+with prim::FusionGroup_0 = graph(%0 : Float(*, *)
%1 : Tensor
- %2 : Tensor) {
- %3 : Float(*, *), %4 : Float(*, *), %5 : Float(*, *), %6 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%2)
- %7 : Float(*, *), %8 : Float(*, *), %9 : Float(*, *), %10 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%1)
- %11 : int = prim::Constant[value=1]()
- %12 : Float(*, *) = aten::add(%3, %7, %11)
- %13 : int = prim::Constant[value=1]()
- %14 : Float(*, *) = aten::add(%4, %8, %13)
- %15 : int = prim::Constant[value=1]()
- %16 : Float(*, *) = aten::add(%5, %9, %15)
- %17 : int = prim::Constant[value=1]()
- %18 : Float(*, *) = aten::add(%6, %10, %17)
- %ingate.1 : Float(*, *) = aten::sigmoid(%12)
- %forgetgate.1 : Float(*, *) = aten::sigmoid(%14)
- %cellgate.1 : Float(*, *) = aten::tanh(%16)
- %outgate.1 : Float(*, *) = aten::sigmoid(%18)
- %23 : Float(*, *) = aten::mul(%forgetgate.1, %0)
- %24 : Float(*, *) = aten::mul(%ingate.1, %cellgate.1)
+ %2 : Tensor
+ %3 : Tensor
+ %4 : Tensor) {
+ %5 : Float(*, *), %6 : Float(*, *), %7 : Float(*, *), %8 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%4)
+ %9 : Float(*, *), %10 : Float(*, *), %11 : Float(*, *), %12 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%3)
+ %13 : Float(*, *), %14 : Float(*, *), %15 : Float(*, *), %16 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%2)
+ %17 : Float(*, *), %18 : Float(*, *), %19 : Float(*, *), %20 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%1)
+ %21 : int = prim::Constant[value=1]()
+ %22 : Float(*, *) = aten::add(%13, %17, %21)
+ %23 : int = prim::Constant[value=1]()
+ %24 : Float(*, *) = aten::add(%14, %18, %23)
%25 : int = prim::Constant[value=1]()
- %cy : Float(*, *) = aten::add(%23, %24, %25)
- %27 : Float(*, *) = aten::tanh(%cy)
- %hy : Float(*, *) = aten::mul(%outgate.1, %27)
- return (%hy, %27, %cy, %24, %23, %outgate.1, %cellgate.1, %forgetgate.1, %ingate.1);
+ %26 : Float(*, *) = aten::add(%15, %19, %25)
+ %27 : int = prim::Constant[value=1]()
+ %28 : Float(*, *) = aten::add(%16, %20, %27)
+ %29 : int = prim::Constant[value=1]()
+ %30 : Float(*, *) = aten::add(%22, %9, %29)
+ %31 : int = prim::Constant[value=1]()
+ %32 : Float(*, *) = aten::add(%24, %10, %31)
+ %33 : int = prim::Constant[value=1]()
+ %34 : Float(*, *) = aten::add(%26, %11, %33)
+ %35 : int = prim::Constant[value=1]()
+ %36 : Float(*, *) = aten::add(%28, %12, %35)
+ %37 : int = prim::Constant[value=1]()
+ %38 : Float(*, *) = aten::add(%30, %5, %37)
+ %39 : int = prim::Constant[value=1]()
+ %40 : Float(*, *) = aten::add(%32, %6, %39)
+ %41 : int = prim::Constant[value=1]()
+ %42 : Float(*, *) = aten::add(%34, %7, %41)
+ %43 : int = prim::Constant[value=1]()
+ %44 : Float(*, *) = aten::add(%36, %8, %43)
+ %ingate.1 : Float(*, *) = aten::sigmoid(%38)
+ %forgetgate.1 : Float(*, *) = aten::sigmoid(%40)
+ %cellgate.1 : Float(*, *) = aten::tanh(%42)
+ %outgate.1 : Float(*, *) = aten::sigmoid(%44)
+ %49 : Float(*, *) = aten::mul(%forgetgate.1, %0)
+ %50 : Float(*, *) = aten::mul(%ingate.1, %cellgate.1)
+ %51 : int = prim::Constant[value=1]()
+ %cy : Float(*, *) = aten::add(%49, %50, %51)
+ %53 : Float(*, *) = aten::tanh(%cy)
+ %hy : Float(*, *) = aten::mul(%outgate.1, %53)
+ return (%hy, %53, %cy, %outgate.1, %cellgate.1, %forgetgate.1, %ingate.1);
}
%Uz.1 : Float(*, *) = aten::mm(%5, %11)
%13 : int[] = aten::size(%4)
%14 : int[] = aten::size(%Wx.1)
- %15 : int[] = aten::size(%Uz.1)
- %16 : int[] = aten::size(%3)
- %17 : int = prim::Constant[value=1]()
- %18 : int[] = aten::size(%2)
- %19 : Float(*, *), %20 : Float(*, *), %21 : Float(*, *), %22 : Float(*, *), %23 : Float(*, *), %24 : Float(*, *) = prim::FusionGroup_0(%2, %Uz.1, %3, %Wx.1, %4)
- %25 : int[] = aten::size(%24)
- %26 : int[] = aten::size(%23)
- %27 : int[] = aten::size(%22)
- %28 : int[] = aten::size(%21)
- %29 : int[] = aten::size(%20)
- %30 : int[] = aten::size(%19)
- %31 : int[] = aten::size(%1)
- %32 : Tensor[] = prim::ListConstruct(%19, %1)
- %33 : Tensor[] = aten::broadcast_tensors(%32)
- %34 : Tensor, %35 : Tensor = prim::ListUnpack(%33)
- %36 : int[] = aten::size(%0)
- %hy : Float(*, *), %38 : Float(*, *), %cy : Float(*, *), %40 : Float(*, *), %41 : Float(*, *), %outgate.1 : Float(*, *), %cellgate.1 : Float(*, *), %forgetgate.1 : Float(*, *), %ingate.1 : Float(*, *) = prim::FusionGroup_1(%0, %35, %34)
+ %15 : Float(*, *) = aten::mul(%4, %Wx.1)
+ %16 : int[] = aten::size(%15)
+ %17 : int[] = aten::size(%Uz.1)
+ %18 : int[] = aten::size(%3)
+ %19 : int[] = aten::size(%2)
+ %20 : int[] = aten::size(%1)
+ %21 : Tensor[] = prim::ListConstruct(%1, %2, %Uz.1, %15, %3, %Wx.1)
+ %22 : Tensor[] = aten::broadcast_tensors(%21)
+ %23 : Tensor, %24 : Tensor, %25 : Tensor, %26 : Tensor, %27 : Tensor, %28 : Tensor = prim::ListUnpack(%22)
+ %29 : int[] = prim::BroadcastSizes(%18, %14)
+ %30 : int[] = prim::BroadcastSizes(%16, %17)
+ %31 : int[] = prim::BroadcastSizes(%19, %17)
+ %32 : int[] = prim::BroadcastSizes(%30, %29)
+ %33 : int[] = prim::BroadcastSizes(%32, %31)
+ %34 : int[] = aten::size(%0)
+ %hy : Float(*, *), %36 : Float(*, *), %cy : Float(*, *), %outgate.1 : Float(*, *), %cellgate.1 : Float(*, *), %forgetgate.1 : Float(*, *), %ingate.1 : Float(*, *) = prim::FusionGroup_0(%0, %28, %27, %26, %25, %24, %23)
+ %42 : int[] = aten::size(%36)
+ %43 : int[] = aten::size(%outgate.1)
+ %44 : int[] = aten::size(%cellgate.1)
+ %45 : int[] = aten::size(%forgetgate.1)
%46 : int[] = aten::size(%ingate.1)
- %47 : int[] = aten::size(%forgetgate.1)
- %48 : int[] = aten::size(%cellgate.1)
- %49 : int[] = aten::size(%outgate.1)
- %50 : int[] = aten::size(%41)
- %51 : int[] = aten::size(%40)
- %52 : int[] = aten::size(%38)
- return (%hy, %cy, %9, %Wx.1, %11, %Uz.1, %13, %14, %24, %25, %15, %16, %14, %26, %27, %18, %15, %28, %29, %30, %31, %ingate.1, %forgetgate.1, %cellgate.1, %outgate.1, %47, %36, %46, %48, %50, %51, %38, %49, %52);
+ %47 : int[] = prim::BroadcastSizes(%45, %34)
+ %48 : int[] = prim::BroadcastSizes(%46, %44)
+ return (%hy, %cy, %9, %Wx.1, %11, %Uz.1, %13, %14, %15, %16, %17, %18, %14, %30, %29, %19, %17, %32, %31, %33, %20, %ingate.1, %forgetgate.1, %cellgate.1, %outgate.1, %45, %34, %46, %44, %47, %48, %36, %43, %42);
}
-with prim::FusionGroup_0 = graph(%0 : Float(*)
- %1 : Float(*, *)
- %2 : Float(*)
- %3 : Float(*, *)
- %4 : Float(*)) {
- %5 : Float(*, *) = aten::mul(%4, %3)
- %6 : Float(*, *) = aten::mul(%5, %1)
- %7 : Float(*, *) = aten::mul(%2, %3)
- %8 : int = prim::Constant[value=1]()
- %9 : Float(*, *) = aten::add(%6, %7, %8)
- %10 : Float(*, *) = aten::mul(%0, %1)
- %11 : int = prim::Constant[value=1]()
- %12 : Float(*, *) = aten::add(%9, %10, %11)
- return (%12, %10, %9, %7, %6, %5);
-}
-with prim::FusionGroup_1 = graph(%0 : Float(*, *)
+with prim::FusionGroup_0 = graph(%0 : Float(*, *)
%1 : Tensor
- %2 : Tensor) {
- %3 : Float(*, *), %4 : Float(*, *), %5 : Float(*, *), %6 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%2)
- %7 : Float(*, *), %8 : Float(*, *), %9 : Float(*, *), %10 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%1)
- %11 : int = prim::Constant[value=1]()
- %12 : Float(*, *) = aten::add(%3, %7, %11)
- %13 : int = prim::Constant[value=1]()
- %14 : Float(*, *) = aten::add(%4, %8, %13)
- %15 : int = prim::Constant[value=1]()
- %16 : Float(*, *) = aten::add(%5, %9, %15)
- %17 : int = prim::Constant[value=1]()
- %18 : Float(*, *) = aten::add(%6, %10, %17)
- %ingate.1 : Float(*, *) = aten::sigmoid(%12)
- %forgetgate.1 : Float(*, *) = aten::sigmoid(%14)
- %cellgate.1 : Float(*, *) = aten::tanh(%16)
- %outgate.1 : Float(*, *) = aten::sigmoid(%18)
- %23 : Float(*, *) = aten::mul(%forgetgate.1, %0)
- %24 : Float(*, *) = aten::mul(%ingate.1, %cellgate.1)
- %25 : int = prim::Constant[value=1]()
- %cy : Float(*, *) = aten::add(%23, %24, %25)
- %27 : Float(*, *) = aten::tanh(%cy)
- %hy : Float(*, *) = aten::mul(%outgate.1, %27)
- return (%hy, %27, %cy, %24, %23, %outgate.1, %cellgate.1, %forgetgate.1, %ingate.1);
+ %2 : Tensor
+ %3 : Tensor
+ %4 : Tensor
+ %5 : Tensor
+ %6 : Tensor) {
+ %7 : Float(*, *), %8 : Float(*, *), %9 : Float(*, *), %10 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%6)
+ %11 : Float(*, *), %12 : Float(*, *), %13 : Float(*, *), %14 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%5)
+ %15 : Float(*, *), %16 : Float(*, *), %17 : Float(*, *), %18 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%4)
+ %19 : Float(*, *), %20 : Float(*, *), %21 : Float(*, *), %22 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%3)
+ %23 : Float(*, *), %24 : Float(*, *), %25 : Float(*, *), %26 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%2)
+ %27 : Float(*, *), %28 : Float(*, *), %29 : Float(*, *), %30 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%1)
+ %31 : Float(*, *) = aten::mul(%23, %27)
+ %32 : Float(*, *) = aten::mul(%24, %28)
+ %33 : Float(*, *) = aten::mul(%25, %29)
+ %34 : Float(*, *) = aten::mul(%26, %30)
+ %35 : Float(*, *) = aten::mul(%19, %15)
+ %36 : Float(*, *) = aten::mul(%20, %16)
+ %37 : Float(*, *) = aten::mul(%21, %17)
+ %38 : Float(*, *) = aten::mul(%22, %18)
+ %39 : Float(*, *) = aten::mul(%11, %15)
+ %40 : Float(*, *) = aten::mul(%12, %16)
+ %41 : Float(*, *) = aten::mul(%13, %17)
+ %42 : Float(*, *) = aten::mul(%14, %18)
+ %43 : int = prim::Constant[value=1]()
+ %44 : Float(*, *) = aten::add(%35, %31, %43)
+ %45 : int = prim::Constant[value=1]()
+ %46 : Float(*, *) = aten::add(%36, %32, %45)
+ %47 : int = prim::Constant[value=1]()
+ %48 : Float(*, *) = aten::add(%37, %33, %47)
+ %49 : int = prim::Constant[value=1]()
+ %50 : Float(*, *) = aten::add(%38, %34, %49)
+ %51 : int = prim::Constant[value=1]()
+ %52 : Float(*, *) = aten::add(%44, %39, %51)
+ %53 : int = prim::Constant[value=1]()
+ %54 : Float(*, *) = aten::add(%46, %40, %53)
+ %55 : int = prim::Constant[value=1]()
+ %56 : Float(*, *) = aten::add(%48, %41, %55)
+ %57 : int = prim::Constant[value=1]()
+ %58 : Float(*, *) = aten::add(%50, %42, %57)
+ %59 : int = prim::Constant[value=1]()
+ %60 : Float(*, *) = aten::add(%52, %7, %59)
+ %61 : int = prim::Constant[value=1]()
+ %62 : Float(*, *) = aten::add(%54, %8, %61)
+ %63 : int = prim::Constant[value=1]()
+ %64 : Float(*, *) = aten::add(%56, %9, %63)
+ %65 : int = prim::Constant[value=1]()
+ %66 : Float(*, *) = aten::add(%58, %10, %65)
+ %ingate.1 : Float(*, *) = aten::sigmoid(%60)
+ %forgetgate.1 : Float(*, *) = aten::sigmoid(%62)
+ %cellgate.1 : Float(*, *) = aten::tanh(%64)
+ %outgate.1 : Float(*, *) = aten::sigmoid(%66)
+ %71 : Float(*, *) = aten::mul(%forgetgate.1, %0)
+ %72 : Float(*, *) = aten::mul(%ingate.1, %cellgate.1)
+ %73 : int = prim::Constant[value=1]()
+ %cy : Float(*, *) = aten::add(%71, %72, %73)
+ %75 : Float(*, *) = aten::tanh(%cy)
+ %hy : Float(*, *) = aten::mul(%outgate.1, %75)
+ return (%hy, %75, %cy, %outgate.1, %cellgate.1, %forgetgate.1, %ingate.1);
}
inputs = get_lstm_inputs('cuda', training=True)
module = self.checkScript(LSTMCellS, inputs)
forward_graph = module.graph_for(*inputs)
- with self.assertRaises(AssertionError):
- self.assertGraphContainsExactly(
- forward_graph, 'prim::FusionGroup', 1, consider_subgraphs=True)
+ self.assertGraphContainsExactly(
+ forward_graph, 'prim::FusionGroup', 1, consider_subgraphs=True)
self.assertExpectedGraph(forward_graph, subname='forward')
hy, cy = module(*inputs)
inputs = get_milstm_inputs('cuda', training=True)
module = self.checkScript(MiLSTMCell, inputs)
forward_graph = module.graph_for(*inputs)
- with self.assertRaises(AssertionError):
- self.assertGraphContainsExactly(
- forward_graph, 'prim::FusionGroup', 1, consider_subgraphs=True)
+ self.assertGraphContainsExactly(
+ forward_graph, 'prim::FusionGroup', 1, consider_subgraphs=True)
self.assertExpectedGraph(forward_graph, subname='forward')
hy, cy = module(*inputs)
, at::TensorList args
, at::IntList arg_subset) {
- int64_t dim_after_broadcast = 0;
- for (const auto arg_idx : arg_subset) {
- dim_after_broadcast = std::max(dim_after_broadcast, args[arg_idx].dim());
- }
// TODO: this keeps reallocating map_size at every iteration, but we know
// exactly how much storage do we need, so this could be fixed in-place at
// every step. We're just missing a few functions for ATen, but the fix
// should be straightforward.
// Note: left unitialized since empty shape is broadcastable to any shape
std::vector<int64_t> map_size;
+ map_size.reserve(8);
for (const auto arg_idx : arg_subset) {
auto& arg = args.at(arg_idx);
auto& chunk_desc = spec.inputChunks().at(arg_idx);
case prim::MMTreeReduce:
case prim::MMBatchSide:
case prim::None:
+ case prim::BroadcastSizes:
+ case prim::ChunkSizes:
return analyzeCreator(node);
case prim::TupleUnpack:
case prim::TupleIndex:
#include "torch/csrc/jit/passes/alias_analysis.h"
#include "torch/csrc/jit/passes/common_subexpression_elimination.h"
+#include "torch/csrc/jit/passes/dead_code_elimination.h"
#include "torch/csrc/jit/symbolic_variable.h"
#include "torch/csrc/jit/fuser/interface.h"
#include "torch/csrc/jit/operator.h"
return true;
}
+Value * broadcastSizes(at::ArrayRef<Value*> sizes) {
+ JIT_ASSERT(!sizes.empty());
+ Graph * graph = sizes[0]->owningGraph();
+ Node * broadcast_n = graph->insertNode(graph->create(prim::BroadcastSizes, sizes));
+ broadcast_n->output()->setType(ListType::ofInts());
+ return broadcast_n->output();
+}
+
struct GraphFuser {
Block * block_;
std::shared_ptr<Graph> graph_;
return isFusableCatNode(node) || node->kind() == prim::FusedConcat;
}
- bool allUsersAreThisConsumer(Node * consumer, Value * producer) {
+ bool calculatesSize(Node * node) {
+ return node->matches("aten::size(Tensor self) -> int[]");
+ }
+
+ bool allUsersAreThisConsumerOrCalcSizes(Node * consumer, Value * producer) {
auto defining_node = producer->node();
for(auto o : defining_node->outputs()) {
for(auto u : o->uses()) {
- if(u.user != consumer)
+ if(u.user != consumer && !calculatesSize(u.user))
return false;
}
}
chunk->inputs().end(),
[&](Value * producer_for_chunk) {
return isFusable(producer_for_chunk->node()) &&
- allUsersAreThisConsumer(chunk, producer_for_chunk);
+ allUsersAreThisConsumerOrCalcSizes(chunk, producer_for_chunk);
});
if (it == chunk->inputs().end()) {
return false;
for (size_t i = 0; i < nchunks; i++) {
bchunk->eraseOutput(nchunks * producer_index);
}
+
+ // The output of producer_for_chunk_node could have been used in some aten::size
+ // operators, so we need to clean those up as well (we simply broadcast all its tensor inputs).
+ auto size_calc_uses = producer_for_chunk_node->output()->uses();
+ if (!size_calc_uses.empty()) {
+ auto tensor_inputs = filter(producer_for_chunk_node->inputs(),
+ [](Value * v) { return v->type()->isSubtypeOf(DynamicType::get()); });
+ auto tensor_sizes = fmap(tensor_inputs,
+ [](Value * v) { return v->owningGraph()->insert(aten::size, {v}); });
+ JIT_ASSERT(!tensor_sizes.empty());
+ Value * output_size = tensor_sizes.size() == 1 ? tensor_sizes[0] : broadcastSizes(tensor_sizes);
+ for (Use u : size_calc_uses) {
+ u.user->output()->replaceAllUsesWith(output_size);
+ u.user->destroy();
+ }
+ }
producer_for_chunk_node->destroy();
return true;
}
}
}
+ bool usedOnlyInSize(Value * v) {
+ const auto & uses = v->uses();
+ return std::all_of(uses.begin(), uses.end(),
+ [](const Use& u) { return u.user->matches("aten::size(Tensor self) -> int[]"); });
+ }
+
+ std::unordered_map<Value*, Value*> buildShapeExpressions(Node * fusion_group) {
+ WithInsertPoint insert_guard { fusion_group->next() };
+ std::unordered_map<Value*, Value*> shape_of;
+
+ Graph * graph = fusion_group->owningGraph();
+ auto subgraph = fusion_group->g(attr::Subgraph);
+
+ auto inputs = fusion_group->inputs();
+ auto sinputs = subgraph->inputs();
+ JIT_ASSERT(inputs.size() == sinputs.size());
+ for (size_t i = 0; i < inputs.size(); ++i) {
+ shape_of[sinputs[i]] = graph->insert(aten::size, {inputs[i]});
+ }
+
+ // When we have a guarantee that an output won't be removed, because it's
+ // used in expressions that don't involve size checks, we can use its size
+ // instead of computing a long chain of broadcasts, starting from the beginning
+ // of the kernel.
+ auto outputs = fusion_group->outputs();
+ auto soutputs = subgraph->outputs();
+ JIT_ASSERT(outputs.size() == soutputs.size());
+ for (size_t i = 0; i < outputs.size(); ++i) {
+ if (usedOnlyInSize(outputs[i])) continue;
+ shape_of[soutputs[i]] = graph->insert(aten::size, {outputs[i]});
+ }
+
+ for (Node * n : subgraph->nodes()) {
+ // XXX: Use of shape_of.emplace is crucial to the output shape optimization!
+ if (n->kind() == prim::FusedConcat) {
+ // This is a bit more involved, because we have to account for the case
+ // when inputs have different shapes, but fortunately those tensors are
+ // always outputs, and so we can simply avoid replacing their queries,
+ // because it won't help us.
+ continue;
+ }
+ if (n->kind() == prim::Constant) {
+ continue;
+ }
+ if (n->kind() == prim::ConstantChunk) {
+ Node * sizes_node = graph->insertNode(graph->create(prim::ChunkSizes, shape_of.at(n->input()), 2));
+ sizes_node->i_(attr::dim, n->i(attr::dim));
+ sizes_node->i_(attr::chunks, n->i(attr::chunks));
+ Value * regular_size = sizes_node->outputs().at(0);
+ Value * last_size = sizes_node->outputs().at(1);
+ auto outputs = n->outputs();
+ for (Value * o : outputs.slice(0, outputs.size() - 1)) {
+ shape_of.emplace(o, regular_size);
+ }
+ shape_of.emplace(outputs.at(outputs.size() - 1), last_size);
+ continue;
+ }
+ auto tensor_inputs = filter(n->inputs(),
+ [](Value * v) { return v->type()->isSubtypeOf(DynamicType::get()); });
+ auto shapes = fmap(tensor_inputs, [&](Value * v) { return shape_of.at(v); });
+ JIT_ASSERT(!shapes.empty());
+ shape_of.emplace(n->output(), shapes.size() == 1 ? shapes[0] : broadcastSizes(shapes));
+ }
+ return shape_of;
+ }
+
+ void removeOutputsUsedOnlyInSize(Node * fusion_group) {
+ if (fusion_group->kind() != prim::FusionGroup) return;
+ auto subgraph = fusion_group->g(attr::Subgraph);
+
+ auto shape_of = buildShapeExpressions(fusion_group);
+ auto outputs = fusion_group->outputs().vec();
+ auto soutputs = subgraph->outputs().vec();
+ // XXX: Iterating in this order is not only good for performance reasons!
+ // It is also crucial for correctness (i has to reflect the current true
+ // index of outputs[i])!
+ for (int64_t i = static_cast<int64_t>(outputs.size()) - 1; i >= 0; --i) {
+ auto output = outputs[i];
+ auto soutput = soutputs[i];
+ if (usedOnlyInSize(output) && shape_of.count(soutput) > 0) {
+ auto uses = output->uses();
+ for (Use u : uses) {
+ JIT_ASSERT(u.user->matches("aten::size(Tensor self) -> int[]"));
+ u.user->output()->replaceAllUsesWith(shape_of.at(soutput));
+ u.user->destroy();
+ }
+ fusion_group->eraseOutput(i);
+ subgraph->eraseOutput(i);
+ }
+ }
+ }
+
void run() {
// Run the pass until no changes are made.
// This is neccessary, because the algorithm can miss out on certain fusion
for (auto it = block_->nodes().rbegin(); it != block_->nodes().rend();) {
it = scanNodeForChunks(*it);
}
+
+ // Remove outputs that have been added only because we need their size
+ for (Node * n : block_->nodes()) {
+ removeOutputsUsedOnlyInSize(n);
+ }
+
for (Node * node : block_->nodes()) {
for (Block * sub_block : node->blocks()) {
GraphFuser(sub_block, graph_).run();
}
};
+void PeepholeOptimizeShapeExpressions(Block * block) {
+ auto nodes = block->nodes();
+ for (auto it = nodes.begin(); it != nodes.end(); ++it) {
+ Node * node = *it;
+ for (Block * subblock : node->blocks()) {
+ PeepholeOptimizeShapeExpressions(subblock);
+ }
+ if (node->kind() == prim::BroadcastSizes) {
+ // Remove no-op broadcasts.
+ if (node->inputs().size() == 1) {
+ node->output()->replaceAllUsesWith(node->input());
+ it.destroyCurrent();
+ continue;
+ }
+ // Deduplicate inputs, but use their unique() values to ensure
+ // this process only depends on the graph.
+ std::map<size_t, Value*> unique_to_value;
+ for (Value * input : node->inputs()) {
+ unique_to_value.emplace(input->unique(), input);
+ }
+ if (unique_to_value.size() != node->inputs().size()) {
+ std::vector<Value*> inputs;
+ for (auto & entry : unique_to_value) {
+ inputs.push_back(entry.second);
+ }
+ if (inputs.size() == 1) {
+ node->output()->replaceAllUsesWith(inputs[0]);
+ } else {
+ WithInsertPoint insert_guard { node };
+ node->output()->replaceAllUsesWith(broadcastSizes(inputs));
+ }
+ it.destroyCurrent();
+ --it; // Revisit the node with deduplicated inputs
+ continue;
+ }
+ // Remove compose simple chains of broadcasts into a single node.
+ const auto & uses = node->output()->uses();
+ if (uses.size() == 1 && uses[0].user->kind() == prim::BroadcastSizes) {
+ Node * user = uses[0].user;
+ user->removeInput(uses[0].offset);
+ // NB: we don't care about deduplication in here, as we will visit user later.
+ for (Value * i : node->inputs()) {
+ user->addInput(i);
+ }
+ it.destroyCurrent();
+ }
+ }
+ }
+}
+
} // anonymous namespace
void FuseGraph(std::shared_ptr<Graph>& graph) {
GraphFuser(graph->block(), graph).run();
// After FuseGraph some common subexpressions may come back
EliminateCommonSubexpression(graph);
+ // We might have emitted a fair amount of useless shape propagating code, so
+ // remove it
+ EliminateDeadCode(graph);
+ // Improve the quality of shape propagation code that was left
+ PeepholeOptimizeShapeExpressions(graph->block());
#endif
}
prim::AutogradAdd, // temporarily inserted by autograd
prim::ConstantChunk, // optimization pass adds it
prim::DifferentiableGraph, // optimization pass adds it
+ prim::BroadcastSizes, // optimization pass (fuser) adds it
+ prim::ChunkSizes, // optimization pass (fuser) adds it
prim::Drop, // used in interpreter only
prim::FusedConcat, // optimization pass adds it
prim::FusionGroup, // optimization pass adds it
#include "torch/csrc/variable_tensor_functions.h"
#include <ATen/ExpandUtils.h>
+#include <ATen/WrapDimUtils.h>
#include <c10/util/SmallVector.h>
#include <exception>
};
}),
Operator(
+ prim::BroadcastSizes,
+ [](const Node* node) -> Operation {
+ size_t num_inputs = node->inputs().size();
+ return [num_inputs](Stack& stack) {
+ std::vector<int64_t> size;
+ size.reserve(8);
+ for (size_t i = 0; i < num_inputs; ++i) {
+ size = at::infer_size(size, peek(stack, i, num_inputs).toIntList()->elements());
+ }
+ drop(stack, num_inputs);
+ push(stack, std::move(size));
+ return 0;
+ };
+ }),
+ Operator(
+ prim::ChunkSizes,
+ [](const Node* node) -> Operation {
+ int64_t raw_dim = node->i(attr::dim);
+ int64_t chunks = node->i(attr::chunks);
+ return [raw_dim, chunks](Stack& stack) {
+ Shared<IntList> sizes_l;
+ pop(stack, sizes_l);
+ const auto & shape = sizes_l->elements();
+ std::vector<int64_t> regular_shape = shape;
+ std::vector<int64_t> last_shape = shape;
+ int64_t dim = at::maybe_wrap_dim(raw_dim, shape.size());
+ AT_CHECK(dim < regular_shape.size(), "Dimension out of range for chunk");
+ int64_t split_size = (regular_shape[dim] + chunks - 1) / chunks;
+ regular_shape[dim] = split_size;
+ if (shape[dim] % chunks == 0) {
+ last_shape[dim] = split_size;
+ } else {
+ int64_t num_splits = std::max<int64_t>((shape[dim] + split_size - 1) / split_size, 1);
+ last_shape[dim] = split_size - (split_size * num_splits - shape[dim]);
+ JIT_ASSERT(last_shape[dim] >= 0);
+ }
+ push(stack, std::move(regular_shape));
+ push(stack, std::move(last_shape));
+ return 0;
+ };
+ }),
+ Operator(
"prim::RaiseException(str msg) -> ()",
[](const Node* node) -> Operation {
return [](Stack& stack) {