Improvements for symbolic AD (#14758)
authorAdam Paszke <adam.paszke@gmail.com>
Wed, 5 Dec 2018 04:35:51 +0000 (20:35 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 5 Dec 2018 04:38:21 +0000 (20:38 -0800)
Summary:
**Review only the last commit.**

This commit adds a few optimizations to AD, that let us dramatically
reduce the number of sizes we capture from forward.

We now:
- collapse chains of SumToSize
- avoid capturing sizes of tensors that are captured anyway
- more aggressively DCE the reverse code
- run CSE on the primal code to deduplicate `aten::size` calls

cc zou3519 zdevito
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14758

Differential Revision: D13324440

Pulled By: zou3519

fbshipit-source-id: 45ccbc13605adcef2b461840c6089d3200000c72

test/cpp/jit/tests.h
test/expect/TestJit.test_cpp_cuda.expect
test/expect/TestScript.test_lstm_fusion_cuda-backward.expect
test/expect/TestScript.test_lstm_fusion_cuda-forward.expect
test/expect/TestScript.test_milstm_fusion_cuda-backward.expect
test/expect/TestScript.test_milstm_fusion_cuda-forward.expect
torch/csrc/jit/autodiff.cpp
torch/csrc/jit/passes/dead_code_elimination.cpp
torch/csrc/jit/passes/dead_code_elimination.h
torch/csrc/jit/passes/graph_fuser.cpp
torch/csrc/jit/passes/peephole.cpp

index d728a65..d087803 100644 (file)
@@ -835,8 +835,8 @@ void testDifferentiate(std::ostream& out = std::cout) {
 
   auto grad_spec = differentiate(graph);
   std::vector<size_t> expected_captured_inputs = {0, 1};
-  std::vector<size_t> expected_captured_outputs = {1, 2, 3, 4, 5, 6, 7};
-  std::vector<size_t> expected_input_vjps = {0, 3};
+  std::vector<size_t> expected_captured_outputs = {1, 2};
+  std::vector<size_t> expected_input_vjps = {0, 1};
   std::vector<size_t> expected_output_vjps = {0, 1};
   ASSERT_EQ(grad_spec.f_real_outputs, 1);
   ASSERT_EQ(grad_spec.df_input_captured_inputs, expected_captured_inputs);
@@ -866,11 +866,11 @@ void testDifferentiateWithRequiresGrad(std::ostream& out = std::cout) {
   PropagateRequiresGrad(graph);
 
   auto grad_spec = differentiate(graph);
-  std::vector<size_t> expected_input_vjps = {1, 4}; // for e and %4 = (d + a)
+  std::vector<size_t> expected_input_vjps = {1, 3}; // for e and %4 = (d + a)
   std::vector<size_t> expected_output_vjps = {0}; // only a requires grad
   ASSERT_EQ(grad_spec.f_real_outputs, 2);
   ASSERT_EQ(grad_spec.df_input_captured_inputs, std::vector<size_t>({0}));
-  ASSERT_EQ(grad_spec.df_input_captured_outputs, std::vector<size_t>({2, 3, 4, 5, 6, 7, 8}));
+  ASSERT_EQ(grad_spec.df_input_captured_outputs, std::vector<size_t>({2, 3, 4, 5}));
   ASSERT_EQ(grad_spec.df_input_vjps, expected_input_vjps);
   ASSERT_EQ(grad_spec.df_output_vjps, expected_output_vjps);
   out << "testDifferentiateWithRequiresGrad\n";
index ddcac4c..743572f 100644 (file)
@@ -93,57 +93,50 @@ with prim::DifferentiableGraph_0 = graph(%13 : Tensor
 testDifferentiate
 graph(%0 : Float(2, 3, 4)
       %1 : Float(2, 3, 4)) {
-  %23 : int[] = aten::size(%0)
-  %26 : int[] = aten::size(%1)
   %2 : Float(2, 3, 4) = aten::mul(%0, %1)
-  %15 : int[] = aten::size(%2)
-  %18 : int[] = aten::size(%0)
   %3 : Float(2, 3, 4) = aten::mul(%2, %0)
   %4 : int = prim::Constant[value=1]()
   %7 : int[] = aten::size(%3)
-  %10 : int[] = aten::size(%1)
   %5 : Float(2, 3, 4) = aten::add(%3, %1, %4)
-  return (%5, %23, %26, %2, %15, %18, %7, %10);
+  return (%5, %2, %7);
 }
 graph(%0 : Float(2, 3, 4)
       %1 : Float(2, 3, 4)
       %2 : Float(2, 3, 4)
       %3 : Float(2, 3, 4)
-      %4 : int[]
-      %5 : int[]
-      %6 : Float(2, 3, 4)
-      %7 : int[]
-      %8 : int[]
-      %9 : int[]
-      %10 : int[]) {
-  %11 : int = prim::Constant[value=1]()
-  %12 : Tensor, %13 : Tensor = prim::GradOf[name="aten::add"](%0)
+      %4 : Float(2, 3, 4)
+      %5 : int[]) {
+  %9 : int = prim::Constant[value=1]()
+  %6 : int[] = aten::size(%4)
+  %7 : int[] = aten::size(%3)
+  %8 : int[] = aten::size(%2)
+  %10 : Tensor, %11 : Tensor = prim::GradOf[name="aten::add"](%0)
     block0() {
-      %14 : Tensor = prim::SumToSize(%0, %9)
-      %15 : Float(2, 3, 4) = aten::mul(%0, %11)
-      %16 : Tensor = prim::SumToSize(%15, %10)
-      -> (%14, %16)
+      %12 : Tensor = prim::SumToSize(%0, %5)
+      %13 : Float(2, 3, 4) = aten::mul(%0, %9)
+      %14 : Tensor = prim::SumToSize(%13, %7)
+      -> (%12, %14)
     }
-  %17 : Tensor, %18 : Tensor = prim::GradOf[name="aten::mul"](%12)
+  %15 : Tensor, %16 : Tensor = prim::GradOf[name="aten::mul"](%10)
     block0() {
-      %19 : Tensor = aten::mul(%12, %2)
-      %20 : Tensor = prim::SumToSize(%19, %7)
-      %21 : Tensor = aten::mul(%12, %6)
-      %22 : Tensor = prim::SumToSize(%21, %8)
-      -> (%20, %22)
+      %17 : Tensor = aten::mul(%10, %2)
+      %18 : Tensor = prim::SumToSize(%17, %6)
+      %19 : Tensor = aten::mul(%10, %4)
+      %20 : Tensor = prim::SumToSize(%19, %8)
+      -> (%18, %20)
     }
-  %23 : Tensor = prim::AutogradAdd(%1, %17)
-  %24 : Tensor, %25 : Tensor = prim::GradOf[name="aten::mul"](%23)
+  %21 : Tensor = prim::AutogradAdd(%1, %15)
+  %22 : Tensor, %23 : Tensor = prim::GradOf[name="aten::mul"](%21)
     block0() {
-      %26 : Tensor = aten::mul(%23, %3)
-      %27 : Tensor = prim::SumToSize(%26, %4)
-      %28 : Tensor = aten::mul(%23, %2)
-      %29 : Tensor = prim::SumToSize(%28, %5)
-      -> (%27, %29)
+      %24 : Tensor = aten::mul(%21, %3)
+      %25 : Tensor = prim::SumToSize(%24, %8)
+      %26 : Tensor = aten::mul(%21, %2)
+      %27 : Tensor = prim::SumToSize(%26, %7)
+      -> (%25, %27)
     }
-  %30 : Tensor = prim::AutogradAdd(%18, %24)
-  %31 : Tensor = prim::AutogradAdd(%13, %25)
-  return (%30, %31);
+  %28 : Tensor = prim::AutogradAdd(%16, %22)
+  %29 : Tensor = prim::AutogradAdd(%11, %23)
+  return (%28, %29);
 }
 
 testDifferentiateWithRequiresGrad
@@ -152,54 +145,48 @@ graph(%0 : Float(*)
   %2 : Float(*) = aten::mul(%1, %1)
   %3 : int = prim::Constant[value=1]()
   %4 : Float(*) = aten::add(%2, %1, %3)
-  %5 : int = prim::Constant[value=1]()
   %26 : int[] = aten::size(%4)
-  %29 : int[] = aten::size(%0)
-  %6 : Float(*) = aten::add(%4, %0, %5)
-  %19 : int[] = aten::size(%6)
-  %22 : int[] = aten::size(%0)
+  %6 : Float(*) = aten::add(%4, %0, %3)
   %7 : Float(*) = aten::mul(%6, %0)
-  %8 : int = prim::Constant[value=1]()
   %11 : int[] = aten::size(%7)
   %14 : int[] = aten::size(%1)
-  %9 : Float(*) = aten::add(%7, %1, %8)
-  return (%4, %9, %26, %29, %6, %19, %22, %11, %14);
+  %9 : Float(*) = aten::add(%7, %1, %3)
+  return (%4, %9, %26, %6, %11, %14);
 }
 graph(%0 : Float(*)
       %1 : Float(*)
       %2 : Float(*)
       %3 : int[]
-      %4 : int[]
-      %5 : Float(*)
-      %6 : int[]
-      %7 : int[]
-      %8 : int[]
-      %9 : int[]) {
-  %10 : int = prim::Constant[value=1]()
-  %12 : Tensor, %13 : Tensor = prim::GradOf[name="aten::add"](%0)
+      %4 : Float(*)
+      %5 : int[]
+      %6 : int[]) {
+  %9 : int = prim::Constant[value=1]()
+  %7 : int[] = aten::size(%4)
+  %8 : int[] = aten::size(%2)
+  %11 : Tensor, %12 : Tensor = prim::GradOf[name="aten::add"](%0)
     block0() {
-      %14 : Tensor = prim::SumToSize(%0, %8)
-      %15 : Tensor = aten::mul(%0, %10)
-      %16 : Tensor = prim::SumToSize(%15, %9)
-      -> (%14, %16)
+      %13 : Tensor = prim::SumToSize(%0, %5)
+      %14 : Tensor = aten::mul(%0, %9)
+      %15 : Tensor = prim::SumToSize(%14, %6)
+      -> (%13, %15)
     }
-  %17 : Tensor, %18 : Tensor = prim::GradOf[name="aten::mul"](%12)
+  %16 : Tensor, %17 : Tensor = prim::GradOf[name="aten::mul"](%11)
     block0() {
-      %19 : Tensor = aten::mul(%12, %2)
-      %20 : Tensor = prim::SumToSize(%19, %6)
-      %21 : Tensor = aten::mul(%12, %5)
-      %22 : Tensor = prim::SumToSize(%21, %7)
-      -> (%20, %22)
+      %18 : Tensor = aten::mul(%11, %2)
+      %19 : Tensor = prim::SumToSize(%18, %7)
+      %20 : Tensor = aten::mul(%11, %4)
+      %21 : Tensor = prim::SumToSize(%20, %8)
+      -> (%19, %21)
     }
-  %23 : Tensor = prim::AutogradAdd(%1, %17)
-  %24 : Tensor, %25 : Tensor = prim::GradOf[name="aten::add"](%23)
+  %22 : Tensor = prim::AutogradAdd(%1, %16)
+  %23 : Tensor, %24 : Tensor = prim::GradOf[name="aten::add"](%22)
     block0() {
-      %26 : Tensor = prim::SumToSize(%23, %3)
-      %27 : Tensor = aten::mul(%23, %10)
-      %28 : Tensor = prim::SumToSize(%27, %4)
-      -> (%26, %28)
+      %25 : Tensor = prim::SumToSize(%22, %3)
+      %26 : Tensor = aten::mul(%22, %9)
+      %27 : Tensor = prim::SumToSize(%26, %8)
+      -> (%25, %27)
     }
-  %29 : Tensor = prim::AutogradAdd(%18, %25)
-  return (%29);
+  %28 : Tensor = prim::AutogradAdd(%17, %24)
+  return (%28);
 }
 
index 9ee5162..af8dedf 100644 (file)
@@ -24,46 +24,44 @@ graph(%0 : Float(*, *)
       %outgate : Float(*, *)
       %24 : int[]
       %25 : int[]
-      %26 : int[]
-      %27 : int[]
-      %28 : int[]
-      %29 : int[]
-      %30 : Float(*, *)
-      %31 : int[]
-      %32 : int[]) {
-  %33 : Float(*, *) = aten::mul(%0, %30)
-  %34 : Tensor = prim::SumToSize(%33, %31)
+      %26 : Float(*, *)) {
+  %27 : int[] = aten::size(%26)
+  %28 : int[] = aten::size(%outgate)
+  %29 : int[] = aten::size(%cellgate)
+  %30 : int[] = aten::size(%ingate)
+  %31 : int[] = aten::size(%9)
+  %32 : int[] = aten::size(%forgetgate)
+  %33 : Float(*, *) = aten::mul(%0, %26)
+  %34 : Tensor = prim::SumToSize(%33, %28)
   %35 : Float(*, *) = aten::mul(%0, %outgate)
-  %36 : Tensor = prim::SumToSize(%35, %32)
-  %37 : Tensor = prim::FusionGroup_0(%1, %36, %30)
-  %38 : Tensor = prim::SumToSize(%37, %28)
-  %39 : Tensor = prim::SumToSize(%37, %29)
+  %36 : Tensor = prim::SumToSize(%35, %27)
+  %37 : Tensor = prim::FusionGroup_0(%1, %36, %26)
+  %38 : Tensor = prim::SumToSize(%37, %24)
+  %39 : Tensor = prim::SumToSize(%37, %25)
   %40 : Tensor = aten::mul(%39, %cellgate)
-  %41 : Tensor = prim::SumToSize(%40, %26)
+  %41 : Tensor = prim::SumToSize(%40, %30)
   %42 : Tensor = aten::mul(%39, %ingate)
-  %43 : Tensor = prim::SumToSize(%42, %27)
+  %43 : Tensor = prim::SumToSize(%42, %29)
   %44 : Tensor = aten::mul(%38, %9)
-  %45 : Tensor = prim::SumToSize(%44, %24)
+  %45 : Tensor = prim::SumToSize(%44, %32)
   %46 : Tensor = aten::mul(%38, %forgetgate)
-  %47 : Tensor = prim::SumToSize(%46, %25)
+  %47 : Tensor = prim::SumToSize(%46, %31)
   %48 : Tensor = prim::FusionGroup_1(%41, %ingate, %45, %forgetgate, %43, %cellgate, %34, %outgate)
-  %49 : Tensor = prim::SumToSize(%48, %18)
-  %50 : Tensor = prim::SumToSize(%48, %19)
-  %51 : Tensor = prim::SumToSize(%49, %16)
-  %52 : Tensor = prim::SumToSize(%49, %17)
-  %53 : Tensor = prim::SumToSize(%51, %14)
-  %54 : Tensor = prim::SumToSize(%51, %15)
-  %55 : Float(*, *) = aten::t(%13)
-  %56 : Float(*, *) = aten::mm(%54, %55)
-  %57 : Float(*, *) = aten::t(%10)
-  %58 : Float(*, *) = aten::mm(%57, %54)
-  %59 : Float(*, *) = aten::t(%58)
-  %60 : Float(*, *) = aten::t(%12)
-  %61 : Float(*, *) = aten::mm(%53, %60)
-  %62 : Float(*, *) = aten::t(%11)
-  %63 : Float(*, *) = aten::mm(%62, %53)
-  %64 : Float(*, *) = aten::t(%63)
-  return (%47, %50, %52, %56, %59, %61, %64);
+  %49 : Tensor = prim::SumToSize(%48, %19)
+  %50 : Tensor = prim::SumToSize(%48, %17)
+  %51 : Tensor = prim::SumToSize(%48, %14)
+  %52 : Tensor = prim::SumToSize(%48, %15)
+  %53 : Float(*, *) = aten::t(%13)
+  %54 : Float(*, *) = aten::mm(%52, %53)
+  %55 : Float(*, *) = aten::t(%10)
+  %56 : Float(*, *) = aten::mm(%55, %52)
+  %57 : Float(*, *) = aten::t(%56)
+  %58 : Float(*, *) = aten::t(%12)
+  %59 : Float(*, *) = aten::mm(%51, %58)
+  %60 : Float(*, *) = aten::t(%11)
+  %61 : Float(*, *) = aten::mm(%60, %51)
+  %62 : Float(*, *) = aten::t(%61)
+  return (%47, %49, %50, %54, %57, %59, %62);
 }
 with prim::FusionGroup_0 = graph(%0 : Float(*, *)
       %1 : Tensor
index 825e802..2933e1b 100644 (file)
@@ -28,16 +28,14 @@ with prim::DifferentiableGraph_0 = graph(%0 : Float(*, *)
   %17 : Tensor, %18 : Tensor, %19 : Tensor, %20 : Tensor = prim::ListUnpack(%16)
   %21 : int[] = prim::BroadcastSizes(%11, %12)
   %22 : int[] = prim::BroadcastSizes(%21, %13)
-  %23 : int[] = aten::size(%0)
-  %hy : Float(*, *), %25 : Float(*, *), %cy : Float(*, *), %outgate.1 : Float(*, *), %cellgate.1 : Float(*, *), %forgetgate.1 : Float(*, *), %ingate.1 : Float(*, *) = prim::FusionGroup_0(%0, %20, %19, %18, %17)
-  %31 : int[] = aten::size(%25)
-  %32 : int[] = aten::size(%outgate.1)
-  %33 : int[] = aten::size(%cellgate.1)
-  %34 : int[] = aten::size(%forgetgate.1)
-  %35 : int[] = aten::size(%ingate.1)
-  %36 : int[] = prim::BroadcastSizes(%34, %23)
-  %37 : int[] = prim::BroadcastSizes(%35, %33)
-  return (%hy, %cy, %7, %9, %11, %12, %21, %13, %22, %14, %ingate.1, %forgetgate.1, %cellgate.1, %outgate.1, %34, %23, %35, %33, %36, %37, %25, %32, %31);
+  %hy : Float(*, *), %24 : Float(*, *), %cy : Float(*, *), %outgate.1 : Float(*, *), %cellgate.1 : Float(*, *), %forgetgate.1 : Float(*, *), %ingate.1 : Float(*, *) = prim::FusionGroup_0(%0, %20, %19, %18, %17)
+  %30 : int[] = aten::size(%0)
+  %31 : int[] = aten::size(%cellgate.1)
+  %32 : int[] = aten::size(%forgetgate.1)
+  %33 : int[] = aten::size(%ingate.1)
+  %34 : int[] = prim::BroadcastSizes(%32, %30)
+  %35 : int[] = prim::BroadcastSizes(%33, %31)
+  return (%hy, %cy, %7, %9, %11, %12, %21, %13, %22, %14, %ingate.1, %forgetgate.1, %cellgate.1, %outgate.1, %34, %35, %24);
 }
 with prim::FusionGroup_0 = graph(%0 : Float(*, *)
       %1 : Tensor
index c53cb88..90ed94d 100644 (file)
@@ -20,80 +20,75 @@ graph(%0 : Float(*, *)
       %Wx : Float(*, *)
       %20 : Float(*, *)
       %Uz : Float(*, *)
-      %22 : int[]
+      %22 : Float(*, *)
       %23 : int[]
-      %24 : Float(*, *)
+      %24 : int[]
       %25 : int[]
       %26 : int[]
       %27 : int[]
       %28 : int[]
-      %29 : int[]
-      %30 : int[]
-      %31 : int[]
-      %32 : int[]
-      %33 : int[]
-      %34 : int[]
-      %35 : int[]
-      %36 : int[]
       %ingate : Float(*, *)
       %forgetgate : Float(*, *)
       %cellgate : Float(*, *)
       %outgate : Float(*, *)
-      %41 : int[]
-      %42 : int[]
-      %43 : int[]
-      %44 : int[]
-      %45 : int[]
-      %46 : int[]
-      %47 : Float(*, *)
-      %48 : int[]
-      %49 : int[]) {
-  %50 : int = prim::Constant[value=1]()
-  %51 : Float(*, *) = aten::mul(%0, %47)
-  %52 : Tensor = prim::SumToSize(%51, %48)
-  %53 : Float(*, *) = aten::mul(%0, %outgate)
-  %54 : Tensor = prim::SumToSize(%53, %49)
-  %55 : Tensor = prim::FusionGroup_0(%1, %54, %47)
-  %56 : Tensor = prim::SumToSize(%55, %45)
-  %57 : Tensor = prim::SumToSize(%55, %46)
-  %58 : Tensor = aten::mul(%57, %cellgate)
-  %59 : Tensor = prim::SumToSize(%58, %43)
-  %60 : Tensor = aten::mul(%57, %ingate)
-  %61 : Tensor = prim::SumToSize(%60, %44)
-  %62 : Tensor = aten::mul(%56, %12)
-  %63 : Tensor = prim::SumToSize(%62, %41)
-  %64 : Tensor = prim::FusionGroup_1(%59, %ingate, %63, %forgetgate, %61, %cellgate, %52, %outgate)
-  %65 : Tensor = prim::SumToSize(%64, %35)
-  %66 : Tensor = prim::SumToSize(%64, %36)
-  %67 : Tensor = prim::SumToSize(%65, %33)
-  %68 : Tensor = prim::SumToSize(%65, %34)
-  %69 : Tensor = aten::mul(%68, %Uz)
-  %70 : Tensor = prim::SumToSize(%69, %31)
-  %71 : Tensor = aten::mul(%68, %13)
-  %72 : Tensor = prim::SumToSize(%71, %32)
-  %73 : Tensor = prim::SumToSize(%67, %29)
-  %74 : Tensor = prim::SumToSize(%67, %30)
-  %75 : Tensor = aten::mul(%74, %Wx)
-  %76 : Tensor = prim::SumToSize(%75, %27)
-  %77 : Tensor = aten::mul(%74, %14)
-  %78 : Tensor = prim::SumToSize(%77, %28)
-  %79 : Tensor = aten::mul(%73, %Uz)
-  %80 : Tensor = prim::SumToSize(%79, %25)
-  %81 : Tensor = aten::mul(%73, %24)
-  %82 : Tensor = prim::SumToSize(%81, %26)
-  %83 : Tensor = aten::add(%72, %82, %50)
-  %84 : Tensor = aten::mul(%80, %Wx)
-  %85 : Tensor = prim::SumToSize(%84, %22)
-  %86 : Tensor = aten::mul(%80, %15)
-  %87 : Tensor = prim::SumToSize(%86, %23)
-  %88 : Tensor = aten::add(%78, %87, %50)
-  %89 : Float(*, *) = aten::t(%16)
-  %90 : Float(*, *) = aten::mm(%89, %83)
-  %91 : Float(*, *) = aten::t(%90)
-  %92 : Float(*, *) = aten::t(%17)
-  %93 : Float(*, *) = aten::mm(%92, %88)
-  %94 : Float(*, *) = aten::t(%93)
-  return (%66, %70, %76, %85, %91, %94);
+      %33 : int[]
+      %34 : int[]
+      %35 : Float(*, *)) {
+  %36 : int = prim::Constant[value=1]()
+  %37 : int[] = aten::size(%35)
+  %38 : int[] = aten::size(%outgate)
+  %39 : int[] = aten::size(%cellgate)
+  %40 : int[] = aten::size(%ingate)
+  %41 : int[] = aten::size(%forgetgate)
+  %42 : int[] = aten::size(%13)
+  %43 : int[] = aten::size(%14)
+  %44 : int[] = aten::size(%Uz)
+  %45 : int[] = aten::size(%22)
+  %46 : int[] = aten::size(%Wx)
+  %47 : int[] = aten::size(%15)
+  %48 : Float(*, *) = aten::mul(%0, %35)
+  %49 : Tensor = prim::SumToSize(%48, %38)
+  %50 : Float(*, *) = aten::mul(%0, %outgate)
+  %51 : Tensor = prim::SumToSize(%50, %37)
+  %52 : Tensor = prim::FusionGroup_0(%1, %51, %35)
+  %53 : Tensor = prim::SumToSize(%52, %33)
+  %54 : Tensor = prim::SumToSize(%52, %34)
+  %55 : Tensor = aten::mul(%54, %cellgate)
+  %56 : Tensor = prim::SumToSize(%55, %40)
+  %57 : Tensor = aten::mul(%54, %ingate)
+  %58 : Tensor = prim::SumToSize(%57, %39)
+  %59 : Tensor = aten::mul(%53, %12)
+  %60 : Tensor = prim::SumToSize(%59, %41)
+  %61 : Tensor = prim::FusionGroup_1(%56, %ingate, %60, %forgetgate, %58, %cellgate, %49, %outgate)
+  %62 : Tensor = prim::SumToSize(%61, %28)
+  %63 : Tensor = prim::SumToSize(%61, %26)
+  %64 : Tensor = aten::mul(%63, %Uz)
+  %65 : Tensor = prim::SumToSize(%64, %42)
+  %66 : Tensor = aten::mul(%63, %13)
+  %67 : Tensor = prim::SumToSize(%66, %44)
+  %68 : Tensor = prim::SumToSize(%61, %23)
+  %69 : Tensor = prim::SumToSize(%61, %24)
+  %70 : Tensor = aten::mul(%69, %Wx)
+  %71 : Tensor = prim::SumToSize(%70, %43)
+  %72 : Tensor = aten::mul(%69, %14)
+  %73 : Tensor = prim::SumToSize(%72, %46)
+  %74 : Tensor = aten::mul(%68, %Uz)
+  %75 : Tensor = prim::SumToSize(%74, %45)
+  %76 : Tensor = aten::mul(%68, %22)
+  %77 : Tensor = prim::SumToSize(%76, %44)
+  %78 : Tensor = aten::add(%67, %77, %36)
+  %79 : Tensor = aten::mul(%75, %Wx)
+  %80 : Tensor = prim::SumToSize(%79, %47)
+  %81 : Tensor = aten::mul(%75, %15)
+  %82 : Tensor = prim::SumToSize(%81, %46)
+  %83 : Tensor = aten::add(%73, %82, %36)
+  %84 : Float(*, *) = aten::t(%16)
+  %85 : Float(*, *) = aten::mm(%84, %78)
+  %86 : Float(*, *) = aten::t(%85)
+  %87 : Float(*, *) = aten::t(%17)
+  %88 : Float(*, *) = aten::mm(%87, %83)
+  %89 : Float(*, *) = aten::t(%88)
+  return (%62, %65, %71, %80, %86, %89);
 }
 with prim::FusionGroup_0 = graph(%0 : Float(*, *)
       %1 : Tensor
index 1c4e483..c53dbb3 100644 (file)
@@ -23,32 +23,29 @@ with prim::DifferentiableGraph_0 = graph(%0 : Float(*, *)
   %Wx.1 : Float(*, *) = aten::mm(%7, %9)
   %11 : Float(*, *) = aten::t(%6)
   %Uz.1 : Float(*, *) = aten::mm(%5, %11)
-  %13 : int[] = aten::size(%4)
-  %14 : int[] = aten::size(%Wx.1)
-  %15 : Float(*, *) = aten::mul(%4, %Wx.1)
-  %16 : int[] = aten::size(%15)
-  %17 : int[] = aten::size(%Uz.1)
-  %18 : int[] = aten::size(%3)
-  %19 : int[] = aten::size(%2)
-  %20 : int[] = aten::size(%1)
-  %21 : Tensor[] = prim::ListConstruct(%1, %2, %Uz.1, %15, %3, %Wx.1)
-  %22 : Tensor[] = aten::broadcast_tensors(%21)
-  %23 : Tensor, %24 : Tensor, %25 : Tensor, %26 : Tensor, %27 : Tensor, %28 : Tensor = prim::ListUnpack(%22)
-  %29 : int[] = prim::BroadcastSizes(%18, %14)
-  %30 : int[] = prim::BroadcastSizes(%16, %17)
-  %31 : int[] = prim::BroadcastSizes(%19, %17)
-  %32 : int[] = prim::BroadcastSizes(%30, %29)
-  %33 : int[] = prim::BroadcastSizes(%32, %31)
-  %34 : int[] = aten::size(%0)
-  %hy : Float(*, *), %36 : Float(*, *), %cy : Float(*, *), %outgate.1 : Float(*, *), %cellgate.1 : Float(*, *), %forgetgate.1 : Float(*, *), %ingate.1 : Float(*, *) = prim::FusionGroup_0(%0, %28, %27, %26, %25, %24, %23)
-  %42 : int[] = aten::size(%36)
-  %43 : int[] = aten::size(%outgate.1)
-  %44 : int[] = aten::size(%cellgate.1)
-  %45 : int[] = aten::size(%forgetgate.1)
-  %46 : int[] = aten::size(%ingate.1)
-  %47 : int[] = prim::BroadcastSizes(%45, %34)
-  %48 : int[] = prim::BroadcastSizes(%46, %44)
-  return (%hy, %cy, %9, %Wx.1, %11, %Uz.1, %13, %14, %15, %16, %17, %18, %14, %30, %29, %19, %17, %32, %31, %33, %20, %ingate.1, %forgetgate.1, %cellgate.1, %outgate.1, %45, %34, %46, %44, %47, %48, %36, %43, %42);
+  %13 : Float(*, *) = aten::mul(%4, %Wx.1)
+  %14 : int[] = aten::size(%1)
+  %15 : Tensor[] = prim::ListConstruct(%1, %2, %Uz.1, %13, %3, %Wx.1)
+  %16 : Tensor[] = aten::broadcast_tensors(%15)
+  %17 : Tensor, %18 : Tensor, %19 : Tensor, %20 : Tensor, %21 : Tensor, %22 : Tensor = prim::ListUnpack(%16)
+  %23 : int[] = aten::size(%3)
+  %24 : int[] = aten::size(%Wx.1)
+  %25 : int[] = prim::BroadcastSizes(%23, %24)
+  %26 : int[] = aten::size(%13)
+  %27 : int[] = aten::size(%Uz.1)
+  %28 : int[] = prim::BroadcastSizes(%26, %27)
+  %29 : int[] = aten::size(%2)
+  %30 : int[] = prim::BroadcastSizes(%29, %27)
+  %31 : int[] = prim::BroadcastSizes(%28, %25)
+  %32 : int[] = prim::BroadcastSizes(%31, %30)
+  %hy : Float(*, *), %34 : Float(*, *), %cy : Float(*, *), %outgate.1 : Float(*, *), %cellgate.1 : Float(*, *), %forgetgate.1 : Float(*, *), %ingate.1 : Float(*, *) = prim::FusionGroup_0(%0, %22, %21, %20, %19, %18, %17)
+  %40 : int[] = aten::size(%0)
+  %41 : int[] = aten::size(%cellgate.1)
+  %42 : int[] = aten::size(%forgetgate.1)
+  %43 : int[] = aten::size(%ingate.1)
+  %44 : int[] = prim::BroadcastSizes(%42, %40)
+  %45 : int[] = prim::BroadcastSizes(%43, %41)
+  return (%hy, %cy, %9, %Wx.1, %11, %Uz.1, %13, %28, %25, %31, %30, %32, %14, %ingate.1, %forgetgate.1, %cellgate.1, %outgate.1, %44, %45, %34);
 }
 with prim::FusionGroup_0 = graph(%0 : Float(*, *)
       %1 : Tensor
index 1e16334..aa42855 100644 (file)
@@ -1,5 +1,6 @@
 #include "torch/csrc/jit/autodiff.h"
 
+#include "torch/csrc/jit/passes/common_subexpression_elimination.h"
 #include "torch/csrc/jit/passes/dead_code_elimination.h"
 #include "torch/csrc/jit/passes/constant_pooling.h"
 #include "torch/csrc/jit/symbolic_variable.h"
@@ -632,6 +633,32 @@ static ReverseDetails addReverseInline(Gradient& grad_desc) {
   return ReverseDetails(std::move(grad_map), reverse_block);
 }
 
+// Returns a topologically-sorted list of values produced in f, and used in its reverse program.
+static value_list getReverseCaptures(Gradient& grad_desc) {
+  auto & graph = *grad_desc.f;
+  auto primal_block = graph.block();
+
+  value_set reverse_captures_set;
+  value_list reverse_captures; // Invariant: topo sorted
+  auto check_uses = [&](Value *v) {
+    for (auto use : v->uses()) {
+      if (use.user->owningBlock() == primal_block)
+        continue;
+      if (/* bool unseen = */ reverse_captures_set.emplace(v).second) {
+        reverse_captures.push_back(v);
+      }
+    }
+  };
+  for (Value * input : graph.inputs()) {
+    check_uses(input);
+  }
+  for (Node * node : graph.nodes()) {
+    for (Value * output : node->outputs())
+      check_uses(output);
+  }
+  return reverse_captures;
+}
+
 // Any temporary value from the primal graphs needs to be captured for later use in the
 // reverse graph, to avoid costly recomputations. However, a lot of the nodes we have
 // in our graphs are simply constants, which are cheap to execute and replicate, and so
@@ -662,6 +689,73 @@ static void liftConstants(Gradient& grad_desc, ReverseDetails& rev_info) {
   }
 }
 
+static void deduplicateSizeCaptures(Gradient& grad_desc, ReverseDetails& rev_info) {
+  Block * primal_block = grad_desc.f->block();
+  const auto usedOnlyInReverse = [primal_block](Value * v) {
+    const auto & uses = v->uses();
+    return std::all_of(uses.begin(), uses.end(),
+                       [primal_block](const Use& u) { return u.user->owningBlock() != primal_block; });
+  };
+  auto captures = getReverseCaptures(grad_desc);
+  value_set capture_set (captures.begin(), captures.end());
+  for (Value * capture : captures) {
+    Node * node = capture->node();
+    if (!node->matches("aten::size(Tensor self) -> int[]")) {
+      continue;
+    }
+    if (usedOnlyInReverse(capture) && capture_set.count(node->input())) {
+      WithInsertPoint insert_guard { *rev_info.reverse_block->nodes().begin() };
+      capture->replaceAllUsesWith(SymbolicVariable(node->input()).size());
+      node->destroy();
+    }
+  }
+}
+
+static void eliminateDeadCode(ReverseDetails& rev_info) {
+  // addReverseInline has to call gradientForNode if *any* of the inputs
+  // require grad, but it will emit vjps for *all* inputs. Use DCE to remove
+  // unnecessary nodes. Additionally, requires_grad() on intermediates is an
+  // overapproximation of the real state, so we might have emitted some
+  // gradients, only to realize that they were unnecessary once we reach a
+  // point that doesn't require grad.
+  // Of course, we need to filter out corresponding entries of grad_map, because
+  // we don't want to accidentally access freed pointers later.
+  auto dead_nodes = FindDeadNodes(rev_info.reverse_block);
+  if (dead_nodes.empty()) {
+    return;
+  }
+  std::vector<Value*> to_erase;
+  for (auto & entry : rev_info.grad_map) {
+    if (dead_nodes.count(entry.second->node()) > 0) {
+      to_erase.push_back(entry.first);
+    }
+  }
+  for (Value * v : to_erase) {
+    rev_info.grad_map.erase(v);
+  }
+  std::vector<Node*> sorted_dead_nodes(dead_nodes.begin(), dead_nodes.end());
+  std::sort(sorted_dead_nodes.begin(), sorted_dead_nodes.end(), [](Node* a, Node* b) { return a->isAfter(b); });
+  for (Node * n : sorted_dead_nodes) {
+    n->destroy();
+  }
+}
+
+static void Optimize(Gradient& grad_desc, ReverseDetails& rev_info) {
+  // TODO: we are sometimes emitting expressions like SumToSize(SumToSize(x, s1), s2),
+  // which are equivalent to SumToSize(x, s2), and could save us some captures, but I'm
+  // not 100% sure how to optimize this at this stage, since we don't know which
+  // GradOf blocks will be stitched together to form the derivative. I guess a smart
+  // analysis could implement this, but I didn't have time before the 1.0 release,
+  // so I put this only as a peephole optimization.
+  liftConstants(grad_desc, rev_info);
+  // We generally add a lot of aten::size calls (for derivatives of broadcasting
+  // operators), and they often end up duplicated, and would get captured multiple
+  // times. Make sure we deduplicate them before lifting.
+  EliminateCommonSubexpression(grad_desc.f);
+  deduplicateSizeCaptures(grad_desc, rev_info);
+  eliminateDeadCode(rev_info);
+}
+
 // Takes a grad_desc.f returned from `addReverseInline` and splits off the
 // reverse_block into its own graph, storing it in df.
 // All intermediates needed in the second stage are added to
@@ -681,24 +775,8 @@ static void lambdaLiftReverse(Gradient& grad_desc, ReverseDetails& rev_info) {
   // and used in df. They will need to be added as inputs of the df
   // and some of them may also need to be appended as outputs of f if
   // they are not already an input or an output of f
-  value_set reverse_captures_set;
-  value_list reverse_captures; // Invariant: topo sorted
-  auto check_uses = [&](Value *v) {
-    for (auto use : v->uses()) {
-      if (use.user->owningBlock() == primal_block)
-        continue;
-      if (/* bool unseen = */ reverse_captures_set.emplace(v).second) {
-        reverse_captures.push_back(v);
-      }
-    }
-  };
-  for (Value * input : graph.inputs()) {
-    check_uses(input);
-  }
-  for (Node * node : graph.nodes()) {
-    for (Value * output : node->outputs())
-      check_uses(output);
-  }
+  // Invariant: topo sorted
+  value_list reverse_captures = getReverseCaptures(grad_desc);
 
   // --------------------------------------------------------------------------
   // 2. Prepare input/outputs lists for f and df
@@ -750,7 +828,12 @@ static void lambdaLiftReverse(Gradient& grad_desc, ReverseDetails& rev_info) {
   for (size_t i = grad_desc.f_real_outputs; i < graph.outputs().size(); ++i) {
     Value * tmp = graph.outputs().at(i);
     // Add VJP inputs only for intermediates that actually required grad.
-    if (!tmp->requires_grad()) continue;
+    // Note that we check the contents of the grad_map instead of tmp->requires_grad(),
+    // becuase it's actually a more faithful source. tmp->requires_grad() is really an
+    // overapproximation (i.e. it can have false positives), while the gradients we will
+    // emit for this value can get DCE-d in the optimization pass (because it has no
+    // influence on the real f's outputs that we differentiate).
+    if (rev_info.grad_map.count(tmp) == 0) continue;
     Value * tmp_vjp_in = reverse_block->addInput()->setType(tmp->type());
     Value * tmp_vjp_prev = rev_info.grad_map.at(tmp);
     // This is quite weird because we can't first make a sum and then replace all uses
@@ -786,6 +869,7 @@ static void lambdaLiftReverse(Gradient& grad_desc, ReverseDetails& rev_info) {
   reverse_block->owningNode()->destroy();
 }
 
+
 Gradient differentiate(std::shared_ptr<Graph>& graph) {
   Gradient grad_desc;
   // Take ownership of the graph
@@ -798,18 +882,13 @@ Gradient differentiate(std::shared_ptr<Graph>& graph) {
   WithInsertPoint guard(grad_desc.f->block());
   // Fills in df_input_vjps and df_output_vjps
   auto rev_info = addReverseInline(grad_desc);
-  // Lift constants captured for the reverse graph into it
-  liftConstants(grad_desc, rev_info);
+  Optimize(grad_desc, rev_info);
   // Fills in f, df, f_real_outputs, df_input_captures,
   // modifies df_input_vjps (new vjps are added for temporaries)
   lambdaLiftReverse(grad_desc, rev_info);
   // It's possible the we've cloned the same constants many times, so
   // de-duplicate them
   ConstantPooling(grad_desc.df);
-  // addReverseInline has to call gradientForNode if *any* of the inputs
-  // require grad, but it will emit vjps for *all* inputs. Use DCE to remove
-  // unnecessary nodes.
-  EliminateDeadCode(grad_desc.df->block());
   return grad_desc;
 }
 
index ddcdd10..637f39e 100644 (file)
@@ -11,7 +11,8 @@ class DeadCodeEliminator {
  public:
   explicit DeadCodeEliminator(std::shared_ptr<Graph> graph)
       : aliasDb_(AliasAnalysis(graph)) {}
-  DeadCodeEliminator(){};
+  DeadCodeEliminator(bool collect_only = false)
+      : collect_only_(collect_only) {}
 
   // The algorithm is an inverse mark-and-sweep. Starting from the return node,
   // we mark "live" nodes that are necessary for the output. Nodes that have
@@ -28,6 +29,12 @@ class DeadCodeEliminator {
     sweep(block, recurse);
   }
 
+  // *_once, because it should be called only once per run() call.
+  std::unordered_set<Node*> get_dead_once() {
+    JIT_ASSERT(collect_only_);
+    return std::move(dead_nodes_);
+  }
+
  private:
   void setLastWildcard() {
     if (!aliasDb_) {
@@ -100,8 +107,10 @@ class DeadCodeEliminator {
       auto node = *it;
       // note these occur before the recursion because we want to uncover
       // dead code in the blocks used to calculate the output
-      removeDeadIfOutputs(node);
-      removeDeadLoopOutputs(node);
+      if (!collect_only_) {
+        removeDeadIfOutputs(node);
+        removeDeadLoopOutputs(node);
+      }
       if (recurse) {
         for (Block* block : node->blocks()) {
           sweep(block, true);
@@ -110,10 +119,28 @@ class DeadCodeEliminator {
       // TODO(suo): We shouldn't really have to check whether a node has uses,
       // since the mark algorithm should do that. But currently, the marking
       // doesn't reach loop counters in certain cases (see TestScript.test_pass)
-      if (!marked_.count(node) && !node->hasUses()) {
-        it.destroyCurrent();
+      if (!marked_.count(node) && !hasUsesOutsideDeadNodes(node)) {
+        if (collect_only_) {
+          dead_nodes_.insert(node);
+        } else {
+          it.destroyCurrent();
+        }
+      }
+    }
+  }
+
+  bool hasUsesOutsideDeadNodes(Node * n) {
+    if (!collect_only_) {
+      return n->hasUses();
+    }
+    for (Value * output : n->outputs()) {
+      for (const Use & u : output->uses()) {
+        if (dead_nodes_.count(u.user) == 0) {
+          return true;
+        }
       }
     }
+    return false;
   }
 
   void markAndEnqueue(Node* n) {
@@ -213,15 +240,23 @@ class DeadCodeEliminator {
   std::list<Node*> workQueue_;
   c10::optional<const Node*> lastWildcard_;
 
+  bool collect_only_ = false;
+  std::unordered_set<Node*> dead_nodes_; // Will be filled iff collect_only_ is true
 };
 
 void EliminateDeadCode(const std::shared_ptr<Graph>& graph) {
-  DeadCodeEliminator(graph).run(graph->block(), true);
+  DeadCodeEliminator(graph).run(graph->block(), /*recurse=*/true);
 }
 
 void EliminateDeadCode(Block* block, bool recurse) {
   DeadCodeEliminator().run(block, recurse);
 }
 
+std::unordered_set<Node*> FindDeadNodes(Block* block, bool recurse) {
+  DeadCodeEliminator eliminator(/*collect_only=*/true);
+  eliminator.run(block, recurse);
+  return eliminator.get_dead_once();
+}
+
 } // namespace jit
 } // namespace torch
index 39aec2f..0e908eb 100644 (file)
@@ -13,4 +13,6 @@ namespace torch { namespace jit {
 TORCH_API void EliminateDeadCode(const std::shared_ptr<Graph>& graph);
 TORCH_API void EliminateDeadCode(Block *block, bool recurse=true);
 
+TORCH_API std::unordered_set<Node*> FindDeadNodes(Block *block, bool recurse=true);
+
 }}
index ab440f2..ee9ba8a 100644 (file)
@@ -808,6 +808,9 @@ struct GraphFuser {
                        [](const Use& u) { return u.user->matches("aten::size(Tensor self) -> int[]"); });
   }
 
+  // Builds up expressions that compute shapes of all intermediates (and outputs)
+  // of the fusion group, based on the sizes of inputs. You should run DCE to remove
+  // those that you end up not using.
   std::unordered_map<Value*, Value*> buildShapeExpressions(Node * fusion_group) {
     WithInsertPoint insert_guard { fusion_group->next() };
     std::unordered_map<Value*, Value*> shape_of;
@@ -852,6 +855,8 @@ struct GraphFuser {
         sizes_node->i_(attr::chunks, n->i(attr::chunks));
         Value * regular_size = sizes_node->outputs().at(0);
         Value * last_size = sizes_node->outputs().at(1);
+        regular_size->setType(ListType::ofInts());
+        last_size->setType(ListType::ofInts());
         auto outputs = n->outputs();
         for (Value * o : outputs.slice(0, outputs.size() - 1)) {
           shape_of.emplace(o, regular_size);
index c7fcccd..db91f4b 100644 (file)
@@ -120,11 +120,18 @@ void PeepholeOptimizeImpl(Block * block, bool addmm_fusion_enabled) {
           node->get<at::Scalar>(attr::other)->toDouble() == 0) {
         node->output()->replaceAllUsesWith(node->input(0));
       }
-    } else if(node->kind() == prim::TensorToNum || node->kind() == prim::ImplicitTensorToNum) {
+    } else if (node->kind() == prim::TensorToNum || node->kind() == prim::ImplicitTensorToNum) {
       Node* input_node = node->input()->node();
       if (input_node->kind() == prim::NumToTensor) {
         node->output()->replaceAllUsesWith(input_node->input());
       }
+    } else if (node->matches("prim::SumToSize(Tensor(a) self, int[] size) -> Tensor(a)")) {
+      auto uses = node->output()->uses();
+      for (Use u : uses) {
+        if (u.user->matches("prim::SumToSize(Tensor(a) self, int[] size) -> Tensor(a)")) {
+          u.user->replaceInput(0, node->inputs().at(0));
+        }
+      }
     }
   }
 }