From: Adam Paszke Date: Wed, 5 Dec 2018 04:35:51 +0000 (-0800) Subject: Improvements for symbolic AD (#14758) X-Git-Tag: accepted/tizen/6.5/unified/20211028.231830~2465 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=8dfebc16cc2dbd6e4f9fd03515428d5b8d49c4c3;p=platform%2Fupstream%2Fpytorch.git Improvements for symbolic AD (#14758) Summary: **Review only the last commit.** This commit adds a few optimizations to AD, that let us dramatically reduce the number of sizes we capture from forward. We now: - collapse chains of SumToSize - avoid capturing sizes of tensors that are captured anyway - more aggressively DCE the reverse code - run CSE on the primal code to deduplicate `aten::size` calls cc zou3519 zdevito Pull Request resolved: https://github.com/pytorch/pytorch/pull/14758 Differential Revision: D13324440 Pulled By: zou3519 fbshipit-source-id: 45ccbc13605adcef2b461840c6089d3200000c72 --- diff --git a/test/cpp/jit/tests.h b/test/cpp/jit/tests.h index d728a65..d087803 100644 --- a/test/cpp/jit/tests.h +++ b/test/cpp/jit/tests.h @@ -835,8 +835,8 @@ void testDifferentiate(std::ostream& out = std::cout) { auto grad_spec = differentiate(graph); std::vector expected_captured_inputs = {0, 1}; - std::vector expected_captured_outputs = {1, 2, 3, 4, 5, 6, 7}; - std::vector expected_input_vjps = {0, 3}; + std::vector expected_captured_outputs = {1, 2}; + std::vector expected_input_vjps = {0, 1}; std::vector expected_output_vjps = {0, 1}; ASSERT_EQ(grad_spec.f_real_outputs, 1); ASSERT_EQ(grad_spec.df_input_captured_inputs, expected_captured_inputs); @@ -866,11 +866,11 @@ void testDifferentiateWithRequiresGrad(std::ostream& out = std::cout) { PropagateRequiresGrad(graph); auto grad_spec = differentiate(graph); - std::vector expected_input_vjps = {1, 4}; // for e and %4 = (d + a) + std::vector expected_input_vjps = {1, 3}; // for e and %4 = (d + a) std::vector expected_output_vjps = {0}; // only a requires grad ASSERT_EQ(grad_spec.f_real_outputs, 2); ASSERT_EQ(grad_spec.df_input_captured_inputs, std::vector({0})); - ASSERT_EQ(grad_spec.df_input_captured_outputs, std::vector({2, 3, 4, 5, 6, 7, 8})); + ASSERT_EQ(grad_spec.df_input_captured_outputs, std::vector({2, 3, 4, 5})); ASSERT_EQ(grad_spec.df_input_vjps, expected_input_vjps); ASSERT_EQ(grad_spec.df_output_vjps, expected_output_vjps); out << "testDifferentiateWithRequiresGrad\n"; diff --git a/test/expect/TestJit.test_cpp_cuda.expect b/test/expect/TestJit.test_cpp_cuda.expect index ddcac4c..743572f 100644 --- a/test/expect/TestJit.test_cpp_cuda.expect +++ b/test/expect/TestJit.test_cpp_cuda.expect @@ -93,57 +93,50 @@ with prim::DifferentiableGraph_0 = graph(%13 : Tensor testDifferentiate graph(%0 : Float(2, 3, 4) %1 : Float(2, 3, 4)) { - %23 : int[] = aten::size(%0) - %26 : int[] = aten::size(%1) %2 : Float(2, 3, 4) = aten::mul(%0, %1) - %15 : int[] = aten::size(%2) - %18 : int[] = aten::size(%0) %3 : Float(2, 3, 4) = aten::mul(%2, %0) %4 : int = prim::Constant[value=1]() %7 : int[] = aten::size(%3) - %10 : int[] = aten::size(%1) %5 : Float(2, 3, 4) = aten::add(%3, %1, %4) - return (%5, %23, %26, %2, %15, %18, %7, %10); + return (%5, %2, %7); } graph(%0 : Float(2, 3, 4) %1 : Float(2, 3, 4) %2 : Float(2, 3, 4) %3 : Float(2, 3, 4) - %4 : int[] - %5 : int[] - %6 : Float(2, 3, 4) - %7 : int[] - %8 : int[] - %9 : int[] - %10 : int[]) { - %11 : int = prim::Constant[value=1]() - %12 : Tensor, %13 : Tensor = prim::GradOf[name="aten::add"](%0) + %4 : Float(2, 3, 4) + %5 : int[]) { + %9 : int = prim::Constant[value=1]() + %6 : int[] = aten::size(%4) + %7 : int[] = aten::size(%3) + %8 : int[] = aten::size(%2) + %10 : Tensor, %11 : Tensor = prim::GradOf[name="aten::add"](%0) block0() { - %14 : Tensor = prim::SumToSize(%0, %9) - %15 : Float(2, 3, 4) = aten::mul(%0, %11) - %16 : Tensor = prim::SumToSize(%15, %10) - -> (%14, %16) + %12 : Tensor = prim::SumToSize(%0, %5) + %13 : Float(2, 3, 4) = aten::mul(%0, %9) + %14 : Tensor = prim::SumToSize(%13, %7) + -> (%12, %14) } - %17 : Tensor, %18 : Tensor = prim::GradOf[name="aten::mul"](%12) + %15 : Tensor, %16 : Tensor = prim::GradOf[name="aten::mul"](%10) block0() { - %19 : Tensor = aten::mul(%12, %2) - %20 : Tensor = prim::SumToSize(%19, %7) - %21 : Tensor = aten::mul(%12, %6) - %22 : Tensor = prim::SumToSize(%21, %8) - -> (%20, %22) + %17 : Tensor = aten::mul(%10, %2) + %18 : Tensor = prim::SumToSize(%17, %6) + %19 : Tensor = aten::mul(%10, %4) + %20 : Tensor = prim::SumToSize(%19, %8) + -> (%18, %20) } - %23 : Tensor = prim::AutogradAdd(%1, %17) - %24 : Tensor, %25 : Tensor = prim::GradOf[name="aten::mul"](%23) + %21 : Tensor = prim::AutogradAdd(%1, %15) + %22 : Tensor, %23 : Tensor = prim::GradOf[name="aten::mul"](%21) block0() { - %26 : Tensor = aten::mul(%23, %3) - %27 : Tensor = prim::SumToSize(%26, %4) - %28 : Tensor = aten::mul(%23, %2) - %29 : Tensor = prim::SumToSize(%28, %5) - -> (%27, %29) + %24 : Tensor = aten::mul(%21, %3) + %25 : Tensor = prim::SumToSize(%24, %8) + %26 : Tensor = aten::mul(%21, %2) + %27 : Tensor = prim::SumToSize(%26, %7) + -> (%25, %27) } - %30 : Tensor = prim::AutogradAdd(%18, %24) - %31 : Tensor = prim::AutogradAdd(%13, %25) - return (%30, %31); + %28 : Tensor = prim::AutogradAdd(%16, %22) + %29 : Tensor = prim::AutogradAdd(%11, %23) + return (%28, %29); } testDifferentiateWithRequiresGrad @@ -152,54 +145,48 @@ graph(%0 : Float(*) %2 : Float(*) = aten::mul(%1, %1) %3 : int = prim::Constant[value=1]() %4 : Float(*) = aten::add(%2, %1, %3) - %5 : int = prim::Constant[value=1]() %26 : int[] = aten::size(%4) - %29 : int[] = aten::size(%0) - %6 : Float(*) = aten::add(%4, %0, %5) - %19 : int[] = aten::size(%6) - %22 : int[] = aten::size(%0) + %6 : Float(*) = aten::add(%4, %0, %3) %7 : Float(*) = aten::mul(%6, %0) - %8 : int = prim::Constant[value=1]() %11 : int[] = aten::size(%7) %14 : int[] = aten::size(%1) - %9 : Float(*) = aten::add(%7, %1, %8) - return (%4, %9, %26, %29, %6, %19, %22, %11, %14); + %9 : Float(*) = aten::add(%7, %1, %3) + return (%4, %9, %26, %6, %11, %14); } graph(%0 : Float(*) %1 : Float(*) %2 : Float(*) %3 : int[] - %4 : int[] - %5 : Float(*) - %6 : int[] - %7 : int[] - %8 : int[] - %9 : int[]) { - %10 : int = prim::Constant[value=1]() - %12 : Tensor, %13 : Tensor = prim::GradOf[name="aten::add"](%0) + %4 : Float(*) + %5 : int[] + %6 : int[]) { + %9 : int = prim::Constant[value=1]() + %7 : int[] = aten::size(%4) + %8 : int[] = aten::size(%2) + %11 : Tensor, %12 : Tensor = prim::GradOf[name="aten::add"](%0) block0() { - %14 : Tensor = prim::SumToSize(%0, %8) - %15 : Tensor = aten::mul(%0, %10) - %16 : Tensor = prim::SumToSize(%15, %9) - -> (%14, %16) + %13 : Tensor = prim::SumToSize(%0, %5) + %14 : Tensor = aten::mul(%0, %9) + %15 : Tensor = prim::SumToSize(%14, %6) + -> (%13, %15) } - %17 : Tensor, %18 : Tensor = prim::GradOf[name="aten::mul"](%12) + %16 : Tensor, %17 : Tensor = prim::GradOf[name="aten::mul"](%11) block0() { - %19 : Tensor = aten::mul(%12, %2) - %20 : Tensor = prim::SumToSize(%19, %6) - %21 : Tensor = aten::mul(%12, %5) - %22 : Tensor = prim::SumToSize(%21, %7) - -> (%20, %22) + %18 : Tensor = aten::mul(%11, %2) + %19 : Tensor = prim::SumToSize(%18, %7) + %20 : Tensor = aten::mul(%11, %4) + %21 : Tensor = prim::SumToSize(%20, %8) + -> (%19, %21) } - %23 : Tensor = prim::AutogradAdd(%1, %17) - %24 : Tensor, %25 : Tensor = prim::GradOf[name="aten::add"](%23) + %22 : Tensor = prim::AutogradAdd(%1, %16) + %23 : Tensor, %24 : Tensor = prim::GradOf[name="aten::add"](%22) block0() { - %26 : Tensor = prim::SumToSize(%23, %3) - %27 : Tensor = aten::mul(%23, %10) - %28 : Tensor = prim::SumToSize(%27, %4) - -> (%26, %28) + %25 : Tensor = prim::SumToSize(%22, %3) + %26 : Tensor = aten::mul(%22, %9) + %27 : Tensor = prim::SumToSize(%26, %8) + -> (%25, %27) } - %29 : Tensor = prim::AutogradAdd(%18, %25) - return (%29); + %28 : Tensor = prim::AutogradAdd(%17, %24) + return (%28); } diff --git a/test/expect/TestScript.test_lstm_fusion_cuda-backward.expect b/test/expect/TestScript.test_lstm_fusion_cuda-backward.expect index 9ee5162..af8dedf 100644 --- a/test/expect/TestScript.test_lstm_fusion_cuda-backward.expect +++ b/test/expect/TestScript.test_lstm_fusion_cuda-backward.expect @@ -24,46 +24,44 @@ graph(%0 : Float(*, *) %outgate : Float(*, *) %24 : int[] %25 : int[] - %26 : int[] - %27 : int[] - %28 : int[] - %29 : int[] - %30 : Float(*, *) - %31 : int[] - %32 : int[]) { - %33 : Float(*, *) = aten::mul(%0, %30) - %34 : Tensor = prim::SumToSize(%33, %31) + %26 : Float(*, *)) { + %27 : int[] = aten::size(%26) + %28 : int[] = aten::size(%outgate) + %29 : int[] = aten::size(%cellgate) + %30 : int[] = aten::size(%ingate) + %31 : int[] = aten::size(%9) + %32 : int[] = aten::size(%forgetgate) + %33 : Float(*, *) = aten::mul(%0, %26) + %34 : Tensor = prim::SumToSize(%33, %28) %35 : Float(*, *) = aten::mul(%0, %outgate) - %36 : Tensor = prim::SumToSize(%35, %32) - %37 : Tensor = prim::FusionGroup_0(%1, %36, %30) - %38 : Tensor = prim::SumToSize(%37, %28) - %39 : Tensor = prim::SumToSize(%37, %29) + %36 : Tensor = prim::SumToSize(%35, %27) + %37 : Tensor = prim::FusionGroup_0(%1, %36, %26) + %38 : Tensor = prim::SumToSize(%37, %24) + %39 : Tensor = prim::SumToSize(%37, %25) %40 : Tensor = aten::mul(%39, %cellgate) - %41 : Tensor = prim::SumToSize(%40, %26) + %41 : Tensor = prim::SumToSize(%40, %30) %42 : Tensor = aten::mul(%39, %ingate) - %43 : Tensor = prim::SumToSize(%42, %27) + %43 : Tensor = prim::SumToSize(%42, %29) %44 : Tensor = aten::mul(%38, %9) - %45 : Tensor = prim::SumToSize(%44, %24) + %45 : Tensor = prim::SumToSize(%44, %32) %46 : Tensor = aten::mul(%38, %forgetgate) - %47 : Tensor = prim::SumToSize(%46, %25) + %47 : Tensor = prim::SumToSize(%46, %31) %48 : Tensor = prim::FusionGroup_1(%41, %ingate, %45, %forgetgate, %43, %cellgate, %34, %outgate) - %49 : Tensor = prim::SumToSize(%48, %18) - %50 : Tensor = prim::SumToSize(%48, %19) - %51 : Tensor = prim::SumToSize(%49, %16) - %52 : Tensor = prim::SumToSize(%49, %17) - %53 : Tensor = prim::SumToSize(%51, %14) - %54 : Tensor = prim::SumToSize(%51, %15) - %55 : Float(*, *) = aten::t(%13) - %56 : Float(*, *) = aten::mm(%54, %55) - %57 : Float(*, *) = aten::t(%10) - %58 : Float(*, *) = aten::mm(%57, %54) - %59 : Float(*, *) = aten::t(%58) - %60 : Float(*, *) = aten::t(%12) - %61 : Float(*, *) = aten::mm(%53, %60) - %62 : Float(*, *) = aten::t(%11) - %63 : Float(*, *) = aten::mm(%62, %53) - %64 : Float(*, *) = aten::t(%63) - return (%47, %50, %52, %56, %59, %61, %64); + %49 : Tensor = prim::SumToSize(%48, %19) + %50 : Tensor = prim::SumToSize(%48, %17) + %51 : Tensor = prim::SumToSize(%48, %14) + %52 : Tensor = prim::SumToSize(%48, %15) + %53 : Float(*, *) = aten::t(%13) + %54 : Float(*, *) = aten::mm(%52, %53) + %55 : Float(*, *) = aten::t(%10) + %56 : Float(*, *) = aten::mm(%55, %52) + %57 : Float(*, *) = aten::t(%56) + %58 : Float(*, *) = aten::t(%12) + %59 : Float(*, *) = aten::mm(%51, %58) + %60 : Float(*, *) = aten::t(%11) + %61 : Float(*, *) = aten::mm(%60, %51) + %62 : Float(*, *) = aten::t(%61) + return (%47, %49, %50, %54, %57, %59, %62); } with prim::FusionGroup_0 = graph(%0 : Float(*, *) %1 : Tensor diff --git a/test/expect/TestScript.test_lstm_fusion_cuda-forward.expect b/test/expect/TestScript.test_lstm_fusion_cuda-forward.expect index 825e802..2933e1b 100644 --- a/test/expect/TestScript.test_lstm_fusion_cuda-forward.expect +++ b/test/expect/TestScript.test_lstm_fusion_cuda-forward.expect @@ -28,16 +28,14 @@ with prim::DifferentiableGraph_0 = graph(%0 : Float(*, *) %17 : Tensor, %18 : Tensor, %19 : Tensor, %20 : Tensor = prim::ListUnpack(%16) %21 : int[] = prim::BroadcastSizes(%11, %12) %22 : int[] = prim::BroadcastSizes(%21, %13) - %23 : int[] = aten::size(%0) - %hy : Float(*, *), %25 : Float(*, *), %cy : Float(*, *), %outgate.1 : Float(*, *), %cellgate.1 : Float(*, *), %forgetgate.1 : Float(*, *), %ingate.1 : Float(*, *) = prim::FusionGroup_0(%0, %20, %19, %18, %17) - %31 : int[] = aten::size(%25) - %32 : int[] = aten::size(%outgate.1) - %33 : int[] = aten::size(%cellgate.1) - %34 : int[] = aten::size(%forgetgate.1) - %35 : int[] = aten::size(%ingate.1) - %36 : int[] = prim::BroadcastSizes(%34, %23) - %37 : int[] = prim::BroadcastSizes(%35, %33) - return (%hy, %cy, %7, %9, %11, %12, %21, %13, %22, %14, %ingate.1, %forgetgate.1, %cellgate.1, %outgate.1, %34, %23, %35, %33, %36, %37, %25, %32, %31); + %hy : Float(*, *), %24 : Float(*, *), %cy : Float(*, *), %outgate.1 : Float(*, *), %cellgate.1 : Float(*, *), %forgetgate.1 : Float(*, *), %ingate.1 : Float(*, *) = prim::FusionGroup_0(%0, %20, %19, %18, %17) + %30 : int[] = aten::size(%0) + %31 : int[] = aten::size(%cellgate.1) + %32 : int[] = aten::size(%forgetgate.1) + %33 : int[] = aten::size(%ingate.1) + %34 : int[] = prim::BroadcastSizes(%32, %30) + %35 : int[] = prim::BroadcastSizes(%33, %31) + return (%hy, %cy, %7, %9, %11, %12, %21, %13, %22, %14, %ingate.1, %forgetgate.1, %cellgate.1, %outgate.1, %34, %35, %24); } with prim::FusionGroup_0 = graph(%0 : Float(*, *) %1 : Tensor diff --git a/test/expect/TestScript.test_milstm_fusion_cuda-backward.expect b/test/expect/TestScript.test_milstm_fusion_cuda-backward.expect index c53cb88..90ed94d 100644 --- a/test/expect/TestScript.test_milstm_fusion_cuda-backward.expect +++ b/test/expect/TestScript.test_milstm_fusion_cuda-backward.expect @@ -20,80 +20,75 @@ graph(%0 : Float(*, *) %Wx : Float(*, *) %20 : Float(*, *) %Uz : Float(*, *) - %22 : int[] + %22 : Float(*, *) %23 : int[] - %24 : Float(*, *) + %24 : int[] %25 : int[] %26 : int[] %27 : int[] %28 : int[] - %29 : int[] - %30 : int[] - %31 : int[] - %32 : int[] - %33 : int[] - %34 : int[] - %35 : int[] - %36 : int[] %ingate : Float(*, *) %forgetgate : Float(*, *) %cellgate : Float(*, *) %outgate : Float(*, *) - %41 : int[] - %42 : int[] - %43 : int[] - %44 : int[] - %45 : int[] - %46 : int[] - %47 : Float(*, *) - %48 : int[] - %49 : int[]) { - %50 : int = prim::Constant[value=1]() - %51 : Float(*, *) = aten::mul(%0, %47) - %52 : Tensor = prim::SumToSize(%51, %48) - %53 : Float(*, *) = aten::mul(%0, %outgate) - %54 : Tensor = prim::SumToSize(%53, %49) - %55 : Tensor = prim::FusionGroup_0(%1, %54, %47) - %56 : Tensor = prim::SumToSize(%55, %45) - %57 : Tensor = prim::SumToSize(%55, %46) - %58 : Tensor = aten::mul(%57, %cellgate) - %59 : Tensor = prim::SumToSize(%58, %43) - %60 : Tensor = aten::mul(%57, %ingate) - %61 : Tensor = prim::SumToSize(%60, %44) - %62 : Tensor = aten::mul(%56, %12) - %63 : Tensor = prim::SumToSize(%62, %41) - %64 : Tensor = prim::FusionGroup_1(%59, %ingate, %63, %forgetgate, %61, %cellgate, %52, %outgate) - %65 : Tensor = prim::SumToSize(%64, %35) - %66 : Tensor = prim::SumToSize(%64, %36) - %67 : Tensor = prim::SumToSize(%65, %33) - %68 : Tensor = prim::SumToSize(%65, %34) - %69 : Tensor = aten::mul(%68, %Uz) - %70 : Tensor = prim::SumToSize(%69, %31) - %71 : Tensor = aten::mul(%68, %13) - %72 : Tensor = prim::SumToSize(%71, %32) - %73 : Tensor = prim::SumToSize(%67, %29) - %74 : Tensor = prim::SumToSize(%67, %30) - %75 : Tensor = aten::mul(%74, %Wx) - %76 : Tensor = prim::SumToSize(%75, %27) - %77 : Tensor = aten::mul(%74, %14) - %78 : Tensor = prim::SumToSize(%77, %28) - %79 : Tensor = aten::mul(%73, %Uz) - %80 : Tensor = prim::SumToSize(%79, %25) - %81 : Tensor = aten::mul(%73, %24) - %82 : Tensor = prim::SumToSize(%81, %26) - %83 : Tensor = aten::add(%72, %82, %50) - %84 : Tensor = aten::mul(%80, %Wx) - %85 : Tensor = prim::SumToSize(%84, %22) - %86 : Tensor = aten::mul(%80, %15) - %87 : Tensor = prim::SumToSize(%86, %23) - %88 : Tensor = aten::add(%78, %87, %50) - %89 : Float(*, *) = aten::t(%16) - %90 : Float(*, *) = aten::mm(%89, %83) - %91 : Float(*, *) = aten::t(%90) - %92 : Float(*, *) = aten::t(%17) - %93 : Float(*, *) = aten::mm(%92, %88) - %94 : Float(*, *) = aten::t(%93) - return (%66, %70, %76, %85, %91, %94); + %33 : int[] + %34 : int[] + %35 : Float(*, *)) { + %36 : int = prim::Constant[value=1]() + %37 : int[] = aten::size(%35) + %38 : int[] = aten::size(%outgate) + %39 : int[] = aten::size(%cellgate) + %40 : int[] = aten::size(%ingate) + %41 : int[] = aten::size(%forgetgate) + %42 : int[] = aten::size(%13) + %43 : int[] = aten::size(%14) + %44 : int[] = aten::size(%Uz) + %45 : int[] = aten::size(%22) + %46 : int[] = aten::size(%Wx) + %47 : int[] = aten::size(%15) + %48 : Float(*, *) = aten::mul(%0, %35) + %49 : Tensor = prim::SumToSize(%48, %38) + %50 : Float(*, *) = aten::mul(%0, %outgate) + %51 : Tensor = prim::SumToSize(%50, %37) + %52 : Tensor = prim::FusionGroup_0(%1, %51, %35) + %53 : Tensor = prim::SumToSize(%52, %33) + %54 : Tensor = prim::SumToSize(%52, %34) + %55 : Tensor = aten::mul(%54, %cellgate) + %56 : Tensor = prim::SumToSize(%55, %40) + %57 : Tensor = aten::mul(%54, %ingate) + %58 : Tensor = prim::SumToSize(%57, %39) + %59 : Tensor = aten::mul(%53, %12) + %60 : Tensor = prim::SumToSize(%59, %41) + %61 : Tensor = prim::FusionGroup_1(%56, %ingate, %60, %forgetgate, %58, %cellgate, %49, %outgate) + %62 : Tensor = prim::SumToSize(%61, %28) + %63 : Tensor = prim::SumToSize(%61, %26) + %64 : Tensor = aten::mul(%63, %Uz) + %65 : Tensor = prim::SumToSize(%64, %42) + %66 : Tensor = aten::mul(%63, %13) + %67 : Tensor = prim::SumToSize(%66, %44) + %68 : Tensor = prim::SumToSize(%61, %23) + %69 : Tensor = prim::SumToSize(%61, %24) + %70 : Tensor = aten::mul(%69, %Wx) + %71 : Tensor = prim::SumToSize(%70, %43) + %72 : Tensor = aten::mul(%69, %14) + %73 : Tensor = prim::SumToSize(%72, %46) + %74 : Tensor = aten::mul(%68, %Uz) + %75 : Tensor = prim::SumToSize(%74, %45) + %76 : Tensor = aten::mul(%68, %22) + %77 : Tensor = prim::SumToSize(%76, %44) + %78 : Tensor = aten::add(%67, %77, %36) + %79 : Tensor = aten::mul(%75, %Wx) + %80 : Tensor = prim::SumToSize(%79, %47) + %81 : Tensor = aten::mul(%75, %15) + %82 : Tensor = prim::SumToSize(%81, %46) + %83 : Tensor = aten::add(%73, %82, %36) + %84 : Float(*, *) = aten::t(%16) + %85 : Float(*, *) = aten::mm(%84, %78) + %86 : Float(*, *) = aten::t(%85) + %87 : Float(*, *) = aten::t(%17) + %88 : Float(*, *) = aten::mm(%87, %83) + %89 : Float(*, *) = aten::t(%88) + return (%62, %65, %71, %80, %86, %89); } with prim::FusionGroup_0 = graph(%0 : Float(*, *) %1 : Tensor diff --git a/test/expect/TestScript.test_milstm_fusion_cuda-forward.expect b/test/expect/TestScript.test_milstm_fusion_cuda-forward.expect index 1c4e483..c53dbb3 100644 --- a/test/expect/TestScript.test_milstm_fusion_cuda-forward.expect +++ b/test/expect/TestScript.test_milstm_fusion_cuda-forward.expect @@ -23,32 +23,29 @@ with prim::DifferentiableGraph_0 = graph(%0 : Float(*, *) %Wx.1 : Float(*, *) = aten::mm(%7, %9) %11 : Float(*, *) = aten::t(%6) %Uz.1 : Float(*, *) = aten::mm(%5, %11) - %13 : int[] = aten::size(%4) - %14 : int[] = aten::size(%Wx.1) - %15 : Float(*, *) = aten::mul(%4, %Wx.1) - %16 : int[] = aten::size(%15) - %17 : int[] = aten::size(%Uz.1) - %18 : int[] = aten::size(%3) - %19 : int[] = aten::size(%2) - %20 : int[] = aten::size(%1) - %21 : Tensor[] = prim::ListConstruct(%1, %2, %Uz.1, %15, %3, %Wx.1) - %22 : Tensor[] = aten::broadcast_tensors(%21) - %23 : Tensor, %24 : Tensor, %25 : Tensor, %26 : Tensor, %27 : Tensor, %28 : Tensor = prim::ListUnpack(%22) - %29 : int[] = prim::BroadcastSizes(%18, %14) - %30 : int[] = prim::BroadcastSizes(%16, %17) - %31 : int[] = prim::BroadcastSizes(%19, %17) - %32 : int[] = prim::BroadcastSizes(%30, %29) - %33 : int[] = prim::BroadcastSizes(%32, %31) - %34 : int[] = aten::size(%0) - %hy : Float(*, *), %36 : Float(*, *), %cy : Float(*, *), %outgate.1 : Float(*, *), %cellgate.1 : Float(*, *), %forgetgate.1 : Float(*, *), %ingate.1 : Float(*, *) = prim::FusionGroup_0(%0, %28, %27, %26, %25, %24, %23) - %42 : int[] = aten::size(%36) - %43 : int[] = aten::size(%outgate.1) - %44 : int[] = aten::size(%cellgate.1) - %45 : int[] = aten::size(%forgetgate.1) - %46 : int[] = aten::size(%ingate.1) - %47 : int[] = prim::BroadcastSizes(%45, %34) - %48 : int[] = prim::BroadcastSizes(%46, %44) - return (%hy, %cy, %9, %Wx.1, %11, %Uz.1, %13, %14, %15, %16, %17, %18, %14, %30, %29, %19, %17, %32, %31, %33, %20, %ingate.1, %forgetgate.1, %cellgate.1, %outgate.1, %45, %34, %46, %44, %47, %48, %36, %43, %42); + %13 : Float(*, *) = aten::mul(%4, %Wx.1) + %14 : int[] = aten::size(%1) + %15 : Tensor[] = prim::ListConstruct(%1, %2, %Uz.1, %13, %3, %Wx.1) + %16 : Tensor[] = aten::broadcast_tensors(%15) + %17 : Tensor, %18 : Tensor, %19 : Tensor, %20 : Tensor, %21 : Tensor, %22 : Tensor = prim::ListUnpack(%16) + %23 : int[] = aten::size(%3) + %24 : int[] = aten::size(%Wx.1) + %25 : int[] = prim::BroadcastSizes(%23, %24) + %26 : int[] = aten::size(%13) + %27 : int[] = aten::size(%Uz.1) + %28 : int[] = prim::BroadcastSizes(%26, %27) + %29 : int[] = aten::size(%2) + %30 : int[] = prim::BroadcastSizes(%29, %27) + %31 : int[] = prim::BroadcastSizes(%28, %25) + %32 : int[] = prim::BroadcastSizes(%31, %30) + %hy : Float(*, *), %34 : Float(*, *), %cy : Float(*, *), %outgate.1 : Float(*, *), %cellgate.1 : Float(*, *), %forgetgate.1 : Float(*, *), %ingate.1 : Float(*, *) = prim::FusionGroup_0(%0, %22, %21, %20, %19, %18, %17) + %40 : int[] = aten::size(%0) + %41 : int[] = aten::size(%cellgate.1) + %42 : int[] = aten::size(%forgetgate.1) + %43 : int[] = aten::size(%ingate.1) + %44 : int[] = prim::BroadcastSizes(%42, %40) + %45 : int[] = prim::BroadcastSizes(%43, %41) + return (%hy, %cy, %9, %Wx.1, %11, %Uz.1, %13, %28, %25, %31, %30, %32, %14, %ingate.1, %forgetgate.1, %cellgate.1, %outgate.1, %44, %45, %34); } with prim::FusionGroup_0 = graph(%0 : Float(*, *) %1 : Tensor diff --git a/torch/csrc/jit/autodiff.cpp b/torch/csrc/jit/autodiff.cpp index 1e16334..aa42855 100644 --- a/torch/csrc/jit/autodiff.cpp +++ b/torch/csrc/jit/autodiff.cpp @@ -1,5 +1,6 @@ #include "torch/csrc/jit/autodiff.h" +#include "torch/csrc/jit/passes/common_subexpression_elimination.h" #include "torch/csrc/jit/passes/dead_code_elimination.h" #include "torch/csrc/jit/passes/constant_pooling.h" #include "torch/csrc/jit/symbolic_variable.h" @@ -632,6 +633,32 @@ static ReverseDetails addReverseInline(Gradient& grad_desc) { return ReverseDetails(std::move(grad_map), reverse_block); } +// Returns a topologically-sorted list of values produced in f, and used in its reverse program. +static value_list getReverseCaptures(Gradient& grad_desc) { + auto & graph = *grad_desc.f; + auto primal_block = graph.block(); + + value_set reverse_captures_set; + value_list reverse_captures; // Invariant: topo sorted + auto check_uses = [&](Value *v) { + for (auto use : v->uses()) { + if (use.user->owningBlock() == primal_block) + continue; + if (/* bool unseen = */ reverse_captures_set.emplace(v).second) { + reverse_captures.push_back(v); + } + } + }; + for (Value * input : graph.inputs()) { + check_uses(input); + } + for (Node * node : graph.nodes()) { + for (Value * output : node->outputs()) + check_uses(output); + } + return reverse_captures; +} + // Any temporary value from the primal graphs needs to be captured for later use in the // reverse graph, to avoid costly recomputations. However, a lot of the nodes we have // in our graphs are simply constants, which are cheap to execute and replicate, and so @@ -662,6 +689,73 @@ static void liftConstants(Gradient& grad_desc, ReverseDetails& rev_info) { } } +static void deduplicateSizeCaptures(Gradient& grad_desc, ReverseDetails& rev_info) { + Block * primal_block = grad_desc.f->block(); + const auto usedOnlyInReverse = [primal_block](Value * v) { + const auto & uses = v->uses(); + return std::all_of(uses.begin(), uses.end(), + [primal_block](const Use& u) { return u.user->owningBlock() != primal_block; }); + }; + auto captures = getReverseCaptures(grad_desc); + value_set capture_set (captures.begin(), captures.end()); + for (Value * capture : captures) { + Node * node = capture->node(); + if (!node->matches("aten::size(Tensor self) -> int[]")) { + continue; + } + if (usedOnlyInReverse(capture) && capture_set.count(node->input())) { + WithInsertPoint insert_guard { *rev_info.reverse_block->nodes().begin() }; + capture->replaceAllUsesWith(SymbolicVariable(node->input()).size()); + node->destroy(); + } + } +} + +static void eliminateDeadCode(ReverseDetails& rev_info) { + // addReverseInline has to call gradientForNode if *any* of the inputs + // require grad, but it will emit vjps for *all* inputs. Use DCE to remove + // unnecessary nodes. Additionally, requires_grad() on intermediates is an + // overapproximation of the real state, so we might have emitted some + // gradients, only to realize that they were unnecessary once we reach a + // point that doesn't require grad. + // Of course, we need to filter out corresponding entries of grad_map, because + // we don't want to accidentally access freed pointers later. + auto dead_nodes = FindDeadNodes(rev_info.reverse_block); + if (dead_nodes.empty()) { + return; + } + std::vector to_erase; + for (auto & entry : rev_info.grad_map) { + if (dead_nodes.count(entry.second->node()) > 0) { + to_erase.push_back(entry.first); + } + } + for (Value * v : to_erase) { + rev_info.grad_map.erase(v); + } + std::vector sorted_dead_nodes(dead_nodes.begin(), dead_nodes.end()); + std::sort(sorted_dead_nodes.begin(), sorted_dead_nodes.end(), [](Node* a, Node* b) { return a->isAfter(b); }); + for (Node * n : sorted_dead_nodes) { + n->destroy(); + } +} + +static void Optimize(Gradient& grad_desc, ReverseDetails& rev_info) { + // TODO: we are sometimes emitting expressions like SumToSize(SumToSize(x, s1), s2), + // which are equivalent to SumToSize(x, s2), and could save us some captures, but I'm + // not 100% sure how to optimize this at this stage, since we don't know which + // GradOf blocks will be stitched together to form the derivative. I guess a smart + // analysis could implement this, but I didn't have time before the 1.0 release, + // so I put this only as a peephole optimization. + liftConstants(grad_desc, rev_info); + // We generally add a lot of aten::size calls (for derivatives of broadcasting + // operators), and they often end up duplicated, and would get captured multiple + // times. Make sure we deduplicate them before lifting. + EliminateCommonSubexpression(grad_desc.f); + deduplicateSizeCaptures(grad_desc, rev_info); + eliminateDeadCode(rev_info); +} + // Takes a grad_desc.f returned from `addReverseInline` and splits off the // reverse_block into its own graph, storing it in df. // All intermediates needed in the second stage are added to @@ -681,24 +775,8 @@ static void lambdaLiftReverse(Gradient& grad_desc, ReverseDetails& rev_info) { // and used in df. They will need to be added as inputs of the df // and some of them may also need to be appended as outputs of f if // they are not already an input or an output of f - value_set reverse_captures_set; - value_list reverse_captures; // Invariant: topo sorted - auto check_uses = [&](Value *v) { - for (auto use : v->uses()) { - if (use.user->owningBlock() == primal_block) - continue; - if (/* bool unseen = */ reverse_captures_set.emplace(v).second) { - reverse_captures.push_back(v); - } - } - }; - for (Value * input : graph.inputs()) { - check_uses(input); - } - for (Node * node : graph.nodes()) { - for (Value * output : node->outputs()) - check_uses(output); - } + // Invariant: topo sorted + value_list reverse_captures = getReverseCaptures(grad_desc); // -------------------------------------------------------------------------- // 2. Prepare input/outputs lists for f and df @@ -750,7 +828,12 @@ static void lambdaLiftReverse(Gradient& grad_desc, ReverseDetails& rev_info) { for (size_t i = grad_desc.f_real_outputs; i < graph.outputs().size(); ++i) { Value * tmp = graph.outputs().at(i); // Add VJP inputs only for intermediates that actually required grad. - if (!tmp->requires_grad()) continue; + // Note that we check the contents of the grad_map instead of tmp->requires_grad(), + // becuase it's actually a more faithful source. tmp->requires_grad() is really an + // overapproximation (i.e. it can have false positives), while the gradients we will + // emit for this value can get DCE-d in the optimization pass (because it has no + // influence on the real f's outputs that we differentiate). + if (rev_info.grad_map.count(tmp) == 0) continue; Value * tmp_vjp_in = reverse_block->addInput()->setType(tmp->type()); Value * tmp_vjp_prev = rev_info.grad_map.at(tmp); // This is quite weird because we can't first make a sum and then replace all uses @@ -786,6 +869,7 @@ static void lambdaLiftReverse(Gradient& grad_desc, ReverseDetails& rev_info) { reverse_block->owningNode()->destroy(); } + Gradient differentiate(std::shared_ptr& graph) { Gradient grad_desc; // Take ownership of the graph @@ -798,18 +882,13 @@ Gradient differentiate(std::shared_ptr& graph) { WithInsertPoint guard(grad_desc.f->block()); // Fills in df_input_vjps and df_output_vjps auto rev_info = addReverseInline(grad_desc); - // Lift constants captured for the reverse graph into it - liftConstants(grad_desc, rev_info); + Optimize(grad_desc, rev_info); // Fills in f, df, f_real_outputs, df_input_captures, // modifies df_input_vjps (new vjps are added for temporaries) lambdaLiftReverse(grad_desc, rev_info); // It's possible the we've cloned the same constants many times, so // de-duplicate them ConstantPooling(grad_desc.df); - // addReverseInline has to call gradientForNode if *any* of the inputs - // require grad, but it will emit vjps for *all* inputs. Use DCE to remove - // unnecessary nodes. - EliminateDeadCode(grad_desc.df->block()); return grad_desc; } diff --git a/torch/csrc/jit/passes/dead_code_elimination.cpp b/torch/csrc/jit/passes/dead_code_elimination.cpp index ddcdd10..637f39e 100644 --- a/torch/csrc/jit/passes/dead_code_elimination.cpp +++ b/torch/csrc/jit/passes/dead_code_elimination.cpp @@ -11,7 +11,8 @@ class DeadCodeEliminator { public: explicit DeadCodeEliminator(std::shared_ptr graph) : aliasDb_(AliasAnalysis(graph)) {} - DeadCodeEliminator(){}; + DeadCodeEliminator(bool collect_only = false) + : collect_only_(collect_only) {} // The algorithm is an inverse mark-and-sweep. Starting from the return node, // we mark "live" nodes that are necessary for the output. Nodes that have @@ -28,6 +29,12 @@ class DeadCodeEliminator { sweep(block, recurse); } + // *_once, because it should be called only once per run() call. + std::unordered_set get_dead_once() { + JIT_ASSERT(collect_only_); + return std::move(dead_nodes_); + } + private: void setLastWildcard() { if (!aliasDb_) { @@ -100,8 +107,10 @@ class DeadCodeEliminator { auto node = *it; // note these occur before the recursion because we want to uncover // dead code in the blocks used to calculate the output - removeDeadIfOutputs(node); - removeDeadLoopOutputs(node); + if (!collect_only_) { + removeDeadIfOutputs(node); + removeDeadLoopOutputs(node); + } if (recurse) { for (Block* block : node->blocks()) { sweep(block, true); @@ -110,10 +119,28 @@ class DeadCodeEliminator { // TODO(suo): We shouldn't really have to check whether a node has uses, // since the mark algorithm should do that. But currently, the marking // doesn't reach loop counters in certain cases (see TestScript.test_pass) - if (!marked_.count(node) && !node->hasUses()) { - it.destroyCurrent(); + if (!marked_.count(node) && !hasUsesOutsideDeadNodes(node)) { + if (collect_only_) { + dead_nodes_.insert(node); + } else { + it.destroyCurrent(); + } + } + } + } + + bool hasUsesOutsideDeadNodes(Node * n) { + if (!collect_only_) { + return n->hasUses(); + } + for (Value * output : n->outputs()) { + for (const Use & u : output->uses()) { + if (dead_nodes_.count(u.user) == 0) { + return true; + } } } + return false; } void markAndEnqueue(Node* n) { @@ -213,15 +240,23 @@ class DeadCodeEliminator { std::list workQueue_; c10::optional lastWildcard_; + bool collect_only_ = false; + std::unordered_set dead_nodes_; // Will be filled iff collect_only_ is true }; void EliminateDeadCode(const std::shared_ptr& graph) { - DeadCodeEliminator(graph).run(graph->block(), true); + DeadCodeEliminator(graph).run(graph->block(), /*recurse=*/true); } void EliminateDeadCode(Block* block, bool recurse) { DeadCodeEliminator().run(block, recurse); } +std::unordered_set FindDeadNodes(Block* block, bool recurse) { + DeadCodeEliminator eliminator(/*collect_only=*/true); + eliminator.run(block, recurse); + return eliminator.get_dead_once(); +} + } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/passes/dead_code_elimination.h b/torch/csrc/jit/passes/dead_code_elimination.h index 39aec2f..0e908eb 100644 --- a/torch/csrc/jit/passes/dead_code_elimination.h +++ b/torch/csrc/jit/passes/dead_code_elimination.h @@ -13,4 +13,6 @@ namespace torch { namespace jit { TORCH_API void EliminateDeadCode(const std::shared_ptr& graph); TORCH_API void EliminateDeadCode(Block *block, bool recurse=true); +TORCH_API std::unordered_set FindDeadNodes(Block *block, bool recurse=true); + }} diff --git a/torch/csrc/jit/passes/graph_fuser.cpp b/torch/csrc/jit/passes/graph_fuser.cpp index ab440f2..ee9ba8a 100644 --- a/torch/csrc/jit/passes/graph_fuser.cpp +++ b/torch/csrc/jit/passes/graph_fuser.cpp @@ -808,6 +808,9 @@ struct GraphFuser { [](const Use& u) { return u.user->matches("aten::size(Tensor self) -> int[]"); }); } + // Builds up expressions that compute shapes of all intermediates (and outputs) + // of the fusion group, based on the sizes of inputs. You should run DCE to remove + // those that you end up not using. std::unordered_map buildShapeExpressions(Node * fusion_group) { WithInsertPoint insert_guard { fusion_group->next() }; std::unordered_map shape_of; @@ -852,6 +855,8 @@ struct GraphFuser { sizes_node->i_(attr::chunks, n->i(attr::chunks)); Value * regular_size = sizes_node->outputs().at(0); Value * last_size = sizes_node->outputs().at(1); + regular_size->setType(ListType::ofInts()); + last_size->setType(ListType::ofInts()); auto outputs = n->outputs(); for (Value * o : outputs.slice(0, outputs.size() - 1)) { shape_of.emplace(o, regular_size); diff --git a/torch/csrc/jit/passes/peephole.cpp b/torch/csrc/jit/passes/peephole.cpp index c7fcccd..db91f4b 100644 --- a/torch/csrc/jit/passes/peephole.cpp +++ b/torch/csrc/jit/passes/peephole.cpp @@ -120,11 +120,18 @@ void PeepholeOptimizeImpl(Block * block, bool addmm_fusion_enabled) { node->get(attr::other)->toDouble() == 0) { node->output()->replaceAllUsesWith(node->input(0)); } - } else if(node->kind() == prim::TensorToNum || node->kind() == prim::ImplicitTensorToNum) { + } else if (node->kind() == prim::TensorToNum || node->kind() == prim::ImplicitTensorToNum) { Node* input_node = node->input()->node(); if (input_node->kind() == prim::NumToTensor) { node->output()->replaceAllUsesWith(input_node->input()); } + } else if (node->matches("prim::SumToSize(Tensor(a) self, int[] size) -> Tensor(a)")) { + auto uses = node->output()->uses(); + for (Use u : uses) { + if (u.user->matches("prim::SumToSize(Tensor(a) self, int[] size) -> Tensor(a)")) { + u.user->replaceInput(0, node->inputs().at(0)); + } + } } } }