From 9ac845f73469e6b600fe02424c202af7e57aab73 Mon Sep 17 00:00:00 2001 From: Michael Suo Date: Mon, 3 Dec 2018 14:52:18 -0800 Subject: [PATCH] Revert D13280899: [pytorch][PR] Reduce broadcasted inputs in derivative code Differential Revision: D13280899 Original commit changeset: 80cc5ec9331b fbshipit-source-id: 2335093cca8fd7db95470fd83b9299adfa17aa8e --- aten/src/ATen/ExpandUtils.h | 21 +-- aten/src/ATen/core/interned_strings.h | 1 - test/cpp/jit/tests.h | 45 +++--- test/expect/TestJit.test_cpp_cuda.expect | 104 ++++-------- ...estScript.test_lstm_fusion_cuda-backward.expect | 140 +++++++--------- ...TestScript.test_lstm_fusion_cuda-forward.expect | 98 ++++++------ ...tScript.test_milstm_fusion_cuda-backward.expect | 178 +++++++++------------ ...stScript.test_milstm_fusion_cuda-forward.expect | 131 +++++++-------- test/test_jit.py | 21 ++- torch/csrc/autograd/engine.cpp | 2 +- torch/csrc/jit/argument_spec.h | 26 --- torch/csrc/jit/autodiff.cpp | 95 ++++------- torch/csrc/jit/register_prim_ops.cpp | 15 +- torch/csrc/jit/register_special_ops.cpp | 4 +- torch/csrc/jit/symbolic_variable.h | 9 -- 15 files changed, 354 insertions(+), 536 deletions(-) diff --git a/aten/src/ATen/ExpandUtils.h b/aten/src/ATen/ExpandUtils.h index 4a8cbd7..a5e9db4 100644 --- a/aten/src/ATen/ExpandUtils.h +++ b/aten/src/ATen/ExpandUtils.h @@ -136,25 +136,20 @@ inline std::vector expand_outplace(TensorList to_expand) { // Sums `tensor` repeatedly to produce a tensor of shape `shape`. // Precondition: is_expandable_to(shape, tensor.sizes()) must be true -static inline Tensor sum_to(Tensor tensor, const IntList shape) { +static inline Tensor sum_to(Tensor tensor, IntList shape) { if (shape.size() == 0) { return tensor.sum(); } - c10::SmallVector reduce_dims; - const at::IntList sizes = tensor.sizes(); - const int64_t leading_dims = sizes.size() - shape.size(); - for (int64_t i = 0; i < leading_dims; ++i) { - reduce_dims.push_back(i); + Tensor result = tensor; + while (result.dim() > (int64_t)shape.size()) { + result = result.sum(0, false); } - for (int64_t i = leading_dims; i < static_cast(sizes.size()); ++i) { - if (shape[i - leading_dims] == 1 && sizes[i] > 1) { - reduce_dims.push_back(i); + for (int64_t i = 0; i < result.dim(); ++i) { + if (shape[i] == 1 && result.sizes()[i] > 1) { + result = result.sum(i, true); } } - if (!reduce_dims.empty()) { - tensor = tensor.sum(reduce_dims, /*keepdim=*/true); - } - return leading_dims > 0 ? tensor.view(shape) : tensor; + return result; } // True if `shape` can be broadcasted to `desired` diff --git a/aten/src/ATen/core/interned_strings.h b/aten/src/ATen/core/interned_strings.h index ad64cb0..8f4ba55 100644 --- a/aten/src/ATen/core/interned_strings.h +++ b/aten/src/ATen/core/interned_strings.h @@ -64,7 +64,6 @@ namespace c10 { _(prim, requires_grad) \ _(prim, AutogradAdd) \ _(prim, GradOf) \ - _(prim, SumToSize) \ _(prim, AnyDefined) \ _(prim, FusedConcat) \ _(prim, ConstantChunk) \ diff --git a/test/cpp/jit/tests.h b/test/cpp/jit/tests.h index d728a65..2a4edc1 100644 --- a/test/cpp/jit/tests.h +++ b/test/cpp/jit/tests.h @@ -439,38 +439,38 @@ std::shared_ptr build_lstm() { return r; } -std::vector run(InterpreterState & interp, const std::vector & inputs) { +void run(InterpreterState & interp, const std::vector & inputs, std::vector & outputs) { std::vector stack(inputs.begin(), inputs.end()); interp.run(stack); - return fmap(stack, [](const IValue& i) { return i.toTensor(); }); + outputs.clear(); + for (auto& ivalue : stack) { + outputs.push_back(std::move(ivalue).toTensor()); + } } std::pair runGradient( Gradient& grad_spec, tensor_list& tensors_in, tensor_list& tensor_grads_in) { - static const auto as_tensorlist = [](const Stack& stack) { - return fmap(stack, [](const IValue& i) { return i.toTensor(); }); - }; + tensor_list tensors_out, tensor_grads_out; Code f_code{grad_spec.f}, df_code{grad_spec.df}; InterpreterState f_interpreter{f_code}, df_interpreter{df_code}; - auto f_stack = fmap(tensors_in); - f_interpreter.run(f_stack); + run(f_interpreter, tensors_in, tensors_out); - Stack df_stack; - df_stack.insert( - df_stack.end(), tensor_grads_in.begin(), tensor_grads_in.end()); + tensor_list df_inputs; + df_inputs.insert( + df_inputs.end(), tensor_grads_in.begin(), tensor_grads_in.end()); for (auto offset : grad_spec.df_input_captured_inputs) - df_stack.push_back(tensors_in[offset]); + df_inputs.push_back(tensors_in[offset]); for (auto offset : grad_spec.df_input_captured_outputs) - df_stack.push_back(f_stack[offset]); - df_interpreter.run(df_stack); + df_inputs.push_back(tensors_out[offset]); + run(df_interpreter, df_inputs, tensor_grads_out); // Outputs of f needs to be sliced - f_stack.erase( - f_stack.begin() + grad_spec.f_real_outputs, f_stack.end()); - return std::make_pair(as_tensorlist(f_stack), as_tensorlist(df_stack)); + tensors_out.erase( + tensors_out.begin() + grad_spec.f_real_outputs, tensors_out.end()); + return std::make_pair(tensors_out, tensor_grads_out); } void assertAllClose(const tensor_list& a, const tensor_list& b) { @@ -496,8 +496,9 @@ void testInterp() { auto lstm_g = build_lstm(); Code lstm_function(lstm_g); + std::vector outputs; InterpreterState lstm_interp(lstm_function); - auto outputs = run(lstm_interp, {input[0], hx, cx, w_ih, w_hh}); + run(lstm_interp, {input[0], hx, cx, w_ih, w_hh}, outputs); std::tie(hx, cx) = lstm(input[0], hx, cx, w_ih, w_hh); // std::cout << almostEqual(outputs[0],hx) << "\n"; @@ -835,8 +836,8 @@ void testDifferentiate(std::ostream& out = std::cout) { auto grad_spec = differentiate(graph); std::vector expected_captured_inputs = {0, 1}; - std::vector expected_captured_outputs = {1, 2, 3, 4, 5, 6, 7}; - std::vector expected_input_vjps = {0, 3}; + std::vector expected_captured_outputs = {1}; + std::vector expected_input_vjps = {0, 1}; std::vector expected_output_vjps = {0, 1}; ASSERT_EQ(grad_spec.f_real_outputs, 1); ASSERT_EQ(grad_spec.df_input_captured_inputs, expected_captured_inputs); @@ -866,11 +867,11 @@ void testDifferentiateWithRequiresGrad(std::ostream& out = std::cout) { PropagateRequiresGrad(graph); auto grad_spec = differentiate(graph); - std::vector expected_input_vjps = {1, 4}; // for e and %4 = (d + a) + std::vector expected_input_vjps = {1, 2}; // for e and %4 = (d + a) std::vector expected_output_vjps = {0}; // only a requires grad - ASSERT_EQ(grad_spec.f_real_outputs, 2); + ASSERT_EQ(grad_spec.f_real_outputs, 2); // we need one temporary %4 = (d + a) ASSERT_EQ(grad_spec.df_input_captured_inputs, std::vector({0})); - ASSERT_EQ(grad_spec.df_input_captured_outputs, std::vector({2, 3, 4, 5, 6, 7, 8})); + ASSERT_EQ(grad_spec.df_input_captured_outputs, std::vector({2})); ASSERT_EQ(grad_spec.df_input_vjps, expected_input_vjps); ASSERT_EQ(grad_spec.df_output_vjps, expected_output_vjps); out << "testDifferentiateWithRequiresGrad\n"; diff --git a/test/expect/TestJit.test_cpp_cuda.expect b/test/expect/TestJit.test_cpp_cuda.expect index ddcac4c..50ec297 100644 --- a/test/expect/TestJit.test_cpp_cuda.expect +++ b/test/expect/TestJit.test_cpp_cuda.expect @@ -93,57 +93,39 @@ with prim::DifferentiableGraph_0 = graph(%13 : Tensor testDifferentiate graph(%0 : Float(2, 3, 4) %1 : Float(2, 3, 4)) { - %23 : int[] = aten::size(%0) - %26 : int[] = aten::size(%1) %2 : Float(2, 3, 4) = aten::mul(%0, %1) - %15 : int[] = aten::size(%2) - %18 : int[] = aten::size(%0) %3 : Float(2, 3, 4) = aten::mul(%2, %0) %4 : int = prim::Constant[value=1]() - %7 : int[] = aten::size(%3) - %10 : int[] = aten::size(%1) %5 : Float(2, 3, 4) = aten::add(%3, %1, %4) - return (%5, %23, %26, %2, %15, %18, %7, %10); + return (%5, %2); } graph(%0 : Float(2, 3, 4) %1 : Float(2, 3, 4) %2 : Float(2, 3, 4) %3 : Float(2, 3, 4) - %4 : int[] - %5 : int[] - %6 : Float(2, 3, 4) - %7 : int[] - %8 : int[] - %9 : int[] - %10 : int[]) { - %11 : int = prim::Constant[value=1]() - %12 : Tensor, %13 : Tensor = prim::GradOf[name="aten::add"](%0) + %4 : Float(2, 3, 4)) { + %5 : int = prim::Constant[value=1]() + %6 : Float(2, 3, 4), %7 : Float(2, 3, 4) = prim::GradOf[name="aten::add"](%0) block0() { - %14 : Tensor = prim::SumToSize(%0, %9) - %15 : Float(2, 3, 4) = aten::mul(%0, %11) - %16 : Tensor = prim::SumToSize(%15, %10) - -> (%14, %16) + %8 : Float(2, 3, 4) = aten::mul(%0, %5) + -> (%0, %8) } - %17 : Tensor, %18 : Tensor = prim::GradOf[name="aten::mul"](%12) + %9 : Float(2, 3, 4), %10 : Float(2, 3, 4) = prim::GradOf[name="aten::mul"](%6) block0() { - %19 : Tensor = aten::mul(%12, %2) - %20 : Tensor = prim::SumToSize(%19, %7) - %21 : Tensor = aten::mul(%12, %6) - %22 : Tensor = prim::SumToSize(%21, %8) - -> (%20, %22) + %11 : Float(2, 3, 4) = aten::mul(%6, %2) + %12 : Float(2, 3, 4) = aten::mul(%6, %4) + -> (%11, %12) } - %23 : Tensor = prim::AutogradAdd(%1, %17) - %24 : Tensor, %25 : Tensor = prim::GradOf[name="aten::mul"](%23) + %13 : Tensor = prim::AutogradAdd(%1, %9) + %14 : Float(2, 3, 4), %15 : Float(2, 3, 4) = prim::GradOf[name="aten::mul"](%13) block0() { - %26 : Tensor = aten::mul(%23, %3) - %27 : Tensor = prim::SumToSize(%26, %4) - %28 : Tensor = aten::mul(%23, %2) - %29 : Tensor = prim::SumToSize(%28, %5) - -> (%27, %29) + %16 : Float(2, 3, 4) = aten::mul(%13, %3) + %17 : Float(2, 3, 4) = aten::mul(%13, %2) + -> (%16, %17) } - %30 : Tensor = prim::AutogradAdd(%18, %24) - %31 : Tensor = prim::AutogradAdd(%13, %25) - return (%30, %31); + %18 : Tensor = prim::AutogradAdd(%10, %14) + %19 : Tensor = prim::AutogradAdd(%7, %15) + return (%18, %19); } testDifferentiateWithRequiresGrad @@ -153,53 +135,35 @@ graph(%0 : Float(*) %3 : int = prim::Constant[value=1]() %4 : Float(*) = aten::add(%2, %1, %3) %5 : int = prim::Constant[value=1]() - %26 : int[] = aten::size(%4) - %29 : int[] = aten::size(%0) %6 : Float(*) = aten::add(%4, %0, %5) - %19 : int[] = aten::size(%6) - %22 : int[] = aten::size(%0) %7 : Float(*) = aten::mul(%6, %0) %8 : int = prim::Constant[value=1]() - %11 : int[] = aten::size(%7) - %14 : int[] = aten::size(%1) %9 : Float(*) = aten::add(%7, %1, %8) - return (%4, %9, %26, %29, %6, %19, %22, %11, %14); + return (%4, %9, %6); } graph(%0 : Float(*) %1 : Float(*) %2 : Float(*) - %3 : int[] - %4 : int[] - %5 : Float(*) - %6 : int[] - %7 : int[] - %8 : int[] - %9 : int[]) { - %10 : int = prim::Constant[value=1]() - %12 : Tensor, %13 : Tensor = prim::GradOf[name="aten::add"](%0) + %3 : Float(*)) { + %4 : int = prim::Constant[value=1]() + %6 : Float(*), %7 : Tensor = prim::GradOf[name="aten::add"](%0) block0() { - %14 : Tensor = prim::SumToSize(%0, %8) - %15 : Tensor = aten::mul(%0, %10) - %16 : Tensor = prim::SumToSize(%15, %9) - -> (%14, %16) + %8 : Tensor = aten::mul(%0, %4) + -> (%0, %8) } - %17 : Tensor, %18 : Tensor = prim::GradOf[name="aten::mul"](%12) + %9 : Tensor, %10 : Tensor = prim::GradOf[name="aten::mul"](%6) block0() { - %19 : Tensor = aten::mul(%12, %2) - %20 : Tensor = prim::SumToSize(%19, %6) - %21 : Tensor = aten::mul(%12, %5) - %22 : Tensor = prim::SumToSize(%21, %7) - -> (%20, %22) + %11 : Tensor = aten::mul(%6, %2) + %12 : Tensor = aten::mul(%6, %3) + -> (%11, %12) } - %23 : Tensor = prim::AutogradAdd(%1, %17) - %24 : Tensor, %25 : Tensor = prim::GradOf[name="aten::add"](%23) + %13 : Tensor = prim::AutogradAdd(%1, %9) + %14 : Tensor, %15 : Tensor = prim::GradOf[name="aten::add"](%13) block0() { - %26 : Tensor = prim::SumToSize(%23, %3) - %27 : Tensor = aten::mul(%23, %10) - %28 : Tensor = prim::SumToSize(%27, %4) - -> (%26, %28) + %16 : Tensor = aten::mul(%13, %4) + -> (%13, %16) } - %29 : Tensor = prim::AutogradAdd(%18, %25) - return (%29); + %17 : Tensor = prim::AutogradAdd(%10, %15) + return (%17); } diff --git a/test/expect/TestScript.test_lstm_fusion_cuda-backward.expect b/test/expect/TestScript.test_lstm_fusion_cuda-backward.expect index 9ee5162..eabf945 100644 --- a/test/expect/TestScript.test_lstm_fusion_cuda-backward.expect +++ b/test/expect/TestScript.test_lstm_fusion_cuda-backward.expect @@ -12,99 +12,71 @@ graph(%0 : Float(*, *) %11 : Float(*, *) %12 : Float(*, *) %13 : Float(*, *) - %14 : int[] - %15 : int[] - %16 : int[] - %17 : int[] - %18 : int[] - %19 : int[] %ingate : Float(*, *) %forgetgate : Float(*, *) %cellgate : Float(*, *) %outgate : Float(*, *) - %24 : int[] - %25 : int[] - %26 : int[] - %27 : int[] - %28 : int[] - %29 : int[] - %30 : Float(*, *) - %31 : int[] - %32 : int[]) { - %33 : Float(*, *) = aten::mul(%0, %30) - %34 : Tensor = prim::SumToSize(%33, %31) - %35 : Float(*, *) = aten::mul(%0, %outgate) - %36 : Tensor = prim::SumToSize(%35, %32) - %37 : Tensor = prim::FusionGroup_0(%1, %36, %30) - %38 : Tensor = prim::SumToSize(%37, %28) - %39 : Tensor = prim::SumToSize(%37, %29) - %40 : Tensor = aten::mul(%39, %cellgate) - %41 : Tensor = prim::SumToSize(%40, %26) - %42 : Tensor = aten::mul(%39, %ingate) - %43 : Tensor = prim::SumToSize(%42, %27) - %44 : Tensor = aten::mul(%38, %9) - %45 : Tensor = prim::SumToSize(%44, %24) - %46 : Tensor = aten::mul(%38, %forgetgate) - %47 : Tensor = prim::SumToSize(%46, %25) - %48 : Tensor = prim::FusionGroup_1(%41, %ingate, %45, %forgetgate, %43, %cellgate, %34, %outgate) - %49 : Tensor = prim::SumToSize(%48, %18) - %50 : Tensor = prim::SumToSize(%48, %19) - %51 : Tensor = prim::SumToSize(%49, %16) - %52 : Tensor = prim::SumToSize(%49, %17) - %53 : Tensor = prim::SumToSize(%51, %14) - %54 : Tensor = prim::SumToSize(%51, %15) - %55 : Float(*, *) = aten::t(%13) - %56 : Float(*, *) = aten::mm(%54, %55) - %57 : Float(*, *) = aten::t(%10) - %58 : Float(*, *) = aten::mm(%57, %54) - %59 : Float(*, *) = aten::t(%58) - %60 : Float(*, *) = aten::t(%12) - %61 : Float(*, *) = aten::mm(%53, %60) - %62 : Float(*, *) = aten::t(%11) - %63 : Float(*, *) = aten::mm(%62, %53) - %64 : Float(*, *) = aten::t(%63) - return (%47, %50, %52, %56, %59, %61, %64); + %18 : Float(*, *)) { + %19 : Float(*, *), %20 : Float(*, *) = prim::FusionGroup_0(%forgetgate, %ingate, %cellgate, %outgate, %9, %1, %18, %0) + %21 : Float(*, *) = aten::t(%13) + %22 : Float(*, *) = aten::mm(%20, %21) + %23 : Float(*, *) = aten::t(%10) + %24 : Float(*, *) = aten::mm(%23, %20) + %25 : Float(*, *) = aten::t(%24) + %26 : Float(*, *) = aten::t(%12) + %27 : Float(*, *) = aten::mm(%20, %26) + %28 : Float(*, *) = aten::t(%11) + %29 : Float(*, *) = aten::mm(%28, %20) + %30 : Float(*, *) = aten::t(%29) + return (%19, %20, %20, %22, %25, %27, %30); } with prim::FusionGroup_0 = graph(%0 : Float(*, *) - %1 : Tensor - %2 : Float(*, *)) { - %3 : Float(*, *) = aten::mul(%2, %2) - %4 : Float(*, *) = aten::neg(%3) - %5 : int = prim::Constant[value=1]() - %6 : Float(*, *) = aten::add(%4, %5, %5) - %7 : Tensor = aten::mul(%1, %6) - %8 : int = prim::Constant[value=1]() - %9 : Tensor = aten::add(%0, %7, %8) - return (%9); -} -with prim::FusionGroup_1 = graph(%0 : Tensor %1 : Float(*, *) - %2 : Tensor + %2 : Float(*, *) %3 : Float(*, *) - %4 : Tensor + %4 : Float(*, *) %5 : Float(*, *) - %6 : Tensor + %6 : Float(*, *) %7 : Float(*, *)) { - %8 : Float(*, *) = aten::neg(%7) - %9 : int = prim::Constant[value=1]() - %10 : Float(*, *) = aten::add(%8, %9, %9) - %11 : Float(*, *) = aten::mul(%10, %7) - %12 : Tensor = aten::mul(%11, %6) - %13 : Float(*, *) = aten::mul(%5, %5) - %14 : Float(*, *) = aten::neg(%13) - %15 : int = prim::Constant[value=1]() - %16 : Float(*, *) = aten::add(%14, %15, %15) - %17 : Tensor = aten::mul(%4, %16) - %18 : Float(*, *) = aten::neg(%3) - %19 : int = prim::Constant[value=1]() - %20 : Float(*, *) = aten::add(%18, %19, %19) - %21 : Float(*, *) = aten::mul(%20, %3) - %22 : Tensor = aten::mul(%21, %2) - %23 : Float(*, *) = aten::neg(%1) + %8 : Float(*, *) = aten::mul(%7, %6) + %9 : Float(*, *) = aten::mul(%7, %3) + %10 : Float(*, *) = aten::mul(%6, %6) + %11 : Float(*, *) = aten::neg(%10) + %12 : int = prim::Constant[value=1]() + %13 : int = prim::Constant[value=1]() + %14 : Float(*, *) = aten::add(%11, %13, %13) + %15 : Float(*, *) = aten::mul(%9, %14) + %16 : int = prim::Constant[value=1]() + %17 : int = prim::Constant[value=1]() + %18 : Float(*, *) = aten::add(%5, %15, %17) + %19 : Float(*, *) = aten::mul(%18, %2) + %20 : Float(*, *) = aten::mul(%18, %1) + %21 : Float(*, *) = aten::mul(%18, %4) + %22 : Float(*, *) = aten::neg(%3) + %23 : int = prim::Constant[value=1]() %24 : int = prim::Constant[value=1]() - %25 : Float(*, *) = aten::add(%23, %24, %24) - %26 : Float(*, *) = aten::mul(%25, %1) - %27 : Tensor = aten::mul(%26, %0) - %28 : Tensor = prim::FusedConcat[dim=1](%27, %22, %17, %12) - return (%28); + %25 : Float(*, *) = aten::add(%22, %24, %24) + %26 : Float(*, *) = aten::mul(%25, %3) + %27 : Float(*, *) = aten::mul(%26, %8) + %28 : Float(*, *) = aten::mul(%2, %2) + %29 : Float(*, *) = aten::neg(%28) + %30 : int = prim::Constant[value=1]() + %31 : int = prim::Constant[value=1]() + %32 : Float(*, *) = aten::add(%29, %31, %31) + %33 : Float(*, *) = aten::mul(%20, %32) + %34 : Float(*, *) = aten::neg(%0) + %35 : int = prim::Constant[value=1]() + %36 : int = prim::Constant[value=1]() + %37 : Float(*, *) = aten::add(%34, %36, %36) + %38 : Float(*, *) = aten::mul(%37, %0) + %39 : Float(*, *) = aten::mul(%38, %21) + %40 : Float(*, *) = aten::neg(%1) + %41 : int = prim::Constant[value=1]() + %42 : int = prim::Constant[value=1]() + %43 : Float(*, *) = aten::add(%40, %42, %42) + %44 : Float(*, *) = aten::mul(%43, %1) + %45 : Float(*, *) = aten::mul(%44, %19) + %46 : Float(*, *) = prim::FusedConcat[dim=1](%45, %39, %33, %27) + %47 : Float(*, *) = aten::mul(%18, %0) + return (%47, %46); } diff --git a/test/expect/TestScript.test_lstm_fusion_cuda-forward.expect b/test/expect/TestScript.test_lstm_fusion_cuda-forward.expect index a10d0c3..3772edc 100644 --- a/test/expect/TestScript.test_lstm_fusion_cuda-forward.expect +++ b/test/expect/TestScript.test_lstm_fusion_cuda-forward.expect @@ -20,58 +20,54 @@ with prim::DifferentiableGraph_0 = graph(%0 : Float(*, *) %9 : Float(*, *) = aten::t(%4) %10 : Float(*, *) = aten::mm(%3, %9) %11 : int = prim::Constant[value=1]() - %12 : int[] = aten::size(%8) - %13 : int[] = aten::size(%10) - %14 : int[] = aten::size(%2) - %15 : Float(*, *), %16 : Float(*, *) = prim::FusionGroup_0(%2, %8, %10) - %17 : int[] = aten::size(%16) - %18 : int[] = aten::size(%15) - %19 : int[] = aten::size(%1) - %20 : Tensor[] = prim::ListConstruct(%15, %1) - %21 : Tensor[] = aten::broadcast_tensors(%20) - %22 : Tensor, %23 : Tensor = prim::ListUnpack(%21) - %24 : int[] = aten::size(%0) - %hy : Float(*, *), %26 : Float(*, *), %cy : Float(*, *), %28 : Float(*, *), %29 : Float(*, *), %outgate.1 : Float(*, *), %cellgate.1 : Float(*, *), %forgetgate.1 : Float(*, *), %ingate.1 : Float(*, *) = prim::FusionGroup_1(%0, %23, %22) - %34 : int[] = aten::size(%ingate.1) - %35 : int[] = aten::size(%forgetgate.1) - %36 : int[] = aten::size(%cellgate.1) - %37 : int[] = aten::size(%outgate.1) - %38 : int[] = aten::size(%29) - %39 : int[] = aten::size(%28) - %40 : int[] = aten::size(%26) - return (%hy, %cy, %7, %9, %12, %13, %17, %14, %18, %19, %ingate.1, %forgetgate.1, %cellgate.1, %outgate.1, %35, %24, %34, %36, %38, %39, %26, %37, %40); + %12 : Tensor[] = prim::ListConstruct(%1, %2, %8, %10) + %13 : Tensor[] = aten::broadcast_tensors(%12) + %14 : Tensor, %15 : Tensor, %16 : Tensor, %17 : Tensor = prim::ListUnpack(%13) + %hy : Float(*, *), %19 : Float(*, *), %cy : Float(*, *), %outgate.1 : Float(*, *), %cellgate.1 : Float(*, *), %forgetgate.1 : Float(*, *), %ingate.1 : Float(*, *) = prim::FusionGroup_0(%0, %17, %16, %15, %14) + return (%hy, %cy, %7, %9, %ingate.1, %forgetgate.1, %cellgate.1, %outgate.1, %19); } -with prim::FusionGroup_0 = graph(%0 : Float(*) - %1 : Float(*, *) - %2 : Float(*, *)) { - %3 : int = prim::Constant[value=1]() - %4 : Float(*, *) = aten::add(%1, %2, %3) - %5 : int = prim::Constant[value=1]() - %6 : Float(*, *) = aten::add(%4, %0, %5) - return (%6, %4); -} -with prim::FusionGroup_1 = graph(%0 : Float(*, *) +with prim::FusionGroup_0 = graph(%0 : Float(*, *) %1 : Tensor - %2 : Tensor) { - %3 : Float(*, *), %4 : Float(*, *), %5 : Float(*, *), %6 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%2) - %7 : Float(*, *), %8 : Float(*, *), %9 : Float(*, *), %10 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%1) - %11 : int = prim::Constant[value=1]() - %12 : Float(*, *) = aten::add(%3, %7, %11) - %13 : int = prim::Constant[value=1]() - %14 : Float(*, *) = aten::add(%4, %8, %13) - %15 : int = prim::Constant[value=1]() - %16 : Float(*, *) = aten::add(%5, %9, %15) - %17 : int = prim::Constant[value=1]() - %18 : Float(*, *) = aten::add(%6, %10, %17) - %ingate.1 : Float(*, *) = aten::sigmoid(%12) - %forgetgate.1 : Float(*, *) = aten::sigmoid(%14) - %cellgate.1 : Float(*, *) = aten::tanh(%16) - %outgate.1 : Float(*, *) = aten::sigmoid(%18) - %23 : Float(*, *) = aten::mul(%forgetgate.1, %0) - %24 : Float(*, *) = aten::mul(%ingate.1, %cellgate.1) + %2 : Tensor + %3 : Tensor + %4 : Tensor) { + %5 : Float(*, *), %6 : Float(*, *), %7 : Float(*, *), %8 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%4) + %9 : Float(*, *), %10 : Float(*, *), %11 : Float(*, *), %12 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%3) + %13 : Float(*, *), %14 : Float(*, *), %15 : Float(*, *), %16 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%2) + %17 : Float(*, *), %18 : Float(*, *), %19 : Float(*, *), %20 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%1) + %21 : int = prim::Constant[value=1]() + %22 : Float(*, *) = aten::add(%13, %17, %21) + %23 : int = prim::Constant[value=1]() + %24 : Float(*, *) = aten::add(%14, %18, %23) %25 : int = prim::Constant[value=1]() - %cy : Float(*, *) = aten::add(%23, %24, %25) - %27 : Float(*, *) = aten::tanh(%cy) - %hy : Float(*, *) = aten::mul(%outgate.1, %27) - return (%hy, %27, %cy, %24, %23, %outgate.1, %cellgate.1, %forgetgate.1, %ingate.1); + %26 : Float(*, *) = aten::add(%15, %19, %25) + %27 : int = prim::Constant[value=1]() + %28 : Float(*, *) = aten::add(%16, %20, %27) + %29 : int = prim::Constant[value=1]() + %30 : Float(*, *) = aten::add(%22, %9, %29) + %31 : int = prim::Constant[value=1]() + %32 : Float(*, *) = aten::add(%24, %10, %31) + %33 : int = prim::Constant[value=1]() + %34 : Float(*, *) = aten::add(%26, %11, %33) + %35 : int = prim::Constant[value=1]() + %36 : Float(*, *) = aten::add(%28, %12, %35) + %37 : int = prim::Constant[value=1]() + %38 : Float(*, *) = aten::add(%30, %5, %37) + %39 : int = prim::Constant[value=1]() + %40 : Float(*, *) = aten::add(%32, %6, %39) + %41 : int = prim::Constant[value=1]() + %42 : Float(*, *) = aten::add(%34, %7, %41) + %43 : int = prim::Constant[value=1]() + %44 : Float(*, *) = aten::add(%36, %8, %43) + %ingate.1 : Float(*, *) = aten::sigmoid(%38) + %forgetgate.1 : Float(*, *) = aten::sigmoid(%40) + %cellgate.1 : Float(*, *) = aten::tanh(%42) + %outgate.1 : Float(*, *) = aten::sigmoid(%44) + %49 : Float(*, *) = aten::mul(%forgetgate.1, %0) + %50 : Float(*, *) = aten::mul(%ingate.1, %cellgate.1) + %51 : int = prim::Constant[value=1]() + %cy : Float(*, *) = aten::add(%49, %50, %51) + %53 : Float(*, *) = aten::tanh(%cy) + %hy : Float(*, *) = aten::mul(%outgate.1, %53) + return (%hy, %53, %cy, %outgate.1, %cellgate.1, %forgetgate.1, %ingate.1); } diff --git a/test/expect/TestScript.test_milstm_fusion_cuda-backward.expect b/test/expect/TestScript.test_milstm_fusion_cuda-backward.expect index c53cb88..c2c0bca 100644 --- a/test/expect/TestScript.test_milstm_fusion_cuda-backward.expect +++ b/test/expect/TestScript.test_milstm_fusion_cuda-backward.expect @@ -20,121 +20,87 @@ graph(%0 : Float(*, *) %Wx : Float(*, *) %20 : Float(*, *) %Uz : Float(*, *) - %22 : int[] - %23 : int[] - %24 : Float(*, *) - %25 : int[] - %26 : int[] - %27 : int[] - %28 : int[] - %29 : int[] - %30 : int[] - %31 : int[] - %32 : int[] - %33 : int[] - %34 : int[] - %35 : int[] - %36 : int[] + %22 : Float(*, *) %ingate : Float(*, *) %forgetgate : Float(*, *) %cellgate : Float(*, *) %outgate : Float(*, *) - %41 : int[] - %42 : int[] - %43 : int[] - %44 : int[] - %45 : int[] - %46 : int[] - %47 : Float(*, *) - %48 : int[] - %49 : int[]) { - %50 : int = prim::Constant[value=1]() - %51 : Float(*, *) = aten::mul(%0, %47) - %52 : Tensor = prim::SumToSize(%51, %48) - %53 : Float(*, *) = aten::mul(%0, %outgate) - %54 : Tensor = prim::SumToSize(%53, %49) - %55 : Tensor = prim::FusionGroup_0(%1, %54, %47) - %56 : Tensor = prim::SumToSize(%55, %45) - %57 : Tensor = prim::SumToSize(%55, %46) - %58 : Tensor = aten::mul(%57, %cellgate) - %59 : Tensor = prim::SumToSize(%58, %43) - %60 : Tensor = aten::mul(%57, %ingate) - %61 : Tensor = prim::SumToSize(%60, %44) - %62 : Tensor = aten::mul(%56, %12) - %63 : Tensor = prim::SumToSize(%62, %41) - %64 : Tensor = prim::FusionGroup_1(%59, %ingate, %63, %forgetgate, %61, %cellgate, %52, %outgate) - %65 : Tensor = prim::SumToSize(%64, %35) - %66 : Tensor = prim::SumToSize(%64, %36) - %67 : Tensor = prim::SumToSize(%65, %33) - %68 : Tensor = prim::SumToSize(%65, %34) - %69 : Tensor = aten::mul(%68, %Uz) - %70 : Tensor = prim::SumToSize(%69, %31) - %71 : Tensor = aten::mul(%68, %13) - %72 : Tensor = prim::SumToSize(%71, %32) - %73 : Tensor = prim::SumToSize(%67, %29) - %74 : Tensor = prim::SumToSize(%67, %30) - %75 : Tensor = aten::mul(%74, %Wx) - %76 : Tensor = prim::SumToSize(%75, %27) - %77 : Tensor = aten::mul(%74, %14) - %78 : Tensor = prim::SumToSize(%77, %28) - %79 : Tensor = aten::mul(%73, %Uz) - %80 : Tensor = prim::SumToSize(%79, %25) - %81 : Tensor = aten::mul(%73, %24) - %82 : Tensor = prim::SumToSize(%81, %26) - %83 : Tensor = aten::add(%72, %82, %50) - %84 : Tensor = aten::mul(%80, %Wx) - %85 : Tensor = prim::SumToSize(%84, %22) - %86 : Tensor = aten::mul(%80, %15) - %87 : Tensor = prim::SumToSize(%86, %23) - %88 : Tensor = aten::add(%78, %87, %50) - %89 : Float(*, *) = aten::t(%16) - %90 : Float(*, *) = aten::mm(%89, %83) - %91 : Float(*, *) = aten::t(%90) - %92 : Float(*, *) = aten::t(%17) - %93 : Float(*, *) = aten::mm(%92, %88) - %94 : Float(*, *) = aten::t(%93) - return (%66, %70, %76, %85, %91, %94); + %27 : Float(*, *)) { + %28 : Float(*, *) = prim::FusionGroup_0(%ingate, %forgetgate, %cellgate, %outgate, %12, %1, %27, %0) + %29 : Float(*, *) = aten::mul(%28, %Wx) + %30 : Float(*, *) = prim::FusionGroup_1(%28, %22, %13) + %31 : Float(*, *), %32 : Float(*, *), %33 : Float(*, *) = prim::FusionGroup_2(%Wx, %15, %28, %14, %Uz) + %34 : Float(*, *) = aten::t(%16) + %35 : Float(*, *) = aten::mm(%34, %30) + %36 : Float(*, *) = aten::t(%35) + %37 : Float(*, *) = aten::t(%17) + %38 : Float(*, *) = aten::mm(%37, %32) + %39 : Float(*, *) = aten::t(%38) + return (%28, %33, %29, %31, %36, %39); } with prim::FusionGroup_0 = graph(%0 : Float(*, *) - %1 : Tensor - %2 : Float(*, *)) { - %3 : Float(*, *) = aten::mul(%2, %2) - %4 : Float(*, *) = aten::neg(%3) - %5 : int = prim::Constant[value=1]() - %6 : Float(*, *) = aten::add(%4, %5, %5) - %7 : Tensor = aten::mul(%1, %6) - %8 : int = prim::Constant[value=1]() - %9 : Tensor = aten::add(%0, %7, %8) - return (%9); -} -with prim::FusionGroup_1 = graph(%0 : Tensor %1 : Float(*, *) - %2 : Tensor + %2 : Float(*, *) %3 : Float(*, *) - %4 : Tensor + %4 : Float(*, *) %5 : Float(*, *) - %6 : Tensor + %6 : Float(*, *) %7 : Float(*, *)) { - %8 : Float(*, *) = aten::neg(%7) - %9 : int = prim::Constant[value=1]() - %10 : Float(*, *) = aten::add(%8, %9, %9) - %11 : Float(*, *) = aten::mul(%10, %7) - %12 : Tensor = aten::mul(%11, %6) - %13 : Float(*, *) = aten::mul(%5, %5) - %14 : Float(*, *) = aten::neg(%13) + %8 : Float(*, *) = aten::mul(%7, %6) + %9 : Float(*, *) = aten::mul(%7, %3) + %10 : Float(*, *) = aten::mul(%6, %6) + %11 : Float(*, *) = aten::neg(%10) + %12 : int = prim::Constant[value=1]() + %13 : Float(*, *) = aten::add(%11, %12, %12) + %14 : Float(*, *) = aten::mul(%9, %13) %15 : int = prim::Constant[value=1]() - %16 : Float(*, *) = aten::add(%14, %15, %15) - %17 : Tensor = aten::mul(%4, %16) - %18 : Float(*, *) = aten::neg(%3) - %19 : int = prim::Constant[value=1]() - %20 : Float(*, *) = aten::add(%18, %19, %19) - %21 : Float(*, *) = aten::mul(%20, %3) - %22 : Tensor = aten::mul(%21, %2) - %23 : Float(*, *) = aten::neg(%1) - %24 : int = prim::Constant[value=1]() - %25 : Float(*, *) = aten::add(%23, %24, %24) - %26 : Float(*, *) = aten::mul(%25, %1) - %27 : Tensor = aten::mul(%26, %0) - %28 : Tensor = prim::FusedConcat[dim=1](%27, %22, %17, %12) - return (%28); + %16 : Float(*, *) = aten::add(%5, %14, %15) + %17 : Float(*, *) = aten::mul(%16, %2) + %18 : Float(*, *) = aten::mul(%16, %0) + %19 : Float(*, *) = aten::mul(%16, %4) + %20 : Float(*, *) = aten::neg(%3) + %21 : int = prim::Constant[value=1]() + %22 : Float(*, *) = aten::add(%20, %21, %21) + %23 : Float(*, *) = aten::mul(%22, %3) + %24 : Float(*, *) = aten::mul(%23, %8) + %25 : Float(*, *) = aten::mul(%2, %2) + %26 : Float(*, *) = aten::neg(%25) + %27 : int = prim::Constant[value=1]() + %28 : Float(*, *) = aten::add(%26, %27, %27) + %29 : Float(*, *) = aten::mul(%18, %28) + %30 : Float(*, *) = aten::neg(%1) + %31 : int = prim::Constant[value=1]() + %32 : Float(*, *) = aten::add(%30, %31, %31) + %33 : Float(*, *) = aten::mul(%32, %1) + %34 : Float(*, *) = aten::mul(%33, %19) + %35 : Float(*, *) = aten::neg(%0) + %36 : int = prim::Constant[value=1]() + %37 : Float(*, *) = aten::add(%35, %36, %36) + %38 : Float(*, *) = aten::mul(%37, %0) + %39 : Float(*, *) = aten::mul(%38, %17) + %40 : Float(*, *) = prim::FusedConcat[dim=1](%39, %34, %29, %24) + return (%40); +} +with prim::FusionGroup_1 = graph(%0 : Float(*, *) + %1 : Float(*, *) + %2 : Float(*)) { + %3 : Float(*, *) = aten::mul(%0, %2) + %4 : Float(*, *) = aten::mul(%0, %1) + %5 : int = prim::Constant[value=1]() + %6 : Float(*, *) = aten::add(%3, %4, %5) + return (%6); +} +with prim::FusionGroup_2 = graph(%0 : Float(*, *) + %1 : Float(*) + %2 : Float(*, *) + %3 : Float(*) + %4 : Float(*, *)) { + %5 : Float(*, *) = aten::mul(%2, %4) + %6 : Float(*, *) = aten::mul(%2, %3) + %7 : Float(*, *) = aten::mul(%5, %1) + %8 : int = prim::Constant[value=1]() + %9 : int = prim::Constant[value=1]() + %10 : Float(*, *) = aten::add(%6, %7, %9) + %11 : Float(*, *) = aten::mul(%5, %0) + return (%11, %10, %5); } diff --git a/test/expect/TestScript.test_milstm_fusion_cuda-forward.expect b/test/expect/TestScript.test_milstm_fusion_cuda-forward.expect index 2e2726c..44cb7da 100644 --- a/test/expect/TestScript.test_milstm_fusion_cuda-forward.expect +++ b/test/expect/TestScript.test_milstm_fusion_cuda-forward.expect @@ -23,71 +23,72 @@ with prim::DifferentiableGraph_0 = graph(%0 : Float(*, *) %Wx.1 : Float(*, *) = aten::mm(%7, %9) %11 : Float(*, *) = aten::t(%6) %Uz.1 : Float(*, *) = aten::mm(%5, %11) - %13 : int[] = aten::size(%4) - %14 : int[] = aten::size(%Wx.1) - %15 : int[] = aten::size(%Uz.1) - %16 : int[] = aten::size(%3) - %17 : int = prim::Constant[value=1]() - %18 : int[] = aten::size(%2) - %19 : Float(*, *), %20 : Float(*, *), %21 : Float(*, *), %22 : Float(*, *), %23 : Float(*, *), %24 : Float(*, *) = prim::FusionGroup_0(%2, %Uz.1, %3, %Wx.1, %4) - %25 : int[] = aten::size(%24) - %26 : int[] = aten::size(%23) - %27 : int[] = aten::size(%22) - %28 : int[] = aten::size(%21) - %29 : int[] = aten::size(%20) - %30 : int[] = aten::size(%19) - %31 : int[] = aten::size(%1) - %32 : Tensor[] = prim::ListConstruct(%19, %1) - %33 : Tensor[] = aten::broadcast_tensors(%32) - %34 : Tensor, %35 : Tensor = prim::ListUnpack(%33) - %36 : int[] = aten::size(%0) - %hy : Float(*, *), %38 : Float(*, *), %cy : Float(*, *), %40 : Float(*, *), %41 : Float(*, *), %outgate.1 : Float(*, *), %cellgate.1 : Float(*, *), %forgetgate.1 : Float(*, *), %ingate.1 : Float(*, *) = prim::FusionGroup_1(%0, %35, %34) - %46 : int[] = aten::size(%ingate.1) - %47 : int[] = aten::size(%forgetgate.1) - %48 : int[] = aten::size(%cellgate.1) - %49 : int[] = aten::size(%outgate.1) - %50 : int[] = aten::size(%41) - %51 : int[] = aten::size(%40) - %52 : int[] = aten::size(%38) - return (%hy, %cy, %9, %Wx.1, %11, %Uz.1, %13, %14, %24, %25, %15, %16, %14, %26, %27, %18, %15, %28, %29, %30, %31, %ingate.1, %forgetgate.1, %cellgate.1, %outgate.1, %47, %36, %46, %48, %50, %51, %38, %49, %52); + %13 : Float(*, *) = aten::mul(%4, %Wx.1) + %14 : int = prim::Constant[value=1]() + %15 : Tensor[] = prim::ListConstruct(%1, %2, %Uz.1, %13, %3, %Wx.1) + %16 : Tensor[] = aten::broadcast_tensors(%15) + %17 : Tensor, %18 : Tensor, %19 : Tensor, %20 : Tensor, %21 : Tensor, %22 : Tensor = prim::ListUnpack(%16) + %hy : Float(*, *), %24 : Float(*, *), %cy : Float(*, *), %outgate.1 : Float(*, *), %cellgate.1 : Float(*, *), %forgetgate.1 : Float(*, *), %ingate.1 : Float(*, *) = prim::FusionGroup_0(%0, %22, %21, %20, %19, %18, %17) + return (%hy, %cy, %9, %Wx.1, %11, %Uz.1, %13, %ingate.1, %forgetgate.1, %cellgate.1, %outgate.1, %24); } -with prim::FusionGroup_0 = graph(%0 : Float(*) - %1 : Float(*, *) - %2 : Float(*) - %3 : Float(*, *) - %4 : Float(*)) { - %5 : Float(*, *) = aten::mul(%4, %3) - %6 : Float(*, *) = aten::mul(%5, %1) - %7 : Float(*, *) = aten::mul(%2, %3) - %8 : int = prim::Constant[value=1]() - %9 : Float(*, *) = aten::add(%6, %7, %8) - %10 : Float(*, *) = aten::mul(%0, %1) - %11 : int = prim::Constant[value=1]() - %12 : Float(*, *) = aten::add(%9, %10, %11) - return (%12, %10, %9, %7, %6, %5); -} -with prim::FusionGroup_1 = graph(%0 : Float(*, *) +with prim::FusionGroup_0 = graph(%0 : Float(*, *) %1 : Tensor - %2 : Tensor) { - %3 : Float(*, *), %4 : Float(*, *), %5 : Float(*, *), %6 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%2) - %7 : Float(*, *), %8 : Float(*, *), %9 : Float(*, *), %10 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%1) - %11 : int = prim::Constant[value=1]() - %12 : Float(*, *) = aten::add(%3, %7, %11) - %13 : int = prim::Constant[value=1]() - %14 : Float(*, *) = aten::add(%4, %8, %13) - %15 : int = prim::Constant[value=1]() - %16 : Float(*, *) = aten::add(%5, %9, %15) - %17 : int = prim::Constant[value=1]() - %18 : Float(*, *) = aten::add(%6, %10, %17) - %ingate.1 : Float(*, *) = aten::sigmoid(%12) - %forgetgate.1 : Float(*, *) = aten::sigmoid(%14) - %cellgate.1 : Float(*, *) = aten::tanh(%16) - %outgate.1 : Float(*, *) = aten::sigmoid(%18) - %23 : Float(*, *) = aten::mul(%forgetgate.1, %0) - %24 : Float(*, *) = aten::mul(%ingate.1, %cellgate.1) - %25 : int = prim::Constant[value=1]() - %cy : Float(*, *) = aten::add(%23, %24, %25) - %27 : Float(*, *) = aten::tanh(%cy) - %hy : Float(*, *) = aten::mul(%outgate.1, %27) - return (%hy, %27, %cy, %24, %23, %outgate.1, %cellgate.1, %forgetgate.1, %ingate.1); + %2 : Tensor + %3 : Tensor + %4 : Tensor + %5 : Tensor + %6 : Tensor) { + %7 : Float(*, *), %8 : Float(*, *), %9 : Float(*, *), %10 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%6) + %11 : Float(*, *), %12 : Float(*, *), %13 : Float(*, *), %14 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%5) + %15 : Float(*, *), %16 : Float(*, *), %17 : Float(*, *), %18 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%4) + %19 : Float(*, *), %20 : Float(*, *), %21 : Float(*, *), %22 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%3) + %23 : Float(*, *), %24 : Float(*, *), %25 : Float(*, *), %26 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%2) + %27 : Float(*, *), %28 : Float(*, *), %29 : Float(*, *), %30 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%1) + %31 : Float(*, *) = aten::mul(%23, %27) + %32 : Float(*, *) = aten::mul(%24, %28) + %33 : Float(*, *) = aten::mul(%25, %29) + %34 : Float(*, *) = aten::mul(%26, %30) + %35 : Float(*, *) = aten::mul(%19, %15) + %36 : Float(*, *) = aten::mul(%20, %16) + %37 : Float(*, *) = aten::mul(%21, %17) + %38 : Float(*, *) = aten::mul(%22, %18) + %39 : Float(*, *) = aten::mul(%11, %15) + %40 : Float(*, *) = aten::mul(%12, %16) + %41 : Float(*, *) = aten::mul(%13, %17) + %42 : Float(*, *) = aten::mul(%14, %18) + %43 : int = prim::Constant[value=1]() + %44 : Float(*, *) = aten::add(%35, %31, %43) + %45 : int = prim::Constant[value=1]() + %46 : Float(*, *) = aten::add(%36, %32, %45) + %47 : int = prim::Constant[value=1]() + %48 : Float(*, *) = aten::add(%37, %33, %47) + %49 : int = prim::Constant[value=1]() + %50 : Float(*, *) = aten::add(%38, %34, %49) + %51 : int = prim::Constant[value=1]() + %52 : Float(*, *) = aten::add(%44, %39, %51) + %53 : int = prim::Constant[value=1]() + %54 : Float(*, *) = aten::add(%46, %40, %53) + %55 : int = prim::Constant[value=1]() + %56 : Float(*, *) = aten::add(%48, %41, %55) + %57 : int = prim::Constant[value=1]() + %58 : Float(*, *) = aten::add(%50, %42, %57) + %59 : int = prim::Constant[value=1]() + %60 : Float(*, *) = aten::add(%52, %7, %59) + %61 : int = prim::Constant[value=1]() + %62 : Float(*, *) = aten::add(%54, %8, %61) + %63 : int = prim::Constant[value=1]() + %64 : Float(*, *) = aten::add(%56, %9, %63) + %65 : int = prim::Constant[value=1]() + %66 : Float(*, *) = aten::add(%58, %10, %65) + %ingate.1 : Float(*, *) = aten::sigmoid(%60) + %forgetgate.1 : Float(*, *) = aten::sigmoid(%62) + %cellgate.1 : Float(*, *) = aten::tanh(%64) + %outgate.1 : Float(*, *) = aten::sigmoid(%66) + %71 : Float(*, *) = aten::mul(%forgetgate.1, %0) + %72 : Float(*, *) = aten::mul(%ingate.1, %cellgate.1) + %73 : int = prim::Constant[value=1]() + %cy : Float(*, *) = aten::add(%71, %72, %73) + %75 : Float(*, *) = aten::tanh(%cy) + %hy : Float(*, *) = aten::mul(%outgate.1, %75) + return (%hy, %75, %cy, %outgate.1, %cellgate.1, %forgetgate.1, %ingate.1); } diff --git a/test/test_jit.py b/test/test_jit.py index 6dca770..9ced108 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -358,7 +358,7 @@ class JitTestCase(TestCase): vs = vs[:-drop] # we don't want all the grad for all the outputs to be the same # so we multiply each by a constant - return sum([math.log(i + 2) * v.sum() for i, v in enumerate(vs) if v is not None]) + return sum([(i + 1) * v.sum() for i, v in enumerate(vs) if v is not None]) if input_tensors is None: input_tensors = reference_tensors @@ -433,11 +433,10 @@ class JitTestCase(TestCase): return ge - def assertAllFused(self, graph, except_for=()): + def assertAllFused(self, graph): if [n.kind() for n in graph.nodes()] == ['prim::DifferentiableGraph']: graph = next(graph.nodes()).g('Subgraph') - allowed_nodes = {'prim::Constant', 'prim::FusionGroup'} | set(except_for) - self.assertTrue(all(node.kind() in allowed_nodes for node in graph.nodes()), + self.assertTrue(all(node.kind() in {'prim::Constant', 'prim::FusionGroup'} for node in graph.nodes()), 'got {}'.format(graph)) self.assertTrue([node.kind() for node in graph.nodes()].count('prim::FusionGroup') == 1) @@ -3270,12 +3269,12 @@ a") funcs = (func2, funcInf) for f in funcs: s = self.checkScript(f, (a, b)) - self.assertAllFused(s.graph_for(a, b), except_for={'aten::size'}) + self.assertAllFused(s.graph_for(a, b)) c = s(a, b) c.sum().backward() graph = backward_graph(s) - self.assertAllFused(graph, except_for={'prim::SumToSize'}) + self.assertAllFused(graph) def test_mul(self): def func(a, b): @@ -4257,9 +4256,8 @@ a") inputs = get_lstm_inputs('cuda', training=True) module = self.checkScript(LSTMCellS, inputs) forward_graph = module.graph_for(*inputs) - with self.assertRaises(AssertionError): - self.assertGraphContainsExactly( - forward_graph, 'prim::FusionGroup', 1, consider_subgraphs=True) + self.assertGraphContainsExactly( + forward_graph, 'prim::FusionGroup', 1, consider_subgraphs=True) self.assertExpectedGraph(forward_graph, subname='forward') hy, cy = module(*inputs) @@ -4273,9 +4271,8 @@ a") inputs = get_milstm_inputs('cuda', training=True) module = self.checkScript(MiLSTMCell, inputs) forward_graph = module.graph_for(*inputs) - with self.assertRaises(AssertionError): - self.assertGraphContainsExactly( - forward_graph, 'prim::FusionGroup', 1, consider_subgraphs=True) + self.assertGraphContainsExactly( + forward_graph, 'prim::FusionGroup', 1, consider_subgraphs=True) self.assertExpectedGraph(forward_graph, subname='forward') hy, cy = module(*inputs) diff --git a/torch/csrc/autograd/engine.cpp b/torch/csrc/autograd/engine.cpp index 01e24be..1bc1586 100644 --- a/torch/csrc/autograd/engine.cpp +++ b/torch/csrc/autograd/engine.cpp @@ -340,7 +340,7 @@ static void validate_outputs(const edge_list& edges, variable_list& grads, const ss << metadata.shape(); AT_ERROR(format_error(ss.str())); } - grads[i] = at::sum_to(std::move(grads[i]), metadata.shape()); + grads[i] = at::sum_to(grads[i], metadata.shape()); } if (!is_compatible_type(metadata.type(), grads[i].type())) { std::stringstream ss; diff --git a/torch/csrc/jit/argument_spec.h b/torch/csrc/jit/argument_spec.h index 731753f..4aba7b6 100644 --- a/torch/csrc/jit/argument_spec.h +++ b/torch/csrc/jit/argument_spec.h @@ -124,9 +124,6 @@ struct ArgumentSpec { size_t size() const { return args.size(); } - const ArgumentInfo& at(size_t i) const { - return args[i]; - } size_t hashCode() const { return hash_code; } @@ -150,7 +147,6 @@ private: return fillType(subtype, offset); })); } else { - offset++; return original; } } @@ -318,28 +314,6 @@ private: const int i; }; -inline std::ostream & operator<<(std::ostream & out, const ArgumentInfo & info) { - if(!info.defined()) { - return out << ""; - } - out << "Tensor(device=" << info.device() - << ", type=" << toString(info.type()) - << ", requires_grad=" << info.requires_grad() - << ", dims=" << info.dim() << ")"; - return out; -} - -inline std::ostream& operator<<(std::ostream & out, const ArgumentSpec & spec) { - out << "{"; - for(size_t i = 0; i < spec.size(); ++i) { - if (i > 0) - out << ", "; - out << spec.at(i); - } - out << "}"; - return out; -} - inline std::ostream & operator<<(std::ostream & out, const CompleteArgumentInfo & info) { if(!info.defined()) { return out << ""; diff --git a/torch/csrc/jit/autodiff.cpp b/torch/csrc/jit/autodiff.cpp index eb5bb21..8d95b0b 100644 --- a/torch/csrc/jit/autodiff.cpp +++ b/torch/csrc/jit/autodiff.cpp @@ -78,7 +78,6 @@ bool isDifferentiable(Node * n) { "aten::sinh(Tensor self) -> Tensor", "aten::tan(Tensor self) -> Tensor", "aten::trunc(Tensor self) -> Tensor", - "prim::SumToSize(Tensor(a) self, int[] size) -> Tensor(a)", "aten::log_softmax(Tensor self, int dim) -> Tensor", "aten::avg_pool2d(Tensor self, int[] kernel_size, int[] stride, int[] padding, bool ceil_mode, bool count_include_pad) -> Tensor", "aten::max_pool2d_with_indices(Tensor self, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode) -> (Tensor, Tensor)", @@ -138,64 +137,48 @@ static std::vector gradientForNode(Node* node, ArrayRef grad_val "aten::eq(Tensor self, Tensor other) -> Tensor", "aten::ne(Tensor self, Tensor other) -> Tensor" }; - const auto sumToSizeOf = [node](SymbolicVariable v, Symbol input_name) -> SymbolicVariable { - Value * size; - { - WithInsertPoint insert_guard {node}; - size = SymbolicVariable(node->namedInput(input_name)).size(); - } - return v.sumToSize(size); - }; - const auto build_sym_grad = [node, &sumToSizeOf](const std::vector& grads) -> std::vector { + const auto build_sym_grad = [node](const std::vector& grads) -> std::vector { auto inputs = fmap(node->inputs()); auto outputs = fmap(node->outputs()); if (node->matches("aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor")) { - return {sumToSizeOf(grads.at(0), attr::self), - sumToSizeOf(grads.at(0) * node->namedInput(attr::alpha), attr::other), - nullptr}; + return {grads.at(0), grads.at(0) * node->namedInput(attr::alpha), nullptr}; } else if (node->matches("aten::add(Tensor self, Scalar other, Scalar alpha) -> Tensor")) { return {grads.at(0), nullptr, nullptr}; } else if (node->kind() == prim::AutogradAdd) { - // NB: AutogradAdds don't broadcast return {grads.at(0), grads.at(0)}; } else if (node->matches("aten::sub(Tensor self, Tensor other, *, Scalar alpha) -> Tensor")) { - return {sumToSizeOf(grads.at(0), attr::self), - sumToSizeOf(-grads.at(0) * node->namedInput(attr::alpha), attr::other), - nullptr}; + return {grads.at(0), -grads.at(0) * node->namedInput(attr::alpha), nullptr}; } else if (node->matches("aten::sub(Tensor self, Scalar other, Scalar alpha) -> Tensor")) { return {grads.at(0), nullptr, nullptr}; } else if (node->matches("aten::mul(Tensor self, Tensor other) -> Tensor")) { - return {sumToSizeOf(grads.at(0) * inputs.at(1), attr::self), - sumToSizeOf(grads.at(0) * inputs.at(0), attr::other)}; + return {grads.at(0) * inputs.at(1), grads.at(0) * inputs.at(0)}; } else if (node->matches("aten::mul(Tensor self, Scalar other) -> Tensor")) { return {grads.at(0) * inputs.at(1), nullptr}; } else if (node->matches("aten::div(Tensor self, Tensor other) -> Tensor")) { - return {sumToSizeOf(grads.at(0) / inputs.at(1), attr::self), - sumToSizeOf(-grads.at(0) * inputs.at(0) / (inputs.at(1) * inputs.at(1)), attr::other)}; + return {grads.at(0) / inputs.at(1), -grads.at(0) * inputs.at(0) / (inputs.at(1) * inputs.at(1))}; } else if (node->matches("aten::div(Tensor self, Scalar other) -> Tensor")) { return {grads.at(0) / inputs.at(1), nullptr}; } else if (node->matches("aten::max(Tensor self, Tensor other) -> Tensor")) { - return {sumToSizeOf(grads.at(0) * (inputs.at(0) > inputs.at(1)).type_as(grads.at(0)), attr::self), - sumToSizeOf(grads.at(0) * (inputs.at(1) > inputs.at(0)).type_as(grads.at(0)), attr::other)}; + return {grads.at(0) * (inputs.at(0) > inputs.at(1)).type_as(grads.at(0)), + grads.at(0) * (inputs.at(1) > inputs.at(0)).type_as(grads.at(0))}; } else if (node->matches("aten::min(Tensor self, Tensor other) -> Tensor")) { - return {sumToSizeOf(grads.at(0) * (inputs.at(0) < inputs.at(1)).type_as(grads.at(0)), attr::self), - sumToSizeOf(grads.at(0) * (inputs.at(1) < inputs.at(0)).type_as(grads.at(0)), attr::other)}; + return {grads.at(0) * (inputs.at(0) < inputs.at(1)).type_as(grads.at(0)), + grads.at(0) * (inputs.at(1) < inputs.at(0)).type_as(grads.at(0))}; } else if (node->matches("aten::where(Tensor condition, Tensor self, Tensor other) -> Tensor")) { - return {nullptr, - sumToSizeOf(grads.at(0) * inputs.at(0).type_as(grads.at(0)), attr::self), - sumToSizeOf(grads.at(0) * (1 - inputs.at(0)).type_as(grads.at(0)), attr::other)}; + return {nullptr, grads.at(0) * inputs.at(0).type_as(grads.at(0)), + grads.at(0) * (1 - inputs.at(0)).type_as(grads.at(0))}; } else if (node->matches("aten::sigmoid(Tensor self) -> Tensor")) { // TODO: The order of operations matter in this case. This @@ -256,14 +239,6 @@ static std::vector gradientForNode(Node* node, ArrayRef grad_val } else if (node->matches("aten::atan(Tensor self) -> Tensor")) { return {grads.at(0) / (inputs.at(0) * inputs.at(0) + at::Scalar(1))}; - } else if (node->matches("prim::SumToSize(Tensor(a) self, int[] size) -> Tensor(a)")) { - Value * self_size; - { - WithInsertPoint insert_guard { node }; - self_size = inputs.at(0).size(); - } - return {grads.at(0).expand(self_size), nullptr}; - } else if (node->matches("aten::ceil(Tensor self) -> Tensor")) { return {SymbolicVariable::zeros_like(grads.at(0))}; @@ -340,7 +315,7 @@ static std::vector gradientForNode(Node* node, ArrayRef grad_val return {grads.at(0).squeeze(node->namedInput(attr::dim)), nullptr}; } else if (node->matches("aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta, Scalar alpha) -> Tensor")) { - return {sumToSizeOf(grads.at(0) * node->namedInput(attr::beta), attr::self), + return {grads.at(0) * node->namedInput(attr::beta), grads.at(0).mm(inputs.at(2).t()) * node->namedInput(attr::alpha), inputs.at(1).t().mm(grads.at(0)) * node->namedInput(attr::alpha), nullptr, nullptr}; @@ -446,15 +421,15 @@ static std::vector gradientForNode(Node* node, ArrayRef grad_val } else if (node->matches("aten::thnn_conv2d_forward(Tensor self, Tensor weight, int[] kernel_size, Tensor? bias, int[] stride, int[] padding) -> (Tensor, Tensor, Tensor)")) { auto graph = node->owningGraph(); auto backward_value = graph->insert(aten::thnn_conv2d_backward, { - grads.at(0).value(), - inputs.at(0).value(), - inputs.at(1).value(), - node->namedInput(attr::kernel_size), - node->namedInput(attr::stride), - node->namedInput(attr::padding), - outputs.at(1).value(), - outputs.at(2).value(), - graph->insertConstant(std::vector{true, true, true}) + grads.at(0).value(), + inputs.at(0).value(), + inputs.at(1).value(), + node->namedInput(attr::kernel_size), + node->namedInput(attr::stride), + node->namedInput(attr::padding), + outputs.at(1).value(), + outputs.at(2).value(), + graph->insertConstant(std::vector{true, true, true}) }); // graph->insert returns a tuple automatically if multiple outputs are returned. So unpack them again. Node* tuple_unpack_node = graph->insertNode(graph->createTupleUnpack(backward_value)); @@ -465,16 +440,16 @@ static std::vector gradientForNode(Node* node, ArrayRef grad_val } else if (node->matches("aten::native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)")) { auto graph = node->owningGraph(); auto backward_value = graph->insert(aten::native_batch_norm_backward, { - grads.at(0).value(), - inputs.at(0).value(), - inputs.at(1).value(), - inputs.at(3).value(), - inputs.at(4).value(), - outputs.at(1).value(), - outputs.at(2).value(), - inputs.at(5).value(), - inputs.at(7).value(), - graph->insertConstant(std::vector{true, true, true}) + grads.at(0).value(), + inputs.at(0).value(), + inputs.at(1).value(), + inputs.at(3).value(), + inputs.at(4).value(), + outputs.at(1).value(), + outputs.at(2).value(), + inputs.at(5).value(), + inputs.at(7).value(), + graph->insertConstant(std::vector{true, true, true}) }); // graph->insert returns a tuple automatically if multiple outputs are returned. So unpack them again. Node* tuple_unpack_node = graph->insertNode(graph->createTupleUnpack(backward_value)); @@ -800,16 +775,16 @@ Gradient differentiate(std::shared_ptr& graph) { auto rev_info = addReverseInline(grad_desc); // Lift constants captured for the reverse graph into it liftConstants(grad_desc, rev_info); + // addReverseInline has to call gradientForNode if *any* of the outputs + // require grad, but it will emit vjps for *all* outputs. Use DCE to remove + // unnecessary nodes. + EliminateDeadCode(rev_info.reverse_block); // Fills in f, df, f_real_outputs, df_input_captures, // modifies df_input_vjps (new vjps are added for temporaries) lambdaLiftReverse(grad_desc, rev_info); // It's possible the we've cloned the same constants many times, so // de-duplicate them ConstantPooling(grad_desc.df); - // addReverseInline has to call gradientForNode if *any* of the inputs - // require grad, but it will emit vjps for *all* inputs. Use DCE to remove - // unnecessary nodes. - EliminateDeadCode(grad_desc.df); return grad_desc; } diff --git a/torch/csrc/jit/register_prim_ops.cpp b/torch/csrc/jit/register_prim_ops.cpp index 5c85722..3e802a9 100644 --- a/torch/csrc/jit/register_prim_ops.cpp +++ b/torch/csrc/jit/register_prim_ops.cpp @@ -9,10 +9,8 @@ #include "torch/csrc/jit/operator.h" #include "torch/csrc/jit/custom_operator.h" #include "torch/csrc/jit/script/jit_exception.h" -#include "torch/csrc/variable_tensor_functions.h" -#include -#include +#include "torch/csrc/variable_tensor_functions.h" #include #include @@ -363,17 +361,6 @@ RegisterOperators reg({ }; }), Operator( - "prim::SumToSize(Tensor(a) self, int[] size) -> Tensor(a)", - [](const Node* node) { - return [=](Stack& stack) { - at::Tensor self; - Shared desired_sizes; - pop(stack, self, desired_sizes); - push(stack, at::sum_to(std::move(self), desired_sizes->elements())); - return 0; - }; - }), - Operator( prim::TupleUnpack, [](const Node* node) { size_t num_elems = node->outputs().size(); diff --git a/torch/csrc/jit/register_special_ops.cpp b/torch/csrc/jit/register_special_ops.cpp index d1eccbc..9a34502 100644 --- a/torch/csrc/jit/register_special_ops.cpp +++ b/torch/csrc/jit/register_special_ops.cpp @@ -33,8 +33,8 @@ RegisterOperators reg({ "aten::size(Tensor self) -> int[]", [](Stack& stack) { autograd::profiler::RecordFunction record("sizes"); - auto t = std::move(pop(stack)).toTensor(); - pack(stack, t.sizes().vec()); + auto result = (std::move(pop(stack))).toTensor().sizes(); + pack(stack, std::move(result)); return 0; }), Operator( diff --git a/torch/csrc/jit/symbolic_variable.h b/torch/csrc/jit/symbolic_variable.h index 1d4b82f..a625f60 100644 --- a/torch/csrc/jit/symbolic_variable.h +++ b/torch/csrc/jit/symbolic_variable.h @@ -123,15 +123,6 @@ struct SymbolicVariable { SymbolicVariable operator%(at::Scalar rhs) const { return create(aten::remainder, {*this, insertConstant(rhs)})[0].typeLike(*this); } - Value* size() const { - return v->owningGraph()->insert(aten::size, {v}); - } - SymbolicVariable sumToSize(Value * size) const { - return create(prim::SumToSize, {*this, size})[0]; - } - SymbolicVariable expand(Value * size) const { - return v->owningGraph()->insert(aten::expand, {v, size}); - } SymbolicVariable isnan() const { return create(aten::ne, {*this, *this})[0].typeLikeWithScalarType(*this, at::kByte); } -- 2.7.4