auto grad_spec = differentiate(graph);
std::vector<size_t> expected_captured_inputs = {0, 1};
- std::vector<size_t> expected_captured_outputs = {1, 2, 3, 4, 5, 6, 7};
- std::vector<size_t> expected_input_vjps = {0, 3};
+ std::vector<size_t> expected_captured_outputs = {1, 2};
+ std::vector<size_t> expected_input_vjps = {0, 1};
std::vector<size_t> expected_output_vjps = {0, 1};
ASSERT_EQ(grad_spec.f_real_outputs, 1);
ASSERT_EQ(grad_spec.df_input_captured_inputs, expected_captured_inputs);
PropagateRequiresGrad(graph);
auto grad_spec = differentiate(graph);
- std::vector<size_t> expected_input_vjps = {1, 4}; // for e and %4 = (d + a)
+ std::vector<size_t> expected_input_vjps = {1, 3}; // for e and %4 = (d + a)
std::vector<size_t> expected_output_vjps = {0}; // only a requires grad
ASSERT_EQ(grad_spec.f_real_outputs, 2);
ASSERT_EQ(grad_spec.df_input_captured_inputs, std::vector<size_t>({0}));
- ASSERT_EQ(grad_spec.df_input_captured_outputs, std::vector<size_t>({2, 3, 4, 5, 6, 7, 8}));
+ ASSERT_EQ(grad_spec.df_input_captured_outputs, std::vector<size_t>({2, 3, 4, 5}));
ASSERT_EQ(grad_spec.df_input_vjps, expected_input_vjps);
ASSERT_EQ(grad_spec.df_output_vjps, expected_output_vjps);
out << "testDifferentiateWithRequiresGrad\n";
testDifferentiate
graph(%0 : Float(2, 3, 4)
%1 : Float(2, 3, 4)) {
- %23 : int[] = aten::size(%0)
- %26 : int[] = aten::size(%1)
%2 : Float(2, 3, 4) = aten::mul(%0, %1)
- %15 : int[] = aten::size(%2)
- %18 : int[] = aten::size(%0)
%3 : Float(2, 3, 4) = aten::mul(%2, %0)
%4 : int = prim::Constant[value=1]()
%7 : int[] = aten::size(%3)
- %10 : int[] = aten::size(%1)
%5 : Float(2, 3, 4) = aten::add(%3, %1, %4)
- return (%5, %23, %26, %2, %15, %18, %7, %10);
+ return (%5, %2, %7);
}
graph(%0 : Float(2, 3, 4)
%1 : Float(2, 3, 4)
%2 : Float(2, 3, 4)
%3 : Float(2, 3, 4)
- %4 : int[]
- %5 : int[]
- %6 : Float(2, 3, 4)
- %7 : int[]
- %8 : int[]
- %9 : int[]
- %10 : int[]) {
- %11 : int = prim::Constant[value=1]()
- %12 : Tensor, %13 : Tensor = prim::GradOf[name="aten::add"](%0)
+ %4 : Float(2, 3, 4)
+ %5 : int[]) {
+ %9 : int = prim::Constant[value=1]()
+ %6 : int[] = aten::size(%4)
+ %7 : int[] = aten::size(%3)
+ %8 : int[] = aten::size(%2)
+ %10 : Tensor, %11 : Tensor = prim::GradOf[name="aten::add"](%0)
block0() {
- %14 : Tensor = prim::SumToSize(%0, %9)
- %15 : Float(2, 3, 4) = aten::mul(%0, %11)
- %16 : Tensor = prim::SumToSize(%15, %10)
- -> (%14, %16)
+ %12 : Tensor = prim::SumToSize(%0, %5)
+ %13 : Float(2, 3, 4) = aten::mul(%0, %9)
+ %14 : Tensor = prim::SumToSize(%13, %7)
+ -> (%12, %14)
}
- %17 : Tensor, %18 : Tensor = prim::GradOf[name="aten::mul"](%12)
+ %15 : Tensor, %16 : Tensor = prim::GradOf[name="aten::mul"](%10)
block0() {
- %19 : Tensor = aten::mul(%12, %2)
- %20 : Tensor = prim::SumToSize(%19, %7)
- %21 : Tensor = aten::mul(%12, %6)
- %22 : Tensor = prim::SumToSize(%21, %8)
- -> (%20, %22)
+ %17 : Tensor = aten::mul(%10, %2)
+ %18 : Tensor = prim::SumToSize(%17, %6)
+ %19 : Tensor = aten::mul(%10, %4)
+ %20 : Tensor = prim::SumToSize(%19, %8)
+ -> (%18, %20)
}
- %23 : Tensor = prim::AutogradAdd(%1, %17)
- %24 : Tensor, %25 : Tensor = prim::GradOf[name="aten::mul"](%23)
+ %21 : Tensor = prim::AutogradAdd(%1, %15)
+ %22 : Tensor, %23 : Tensor = prim::GradOf[name="aten::mul"](%21)
block0() {
- %26 : Tensor = aten::mul(%23, %3)
- %27 : Tensor = prim::SumToSize(%26, %4)
- %28 : Tensor = aten::mul(%23, %2)
- %29 : Tensor = prim::SumToSize(%28, %5)
- -> (%27, %29)
+ %24 : Tensor = aten::mul(%21, %3)
+ %25 : Tensor = prim::SumToSize(%24, %8)
+ %26 : Tensor = aten::mul(%21, %2)
+ %27 : Tensor = prim::SumToSize(%26, %7)
+ -> (%25, %27)
}
- %30 : Tensor = prim::AutogradAdd(%18, %24)
- %31 : Tensor = prim::AutogradAdd(%13, %25)
- return (%30, %31);
+ %28 : Tensor = prim::AutogradAdd(%16, %22)
+ %29 : Tensor = prim::AutogradAdd(%11, %23)
+ return (%28, %29);
}
testDifferentiateWithRequiresGrad
%2 : Float(*) = aten::mul(%1, %1)
%3 : int = prim::Constant[value=1]()
%4 : Float(*) = aten::add(%2, %1, %3)
- %5 : int = prim::Constant[value=1]()
%26 : int[] = aten::size(%4)
- %29 : int[] = aten::size(%0)
- %6 : Float(*) = aten::add(%4, %0, %5)
- %19 : int[] = aten::size(%6)
- %22 : int[] = aten::size(%0)
+ %6 : Float(*) = aten::add(%4, %0, %3)
%7 : Float(*) = aten::mul(%6, %0)
- %8 : int = prim::Constant[value=1]()
%11 : int[] = aten::size(%7)
%14 : int[] = aten::size(%1)
- %9 : Float(*) = aten::add(%7, %1, %8)
- return (%4, %9, %26, %29, %6, %19, %22, %11, %14);
+ %9 : Float(*) = aten::add(%7, %1, %3)
+ return (%4, %9, %26, %6, %11, %14);
}
graph(%0 : Float(*)
%1 : Float(*)
%2 : Float(*)
%3 : int[]
- %4 : int[]
- %5 : Float(*)
- %6 : int[]
- %7 : int[]
- %8 : int[]
- %9 : int[]) {
- %10 : int = prim::Constant[value=1]()
- %12 : Tensor, %13 : Tensor = prim::GradOf[name="aten::add"](%0)
+ %4 : Float(*)
+ %5 : int[]
+ %6 : int[]) {
+ %9 : int = prim::Constant[value=1]()
+ %7 : int[] = aten::size(%4)
+ %8 : int[] = aten::size(%2)
+ %11 : Tensor, %12 : Tensor = prim::GradOf[name="aten::add"](%0)
block0() {
- %14 : Tensor = prim::SumToSize(%0, %8)
- %15 : Tensor = aten::mul(%0, %10)
- %16 : Tensor = prim::SumToSize(%15, %9)
- -> (%14, %16)
+ %13 : Tensor = prim::SumToSize(%0, %5)
+ %14 : Tensor = aten::mul(%0, %9)
+ %15 : Tensor = prim::SumToSize(%14, %6)
+ -> (%13, %15)
}
- %17 : Tensor, %18 : Tensor = prim::GradOf[name="aten::mul"](%12)
+ %16 : Tensor, %17 : Tensor = prim::GradOf[name="aten::mul"](%11)
block0() {
- %19 : Tensor = aten::mul(%12, %2)
- %20 : Tensor = prim::SumToSize(%19, %6)
- %21 : Tensor = aten::mul(%12, %5)
- %22 : Tensor = prim::SumToSize(%21, %7)
- -> (%20, %22)
+ %18 : Tensor = aten::mul(%11, %2)
+ %19 : Tensor = prim::SumToSize(%18, %7)
+ %20 : Tensor = aten::mul(%11, %4)
+ %21 : Tensor = prim::SumToSize(%20, %8)
+ -> (%19, %21)
}
- %23 : Tensor = prim::AutogradAdd(%1, %17)
- %24 : Tensor, %25 : Tensor = prim::GradOf[name="aten::add"](%23)
+ %22 : Tensor = prim::AutogradAdd(%1, %16)
+ %23 : Tensor, %24 : Tensor = prim::GradOf[name="aten::add"](%22)
block0() {
- %26 : Tensor = prim::SumToSize(%23, %3)
- %27 : Tensor = aten::mul(%23, %10)
- %28 : Tensor = prim::SumToSize(%27, %4)
- -> (%26, %28)
+ %25 : Tensor = prim::SumToSize(%22, %3)
+ %26 : Tensor = aten::mul(%22, %9)
+ %27 : Tensor = prim::SumToSize(%26, %8)
+ -> (%25, %27)
}
- %29 : Tensor = prim::AutogradAdd(%18, %25)
- return (%29);
+ %28 : Tensor = prim::AutogradAdd(%17, %24)
+ return (%28);
}
%outgate : Float(*, *)
%24 : int[]
%25 : int[]
- %26 : int[]
- %27 : int[]
- %28 : int[]
- %29 : int[]
- %30 : Float(*, *)
- %31 : int[]
- %32 : int[]) {
- %33 : Float(*, *) = aten::mul(%0, %30)
- %34 : Tensor = prim::SumToSize(%33, %31)
+ %26 : Float(*, *)) {
+ %27 : int[] = aten::size(%26)
+ %28 : int[] = aten::size(%outgate)
+ %29 : int[] = aten::size(%cellgate)
+ %30 : int[] = aten::size(%ingate)
+ %31 : int[] = aten::size(%9)
+ %32 : int[] = aten::size(%forgetgate)
+ %33 : Float(*, *) = aten::mul(%0, %26)
+ %34 : Tensor = prim::SumToSize(%33, %28)
%35 : Float(*, *) = aten::mul(%0, %outgate)
- %36 : Tensor = prim::SumToSize(%35, %32)
- %37 : Tensor = prim::FusionGroup_0(%1, %36, %30)
- %38 : Tensor = prim::SumToSize(%37, %28)
- %39 : Tensor = prim::SumToSize(%37, %29)
+ %36 : Tensor = prim::SumToSize(%35, %27)
+ %37 : Tensor = prim::FusionGroup_0(%1, %36, %26)
+ %38 : Tensor = prim::SumToSize(%37, %24)
+ %39 : Tensor = prim::SumToSize(%37, %25)
%40 : Tensor = aten::mul(%39, %cellgate)
- %41 : Tensor = prim::SumToSize(%40, %26)
+ %41 : Tensor = prim::SumToSize(%40, %30)
%42 : Tensor = aten::mul(%39, %ingate)
- %43 : Tensor = prim::SumToSize(%42, %27)
+ %43 : Tensor = prim::SumToSize(%42, %29)
%44 : Tensor = aten::mul(%38, %9)
- %45 : Tensor = prim::SumToSize(%44, %24)
+ %45 : Tensor = prim::SumToSize(%44, %32)
%46 : Tensor = aten::mul(%38, %forgetgate)
- %47 : Tensor = prim::SumToSize(%46, %25)
+ %47 : Tensor = prim::SumToSize(%46, %31)
%48 : Tensor = prim::FusionGroup_1(%41, %ingate, %45, %forgetgate, %43, %cellgate, %34, %outgate)
- %49 : Tensor = prim::SumToSize(%48, %18)
- %50 : Tensor = prim::SumToSize(%48, %19)
- %51 : Tensor = prim::SumToSize(%49, %16)
- %52 : Tensor = prim::SumToSize(%49, %17)
- %53 : Tensor = prim::SumToSize(%51, %14)
- %54 : Tensor = prim::SumToSize(%51, %15)
- %55 : Float(*, *) = aten::t(%13)
- %56 : Float(*, *) = aten::mm(%54, %55)
- %57 : Float(*, *) = aten::t(%10)
- %58 : Float(*, *) = aten::mm(%57, %54)
- %59 : Float(*, *) = aten::t(%58)
- %60 : Float(*, *) = aten::t(%12)
- %61 : Float(*, *) = aten::mm(%53, %60)
- %62 : Float(*, *) = aten::t(%11)
- %63 : Float(*, *) = aten::mm(%62, %53)
- %64 : Float(*, *) = aten::t(%63)
- return (%47, %50, %52, %56, %59, %61, %64);
+ %49 : Tensor = prim::SumToSize(%48, %19)
+ %50 : Tensor = prim::SumToSize(%48, %17)
+ %51 : Tensor = prim::SumToSize(%48, %14)
+ %52 : Tensor = prim::SumToSize(%48, %15)
+ %53 : Float(*, *) = aten::t(%13)
+ %54 : Float(*, *) = aten::mm(%52, %53)
+ %55 : Float(*, *) = aten::t(%10)
+ %56 : Float(*, *) = aten::mm(%55, %52)
+ %57 : Float(*, *) = aten::t(%56)
+ %58 : Float(*, *) = aten::t(%12)
+ %59 : Float(*, *) = aten::mm(%51, %58)
+ %60 : Float(*, *) = aten::t(%11)
+ %61 : Float(*, *) = aten::mm(%60, %51)
+ %62 : Float(*, *) = aten::t(%61)
+ return (%47, %49, %50, %54, %57, %59, %62);
}
with prim::FusionGroup_0 = graph(%0 : Float(*, *)
%1 : Tensor
%17 : Tensor, %18 : Tensor, %19 : Tensor, %20 : Tensor = prim::ListUnpack(%16)
%21 : int[] = prim::BroadcastSizes(%11, %12)
%22 : int[] = prim::BroadcastSizes(%21, %13)
- %23 : int[] = aten::size(%0)
- %hy : Float(*, *), %25 : Float(*, *), %cy : Float(*, *), %outgate.1 : Float(*, *), %cellgate.1 : Float(*, *), %forgetgate.1 : Float(*, *), %ingate.1 : Float(*, *) = prim::FusionGroup_0(%0, %20, %19, %18, %17)
- %31 : int[] = aten::size(%25)
- %32 : int[] = aten::size(%outgate.1)
- %33 : int[] = aten::size(%cellgate.1)
- %34 : int[] = aten::size(%forgetgate.1)
- %35 : int[] = aten::size(%ingate.1)
- %36 : int[] = prim::BroadcastSizes(%34, %23)
- %37 : int[] = prim::BroadcastSizes(%35, %33)
- return (%hy, %cy, %7, %9, %11, %12, %21, %13, %22, %14, %ingate.1, %forgetgate.1, %cellgate.1, %outgate.1, %34, %23, %35, %33, %36, %37, %25, %32, %31);
+ %hy : Float(*, *), %24 : Float(*, *), %cy : Float(*, *), %outgate.1 : Float(*, *), %cellgate.1 : Float(*, *), %forgetgate.1 : Float(*, *), %ingate.1 : Float(*, *) = prim::FusionGroup_0(%0, %20, %19, %18, %17)
+ %30 : int[] = aten::size(%0)
+ %31 : int[] = aten::size(%cellgate.1)
+ %32 : int[] = aten::size(%forgetgate.1)
+ %33 : int[] = aten::size(%ingate.1)
+ %34 : int[] = prim::BroadcastSizes(%32, %30)
+ %35 : int[] = prim::BroadcastSizes(%33, %31)
+ return (%hy, %cy, %7, %9, %11, %12, %21, %13, %22, %14, %ingate.1, %forgetgate.1, %cellgate.1, %outgate.1, %34, %35, %24);
}
with prim::FusionGroup_0 = graph(%0 : Float(*, *)
%1 : Tensor
%Wx : Float(*, *)
%20 : Float(*, *)
%Uz : Float(*, *)
- %22 : int[]
+ %22 : Float(*, *)
%23 : int[]
- %24 : Float(*, *)
+ %24 : int[]
%25 : int[]
%26 : int[]
%27 : int[]
%28 : int[]
- %29 : int[]
- %30 : int[]
- %31 : int[]
- %32 : int[]
- %33 : int[]
- %34 : int[]
- %35 : int[]
- %36 : int[]
%ingate : Float(*, *)
%forgetgate : Float(*, *)
%cellgate : Float(*, *)
%outgate : Float(*, *)
- %41 : int[]
- %42 : int[]
- %43 : int[]
- %44 : int[]
- %45 : int[]
- %46 : int[]
- %47 : Float(*, *)
- %48 : int[]
- %49 : int[]) {
- %50 : int = prim::Constant[value=1]()
- %51 : Float(*, *) = aten::mul(%0, %47)
- %52 : Tensor = prim::SumToSize(%51, %48)
- %53 : Float(*, *) = aten::mul(%0, %outgate)
- %54 : Tensor = prim::SumToSize(%53, %49)
- %55 : Tensor = prim::FusionGroup_0(%1, %54, %47)
- %56 : Tensor = prim::SumToSize(%55, %45)
- %57 : Tensor = prim::SumToSize(%55, %46)
- %58 : Tensor = aten::mul(%57, %cellgate)
- %59 : Tensor = prim::SumToSize(%58, %43)
- %60 : Tensor = aten::mul(%57, %ingate)
- %61 : Tensor = prim::SumToSize(%60, %44)
- %62 : Tensor = aten::mul(%56, %12)
- %63 : Tensor = prim::SumToSize(%62, %41)
- %64 : Tensor = prim::FusionGroup_1(%59, %ingate, %63, %forgetgate, %61, %cellgate, %52, %outgate)
- %65 : Tensor = prim::SumToSize(%64, %35)
- %66 : Tensor = prim::SumToSize(%64, %36)
- %67 : Tensor = prim::SumToSize(%65, %33)
- %68 : Tensor = prim::SumToSize(%65, %34)
- %69 : Tensor = aten::mul(%68, %Uz)
- %70 : Tensor = prim::SumToSize(%69, %31)
- %71 : Tensor = aten::mul(%68, %13)
- %72 : Tensor = prim::SumToSize(%71, %32)
- %73 : Tensor = prim::SumToSize(%67, %29)
- %74 : Tensor = prim::SumToSize(%67, %30)
- %75 : Tensor = aten::mul(%74, %Wx)
- %76 : Tensor = prim::SumToSize(%75, %27)
- %77 : Tensor = aten::mul(%74, %14)
- %78 : Tensor = prim::SumToSize(%77, %28)
- %79 : Tensor = aten::mul(%73, %Uz)
- %80 : Tensor = prim::SumToSize(%79, %25)
- %81 : Tensor = aten::mul(%73, %24)
- %82 : Tensor = prim::SumToSize(%81, %26)
- %83 : Tensor = aten::add(%72, %82, %50)
- %84 : Tensor = aten::mul(%80, %Wx)
- %85 : Tensor = prim::SumToSize(%84, %22)
- %86 : Tensor = aten::mul(%80, %15)
- %87 : Tensor = prim::SumToSize(%86, %23)
- %88 : Tensor = aten::add(%78, %87, %50)
- %89 : Float(*, *) = aten::t(%16)
- %90 : Float(*, *) = aten::mm(%89, %83)
- %91 : Float(*, *) = aten::t(%90)
- %92 : Float(*, *) = aten::t(%17)
- %93 : Float(*, *) = aten::mm(%92, %88)
- %94 : Float(*, *) = aten::t(%93)
- return (%66, %70, %76, %85, %91, %94);
+ %33 : int[]
+ %34 : int[]
+ %35 : Float(*, *)) {
+ %36 : int = prim::Constant[value=1]()
+ %37 : int[] = aten::size(%35)
+ %38 : int[] = aten::size(%outgate)
+ %39 : int[] = aten::size(%cellgate)
+ %40 : int[] = aten::size(%ingate)
+ %41 : int[] = aten::size(%forgetgate)
+ %42 : int[] = aten::size(%13)
+ %43 : int[] = aten::size(%14)
+ %44 : int[] = aten::size(%Uz)
+ %45 : int[] = aten::size(%22)
+ %46 : int[] = aten::size(%Wx)
+ %47 : int[] = aten::size(%15)
+ %48 : Float(*, *) = aten::mul(%0, %35)
+ %49 : Tensor = prim::SumToSize(%48, %38)
+ %50 : Float(*, *) = aten::mul(%0, %outgate)
+ %51 : Tensor = prim::SumToSize(%50, %37)
+ %52 : Tensor = prim::FusionGroup_0(%1, %51, %35)
+ %53 : Tensor = prim::SumToSize(%52, %33)
+ %54 : Tensor = prim::SumToSize(%52, %34)
+ %55 : Tensor = aten::mul(%54, %cellgate)
+ %56 : Tensor = prim::SumToSize(%55, %40)
+ %57 : Tensor = aten::mul(%54, %ingate)
+ %58 : Tensor = prim::SumToSize(%57, %39)
+ %59 : Tensor = aten::mul(%53, %12)
+ %60 : Tensor = prim::SumToSize(%59, %41)
+ %61 : Tensor = prim::FusionGroup_1(%56, %ingate, %60, %forgetgate, %58, %cellgate, %49, %outgate)
+ %62 : Tensor = prim::SumToSize(%61, %28)
+ %63 : Tensor = prim::SumToSize(%61, %26)
+ %64 : Tensor = aten::mul(%63, %Uz)
+ %65 : Tensor = prim::SumToSize(%64, %42)
+ %66 : Tensor = aten::mul(%63, %13)
+ %67 : Tensor = prim::SumToSize(%66, %44)
+ %68 : Tensor = prim::SumToSize(%61, %23)
+ %69 : Tensor = prim::SumToSize(%61, %24)
+ %70 : Tensor = aten::mul(%69, %Wx)
+ %71 : Tensor = prim::SumToSize(%70, %43)
+ %72 : Tensor = aten::mul(%69, %14)
+ %73 : Tensor = prim::SumToSize(%72, %46)
+ %74 : Tensor = aten::mul(%68, %Uz)
+ %75 : Tensor = prim::SumToSize(%74, %45)
+ %76 : Tensor = aten::mul(%68, %22)
+ %77 : Tensor = prim::SumToSize(%76, %44)
+ %78 : Tensor = aten::add(%67, %77, %36)
+ %79 : Tensor = aten::mul(%75, %Wx)
+ %80 : Tensor = prim::SumToSize(%79, %47)
+ %81 : Tensor = aten::mul(%75, %15)
+ %82 : Tensor = prim::SumToSize(%81, %46)
+ %83 : Tensor = aten::add(%73, %82, %36)
+ %84 : Float(*, *) = aten::t(%16)
+ %85 : Float(*, *) = aten::mm(%84, %78)
+ %86 : Float(*, *) = aten::t(%85)
+ %87 : Float(*, *) = aten::t(%17)
+ %88 : Float(*, *) = aten::mm(%87, %83)
+ %89 : Float(*, *) = aten::t(%88)
+ return (%62, %65, %71, %80, %86, %89);
}
with prim::FusionGroup_0 = graph(%0 : Float(*, *)
%1 : Tensor
%Wx.1 : Float(*, *) = aten::mm(%7, %9)
%11 : Float(*, *) = aten::t(%6)
%Uz.1 : Float(*, *) = aten::mm(%5, %11)
- %13 : int[] = aten::size(%4)
- %14 : int[] = aten::size(%Wx.1)
- %15 : Float(*, *) = aten::mul(%4, %Wx.1)
- %16 : int[] = aten::size(%15)
- %17 : int[] = aten::size(%Uz.1)
- %18 : int[] = aten::size(%3)
- %19 : int[] = aten::size(%2)
- %20 : int[] = aten::size(%1)
- %21 : Tensor[] = prim::ListConstruct(%1, %2, %Uz.1, %15, %3, %Wx.1)
- %22 : Tensor[] = aten::broadcast_tensors(%21)
- %23 : Tensor, %24 : Tensor, %25 : Tensor, %26 : Tensor, %27 : Tensor, %28 : Tensor = prim::ListUnpack(%22)
- %29 : int[] = prim::BroadcastSizes(%18, %14)
- %30 : int[] = prim::BroadcastSizes(%16, %17)
- %31 : int[] = prim::BroadcastSizes(%19, %17)
- %32 : int[] = prim::BroadcastSizes(%30, %29)
- %33 : int[] = prim::BroadcastSizes(%32, %31)
- %34 : int[] = aten::size(%0)
- %hy : Float(*, *), %36 : Float(*, *), %cy : Float(*, *), %outgate.1 : Float(*, *), %cellgate.1 : Float(*, *), %forgetgate.1 : Float(*, *), %ingate.1 : Float(*, *) = prim::FusionGroup_0(%0, %28, %27, %26, %25, %24, %23)
- %42 : int[] = aten::size(%36)
- %43 : int[] = aten::size(%outgate.1)
- %44 : int[] = aten::size(%cellgate.1)
- %45 : int[] = aten::size(%forgetgate.1)
- %46 : int[] = aten::size(%ingate.1)
- %47 : int[] = prim::BroadcastSizes(%45, %34)
- %48 : int[] = prim::BroadcastSizes(%46, %44)
- return (%hy, %cy, %9, %Wx.1, %11, %Uz.1, %13, %14, %15, %16, %17, %18, %14, %30, %29, %19, %17, %32, %31, %33, %20, %ingate.1, %forgetgate.1, %cellgate.1, %outgate.1, %45, %34, %46, %44, %47, %48, %36, %43, %42);
+ %13 : Float(*, *) = aten::mul(%4, %Wx.1)
+ %14 : int[] = aten::size(%1)
+ %15 : Tensor[] = prim::ListConstruct(%1, %2, %Uz.1, %13, %3, %Wx.1)
+ %16 : Tensor[] = aten::broadcast_tensors(%15)
+ %17 : Tensor, %18 : Tensor, %19 : Tensor, %20 : Tensor, %21 : Tensor, %22 : Tensor = prim::ListUnpack(%16)
+ %23 : int[] = aten::size(%3)
+ %24 : int[] = aten::size(%Wx.1)
+ %25 : int[] = prim::BroadcastSizes(%23, %24)
+ %26 : int[] = aten::size(%13)
+ %27 : int[] = aten::size(%Uz.1)
+ %28 : int[] = prim::BroadcastSizes(%26, %27)
+ %29 : int[] = aten::size(%2)
+ %30 : int[] = prim::BroadcastSizes(%29, %27)
+ %31 : int[] = prim::BroadcastSizes(%28, %25)
+ %32 : int[] = prim::BroadcastSizes(%31, %30)
+ %hy : Float(*, *), %34 : Float(*, *), %cy : Float(*, *), %outgate.1 : Float(*, *), %cellgate.1 : Float(*, *), %forgetgate.1 : Float(*, *), %ingate.1 : Float(*, *) = prim::FusionGroup_0(%0, %22, %21, %20, %19, %18, %17)
+ %40 : int[] = aten::size(%0)
+ %41 : int[] = aten::size(%cellgate.1)
+ %42 : int[] = aten::size(%forgetgate.1)
+ %43 : int[] = aten::size(%ingate.1)
+ %44 : int[] = prim::BroadcastSizes(%42, %40)
+ %45 : int[] = prim::BroadcastSizes(%43, %41)
+ return (%hy, %cy, %9, %Wx.1, %11, %Uz.1, %13, %28, %25, %31, %30, %32, %14, %ingate.1, %forgetgate.1, %cellgate.1, %outgate.1, %44, %45, %34);
}
with prim::FusionGroup_0 = graph(%0 : Float(*, *)
%1 : Tensor
#include "torch/csrc/jit/autodiff.h"
+#include "torch/csrc/jit/passes/common_subexpression_elimination.h"
#include "torch/csrc/jit/passes/dead_code_elimination.h"
#include "torch/csrc/jit/passes/constant_pooling.h"
#include "torch/csrc/jit/symbolic_variable.h"
return ReverseDetails(std::move(grad_map), reverse_block);
}
+// Returns a topologically-sorted list of values produced in f, and used in its reverse program.
+static value_list getReverseCaptures(Gradient& grad_desc) {
+ auto & graph = *grad_desc.f;
+ auto primal_block = graph.block();
+
+ value_set reverse_captures_set;
+ value_list reverse_captures; // Invariant: topo sorted
+ auto check_uses = [&](Value *v) {
+ for (auto use : v->uses()) {
+ if (use.user->owningBlock() == primal_block)
+ continue;
+ if (/* bool unseen = */ reverse_captures_set.emplace(v).second) {
+ reverse_captures.push_back(v);
+ }
+ }
+ };
+ for (Value * input : graph.inputs()) {
+ check_uses(input);
+ }
+ for (Node * node : graph.nodes()) {
+ for (Value * output : node->outputs())
+ check_uses(output);
+ }
+ return reverse_captures;
+}
+
// Any temporary value from the primal graphs needs to be captured for later use in the
// reverse graph, to avoid costly recomputations. However, a lot of the nodes we have
// in our graphs are simply constants, which are cheap to execute and replicate, and so
}
}
+static void deduplicateSizeCaptures(Gradient& grad_desc, ReverseDetails& rev_info) {
+ Block * primal_block = grad_desc.f->block();
+ const auto usedOnlyInReverse = [primal_block](Value * v) {
+ const auto & uses = v->uses();
+ return std::all_of(uses.begin(), uses.end(),
+ [primal_block](const Use& u) { return u.user->owningBlock() != primal_block; });
+ };
+ auto captures = getReverseCaptures(grad_desc);
+ value_set capture_set (captures.begin(), captures.end());
+ for (Value * capture : captures) {
+ Node * node = capture->node();
+ if (!node->matches("aten::size(Tensor self) -> int[]")) {
+ continue;
+ }
+ if (usedOnlyInReverse(capture) && capture_set.count(node->input())) {
+ WithInsertPoint insert_guard { *rev_info.reverse_block->nodes().begin() };
+ capture->replaceAllUsesWith(SymbolicVariable(node->input()).size());
+ node->destroy();
+ }
+ }
+}
+
+static void eliminateDeadCode(ReverseDetails& rev_info) {
+ // addReverseInline has to call gradientForNode if *any* of the inputs
+ // require grad, but it will emit vjps for *all* inputs. Use DCE to remove
+ // unnecessary nodes. Additionally, requires_grad() on intermediates is an
+ // overapproximation of the real state, so we might have emitted some
+ // gradients, only to realize that they were unnecessary once we reach a
+ // point that doesn't require grad.
+ // Of course, we need to filter out corresponding entries of grad_map, because
+ // we don't want to accidentally access freed pointers later.
+ auto dead_nodes = FindDeadNodes(rev_info.reverse_block);
+ if (dead_nodes.empty()) {
+ return;
+ }
+ std::vector<Value*> to_erase;
+ for (auto & entry : rev_info.grad_map) {
+ if (dead_nodes.count(entry.second->node()) > 0) {
+ to_erase.push_back(entry.first);
+ }
+ }
+ for (Value * v : to_erase) {
+ rev_info.grad_map.erase(v);
+ }
+ std::vector<Node*> sorted_dead_nodes(dead_nodes.begin(), dead_nodes.end());
+ std::sort(sorted_dead_nodes.begin(), sorted_dead_nodes.end(), [](Node* a, Node* b) { return a->isAfter(b); });
+ for (Node * n : sorted_dead_nodes) {
+ n->destroy();
+ }
+}
+
+static void Optimize(Gradient& grad_desc, ReverseDetails& rev_info) {
+ // TODO: we are sometimes emitting expressions like SumToSize(SumToSize(x, s1), s2),
+ // which are equivalent to SumToSize(x, s2), and could save us some captures, but I'm
+ // not 100% sure how to optimize this at this stage, since we don't know which
+ // GradOf blocks will be stitched together to form the derivative. I guess a smart
+ // analysis could implement this, but I didn't have time before the 1.0 release,
+ // so I put this only as a peephole optimization.
+ liftConstants(grad_desc, rev_info);
+ // We generally add a lot of aten::size calls (for derivatives of broadcasting
+ // operators), and they often end up duplicated, and would get captured multiple
+ // times. Make sure we deduplicate them before lifting.
+ EliminateCommonSubexpression(grad_desc.f);
+ deduplicateSizeCaptures(grad_desc, rev_info);
+ eliminateDeadCode(rev_info);
+}
+
// Takes a grad_desc.f returned from `addReverseInline` and splits off the
// reverse_block into its own graph, storing it in df.
// All intermediates needed in the second stage are added to
// and used in df. They will need to be added as inputs of the df
// and some of them may also need to be appended as outputs of f if
// they are not already an input or an output of f
- value_set reverse_captures_set;
- value_list reverse_captures; // Invariant: topo sorted
- auto check_uses = [&](Value *v) {
- for (auto use : v->uses()) {
- if (use.user->owningBlock() == primal_block)
- continue;
- if (/* bool unseen = */ reverse_captures_set.emplace(v).second) {
- reverse_captures.push_back(v);
- }
- }
- };
- for (Value * input : graph.inputs()) {
- check_uses(input);
- }
- for (Node * node : graph.nodes()) {
- for (Value * output : node->outputs())
- check_uses(output);
- }
+ // Invariant: topo sorted
+ value_list reverse_captures = getReverseCaptures(grad_desc);
// --------------------------------------------------------------------------
// 2. Prepare input/outputs lists for f and df
for (size_t i = grad_desc.f_real_outputs; i < graph.outputs().size(); ++i) {
Value * tmp = graph.outputs().at(i);
// Add VJP inputs only for intermediates that actually required grad.
- if (!tmp->requires_grad()) continue;
+ // Note that we check the contents of the grad_map instead of tmp->requires_grad(),
+ // becuase it's actually a more faithful source. tmp->requires_grad() is really an
+ // overapproximation (i.e. it can have false positives), while the gradients we will
+ // emit for this value can get DCE-d in the optimization pass (because it has no
+ // influence on the real f's outputs that we differentiate).
+ if (rev_info.grad_map.count(tmp) == 0) continue;
Value * tmp_vjp_in = reverse_block->addInput()->setType(tmp->type());
Value * tmp_vjp_prev = rev_info.grad_map.at(tmp);
// This is quite weird because we can't first make a sum and then replace all uses
reverse_block->owningNode()->destroy();
}
+
Gradient differentiate(std::shared_ptr<Graph>& graph) {
Gradient grad_desc;
// Take ownership of the graph
WithInsertPoint guard(grad_desc.f->block());
// Fills in df_input_vjps and df_output_vjps
auto rev_info = addReverseInline(grad_desc);
- // Lift constants captured for the reverse graph into it
- liftConstants(grad_desc, rev_info);
+ Optimize(grad_desc, rev_info);
// Fills in f, df, f_real_outputs, df_input_captures,
// modifies df_input_vjps (new vjps are added for temporaries)
lambdaLiftReverse(grad_desc, rev_info);
// It's possible the we've cloned the same constants many times, so
// de-duplicate them
ConstantPooling(grad_desc.df);
- // addReverseInline has to call gradientForNode if *any* of the inputs
- // require grad, but it will emit vjps for *all* inputs. Use DCE to remove
- // unnecessary nodes.
- EliminateDeadCode(grad_desc.df->block());
return grad_desc;
}
public:
explicit DeadCodeEliminator(std::shared_ptr<Graph> graph)
: aliasDb_(AliasAnalysis(graph)) {}
- DeadCodeEliminator(){};
+ DeadCodeEliminator(bool collect_only = false)
+ : collect_only_(collect_only) {}
// The algorithm is an inverse mark-and-sweep. Starting from the return node,
// we mark "live" nodes that are necessary for the output. Nodes that have
sweep(block, recurse);
}
+ // *_once, because it should be called only once per run() call.
+ std::unordered_set<Node*> get_dead_once() {
+ JIT_ASSERT(collect_only_);
+ return std::move(dead_nodes_);
+ }
+
private:
void setLastWildcard() {
if (!aliasDb_) {
auto node = *it;
// note these occur before the recursion because we want to uncover
// dead code in the blocks used to calculate the output
- removeDeadIfOutputs(node);
- removeDeadLoopOutputs(node);
+ if (!collect_only_) {
+ removeDeadIfOutputs(node);
+ removeDeadLoopOutputs(node);
+ }
if (recurse) {
for (Block* block : node->blocks()) {
sweep(block, true);
// TODO(suo): We shouldn't really have to check whether a node has uses,
// since the mark algorithm should do that. But currently, the marking
// doesn't reach loop counters in certain cases (see TestScript.test_pass)
- if (!marked_.count(node) && !node->hasUses()) {
- it.destroyCurrent();
+ if (!marked_.count(node) && !hasUsesOutsideDeadNodes(node)) {
+ if (collect_only_) {
+ dead_nodes_.insert(node);
+ } else {
+ it.destroyCurrent();
+ }
+ }
+ }
+ }
+
+ bool hasUsesOutsideDeadNodes(Node * n) {
+ if (!collect_only_) {
+ return n->hasUses();
+ }
+ for (Value * output : n->outputs()) {
+ for (const Use & u : output->uses()) {
+ if (dead_nodes_.count(u.user) == 0) {
+ return true;
+ }
}
}
+ return false;
}
void markAndEnqueue(Node* n) {
std::list<Node*> workQueue_;
c10::optional<const Node*> lastWildcard_;
+ bool collect_only_ = false;
+ std::unordered_set<Node*> dead_nodes_; // Will be filled iff collect_only_ is true
};
void EliminateDeadCode(const std::shared_ptr<Graph>& graph) {
- DeadCodeEliminator(graph).run(graph->block(), true);
+ DeadCodeEliminator(graph).run(graph->block(), /*recurse=*/true);
}
void EliminateDeadCode(Block* block, bool recurse) {
DeadCodeEliminator().run(block, recurse);
}
+std::unordered_set<Node*> FindDeadNodes(Block* block, bool recurse) {
+ DeadCodeEliminator eliminator(/*collect_only=*/true);
+ eliminator.run(block, recurse);
+ return eliminator.get_dead_once();
+}
+
} // namespace jit
} // namespace torch
TORCH_API void EliminateDeadCode(const std::shared_ptr<Graph>& graph);
TORCH_API void EliminateDeadCode(Block *block, bool recurse=true);
+TORCH_API std::unordered_set<Node*> FindDeadNodes(Block *block, bool recurse=true);
+
}}
[](const Use& u) { return u.user->matches("aten::size(Tensor self) -> int[]"); });
}
+ // Builds up expressions that compute shapes of all intermediates (and outputs)
+ // of the fusion group, based on the sizes of inputs. You should run DCE to remove
+ // those that you end up not using.
std::unordered_map<Value*, Value*> buildShapeExpressions(Node * fusion_group) {
WithInsertPoint insert_guard { fusion_group->next() };
std::unordered_map<Value*, Value*> shape_of;
sizes_node->i_(attr::chunks, n->i(attr::chunks));
Value * regular_size = sizes_node->outputs().at(0);
Value * last_size = sizes_node->outputs().at(1);
+ regular_size->setType(ListType::ofInts());
+ last_size->setType(ListType::ofInts());
auto outputs = n->outputs();
for (Value * o : outputs.slice(0, outputs.size() - 1)) {
shape_of.emplace(o, regular_size);
node->get<at::Scalar>(attr::other)->toDouble() == 0) {
node->output()->replaceAllUsesWith(node->input(0));
}
- } else if(node->kind() == prim::TensorToNum || node->kind() == prim::ImplicitTensorToNum) {
+ } else if (node->kind() == prim::TensorToNum || node->kind() == prim::ImplicitTensorToNum) {
Node* input_node = node->input()->node();
if (input_node->kind() == prim::NumToTensor) {
node->output()->replaceAllUsesWith(input_node->input());
}
+ } else if (node->matches("prim::SumToSize(Tensor(a) self, int[] size) -> Tensor(a)")) {
+ auto uses = node->output()->uses();
+ for (Use u : uses) {
+ if (u.user->matches("prim::SumToSize(Tensor(a) self, int[] size) -> Tensor(a)")) {
+ u.user->replaceInput(0, node->inputs().at(0));
+ }
+ }
}
}
}