Reenable all forward-pass fusions that worked before the AD fix (#14558)
authorAdam Paszke <adam.paszke@gmail.com>
Tue, 4 Dec 2018 23:40:41 +0000 (15:40 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Tue, 4 Dec 2018 23:43:37 +0000 (15:43 -0800)
Summary:
Dealing with so many `aten::size` calls (in particular calls on elements computed inside fusion groups) requires us to do some extra graph processing in the fuser (to compute the sizes by explicit broadcasts, instead of writing the intermediate tensors only to check their size). This restores the forward expects of LSTM and MiLSTM to a single big kernel. Unfortunately the backward is much harder, because as long as we can't prove that the reductions are unnecessary (or if we can't distribute them over the op), we will not be able to fuse them.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14558

Differential Revision: D13321748

Pulled By: zou3519

fbshipit-source-id: c04fc2f70d106d2bfb56206b5aec517a93b79d1f

aten/src/ATen/core/interned_strings.h
test/expect/TestScript.test_lstm_fusion_cuda-forward.expect
test/expect/TestScript.test_milstm_fusion_cuda-forward.expect
test/test_jit.py
torch/csrc/jit/fuser/executor.cpp
torch/csrc/jit/passes/alias_analysis.cpp
torch/csrc/jit/passes/graph_fuser.cpp
torch/csrc/jit/passes/python_print.cpp
torch/csrc/jit/register_prim_ops.cpp

index 477d392..ec0e045 100644 (file)
@@ -20,7 +20,9 @@ namespace c10 {
   _(namespaces, namespaces)        \
   _(prim, Assign)                  \
   _(prim, BroadcastingChunk)       \
+  _(prim, BroadcastSizes)          \
   _(prim, Constant)                \
+  _(prim, ChunkSizes)              \
   _(prim, None)                    \
   _(prim, Drop)                    \
   _(prim, Eval)                    \
index a10d0c3..825e802 100644 (file)
@@ -19,59 +19,68 @@ with prim::DifferentiableGraph_0 = graph(%0 : Float(*, *)
   %8 : Float(*, *) = aten::mm(%5, %7)
   %9 : Float(*, *) = aten::t(%4)
   %10 : Float(*, *) = aten::mm(%3, %9)
-  %11 : int = prim::Constant[value=1]()
-  %12 : int[] = aten::size(%8)
-  %13 : int[] = aten::size(%10)
-  %14 : int[] = aten::size(%2)
-  %15 : Float(*, *), %16 : Float(*, *) = prim::FusionGroup_0(%2, %8, %10)
-  %17 : int[] = aten::size(%16)
-  %18 : int[] = aten::size(%15)
-  %19 : int[] = aten::size(%1)
-  %20 : Tensor[] = prim::ListConstruct(%15, %1)
-  %21 : Tensor[] = aten::broadcast_tensors(%20)
-  %22 : Tensor, %23 : Tensor = prim::ListUnpack(%21)
-  %24 : int[] = aten::size(%0)
-  %hy : Float(*, *), %26 : Float(*, *), %cy : Float(*, *), %28 : Float(*, *), %29 : Float(*, *), %outgate.1 : Float(*, *), %cellgate.1 : Float(*, *), %forgetgate.1 : Float(*, *), %ingate.1 : Float(*, *) = prim::FusionGroup_1(%0, %23, %22)
-  %34 : int[] = aten::size(%ingate.1)
-  %35 : int[] = aten::size(%forgetgate.1)
-  %36 : int[] = aten::size(%cellgate.1)
-  %37 : int[] = aten::size(%outgate.1)
-  %38 : int[] = aten::size(%29)
-  %39 : int[] = aten::size(%28)
-  %40 : int[] = aten::size(%26)
-  return (%hy, %cy, %7, %9, %12, %13, %17, %14, %18, %19, %ingate.1, %forgetgate.1, %cellgate.1, %outgate.1, %35, %24, %34, %36, %38, %39, %26, %37, %40);
+  %11 : int[] = aten::size(%8)
+  %12 : int[] = aten::size(%10)
+  %13 : int[] = aten::size(%2)
+  %14 : int[] = aten::size(%1)
+  %15 : Tensor[] = prim::ListConstruct(%1, %2, %8, %10)
+  %16 : Tensor[] = aten::broadcast_tensors(%15)
+  %17 : Tensor, %18 : Tensor, %19 : Tensor, %20 : Tensor = prim::ListUnpack(%16)
+  %21 : int[] = prim::BroadcastSizes(%11, %12)
+  %22 : int[] = prim::BroadcastSizes(%21, %13)
+  %23 : int[] = aten::size(%0)
+  %hy : Float(*, *), %25 : Float(*, *), %cy : Float(*, *), %outgate.1 : Float(*, *), %cellgate.1 : Float(*, *), %forgetgate.1 : Float(*, *), %ingate.1 : Float(*, *) = prim::FusionGroup_0(%0, %20, %19, %18, %17)
+  %31 : int[] = aten::size(%25)
+  %32 : int[] = aten::size(%outgate.1)
+  %33 : int[] = aten::size(%cellgate.1)
+  %34 : int[] = aten::size(%forgetgate.1)
+  %35 : int[] = aten::size(%ingate.1)
+  %36 : int[] = prim::BroadcastSizes(%34, %23)
+  %37 : int[] = prim::BroadcastSizes(%35, %33)
+  return (%hy, %cy, %7, %9, %11, %12, %21, %13, %22, %14, %ingate.1, %forgetgate.1, %cellgate.1, %outgate.1, %34, %23, %35, %33, %36, %37, %25, %32, %31);
 }
-with prim::FusionGroup_0 = graph(%0 : Float(*)
-      %1 : Float(*, *)
-      %2 : Float(*, *)) {
-  %3 : int = prim::Constant[value=1]()
-  %4 : Float(*, *) = aten::add(%1, %2, %3)
-  %5 : int = prim::Constant[value=1]()
-  %6 : Float(*, *) = aten::add(%4, %0, %5)
-  return (%6, %4);
-}
-with prim::FusionGroup_1 = graph(%0 : Float(*, *)
+with prim::FusionGroup_0 = graph(%0 : Float(*, *)
       %1 : Tensor
-      %2 : Tensor) {
-  %3 : Float(*, *), %4 : Float(*, *), %5 : Float(*, *), %6 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%2)
-  %7 : Float(*, *), %8 : Float(*, *), %9 : Float(*, *), %10 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%1)
-  %11 : int = prim::Constant[value=1]()
-  %12 : Float(*, *) = aten::add(%3, %7, %11)
-  %13 : int = prim::Constant[value=1]()
-  %14 : Float(*, *) = aten::add(%4, %8, %13)
-  %15 : int = prim::Constant[value=1]()
-  %16 : Float(*, *) = aten::add(%5, %9, %15)
-  %17 : int = prim::Constant[value=1]()
-  %18 : Float(*, *) = aten::add(%6, %10, %17)
-  %ingate.1 : Float(*, *) = aten::sigmoid(%12)
-  %forgetgate.1 : Float(*, *) = aten::sigmoid(%14)
-  %cellgate.1 : Float(*, *) = aten::tanh(%16)
-  %outgate.1 : Float(*, *) = aten::sigmoid(%18)
-  %23 : Float(*, *) = aten::mul(%forgetgate.1, %0)
-  %24 : Float(*, *) = aten::mul(%ingate.1, %cellgate.1)
+      %2 : Tensor
+      %3 : Tensor
+      %4 : Tensor) {
+  %5 : Float(*, *), %6 : Float(*, *), %7 : Float(*, *), %8 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%4)
+  %9 : Float(*, *), %10 : Float(*, *), %11 : Float(*, *), %12 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%3)
+  %13 : Float(*, *), %14 : Float(*, *), %15 : Float(*, *), %16 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%2)
+  %17 : Float(*, *), %18 : Float(*, *), %19 : Float(*, *), %20 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%1)
+  %21 : int = prim::Constant[value=1]()
+  %22 : Float(*, *) = aten::add(%13, %17, %21)
+  %23 : int = prim::Constant[value=1]()
+  %24 : Float(*, *) = aten::add(%14, %18, %23)
   %25 : int = prim::Constant[value=1]()
-  %cy : Float(*, *) = aten::add(%23, %24, %25)
-  %27 : Float(*, *) = aten::tanh(%cy)
-  %hy : Float(*, *) = aten::mul(%outgate.1, %27)
-  return (%hy, %27, %cy, %24, %23, %outgate.1, %cellgate.1, %forgetgate.1, %ingate.1);
+  %26 : Float(*, *) = aten::add(%15, %19, %25)
+  %27 : int = prim::Constant[value=1]()
+  %28 : Float(*, *) = aten::add(%16, %20, %27)
+  %29 : int = prim::Constant[value=1]()
+  %30 : Float(*, *) = aten::add(%22, %9, %29)
+  %31 : int = prim::Constant[value=1]()
+  %32 : Float(*, *) = aten::add(%24, %10, %31)
+  %33 : int = prim::Constant[value=1]()
+  %34 : Float(*, *) = aten::add(%26, %11, %33)
+  %35 : int = prim::Constant[value=1]()
+  %36 : Float(*, *) = aten::add(%28, %12, %35)
+  %37 : int = prim::Constant[value=1]()
+  %38 : Float(*, *) = aten::add(%30, %5, %37)
+  %39 : int = prim::Constant[value=1]()
+  %40 : Float(*, *) = aten::add(%32, %6, %39)
+  %41 : int = prim::Constant[value=1]()
+  %42 : Float(*, *) = aten::add(%34, %7, %41)
+  %43 : int = prim::Constant[value=1]()
+  %44 : Float(*, *) = aten::add(%36, %8, %43)
+  %ingate.1 : Float(*, *) = aten::sigmoid(%38)
+  %forgetgate.1 : Float(*, *) = aten::sigmoid(%40)
+  %cellgate.1 : Float(*, *) = aten::tanh(%42)
+  %outgate.1 : Float(*, *) = aten::sigmoid(%44)
+  %49 : Float(*, *) = aten::mul(%forgetgate.1, %0)
+  %50 : Float(*, *) = aten::mul(%ingate.1, %cellgate.1)
+  %51 : int = prim::Constant[value=1]()
+  %cy : Float(*, *) = aten::add(%49, %50, %51)
+  %53 : Float(*, *) = aten::tanh(%cy)
+  %hy : Float(*, *) = aten::mul(%outgate.1, %53)
+  return (%hy, %53, %cy, %outgate.1, %cellgate.1, %forgetgate.1, %ingate.1);
 }
index 2e2726c..1c4e483 100644 (file)
@@ -25,69 +25,89 @@ with prim::DifferentiableGraph_0 = graph(%0 : Float(*, *)
   %Uz.1 : Float(*, *) = aten::mm(%5, %11)
   %13 : int[] = aten::size(%4)
   %14 : int[] = aten::size(%Wx.1)
-  %15 : int[] = aten::size(%Uz.1)
-  %16 : int[] = aten::size(%3)
-  %17 : int = prim::Constant[value=1]()
-  %18 : int[] = aten::size(%2)
-  %19 : Float(*, *), %20 : Float(*, *), %21 : Float(*, *), %22 : Float(*, *), %23 : Float(*, *), %24 : Float(*, *) = prim::FusionGroup_0(%2, %Uz.1, %3, %Wx.1, %4)
-  %25 : int[] = aten::size(%24)
-  %26 : int[] = aten::size(%23)
-  %27 : int[] = aten::size(%22)
-  %28 : int[] = aten::size(%21)
-  %29 : int[] = aten::size(%20)
-  %30 : int[] = aten::size(%19)
-  %31 : int[] = aten::size(%1)
-  %32 : Tensor[] = prim::ListConstruct(%19, %1)
-  %33 : Tensor[] = aten::broadcast_tensors(%32)
-  %34 : Tensor, %35 : Tensor = prim::ListUnpack(%33)
-  %36 : int[] = aten::size(%0)
-  %hy : Float(*, *), %38 : Float(*, *), %cy : Float(*, *), %40 : Float(*, *), %41 : Float(*, *), %outgate.1 : Float(*, *), %cellgate.1 : Float(*, *), %forgetgate.1 : Float(*, *), %ingate.1 : Float(*, *) = prim::FusionGroup_1(%0, %35, %34)
+  %15 : Float(*, *) = aten::mul(%4, %Wx.1)
+  %16 : int[] = aten::size(%15)
+  %17 : int[] = aten::size(%Uz.1)
+  %18 : int[] = aten::size(%3)
+  %19 : int[] = aten::size(%2)
+  %20 : int[] = aten::size(%1)
+  %21 : Tensor[] = prim::ListConstruct(%1, %2, %Uz.1, %15, %3, %Wx.1)
+  %22 : Tensor[] = aten::broadcast_tensors(%21)
+  %23 : Tensor, %24 : Tensor, %25 : Tensor, %26 : Tensor, %27 : Tensor, %28 : Tensor = prim::ListUnpack(%22)
+  %29 : int[] = prim::BroadcastSizes(%18, %14)
+  %30 : int[] = prim::BroadcastSizes(%16, %17)
+  %31 : int[] = prim::BroadcastSizes(%19, %17)
+  %32 : int[] = prim::BroadcastSizes(%30, %29)
+  %33 : int[] = prim::BroadcastSizes(%32, %31)
+  %34 : int[] = aten::size(%0)
+  %hy : Float(*, *), %36 : Float(*, *), %cy : Float(*, *), %outgate.1 : Float(*, *), %cellgate.1 : Float(*, *), %forgetgate.1 : Float(*, *), %ingate.1 : Float(*, *) = prim::FusionGroup_0(%0, %28, %27, %26, %25, %24, %23)
+  %42 : int[] = aten::size(%36)
+  %43 : int[] = aten::size(%outgate.1)
+  %44 : int[] = aten::size(%cellgate.1)
+  %45 : int[] = aten::size(%forgetgate.1)
   %46 : int[] = aten::size(%ingate.1)
-  %47 : int[] = aten::size(%forgetgate.1)
-  %48 : int[] = aten::size(%cellgate.1)
-  %49 : int[] = aten::size(%outgate.1)
-  %50 : int[] = aten::size(%41)
-  %51 : int[] = aten::size(%40)
-  %52 : int[] = aten::size(%38)
-  return (%hy, %cy, %9, %Wx.1, %11, %Uz.1, %13, %14, %24, %25, %15, %16, %14, %26, %27, %18, %15, %28, %29, %30, %31, %ingate.1, %forgetgate.1, %cellgate.1, %outgate.1, %47, %36, %46, %48, %50, %51, %38, %49, %52);
+  %47 : int[] = prim::BroadcastSizes(%45, %34)
+  %48 : int[] = prim::BroadcastSizes(%46, %44)
+  return (%hy, %cy, %9, %Wx.1, %11, %Uz.1, %13, %14, %15, %16, %17, %18, %14, %30, %29, %19, %17, %32, %31, %33, %20, %ingate.1, %forgetgate.1, %cellgate.1, %outgate.1, %45, %34, %46, %44, %47, %48, %36, %43, %42);
 }
-with prim::FusionGroup_0 = graph(%0 : Float(*)
-      %1 : Float(*, *)
-      %2 : Float(*)
-      %3 : Float(*, *)
-      %4 : Float(*)) {
-  %5 : Float(*, *) = aten::mul(%4, %3)
-  %6 : Float(*, *) = aten::mul(%5, %1)
-  %7 : Float(*, *) = aten::mul(%2, %3)
-  %8 : int = prim::Constant[value=1]()
-  %9 : Float(*, *) = aten::add(%6, %7, %8)
-  %10 : Float(*, *) = aten::mul(%0, %1)
-  %11 : int = prim::Constant[value=1]()
-  %12 : Float(*, *) = aten::add(%9, %10, %11)
-  return (%12, %10, %9, %7, %6, %5);
-}
-with prim::FusionGroup_1 = graph(%0 : Float(*, *)
+with prim::FusionGroup_0 = graph(%0 : Float(*, *)
       %1 : Tensor
-      %2 : Tensor) {
-  %3 : Float(*, *), %4 : Float(*, *), %5 : Float(*, *), %6 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%2)
-  %7 : Float(*, *), %8 : Float(*, *), %9 : Float(*, *), %10 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%1)
-  %11 : int = prim::Constant[value=1]()
-  %12 : Float(*, *) = aten::add(%3, %7, %11)
-  %13 : int = prim::Constant[value=1]()
-  %14 : Float(*, *) = aten::add(%4, %8, %13)
-  %15 : int = prim::Constant[value=1]()
-  %16 : Float(*, *) = aten::add(%5, %9, %15)
-  %17 : int = prim::Constant[value=1]()
-  %18 : Float(*, *) = aten::add(%6, %10, %17)
-  %ingate.1 : Float(*, *) = aten::sigmoid(%12)
-  %forgetgate.1 : Float(*, *) = aten::sigmoid(%14)
-  %cellgate.1 : Float(*, *) = aten::tanh(%16)
-  %outgate.1 : Float(*, *) = aten::sigmoid(%18)
-  %23 : Float(*, *) = aten::mul(%forgetgate.1, %0)
-  %24 : Float(*, *) = aten::mul(%ingate.1, %cellgate.1)
-  %25 : int = prim::Constant[value=1]()
-  %cy : Float(*, *) = aten::add(%23, %24, %25)
-  %27 : Float(*, *) = aten::tanh(%cy)
-  %hy : Float(*, *) = aten::mul(%outgate.1, %27)
-  return (%hy, %27, %cy, %24, %23, %outgate.1, %cellgate.1, %forgetgate.1, %ingate.1);
+      %2 : Tensor
+      %3 : Tensor
+      %4 : Tensor
+      %5 : Tensor
+      %6 : Tensor) {
+  %7 : Float(*, *), %8 : Float(*, *), %9 : Float(*, *), %10 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%6)
+  %11 : Float(*, *), %12 : Float(*, *), %13 : Float(*, *), %14 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%5)
+  %15 : Float(*, *), %16 : Float(*, *), %17 : Float(*, *), %18 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%4)
+  %19 : Float(*, *), %20 : Float(*, *), %21 : Float(*, *), %22 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%3)
+  %23 : Float(*, *), %24 : Float(*, *), %25 : Float(*, *), %26 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%2)
+  %27 : Float(*, *), %28 : Float(*, *), %29 : Float(*, *), %30 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%1)
+  %31 : Float(*, *) = aten::mul(%23, %27)
+  %32 : Float(*, *) = aten::mul(%24, %28)
+  %33 : Float(*, *) = aten::mul(%25, %29)
+  %34 : Float(*, *) = aten::mul(%26, %30)
+  %35 : Float(*, *) = aten::mul(%19, %15)
+  %36 : Float(*, *) = aten::mul(%20, %16)
+  %37 : Float(*, *) = aten::mul(%21, %17)
+  %38 : Float(*, *) = aten::mul(%22, %18)
+  %39 : Float(*, *) = aten::mul(%11, %15)
+  %40 : Float(*, *) = aten::mul(%12, %16)
+  %41 : Float(*, *) = aten::mul(%13, %17)
+  %42 : Float(*, *) = aten::mul(%14, %18)
+  %43 : int = prim::Constant[value=1]()
+  %44 : Float(*, *) = aten::add(%35, %31, %43)
+  %45 : int = prim::Constant[value=1]()
+  %46 : Float(*, *) = aten::add(%36, %32, %45)
+  %47 : int = prim::Constant[value=1]()
+  %48 : Float(*, *) = aten::add(%37, %33, %47)
+  %49 : int = prim::Constant[value=1]()
+  %50 : Float(*, *) = aten::add(%38, %34, %49)
+  %51 : int = prim::Constant[value=1]()
+  %52 : Float(*, *) = aten::add(%44, %39, %51)
+  %53 : int = prim::Constant[value=1]()
+  %54 : Float(*, *) = aten::add(%46, %40, %53)
+  %55 : int = prim::Constant[value=1]()
+  %56 : Float(*, *) = aten::add(%48, %41, %55)
+  %57 : int = prim::Constant[value=1]()
+  %58 : Float(*, *) = aten::add(%50, %42, %57)
+  %59 : int = prim::Constant[value=1]()
+  %60 : Float(*, *) = aten::add(%52, %7, %59)
+  %61 : int = prim::Constant[value=1]()
+  %62 : Float(*, *) = aten::add(%54, %8, %61)
+  %63 : int = prim::Constant[value=1]()
+  %64 : Float(*, *) = aten::add(%56, %9, %63)
+  %65 : int = prim::Constant[value=1]()
+  %66 : Float(*, *) = aten::add(%58, %10, %65)
+  %ingate.1 : Float(*, *) = aten::sigmoid(%60)
+  %forgetgate.1 : Float(*, *) = aten::sigmoid(%62)
+  %cellgate.1 : Float(*, *) = aten::tanh(%64)
+  %outgate.1 : Float(*, *) = aten::sigmoid(%66)
+  %71 : Float(*, *) = aten::mul(%forgetgate.1, %0)
+  %72 : Float(*, *) = aten::mul(%ingate.1, %cellgate.1)
+  %73 : int = prim::Constant[value=1]()
+  %cy : Float(*, *) = aten::add(%71, %72, %73)
+  %75 : Float(*, *) = aten::tanh(%cy)
+  %hy : Float(*, *) = aten::mul(%outgate.1, %75)
+  return (%hy, %75, %cy, %outgate.1, %cellgate.1, %forgetgate.1, %ingate.1);
 }
index 063ac01..a71e31f 100644 (file)
@@ -4308,9 +4308,8 @@ a")
         inputs = get_lstm_inputs('cuda', training=True)
         module = self.checkScript(LSTMCellS, inputs)
         forward_graph = module.graph_for(*inputs)
-        with self.assertRaises(AssertionError):
-            self.assertGraphContainsExactly(
-                forward_graph, 'prim::FusionGroup', 1, consider_subgraphs=True)
+        self.assertGraphContainsExactly(
+            forward_graph, 'prim::FusionGroup', 1, consider_subgraphs=True)
         self.assertExpectedGraph(forward_graph, subname='forward')
 
         hy, cy = module(*inputs)
@@ -4324,9 +4323,8 @@ a")
         inputs = get_milstm_inputs('cuda', training=True)
         module = self.checkScript(MiLSTMCell, inputs)
         forward_graph = module.graph_for(*inputs)
-        with self.assertRaises(AssertionError):
-            self.assertGraphContainsExactly(
-                forward_graph, 'prim::FusionGroup', 1, consider_subgraphs=True)
+        self.assertGraphContainsExactly(
+            forward_graph, 'prim::FusionGroup', 1, consider_subgraphs=True)
         self.assertExpectedGraph(forward_graph, subname='forward')
 
         hy, cy = module(*inputs)
index 2149936..99b01ac 100644 (file)
@@ -28,16 +28,13 @@ static c10::optional<std::vector<int64_t>> getMapSize(
 , at::TensorList args
 , at::IntList arg_subset) {
 
-  int64_t dim_after_broadcast = 0;
-  for (const auto arg_idx : arg_subset) {
-    dim_after_broadcast = std::max(dim_after_broadcast, args[arg_idx].dim());
-  }
   // TODO: this keeps reallocating map_size at every iteration, but we know
   // exactly how much storage do we need, so this could be fixed in-place at
   // every step. We're just missing a few functions for ATen, but the fix
   // should be straightforward.
   // Note: left unitialized since empty shape is broadcastable to any shape
   std::vector<int64_t> map_size;
+  map_size.reserve(8);
   for (const auto arg_idx : arg_subset) {
     auto& arg = args.at(arg_idx);
     auto& chunk_desc = spec.inputChunks().at(arg_idx);
index 5953aac..a47fdcc 100644 (file)
@@ -279,6 +279,8 @@ void AliasDb::analyze(Node* node) {
     case prim::MMTreeReduce:
     case prim::MMBatchSide:
     case prim::None:
+    case prim::BroadcastSizes:
+    case prim::ChunkSizes:
       return analyzeCreator(node);
     case prim::TupleUnpack:
     case prim::TupleIndex:
index 85717a5..ab440f2 100644 (file)
@@ -2,6 +2,7 @@
 
 #include "torch/csrc/jit/passes/alias_analysis.h"
 #include "torch/csrc/jit/passes/common_subexpression_elimination.h"
+#include "torch/csrc/jit/passes/dead_code_elimination.h"
 #include "torch/csrc/jit/symbolic_variable.h"
 #include "torch/csrc/jit/fuser/interface.h"
 #include "torch/csrc/jit/operator.h"
@@ -112,6 +113,14 @@ bool isSimpleMap(Node *node) {
   return true;
 }
 
+Value * broadcastSizes(at::ArrayRef<Value*> sizes) {
+  JIT_ASSERT(!sizes.empty());
+  Graph * graph = sizes[0]->owningGraph();
+  Node * broadcast_n = graph->insertNode(graph->create(prim::BroadcastSizes, sizes));
+  broadcast_n->output()->setType(ListType::ofInts());
+  return broadcast_n->output();
+}
+
 struct GraphFuser {
   Block * block_;
   std::shared_ptr<Graph> graph_;
@@ -159,11 +168,15 @@ struct GraphFuser {
     return isFusableCatNode(node) || node->kind() == prim::FusedConcat;
   }
 
-  bool allUsersAreThisConsumer(Node * consumer, Value * producer) {
+  bool calculatesSize(Node * node) {
+    return node->matches("aten::size(Tensor self) -> int[]");
+  }
+
+  bool allUsersAreThisConsumerOrCalcSizes(Node * consumer, Value * producer) {
     auto defining_node = producer->node();
     for(auto o : defining_node->outputs()) {
       for(auto u : o->uses()) {
-        if(u.user != consumer)
+        if(u.user != consumer && !calculatesSize(u.user))
           return false;
       }
     }
@@ -610,7 +623,7 @@ struct GraphFuser {
         chunk->inputs().end(),
         [&](Value * producer_for_chunk) {
           return isFusable(producer_for_chunk->node()) &&
-              allUsersAreThisConsumer(chunk, producer_for_chunk);
+              allUsersAreThisConsumerOrCalcSizes(chunk, producer_for_chunk);
         });
     if (it == chunk->inputs().end()) {
       return false;
@@ -707,6 +720,22 @@ struct GraphFuser {
     for (size_t i = 0; i < nchunks; i++) {
       bchunk->eraseOutput(nchunks * producer_index);
     }
+
+    // The output of producer_for_chunk_node could have been used in some aten::size
+    // operators, so we need to clean those up as well (we simply broadcast all its tensor inputs).
+    auto size_calc_uses = producer_for_chunk_node->output()->uses();
+    if (!size_calc_uses.empty()) {
+      auto tensor_inputs = filter(producer_for_chunk_node->inputs(),
+                                  [](Value * v) { return v->type()->isSubtypeOf(DynamicType::get()); });
+      auto tensor_sizes = fmap(tensor_inputs,
+                               [](Value * v) { return v->owningGraph()->insert(aten::size, {v}); });
+      JIT_ASSERT(!tensor_sizes.empty());
+      Value * output_size = tensor_sizes.size() == 1 ? tensor_sizes[0] : broadcastSizes(tensor_sizes);
+      for (Use u : size_calc_uses) {
+        u.user->output()->replaceAllUsesWith(output_size);
+        u.user->destroy();
+      }
+    }
     producer_for_chunk_node->destroy();
     return true;
   }
@@ -773,6 +802,98 @@ struct GraphFuser {
     }
   }
 
+  bool usedOnlyInSize(Value * v) {
+    const auto & uses = v->uses();
+    return std::all_of(uses.begin(), uses.end(),
+                       [](const Use& u) { return u.user->matches("aten::size(Tensor self) -> int[]"); });
+  }
+
+  std::unordered_map<Value*, Value*> buildShapeExpressions(Node * fusion_group) {
+    WithInsertPoint insert_guard { fusion_group->next() };
+    std::unordered_map<Value*, Value*> shape_of;
+
+    Graph * graph = fusion_group->owningGraph();
+    auto subgraph = fusion_group->g(attr::Subgraph);
+
+    auto inputs = fusion_group->inputs();
+    auto sinputs = subgraph->inputs();
+    JIT_ASSERT(inputs.size() == sinputs.size());
+    for (size_t i = 0; i < inputs.size(); ++i) {
+      shape_of[sinputs[i]] = graph->insert(aten::size, {inputs[i]});
+    }
+
+    // When we have a guarantee that an output won't be removed, because it's
+    // used in expressions that don't involve size checks, we can use its size
+    // instead of computing a long chain of broadcasts, starting from the beginning
+    // of the kernel.
+    auto outputs = fusion_group->outputs();
+    auto soutputs = subgraph->outputs();
+    JIT_ASSERT(outputs.size() == soutputs.size());
+    for (size_t i = 0; i < outputs.size(); ++i) {
+      if (usedOnlyInSize(outputs[i])) continue;
+      shape_of[soutputs[i]] = graph->insert(aten::size, {outputs[i]});
+    }
+
+    for (Node * n : subgraph->nodes()) {
+      // XXX: Use of shape_of.emplace is crucial to the output shape optimization!
+      if (n->kind() == prim::FusedConcat) {
+        // This is a bit more involved, because we have to account for the case
+        // when inputs have different shapes, but fortunately those tensors are
+        // always outputs, and so we can simply avoid replacing their queries,
+        // because it won't help us.
+        continue;
+      }
+      if (n->kind() == prim::Constant) {
+        continue;
+      }
+      if (n->kind() == prim::ConstantChunk) {
+        Node * sizes_node = graph->insertNode(graph->create(prim::ChunkSizes, shape_of.at(n->input()), 2));
+        sizes_node->i_(attr::dim, n->i(attr::dim));
+        sizes_node->i_(attr::chunks, n->i(attr::chunks));
+        Value * regular_size = sizes_node->outputs().at(0);
+        Value * last_size = sizes_node->outputs().at(1);
+        auto outputs = n->outputs();
+        for (Value * o : outputs.slice(0, outputs.size() - 1)) {
+          shape_of.emplace(o, regular_size);
+        }
+        shape_of.emplace(outputs.at(outputs.size() - 1), last_size);
+        continue;
+      }
+      auto tensor_inputs = filter(n->inputs(),
+                                  [](Value * v) { return v->type()->isSubtypeOf(DynamicType::get()); });
+      auto shapes = fmap(tensor_inputs, [&](Value * v) { return shape_of.at(v); });
+      JIT_ASSERT(!shapes.empty());
+      shape_of.emplace(n->output(), shapes.size() == 1 ? shapes[0] : broadcastSizes(shapes));
+    }
+    return shape_of;
+  }
+
+  void removeOutputsUsedOnlyInSize(Node * fusion_group) {
+    if (fusion_group->kind() != prim::FusionGroup) return;
+    auto subgraph = fusion_group->g(attr::Subgraph);
+
+    auto shape_of = buildShapeExpressions(fusion_group);
+    auto outputs = fusion_group->outputs().vec();
+    auto soutputs = subgraph->outputs().vec();
+    // XXX: Iterating in this order is not only good for performance reasons!
+    // It is also crucial for correctness (i has to reflect the current true
+    // index of outputs[i])!
+    for (int64_t i = static_cast<int64_t>(outputs.size()) - 1; i >= 0; --i) {
+      auto output = outputs[i];
+      auto soutput = soutputs[i];
+      if (usedOnlyInSize(output) && shape_of.count(soutput) > 0) {
+        auto uses = output->uses();
+        for (Use u : uses) {
+          JIT_ASSERT(u.user->matches("aten::size(Tensor self) -> int[]"));
+          u.user->output()->replaceAllUsesWith(shape_of.at(soutput));
+          u.user->destroy();
+        }
+        fusion_group->eraseOutput(i);
+        subgraph->eraseOutput(i);
+      }
+    }
+  }
+
   void run() {
     // Run the pass until no changes are made.
     // This is neccessary, because the algorithm can miss out on certain fusion
@@ -809,6 +930,12 @@ struct GraphFuser {
     for (auto it = block_->nodes().rbegin(); it != block_->nodes().rend();) {
       it = scanNodeForChunks(*it);
     }
+
+    // Remove outputs that have been added only because we need their size
+    for (Node * n : block_->nodes()) {
+      removeOutputsUsedOnlyInSize(n);
+    }
+
     for (Node * node : block_->nodes()) {
       for (Block * sub_block : node->blocks()) {
         GraphFuser(sub_block, graph_).run();
@@ -817,6 +944,56 @@ struct GraphFuser {
   }
 };
 
+void PeepholeOptimizeShapeExpressions(Block * block) {
+  auto nodes = block->nodes();
+  for (auto it = nodes.begin(); it != nodes.end(); ++it) {
+    Node * node = *it;
+    for (Block * subblock : node->blocks()) {
+      PeepholeOptimizeShapeExpressions(subblock);
+    }
+    if (node->kind() == prim::BroadcastSizes) {
+      // Remove no-op broadcasts.
+      if (node->inputs().size() == 1) {
+        node->output()->replaceAllUsesWith(node->input());
+        it.destroyCurrent();
+        continue;
+      }
+      // Deduplicate inputs, but use their unique() values to ensure
+      // this process only depends on the graph.
+      std::map<size_t, Value*> unique_to_value;
+      for (Value * input : node->inputs()) {
+        unique_to_value.emplace(input->unique(), input);
+      }
+      if (unique_to_value.size() != node->inputs().size()) {
+        std::vector<Value*> inputs;
+        for (auto & entry : unique_to_value) {
+          inputs.push_back(entry.second);
+        }
+        if (inputs.size() == 1) {
+          node->output()->replaceAllUsesWith(inputs[0]);
+        } else {
+          WithInsertPoint insert_guard { node };
+          node->output()->replaceAllUsesWith(broadcastSizes(inputs));
+        }
+        it.destroyCurrent();
+        --it; // Revisit the node with deduplicated inputs
+        continue;
+      }
+      // Remove compose simple chains of broadcasts into a single node.
+      const auto & uses = node->output()->uses();
+      if (uses.size() == 1 && uses[0].user->kind() == prim::BroadcastSizes) {
+        Node * user = uses[0].user;
+        user->removeInput(uses[0].offset);
+        // NB: we don't care about deduplication in here, as we will visit user later.
+        for (Value * i : node->inputs()) {
+          user->addInput(i);
+        }
+        it.destroyCurrent();
+      }
+    }
+  }
+}
+
 } // anonymous namespace
 
 void FuseGraph(std::shared_ptr<Graph>& graph) {
@@ -826,6 +1003,11 @@ void FuseGraph(std::shared_ptr<Graph>& graph) {
   GraphFuser(graph->block(), graph).run();
   // After FuseGraph some common subexpressions may come back
   EliminateCommonSubexpression(graph);
+  // We might have emitted a fair amount of useless shape propagating code, so
+  // remove it
+  EliminateDeadCode(graph);
+  // Improve the quality of shape propagation code that was left
+  PeepholeOptimizeShapeExpressions(graph->block());
 
   #endif
 }
index 311b360..7f7e761 100644 (file)
@@ -1012,6 +1012,8 @@ TORCH_API bool printerHasSpecialCaseFor(Symbol sym) {
     prim::AutogradAdd, // temporarily inserted by autograd
     prim::ConstantChunk, // optimization pass adds it
     prim::DifferentiableGraph, // optimization pass adds it
+    prim::BroadcastSizes, // optimization pass (fuser) adds it
+    prim::ChunkSizes, // optimization pass (fuser) adds it
     prim::Drop, // used in interpreter only
     prim::FusedConcat, // optimization pass adds it
     prim::FusionGroup, // optimization pass adds it
index 72ba0cc..7ed54c0 100644 (file)
@@ -12,6 +12,7 @@
 #include "torch/csrc/variable_tensor_functions.h"
 
 #include <ATen/ExpandUtils.h>
+#include <ATen/WrapDimUtils.h>
 #include <c10/util/SmallVector.h>
 
 #include <exception>
@@ -287,6 +288,48 @@ RegisterOperators reg({
           };
         }),
     Operator(
+        prim::BroadcastSizes,
+        [](const Node* node) -> Operation {
+          size_t num_inputs = node->inputs().size();
+          return [num_inputs](Stack& stack) {
+            std::vector<int64_t> size;
+            size.reserve(8);
+            for (size_t i = 0; i < num_inputs; ++i) {
+              size = at::infer_size(size, peek(stack, i, num_inputs).toIntList()->elements());
+            }
+            drop(stack, num_inputs);
+            push(stack, std::move(size));
+            return 0;
+          };
+        }),
+    Operator(
+        prim::ChunkSizes,
+        [](const Node* node) -> Operation {
+          int64_t raw_dim = node->i(attr::dim);
+          int64_t chunks = node->i(attr::chunks);
+          return [raw_dim, chunks](Stack& stack) {
+            Shared<IntList> sizes_l;
+            pop(stack, sizes_l);
+            const auto & shape = sizes_l->elements();
+            std::vector<int64_t> regular_shape = shape;
+            std::vector<int64_t> last_shape = shape;
+            int64_t dim = at::maybe_wrap_dim(raw_dim, shape.size());
+            AT_CHECK(dim < regular_shape.size(), "Dimension out of range for chunk");
+            int64_t split_size = (regular_shape[dim] + chunks - 1) / chunks;
+            regular_shape[dim] = split_size;
+            if (shape[dim] % chunks == 0) {
+              last_shape[dim] = split_size;
+            } else {
+              int64_t num_splits = std::max<int64_t>((shape[dim] + split_size - 1) / split_size, 1);
+              last_shape[dim] = split_size - (split_size * num_splits - shape[dim]);
+              JIT_ASSERT(last_shape[dim] >= 0);
+            }
+            push(stack, std::move(regular_shape));
+            push(stack, std::move(last_shape));
+            return 0;
+          };
+        }),
+    Operator(
         "prim::RaiseException(str msg) -> ()",
         [](const Node* node) -> Operation {
           return [](Stack& stack) {