Reenable all forward-pass fusions that worked before the AD fix (#14558)
authorAdam Paszke <adam.paszke@gmail.com>
Tue, 4 Dec 2018 23:40:41 +0000 (15:40 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Tue, 4 Dec 2018 23:43:37 +0000 (15:43 -0800)
Summary:
Dealing with so many `aten::size` calls (in particular calls on elements computed inside fusion groups) requires us to do some extra graph processing in the fuser (to compute the sizes by explicit broadcasts, instead of writing the intermediate tensors only to check their size). This restores the forward expects of LSTM and MiLSTM to a single big kernel. Unfortunately the backward is much harder, because as long as we can't prove that the reductions are unnecessary (or if we can't distribute them over the op), we will not be able to fuse them.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14558

Differential Revision: D13321748

Pulled By: zou3519

fbshipit-source-id: c04fc2f70d106d2bfb56206b5aec517a93b79d1f

aten/src/ATen/core/interned_strings.h
test/expect/TestScript.test_lstm_fusion_cuda-forward.expect
test/expect/TestScript.test_milstm_fusion_cuda-forward.expect
test/test_jit.py
torch/csrc/jit/fuser/executor.cpp
torch/csrc/jit/passes/alias_analysis.cpp
torch/csrc/jit/passes/graph_fuser.cpp
torch/csrc/jit/passes/python_print.cpp
torch/csrc/jit/register_prim_ops.cpp

index 477d3927eacd064d60990f1b75ab2860343d7c52..ec0e04531458524a37835c9d8fdd7ebe273c16b9 100644 (file)
@@ -20,7 +20,9 @@ namespace c10 {
   _(namespaces, namespaces)        \
   _(prim, Assign)                  \
   _(prim, BroadcastingChunk)       \
+  _(prim, BroadcastSizes)          \
   _(prim, Constant)                \
+  _(prim, ChunkSizes)              \
   _(prim, None)                    \
   _(prim, Drop)                    \
   _(prim, Eval)                    \
index a10d0c3f09c7b52900417ad9b8c1a3f3f1710945..825e8029b37789bc539edda0b55c11395584d6fb 100644 (file)
@@ -19,59 +19,68 @@ with prim::DifferentiableGraph_0 = graph(%0 : Float(*, *)
   %8 : Float(*, *) = aten::mm(%5, %7)
   %9 : Float(*, *) = aten::t(%4)
   %10 : Float(*, *) = aten::mm(%3, %9)
-  %11 : int = prim::Constant[value=1]()
-  %12 : int[] = aten::size(%8)
-  %13 : int[] = aten::size(%10)
-  %14 : int[] = aten::size(%2)
-  %15 : Float(*, *), %16 : Float(*, *) = prim::FusionGroup_0(%2, %8, %10)
-  %17 : int[] = aten::size(%16)
-  %18 : int[] = aten::size(%15)
-  %19 : int[] = aten::size(%1)
-  %20 : Tensor[] = prim::ListConstruct(%15, %1)
-  %21 : Tensor[] = aten::broadcast_tensors(%20)
-  %22 : Tensor, %23 : Tensor = prim::ListUnpack(%21)
-  %24 : int[] = aten::size(%0)
-  %hy : Float(*, *), %26 : Float(*, *), %cy : Float(*, *), %28 : Float(*, *), %29 : Float(*, *), %outgate.1 : Float(*, *), %cellgate.1 : Float(*, *), %forgetgate.1 : Float(*, *), %ingate.1 : Float(*, *) = prim::FusionGroup_1(%0, %23, %22)
-  %34 : int[] = aten::size(%ingate.1)
-  %35 : int[] = aten::size(%forgetgate.1)
-  %36 : int[] = aten::size(%cellgate.1)
-  %37 : int[] = aten::size(%outgate.1)
-  %38 : int[] = aten::size(%29)
-  %39 : int[] = aten::size(%28)
-  %40 : int[] = aten::size(%26)
-  return (%hy, %cy, %7, %9, %12, %13, %17, %14, %18, %19, %ingate.1, %forgetgate.1, %cellgate.1, %outgate.1, %35, %24, %34, %36, %38, %39, %26, %37, %40);
+  %11 : int[] = aten::size(%8)
+  %12 : int[] = aten::size(%10)
+  %13 : int[] = aten::size(%2)
+  %14 : int[] = aten::size(%1)
+  %15 : Tensor[] = prim::ListConstruct(%1, %2, %8, %10)
+  %16 : Tensor[] = aten::broadcast_tensors(%15)
+  %17 : Tensor, %18 : Tensor, %19 : Tensor, %20 : Tensor = prim::ListUnpack(%16)
+  %21 : int[] = prim::BroadcastSizes(%11, %12)
+  %22 : int[] = prim::BroadcastSizes(%21, %13)
+  %23 : int[] = aten::size(%0)
+  %hy : Float(*, *), %25 : Float(*, *), %cy : Float(*, *), %outgate.1 : Float(*, *), %cellgate.1 : Float(*, *), %forgetgate.1 : Float(*, *), %ingate.1 : Float(*, *) = prim::FusionGroup_0(%0, %20, %19, %18, %17)
+  %31 : int[] = aten::size(%25)
+  %32 : int[] = aten::size(%outgate.1)
+  %33 : int[] = aten::size(%cellgate.1)
+  %34 : int[] = aten::size(%forgetgate.1)
+  %35 : int[] = aten::size(%ingate.1)
+  %36 : int[] = prim::BroadcastSizes(%34, %23)
+  %37 : int[] = prim::BroadcastSizes(%35, %33)
+  return (%hy, %cy, %7, %9, %11, %12, %21, %13, %22, %14, %ingate.1, %forgetgate.1, %cellgate.1, %outgate.1, %34, %23, %35, %33, %36, %37, %25, %32, %31);
 }
-with prim::FusionGroup_0 = graph(%0 : Float(*)
-      %1 : Float(*, *)
-      %2 : Float(*, *)) {
-  %3 : int = prim::Constant[value=1]()
-  %4 : Float(*, *) = aten::add(%1, %2, %3)
-  %5 : int = prim::Constant[value=1]()
-  %6 : Float(*, *) = aten::add(%4, %0, %5)
-  return (%6, %4);
-}
-with prim::FusionGroup_1 = graph(%0 : Float(*, *)
+with prim::FusionGroup_0 = graph(%0 : Float(*, *)
       %1 : Tensor
-      %2 : Tensor) {
-  %3 : Float(*, *), %4 : Float(*, *), %5 : Float(*, *), %6 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%2)
-  %7 : Float(*, *), %8 : Float(*, *), %9 : Float(*, *), %10 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%1)
-  %11 : int = prim::Constant[value=1]()
-  %12 : Float(*, *) = aten::add(%3, %7, %11)
-  %13 : int = prim::Constant[value=1]()
-  %14 : Float(*, *) = aten::add(%4, %8, %13)
-  %15 : int = prim::Constant[value=1]()
-  %16 : Float(*, *) = aten::add(%5, %9, %15)
-  %17 : int = prim::Constant[value=1]()
-  %18 : Float(*, *) = aten::add(%6, %10, %17)
-  %ingate.1 : Float(*, *) = aten::sigmoid(%12)
-  %forgetgate.1 : Float(*, *) = aten::sigmoid(%14)
-  %cellgate.1 : Float(*, *) = aten::tanh(%16)
-  %outgate.1 : Float(*, *) = aten::sigmoid(%18)
-  %23 : Float(*, *) = aten::mul(%forgetgate.1, %0)
-  %24 : Float(*, *) = aten::mul(%ingate.1, %cellgate.1)
+      %2 : Tensor
+      %3 : Tensor
+      %4 : Tensor) {
+  %5 : Float(*, *), %6 : Float(*, *), %7 : Float(*, *), %8 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%4)
+  %9 : Float(*, *), %10 : Float(*, *), %11 : Float(*, *), %12 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%3)
+  %13 : Float(*, *), %14 : Float(*, *), %15 : Float(*, *), %16 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%2)
+  %17 : Float(*, *), %18 : Float(*, *), %19 : Float(*, *), %20 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%1)
+  %21 : int = prim::Constant[value=1]()
+  %22 : Float(*, *) = aten::add(%13, %17, %21)
+  %23 : int = prim::Constant[value=1]()
+  %24 : Float(*, *) = aten::add(%14, %18, %23)
   %25 : int = prim::Constant[value=1]()
-  %cy : Float(*, *) = aten::add(%23, %24, %25)
-  %27 : Float(*, *) = aten::tanh(%cy)
-  %hy : Float(*, *) = aten::mul(%outgate.1, %27)
-  return (%hy, %27, %cy, %24, %23, %outgate.1, %cellgate.1, %forgetgate.1, %ingate.1);
+  %26 : Float(*, *) = aten::add(%15, %19, %25)
+  %27 : int = prim::Constant[value=1]()
+  %28 : Float(*, *) = aten::add(%16, %20, %27)
+  %29 : int = prim::Constant[value=1]()
+  %30 : Float(*, *) = aten::add(%22, %9, %29)
+  %31 : int = prim::Constant[value=1]()
+  %32 : Float(*, *) = aten::add(%24, %10, %31)
+  %33 : int = prim::Constant[value=1]()
+  %34 : Float(*, *) = aten::add(%26, %11, %33)
+  %35 : int = prim::Constant[value=1]()
+  %36 : Float(*, *) = aten::add(%28, %12, %35)
+  %37 : int = prim::Constant[value=1]()
+  %38 : Float(*, *) = aten::add(%30, %5, %37)
+  %39 : int = prim::Constant[value=1]()
+  %40 : Float(*, *) = aten::add(%32, %6, %39)
+  %41 : int = prim::Constant[value=1]()
+  %42 : Float(*, *) = aten::add(%34, %7, %41)
+  %43 : int = prim::Constant[value=1]()
+  %44 : Float(*, *) = aten::add(%36, %8, %43)
+  %ingate.1 : Float(*, *) = aten::sigmoid(%38)
+  %forgetgate.1 : Float(*, *) = aten::sigmoid(%40)
+  %cellgate.1 : Float(*, *) = aten::tanh(%42)
+  %outgate.1 : Float(*, *) = aten::sigmoid(%44)
+  %49 : Float(*, *) = aten::mul(%forgetgate.1, %0)
+  %50 : Float(*, *) = aten::mul(%ingate.1, %cellgate.1)
+  %51 : int = prim::Constant[value=1]()
+  %cy : Float(*, *) = aten::add(%49, %50, %51)
+  %53 : Float(*, *) = aten::tanh(%cy)
+  %hy : Float(*, *) = aten::mul(%outgate.1, %53)
+  return (%hy, %53, %cy, %outgate.1, %cellgate.1, %forgetgate.1, %ingate.1);
 }
index 2e2726c7900616c4c9d1f8ea91bb56308214a0cc..1c4e4832a6fe7a2af5928dcc64230fd13246858d 100644 (file)
@@ -25,69 +25,89 @@ with prim::DifferentiableGraph_0 = graph(%0 : Float(*, *)
   %Uz.1 : Float(*, *) = aten::mm(%5, %11)
   %13 : int[] = aten::size(%4)
   %14 : int[] = aten::size(%Wx.1)
-  %15 : int[] = aten::size(%Uz.1)
-  %16 : int[] = aten::size(%3)
-  %17 : int = prim::Constant[value=1]()
-  %18 : int[] = aten::size(%2)
-  %19 : Float(*, *), %20 : Float(*, *), %21 : Float(*, *), %22 : Float(*, *), %23 : Float(*, *), %24 : Float(*, *) = prim::FusionGroup_0(%2, %Uz.1, %3, %Wx.1, %4)
-  %25 : int[] = aten::size(%24)
-  %26 : int[] = aten::size(%23)
-  %27 : int[] = aten::size(%22)
-  %28 : int[] = aten::size(%21)
-  %29 : int[] = aten::size(%20)
-  %30 : int[] = aten::size(%19)
-  %31 : int[] = aten::size(%1)
-  %32 : Tensor[] = prim::ListConstruct(%19, %1)
-  %33 : Tensor[] = aten::broadcast_tensors(%32)
-  %34 : Tensor, %35 : Tensor = prim::ListUnpack(%33)
-  %36 : int[] = aten::size(%0)
-  %hy : Float(*, *), %38 : Float(*, *), %cy : Float(*, *), %40 : Float(*, *), %41 : Float(*, *), %outgate.1 : Float(*, *), %cellgate.1 : Float(*, *), %forgetgate.1 : Float(*, *), %ingate.1 : Float(*, *) = prim::FusionGroup_1(%0, %35, %34)
+  %15 : Float(*, *) = aten::mul(%4, %Wx.1)
+  %16 : int[] = aten::size(%15)
+  %17 : int[] = aten::size(%Uz.1)
+  %18 : int[] = aten::size(%3)
+  %19 : int[] = aten::size(%2)
+  %20 : int[] = aten::size(%1)
+  %21 : Tensor[] = prim::ListConstruct(%1, %2, %Uz.1, %15, %3, %Wx.1)
+  %22 : Tensor[] = aten::broadcast_tensors(%21)
+  %23 : Tensor, %24 : Tensor, %25 : Tensor, %26 : Tensor, %27 : Tensor, %28 : Tensor = prim::ListUnpack(%22)
+  %29 : int[] = prim::BroadcastSizes(%18, %14)
+  %30 : int[] = prim::BroadcastSizes(%16, %17)
+  %31 : int[] = prim::BroadcastSizes(%19, %17)
+  %32 : int[] = prim::BroadcastSizes(%30, %29)
+  %33 : int[] = prim::BroadcastSizes(%32, %31)
+  %34 : int[] = aten::size(%0)
+  %hy : Float(*, *), %36 : Float(*, *), %cy : Float(*, *), %outgate.1 : Float(*, *), %cellgate.1 : Float(*, *), %forgetgate.1 : Float(*, *), %ingate.1 : Float(*, *) = prim::FusionGroup_0(%0, %28, %27, %26, %25, %24, %23)
+  %42 : int[] = aten::size(%36)
+  %43 : int[] = aten::size(%outgate.1)
+  %44 : int[] = aten::size(%cellgate.1)
+  %45 : int[] = aten::size(%forgetgate.1)
   %46 : int[] = aten::size(%ingate.1)
-  %47 : int[] = aten::size(%forgetgate.1)
-  %48 : int[] = aten::size(%cellgate.1)
-  %49 : int[] = aten::size(%outgate.1)
-  %50 : int[] = aten::size(%41)
-  %51 : int[] = aten::size(%40)
-  %52 : int[] = aten::size(%38)
-  return (%hy, %cy, %9, %Wx.1, %11, %Uz.1, %13, %14, %24, %25, %15, %16, %14, %26, %27, %18, %15, %28, %29, %30, %31, %ingate.1, %forgetgate.1, %cellgate.1, %outgate.1, %47, %36, %46, %48, %50, %51, %38, %49, %52);
+  %47 : int[] = prim::BroadcastSizes(%45, %34)
+  %48 : int[] = prim::BroadcastSizes(%46, %44)
+  return (%hy, %cy, %9, %Wx.1, %11, %Uz.1, %13, %14, %15, %16, %17, %18, %14, %30, %29, %19, %17, %32, %31, %33, %20, %ingate.1, %forgetgate.1, %cellgate.1, %outgate.1, %45, %34, %46, %44, %47, %48, %36, %43, %42);
 }
-with prim::FusionGroup_0 = graph(%0 : Float(*)
-      %1 : Float(*, *)
-      %2 : Float(*)
-      %3 : Float(*, *)
-      %4 : Float(*)) {
-  %5 : Float(*, *) = aten::mul(%4, %3)
-  %6 : Float(*, *) = aten::mul(%5, %1)
-  %7 : Float(*, *) = aten::mul(%2, %3)
-  %8 : int = prim::Constant[value=1]()
-  %9 : Float(*, *) = aten::add(%6, %7, %8)
-  %10 : Float(*, *) = aten::mul(%0, %1)
-  %11 : int = prim::Constant[value=1]()
-  %12 : Float(*, *) = aten::add(%9, %10, %11)
-  return (%12, %10, %9, %7, %6, %5);
-}
-with prim::FusionGroup_1 = graph(%0 : Float(*, *)
+with prim::FusionGroup_0 = graph(%0 : Float(*, *)
       %1 : Tensor
-      %2 : Tensor) {
-  %3 : Float(*, *), %4 : Float(*, *), %5 : Float(*, *), %6 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%2)
-  %7 : Float(*, *), %8 : Float(*, *), %9 : Float(*, *), %10 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%1)
-  %11 : int = prim::Constant[value=1]()
-  %12 : Float(*, *) = aten::add(%3, %7, %11)
-  %13 : int = prim::Constant[value=1]()
-  %14 : Float(*, *) = aten::add(%4, %8, %13)
-  %15 : int = prim::Constant[value=1]()
-  %16 : Float(*, *) = aten::add(%5, %9, %15)
-  %17 : int = prim::Constant[value=1]()
-  %18 : Float(*, *) = aten::add(%6, %10, %17)
-  %ingate.1 : Float(*, *) = aten::sigmoid(%12)
-  %forgetgate.1 : Float(*, *) = aten::sigmoid(%14)
-  %cellgate.1 : Float(*, *) = aten::tanh(%16)
-  %outgate.1 : Float(*, *) = aten::sigmoid(%18)
-  %23 : Float(*, *) = aten::mul(%forgetgate.1, %0)
-  %24 : Float(*, *) = aten::mul(%ingate.1, %cellgate.1)
-  %25 : int = prim::Constant[value=1]()
-  %cy : Float(*, *) = aten::add(%23, %24, %25)
-  %27 : Float(*, *) = aten::tanh(%cy)
-  %hy : Float(*, *) = aten::mul(%outgate.1, %27)
-  return (%hy, %27, %cy, %24, %23, %outgate.1, %cellgate.1, %forgetgate.1, %ingate.1);
+      %2 : Tensor
+      %3 : Tensor
+      %4 : Tensor
+      %5 : Tensor
+      %6 : Tensor) {
+  %7 : Float(*, *), %8 : Float(*, *), %9 : Float(*, *), %10 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%6)
+  %11 : Float(*, *), %12 : Float(*, *), %13 : Float(*, *), %14 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%5)
+  %15 : Float(*, *), %16 : Float(*, *), %17 : Float(*, *), %18 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%4)
+  %19 : Float(*, *), %20 : Float(*, *), %21 : Float(*, *), %22 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%3)
+  %23 : Float(*, *), %24 : Float(*, *), %25 : Float(*, *), %26 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%2)
+  %27 : Float(*, *), %28 : Float(*, *), %29 : Float(*, *), %30 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%1)
+  %31 : Float(*, *) = aten::mul(%23, %27)
+  %32 : Float(*, *) = aten::mul(%24, %28)
+  %33 : Float(*, *) = aten::mul(%25, %29)
+  %34 : Float(*, *) = aten::mul(%26, %30)
+  %35 : Float(*, *) = aten::mul(%19, %15)
+  %36 : Float(*, *) = aten::mul(%20, %16)
+  %37 : Float(*, *) = aten::mul(%21, %17)
+  %38 : Float(*, *) = aten::mul(%22, %18)
+  %39 : Float(*, *) = aten::mul(%11, %15)
+  %40 : Float(*, *) = aten::mul(%12, %16)
+  %41 : Float(*, *) = aten::mul(%13, %17)
+  %42 : Float(*, *) = aten::mul(%14, %18)
+  %43 : int = prim::Constant[value=1]()
+  %44 : Float(*, *) = aten::add(%35, %31, %43)
+  %45 : int = prim::Constant[value=1]()
+  %46 : Float(*, *) = aten::add(%36, %32, %45)
+  %47 : int = prim::Constant[value=1]()
+  %48 : Float(*, *) = aten::add(%37, %33, %47)
+  %49 : int = prim::Constant[value=1]()
+  %50 : Float(*, *) = aten::add(%38, %34, %49)
+  %51 : int = prim::Constant[value=1]()
+  %52 : Float(*, *) = aten::add(%44, %39, %51)
+  %53 : int = prim::Constant[value=1]()
+  %54 : Float(*, *) = aten::add(%46, %40, %53)
+  %55 : int = prim::Constant[value=1]()
+  %56 : Float(*, *) = aten::add(%48, %41, %55)
+  %57 : int = prim::Constant[value=1]()
+  %58 : Float(*, *) = aten::add(%50, %42, %57)
+  %59 : int = prim::Constant[value=1]()
+  %60 : Float(*, *) = aten::add(%52, %7, %59)
+  %61 : int = prim::Constant[value=1]()
+  %62 : Float(*, *) = aten::add(%54, %8, %61)
+  %63 : int = prim::Constant[value=1]()
+  %64 : Float(*, *) = aten::add(%56, %9, %63)
+  %65 : int = prim::Constant[value=1]()
+  %66 : Float(*, *) = aten::add(%58, %10, %65)
+  %ingate.1 : Float(*, *) = aten::sigmoid(%60)
+  %forgetgate.1 : Float(*, *) = aten::sigmoid(%62)
+  %cellgate.1 : Float(*, *) = aten::tanh(%64)
+  %outgate.1 : Float(*, *) = aten::sigmoid(%66)
+  %71 : Float(*, *) = aten::mul(%forgetgate.1, %0)
+  %72 : Float(*, *) = aten::mul(%ingate.1, %cellgate.1)
+  %73 : int = prim::Constant[value=1]()
+  %cy : Float(*, *) = aten::add(%71, %72, %73)
+  %75 : Float(*, *) = aten::tanh(%cy)
+  %hy : Float(*, *) = aten::mul(%outgate.1, %75)
+  return (%hy, %75, %cy, %outgate.1, %cellgate.1, %forgetgate.1, %ingate.1);
 }
index 063ac013532a22a1abcdf285b0e6f3ef9b8b4b31..a71e31f8914ed3786f3c465cc7873b4af31690aa 100644 (file)
@@ -4308,9 +4308,8 @@ a")
         inputs = get_lstm_inputs('cuda', training=True)
         module = self.checkScript(LSTMCellS, inputs)
         forward_graph = module.graph_for(*inputs)
-        with self.assertRaises(AssertionError):
-            self.assertGraphContainsExactly(
-                forward_graph, 'prim::FusionGroup', 1, consider_subgraphs=True)
+        self.assertGraphContainsExactly(
+            forward_graph, 'prim::FusionGroup', 1, consider_subgraphs=True)
         self.assertExpectedGraph(forward_graph, subname='forward')
 
         hy, cy = module(*inputs)
@@ -4324,9 +4323,8 @@ a")
         inputs = get_milstm_inputs('cuda', training=True)
         module = self.checkScript(MiLSTMCell, inputs)
         forward_graph = module.graph_for(*inputs)
-        with self.assertRaises(AssertionError):
-            self.assertGraphContainsExactly(
-                forward_graph, 'prim::FusionGroup', 1, consider_subgraphs=True)
+        self.assertGraphContainsExactly(
+            forward_graph, 'prim::FusionGroup', 1, consider_subgraphs=True)
         self.assertExpectedGraph(forward_graph, subname='forward')
 
         hy, cy = module(*inputs)
index 2149936659a35ab457eb1c684b46f51c7cd15613..99b01aca4233e27f2466ca0abda52f769a332ecb 100644 (file)
@@ -28,16 +28,13 @@ static c10::optional<std::vector<int64_t>> getMapSize(
 , at::TensorList args
 , at::IntList arg_subset) {
 
-  int64_t dim_after_broadcast = 0;
-  for (const auto arg_idx : arg_subset) {
-    dim_after_broadcast = std::max(dim_after_broadcast, args[arg_idx].dim());
-  }
   // TODO: this keeps reallocating map_size at every iteration, but we know
   // exactly how much storage do we need, so this could be fixed in-place at
   // every step. We're just missing a few functions for ATen, but the fix
   // should be straightforward.
   // Note: left unitialized since empty shape is broadcastable to any shape
   std::vector<int64_t> map_size;
+  map_size.reserve(8);
   for (const auto arg_idx : arg_subset) {
     auto& arg = args.at(arg_idx);
     auto& chunk_desc = spec.inputChunks().at(arg_idx);
index 5953aac77bf67ceaaef40104cbc79e80d9a88f42..a47fdccff0cf057a2ef6a6ffe62082b8d0deeb59 100644 (file)
@@ -279,6 +279,8 @@ void AliasDb::analyze(Node* node) {
     case prim::MMTreeReduce:
     case prim::MMBatchSide:
     case prim::None:
+    case prim::BroadcastSizes:
+    case prim::ChunkSizes:
       return analyzeCreator(node);
     case prim::TupleUnpack:
     case prim::TupleIndex:
index 85717a553c7dc62cc0595f0f4fad9a260c7868ca..ab440f2084610ddd29a772bee76022e4249d050e 100644 (file)
@@ -2,6 +2,7 @@
 
 #include "torch/csrc/jit/passes/alias_analysis.h"
 #include "torch/csrc/jit/passes/common_subexpression_elimination.h"
+#include "torch/csrc/jit/passes/dead_code_elimination.h"
 #include "torch/csrc/jit/symbolic_variable.h"
 #include "torch/csrc/jit/fuser/interface.h"
 #include "torch/csrc/jit/operator.h"
@@ -112,6 +113,14 @@ bool isSimpleMap(Node *node) {
   return true;
 }
 
+Value * broadcastSizes(at::ArrayRef<Value*> sizes) {
+  JIT_ASSERT(!sizes.empty());
+  Graph * graph = sizes[0]->owningGraph();
+  Node * broadcast_n = graph->insertNode(graph->create(prim::BroadcastSizes, sizes));
+  broadcast_n->output()->setType(ListType::ofInts());
+  return broadcast_n->output();
+}
+
 struct GraphFuser {
   Block * block_;
   std::shared_ptr<Graph> graph_;
@@ -159,11 +168,15 @@ struct GraphFuser {
     return isFusableCatNode(node) || node->kind() == prim::FusedConcat;
   }
 
-  bool allUsersAreThisConsumer(Node * consumer, Value * producer) {
+  bool calculatesSize(Node * node) {
+    return node->matches("aten::size(Tensor self) -> int[]");
+  }
+
+  bool allUsersAreThisConsumerOrCalcSizes(Node * consumer, Value * producer) {
     auto defining_node = producer->node();
     for(auto o : defining_node->outputs()) {
       for(auto u : o->uses()) {
-        if(u.user != consumer)
+        if(u.user != consumer && !calculatesSize(u.user))
           return false;
       }
     }
@@ -610,7 +623,7 @@ struct GraphFuser {
         chunk->inputs().end(),
         [&](Value * producer_for_chunk) {
           return isFusable(producer_for_chunk->node()) &&
-              allUsersAreThisConsumer(chunk, producer_for_chunk);
+              allUsersAreThisConsumerOrCalcSizes(chunk, producer_for_chunk);
         });
     if (it == chunk->inputs().end()) {
       return false;
@@ -707,6 +720,22 @@ struct GraphFuser {
     for (size_t i = 0; i < nchunks; i++) {
       bchunk->eraseOutput(nchunks * producer_index);
     }
+
+    // The output of producer_for_chunk_node could have been used in some aten::size
+    // operators, so we need to clean those up as well (we simply broadcast all its tensor inputs).
+    auto size_calc_uses = producer_for_chunk_node->output()->uses();
+    if (!size_calc_uses.empty()) {
+      auto tensor_inputs = filter(producer_for_chunk_node->inputs(),
+                                  [](Value * v) { return v->type()->isSubtypeOf(DynamicType::get()); });
+      auto tensor_sizes = fmap(tensor_inputs,
+                               [](Value * v) { return v->owningGraph()->insert(aten::size, {v}); });
+      JIT_ASSERT(!tensor_sizes.empty());
+      Value * output_size = tensor_sizes.size() == 1 ? tensor_sizes[0] : broadcastSizes(tensor_sizes);
+      for (Use u : size_calc_uses) {
+        u.user->output()->replaceAllUsesWith(output_size);
+        u.user->destroy();
+      }
+    }
     producer_for_chunk_node->destroy();
     return true;
   }
@@ -773,6 +802,98 @@ struct GraphFuser {
     }
   }
 
+  bool usedOnlyInSize(Value * v) {
+    const auto & uses = v->uses();
+    return std::all_of(uses.begin(), uses.end(),
+                       [](const Use& u) { return u.user->matches("aten::size(Tensor self) -> int[]"); });
+  }
+
+  std::unordered_map<Value*, Value*> buildShapeExpressions(Node * fusion_group) {
+    WithInsertPoint insert_guard { fusion_group->next() };
+    std::unordered_map<Value*, Value*> shape_of;
+
+    Graph * graph = fusion_group->owningGraph();
+    auto subgraph = fusion_group->g(attr::Subgraph);
+
+    auto inputs = fusion_group->inputs();
+    auto sinputs = subgraph->inputs();
+    JIT_ASSERT(inputs.size() == sinputs.size());
+    for (size_t i = 0; i < inputs.size(); ++i) {
+      shape_of[sinputs[i]] = graph->insert(aten::size, {inputs[i]});
+    }
+
+    // When we have a guarantee that an output won't be removed, because it's
+    // used in expressions that don't involve size checks, we can use its size
+    // instead of computing a long chain of broadcasts, starting from the beginning
+    // of the kernel.
+    auto outputs = fusion_group->outputs();
+    auto soutputs = subgraph->outputs();
+    JIT_ASSERT(outputs.size() == soutputs.size());
+    for (size_t i = 0; i < outputs.size(); ++i) {
+      if (usedOnlyInSize(outputs[i])) continue;
+      shape_of[soutputs[i]] = graph->insert(aten::size, {outputs[i]});
+    }
+
+    for (Node * n : subgraph->nodes()) {
+      // XXX: Use of shape_of.emplace is crucial to the output shape optimization!
+      if (n->kind() == prim::FusedConcat) {
+        // This is a bit more involved, because we have to account for the case
+        // when inputs have different shapes, but fortunately those tensors are
+        // always outputs, and so we can simply avoid replacing their queries,
+        // because it won't help us.
+        continue;
+      }
+      if (n->kind() == prim::Constant) {
+        continue;
+      }
+      if (n->kind() == prim::ConstantChunk) {
+        Node * sizes_node = graph->insertNode(graph->create(prim::ChunkSizes, shape_of.at(n->input()), 2));
+        sizes_node->i_(attr::dim, n->i(attr::dim));
+        sizes_node->i_(attr::chunks, n->i(attr::chunks));
+        Value * regular_size = sizes_node->outputs().at(0);
+        Value * last_size = sizes_node->outputs().at(1);
+        auto outputs = n->outputs();
+        for (Value * o : outputs.slice(0, outputs.size() - 1)) {
+          shape_of.emplace(o, regular_size);
+        }
+        shape_of.emplace(outputs.at(outputs.size() - 1), last_size);
+        continue;
+      }
+      auto tensor_inputs = filter(n->inputs(),
+                                  [](Value * v) { return v->type()->isSubtypeOf(DynamicType::get()); });
+      auto shapes = fmap(tensor_inputs, [&](Value * v) { return shape_of.at(v); });
+      JIT_ASSERT(!shapes.empty());
+      shape_of.emplace(n->output(), shapes.size() == 1 ? shapes[0] : broadcastSizes(shapes));
+    }
+    return shape_of;
+  }
+
+  void removeOutputsUsedOnlyInSize(Node * fusion_group) {
+    if (fusion_group->kind() != prim::FusionGroup) return;
+    auto subgraph = fusion_group->g(attr::Subgraph);
+
+    auto shape_of = buildShapeExpressions(fusion_group);
+    auto outputs = fusion_group->outputs().vec();
+    auto soutputs = subgraph->outputs().vec();
+    // XXX: Iterating in this order is not only good for performance reasons!
+    // It is also crucial for correctness (i has to reflect the current true
+    // index of outputs[i])!
+    for (int64_t i = static_cast<int64_t>(outputs.size()) - 1; i >= 0; --i) {
+      auto output = outputs[i];
+      auto soutput = soutputs[i];
+      if (usedOnlyInSize(output) && shape_of.count(soutput) > 0) {
+        auto uses = output->uses();
+        for (Use u : uses) {
+          JIT_ASSERT(u.user->matches("aten::size(Tensor self) -> int[]"));
+          u.user->output()->replaceAllUsesWith(shape_of.at(soutput));
+          u.user->destroy();
+        }
+        fusion_group->eraseOutput(i);
+        subgraph->eraseOutput(i);
+      }
+    }
+  }
+
   void run() {
     // Run the pass until no changes are made.
     // This is neccessary, because the algorithm can miss out on certain fusion
@@ -809,6 +930,12 @@ struct GraphFuser {
     for (auto it = block_->nodes().rbegin(); it != block_->nodes().rend();) {
       it = scanNodeForChunks(*it);
     }
+
+    // Remove outputs that have been added only because we need their size
+    for (Node * n : block_->nodes()) {
+      removeOutputsUsedOnlyInSize(n);
+    }
+
     for (Node * node : block_->nodes()) {
       for (Block * sub_block : node->blocks()) {
         GraphFuser(sub_block, graph_).run();
@@ -817,6 +944,56 @@ struct GraphFuser {
   }
 };
 
+void PeepholeOptimizeShapeExpressions(Block * block) {
+  auto nodes = block->nodes();
+  for (auto it = nodes.begin(); it != nodes.end(); ++it) {
+    Node * node = *it;
+    for (Block * subblock : node->blocks()) {
+      PeepholeOptimizeShapeExpressions(subblock);
+    }
+    if (node->kind() == prim::BroadcastSizes) {
+      // Remove no-op broadcasts.
+      if (node->inputs().size() == 1) {
+        node->output()->replaceAllUsesWith(node->input());
+        it.destroyCurrent();
+        continue;
+      }
+      // Deduplicate inputs, but use their unique() values to ensure
+      // this process only depends on the graph.
+      std::map<size_t, Value*> unique_to_value;
+      for (Value * input : node->inputs()) {
+        unique_to_value.emplace(input->unique(), input);
+      }
+      if (unique_to_value.size() != node->inputs().size()) {
+        std::vector<Value*> inputs;
+        for (auto & entry : unique_to_value) {
+          inputs.push_back(entry.second);
+        }
+        if (inputs.size() == 1) {
+          node->output()->replaceAllUsesWith(inputs[0]);
+        } else {
+          WithInsertPoint insert_guard { node };
+          node->output()->replaceAllUsesWith(broadcastSizes(inputs));
+        }
+        it.destroyCurrent();
+        --it; // Revisit the node with deduplicated inputs
+        continue;
+      }
+      // Remove compose simple chains of broadcasts into a single node.
+      const auto & uses = node->output()->uses();
+      if (uses.size() == 1 && uses[0].user->kind() == prim::BroadcastSizes) {
+        Node * user = uses[0].user;
+        user->removeInput(uses[0].offset);
+        // NB: we don't care about deduplication in here, as we will visit user later.
+        for (Value * i : node->inputs()) {
+          user->addInput(i);
+        }
+        it.destroyCurrent();
+      }
+    }
+  }
+}
+
 } // anonymous namespace
 
 void FuseGraph(std::shared_ptr<Graph>& graph) {
@@ -826,6 +1003,11 @@ void FuseGraph(std::shared_ptr<Graph>& graph) {
   GraphFuser(graph->block(), graph).run();
   // After FuseGraph some common subexpressions may come back
   EliminateCommonSubexpression(graph);
+  // We might have emitted a fair amount of useless shape propagating code, so
+  // remove it
+  EliminateDeadCode(graph);
+  // Improve the quality of shape propagation code that was left
+  PeepholeOptimizeShapeExpressions(graph->block());
 
   #endif
 }
index 311b36046617cdd4edd1f6f8d35b94aa7f8d0f80..7f7e761ce43a8620cf70e3afdc50561bff6a9294 100644 (file)
@@ -1012,6 +1012,8 @@ TORCH_API bool printerHasSpecialCaseFor(Symbol sym) {
     prim::AutogradAdd, // temporarily inserted by autograd
     prim::ConstantChunk, // optimization pass adds it
     prim::DifferentiableGraph, // optimization pass adds it
+    prim::BroadcastSizes, // optimization pass (fuser) adds it
+    prim::ChunkSizes, // optimization pass (fuser) adds it
     prim::Drop, // used in interpreter only
     prim::FusedConcat, // optimization pass adds it
     prim::FusionGroup, // optimization pass adds it
index 72ba0cc7c8ed4f5765ea5172577a4492fd170ae4..7ed54c07f559275ce252ba95bf2fdff30694cc84 100644 (file)
@@ -12,6 +12,7 @@
 #include "torch/csrc/variable_tensor_functions.h"
 
 #include <ATen/ExpandUtils.h>
+#include <ATen/WrapDimUtils.h>
 #include <c10/util/SmallVector.h>
 
 #include <exception>
@@ -286,6 +287,48 @@ RegisterOperators reg({
             return 0;
           };
         }),
+    Operator(
+        prim::BroadcastSizes,
+        [](const Node* node) -> Operation {
+          size_t num_inputs = node->inputs().size();
+          return [num_inputs](Stack& stack) {
+            std::vector<int64_t> size;
+            size.reserve(8);
+            for (size_t i = 0; i < num_inputs; ++i) {
+              size = at::infer_size(size, peek(stack, i, num_inputs).toIntList()->elements());
+            }
+            drop(stack, num_inputs);
+            push(stack, std::move(size));
+            return 0;
+          };
+        }),
+    Operator(
+        prim::ChunkSizes,
+        [](const Node* node) -> Operation {
+          int64_t raw_dim = node->i(attr::dim);
+          int64_t chunks = node->i(attr::chunks);
+          return [raw_dim, chunks](Stack& stack) {
+            Shared<IntList> sizes_l;
+            pop(stack, sizes_l);
+            const auto & shape = sizes_l->elements();
+            std::vector<int64_t> regular_shape = shape;
+            std::vector<int64_t> last_shape = shape;
+            int64_t dim = at::maybe_wrap_dim(raw_dim, shape.size());
+            AT_CHECK(dim < regular_shape.size(), "Dimension out of range for chunk");
+            int64_t split_size = (regular_shape[dim] + chunks - 1) / chunks;
+            regular_shape[dim] = split_size;
+            if (shape[dim] % chunks == 0) {
+              last_shape[dim] = split_size;
+            } else {
+              int64_t num_splits = std::max<int64_t>((shape[dim] + split_size - 1) / split_size, 1);
+              last_shape[dim] = split_size - (split_size * num_splits - shape[dim]);
+              JIT_ASSERT(last_shape[dim] >= 0);
+            }
+            push(stack, std::move(regular_shape));
+            push(stack, std::move(last_shape));
+            return 0;
+          };
+        }),
     Operator(
         "prim::RaiseException(str msg) -> ()",
         [](const Node* node) -> Operation {