Reduce broadcasted inputs in derivative code (#14485)
authorAdam Paszke <adam.paszke@gmail.com>
Mon, 3 Dec 2018 21:41:05 +0000 (13:41 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Mon, 3 Dec 2018 21:44:18 +0000 (13:44 -0800)
Summary:
Previously symbolic AD formulas assumed that no broadcasting happened,
and would return gradients of incorrect shapes (possibly leading to
silent errors later).

Fixes a few bugs (known and unknown):
- #11736
- ArgumentSpec didn't compute the input types correctly [(it didn't advance the offset for non-tensor args)](https://github.com/pytorch/pytorch/pull/14485/files#diff-4fd3157a056596aefb8cdf41022a208bR153)
- Symbolic AD could suffer from use after free (dangling pointers in grad map), because [`EliminateDeadCode` could have removed nodes](https://github.com/pytorch/pytorch/pull/14485/files#diff-25d33ad1ed6855684dec79d927ca6142L781) that referenced gradients of certain values.
- Undefined behavior in `aten::size`

During my tests I've also found a few new problems, and I have opened issues for them:
- FusionGroup seems to think that cat nodes broadcast their inputs (#14483)
- `prim::ConstantChunk` derivative formula doesn't handle undefined inputs (#14484)

This patch unfortunately deoptimizes some of our code (Fusion doesn't happen past chunk nodes, and outputs more tensors only because we have to get their size). I know how to fix those issues, but wanted to fix this terrible bug quickly.

cc zou3519 zdevito ngimel
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14485

Differential Revision: D13280899

Pulled By: soumith

fbshipit-source-id: 80cc5ec9331be80e1bb9ddfe85b81c2b997e0b0c

15 files changed:
aten/src/ATen/ExpandUtils.h
aten/src/ATen/core/interned_strings.h
test/cpp/jit/tests.h
test/expect/TestJit.test_cpp_cuda.expect
test/expect/TestScript.test_lstm_fusion_cuda-backward.expect
test/expect/TestScript.test_lstm_fusion_cuda-forward.expect
test/expect/TestScript.test_milstm_fusion_cuda-backward.expect
test/expect/TestScript.test_milstm_fusion_cuda-forward.expect
test/test_jit.py
torch/csrc/autograd/engine.cpp
torch/csrc/jit/argument_spec.h
torch/csrc/jit/autodiff.cpp
torch/csrc/jit/register_prim_ops.cpp
torch/csrc/jit/register_special_ops.cpp
torch/csrc/jit/symbolic_variable.h

index a5e9db4..4a8cbd7 100644 (file)
@@ -136,20 +136,25 @@ inline std::vector<Tensor> expand_outplace(TensorList to_expand) {
 
 // Sums `tensor` repeatedly to produce a tensor of shape `shape`.
 // Precondition: is_expandable_to(shape, tensor.sizes()) must be true
-static inline Tensor sum_to(Tensor tensor, IntList shape) {
+static inline Tensor sum_to(Tensor tensor, const IntList shape) {
   if (shape.size() == 0) {
     return tensor.sum();
   }
-  Tensor result = tensor;
-  while (result.dim() > (int64_t)shape.size()) {
-    result = result.sum(0, false);
+  c10::SmallVector<int64_t, 8> reduce_dims;
+  const at::IntList sizes = tensor.sizes();
+  const int64_t leading_dims = sizes.size() - shape.size();
+  for (int64_t i = 0; i < leading_dims; ++i) {
+    reduce_dims.push_back(i);
   }
-  for (int64_t i = 0; i < result.dim(); ++i) {
-    if (shape[i] == 1 && result.sizes()[i] > 1) {
-      result = result.sum(i, true);
+  for (int64_t i = leading_dims; i < static_cast<int64_t>(sizes.size()); ++i) {
+    if (shape[i - leading_dims] == 1 && sizes[i] > 1) {
+      reduce_dims.push_back(i);
     }
   }
-  return result;
+  if (!reduce_dims.empty()) {
+    tensor = tensor.sum(reduce_dims, /*keepdim=*/true);
+  }
+  return leading_dims > 0 ? tensor.view(shape) : tensor;
 }
 
 // True if `shape` can be broadcasted to `desired`
index 8f4ba55..ad64cb0 100644 (file)
@@ -64,6 +64,7 @@ namespace c10 {
   _(prim, requires_grad)           \
   _(prim, AutogradAdd)             \
   _(prim, GradOf)                  \
+  _(prim, SumToSize)               \
   _(prim, AnyDefined)              \
   _(prim, FusedConcat)             \
   _(prim, ConstantChunk)           \
index 2a4edc1..d728a65 100644 (file)
@@ -439,38 +439,38 @@ std::shared_ptr<Graph> build_lstm() {
   return r;
 }
 
-void run(InterpreterState & interp, const std::vector<at::Tensor> & inputs, std::vector<at::Tensor> & outputs) {
+std::vector<at::Tensor> run(InterpreterState & interp, const std::vector<at::Tensor> & inputs) {
   std::vector<IValue> stack(inputs.begin(), inputs.end());
   interp.run(stack);
-  outputs.clear();
-  for (auto& ivalue : stack) {
-    outputs.push_back(std::move(ivalue).toTensor());
-  }
+  return fmap(stack, [](const IValue& i) { return i.toTensor(); });
 }
 
 std::pair<tensor_list, tensor_list> runGradient(
     Gradient& grad_spec,
     tensor_list& tensors_in,
     tensor_list& tensor_grads_in) {
-  tensor_list tensors_out, tensor_grads_out;
+  static const auto as_tensorlist = [](const Stack& stack) {
+    return fmap(stack, [](const IValue& i) { return i.toTensor(); });
+  };
   Code f_code{grad_spec.f}, df_code{grad_spec.df};
   InterpreterState f_interpreter{f_code}, df_interpreter{df_code};
 
-  run(f_interpreter, tensors_in, tensors_out);
+  auto f_stack = fmap<IValue>(tensors_in);
+  f_interpreter.run(f_stack);
 
-  tensor_list df_inputs;
-  df_inputs.insert(
-      df_inputs.end(), tensor_grads_in.begin(), tensor_grads_in.end());
+  Stack df_stack;
+  df_stack.insert(
+      df_stack.end(), tensor_grads_in.begin(), tensor_grads_in.end());
   for (auto offset : grad_spec.df_input_captured_inputs)
-    df_inputs.push_back(tensors_in[offset]);
+    df_stack.push_back(tensors_in[offset]);
   for (auto offset : grad_spec.df_input_captured_outputs)
-    df_inputs.push_back(tensors_out[offset]);
-  run(df_interpreter, df_inputs, tensor_grads_out);
+    df_stack.push_back(f_stack[offset]);
+  df_interpreter.run(df_stack);
 
   // Outputs of f needs to be sliced
-  tensors_out.erase(
-      tensors_out.begin() + grad_spec.f_real_outputs, tensors_out.end());
-  return std::make_pair(tensors_out, tensor_grads_out);
+  f_stack.erase(
+      f_stack.begin() + grad_spec.f_real_outputs, f_stack.end());
+  return std::make_pair(as_tensorlist(f_stack), as_tensorlist(df_stack));
 }
 
 void assertAllClose(const tensor_list& a, const tensor_list& b) {
@@ -496,9 +496,8 @@ void testInterp() {
 
   auto lstm_g = build_lstm();
   Code lstm_function(lstm_g);
-  std::vector<at::Tensor> outputs;
   InterpreterState lstm_interp(lstm_function);
-  run(lstm_interp, {input[0], hx, cx, w_ih, w_hh}, outputs);
+  auto outputs = run(lstm_interp, {input[0], hx, cx, w_ih, w_hh});
   std::tie(hx, cx) = lstm(input[0], hx, cx, w_ih, w_hh);
 
   // std::cout << almostEqual(outputs[0],hx) << "\n";
@@ -836,8 +835,8 @@ void testDifferentiate(std::ostream& out = std::cout) {
 
   auto grad_spec = differentiate(graph);
   std::vector<size_t> expected_captured_inputs = {0, 1};
-  std::vector<size_t> expected_captured_outputs = {1};
-  std::vector<size_t> expected_input_vjps = {0, 1};
+  std::vector<size_t> expected_captured_outputs = {1, 2, 3, 4, 5, 6, 7};
+  std::vector<size_t> expected_input_vjps = {0, 3};
   std::vector<size_t> expected_output_vjps = {0, 1};
   ASSERT_EQ(grad_spec.f_real_outputs, 1);
   ASSERT_EQ(grad_spec.df_input_captured_inputs, expected_captured_inputs);
@@ -867,11 +866,11 @@ void testDifferentiateWithRequiresGrad(std::ostream& out = std::cout) {
   PropagateRequiresGrad(graph);
 
   auto grad_spec = differentiate(graph);
-  std::vector<size_t> expected_input_vjps = {1, 2}; // for e and %4 = (d + a)
+  std::vector<size_t> expected_input_vjps = {1, 4}; // for e and %4 = (d + a)
   std::vector<size_t> expected_output_vjps = {0}; // only a requires grad
-  ASSERT_EQ(grad_spec.f_real_outputs, 2); // we need one temporary %4 = (d + a)
+  ASSERT_EQ(grad_spec.f_real_outputs, 2);
   ASSERT_EQ(grad_spec.df_input_captured_inputs, std::vector<size_t>({0}));
-  ASSERT_EQ(grad_spec.df_input_captured_outputs, std::vector<size_t>({2}));
+  ASSERT_EQ(grad_spec.df_input_captured_outputs, std::vector<size_t>({2, 3, 4, 5, 6, 7, 8}));
   ASSERT_EQ(grad_spec.df_input_vjps, expected_input_vjps);
   ASSERT_EQ(grad_spec.df_output_vjps, expected_output_vjps);
   out << "testDifferentiateWithRequiresGrad\n";
index 50ec297..ddcac4c 100644 (file)
@@ -93,39 +93,57 @@ with prim::DifferentiableGraph_0 = graph(%13 : Tensor
 testDifferentiate
 graph(%0 : Float(2, 3, 4)
       %1 : Float(2, 3, 4)) {
+  %23 : int[] = aten::size(%0)
+  %26 : int[] = aten::size(%1)
   %2 : Float(2, 3, 4) = aten::mul(%0, %1)
+  %15 : int[] = aten::size(%2)
+  %18 : int[] = aten::size(%0)
   %3 : Float(2, 3, 4) = aten::mul(%2, %0)
   %4 : int = prim::Constant[value=1]()
+  %7 : int[] = aten::size(%3)
+  %10 : int[] = aten::size(%1)
   %5 : Float(2, 3, 4) = aten::add(%3, %1, %4)
-  return (%5, %2);
+  return (%5, %23, %26, %2, %15, %18, %7, %10);
 }
 graph(%0 : Float(2, 3, 4)
       %1 : Float(2, 3, 4)
       %2 : Float(2, 3, 4)
       %3 : Float(2, 3, 4)
-      %4 : Float(2, 3, 4)) {
-  %5 : int = prim::Constant[value=1]()
-  %6 : Float(2, 3, 4), %7 : Float(2, 3, 4) = prim::GradOf[name="aten::add"](%0)
+      %4 : int[]
+      %5 : int[]
+      %6 : Float(2, 3, 4)
+      %7 : int[]
+      %8 : int[]
+      %9 : int[]
+      %10 : int[]) {
+  %11 : int = prim::Constant[value=1]()
+  %12 : Tensor, %13 : Tensor = prim::GradOf[name="aten::add"](%0)
     block0() {
-      %8 : Float(2, 3, 4) = aten::mul(%0, %5)
-      -> (%0, %8)
+      %14 : Tensor = prim::SumToSize(%0, %9)
+      %15 : Float(2, 3, 4) = aten::mul(%0, %11)
+      %16 : Tensor = prim::SumToSize(%15, %10)
+      -> (%14, %16)
     }
-  %9 : Float(2, 3, 4), %10 : Float(2, 3, 4) = prim::GradOf[name="aten::mul"](%6)
+  %17 : Tensor, %18 : Tensor = prim::GradOf[name="aten::mul"](%12)
     block0() {
-      %11 : Float(2, 3, 4) = aten::mul(%6, %2)
-      %12 : Float(2, 3, 4) = aten::mul(%6, %4)
-      -> (%11, %12)
+      %19 : Tensor = aten::mul(%12, %2)
+      %20 : Tensor = prim::SumToSize(%19, %7)
+      %21 : Tensor = aten::mul(%12, %6)
+      %22 : Tensor = prim::SumToSize(%21, %8)
+      -> (%20, %22)
     }
-  %13 : Tensor = prim::AutogradAdd(%1, %9)
-  %14 : Float(2, 3, 4), %15 : Float(2, 3, 4) = prim::GradOf[name="aten::mul"](%13)
+  %23 : Tensor = prim::AutogradAdd(%1, %17)
+  %24 : Tensor, %25 : Tensor = prim::GradOf[name="aten::mul"](%23)
     block0() {
-      %16 : Float(2, 3, 4) = aten::mul(%13, %3)
-      %17 : Float(2, 3, 4) = aten::mul(%13, %2)
-      -> (%16, %17)
+      %26 : Tensor = aten::mul(%23, %3)
+      %27 : Tensor = prim::SumToSize(%26, %4)
+      %28 : Tensor = aten::mul(%23, %2)
+      %29 : Tensor = prim::SumToSize(%28, %5)
+      -> (%27, %29)
     }
-  %18 : Tensor = prim::AutogradAdd(%10, %14)
-  %19 : Tensor = prim::AutogradAdd(%7, %15)
-  return (%18, %19);
+  %30 : Tensor = prim::AutogradAdd(%18, %24)
+  %31 : Tensor = prim::AutogradAdd(%13, %25)
+  return (%30, %31);
 }
 
 testDifferentiateWithRequiresGrad
@@ -135,35 +153,53 @@ graph(%0 : Float(*)
   %3 : int = prim::Constant[value=1]()
   %4 : Float(*) = aten::add(%2, %1, %3)
   %5 : int = prim::Constant[value=1]()
+  %26 : int[] = aten::size(%4)
+  %29 : int[] = aten::size(%0)
   %6 : Float(*) = aten::add(%4, %0, %5)
+  %19 : int[] = aten::size(%6)
+  %22 : int[] = aten::size(%0)
   %7 : Float(*) = aten::mul(%6, %0)
   %8 : int = prim::Constant[value=1]()
+  %11 : int[] = aten::size(%7)
+  %14 : int[] = aten::size(%1)
   %9 : Float(*) = aten::add(%7, %1, %8)
-  return (%4, %9, %6);
+  return (%4, %9, %26, %29, %6, %19, %22, %11, %14);
 }
 graph(%0 : Float(*)
       %1 : Float(*)
       %2 : Float(*)
-      %3 : Float(*)) {
-  %4 : int = prim::Constant[value=1]()
-  %6 : Float(*), %7 : Tensor = prim::GradOf[name="aten::add"](%0)
+      %3 : int[]
+      %4 : int[]
+      %5 : Float(*)
+      %6 : int[]
+      %7 : int[]
+      %8 : int[]
+      %9 : int[]) {
+  %10 : int = prim::Constant[value=1]()
+  %12 : Tensor, %13 : Tensor = prim::GradOf[name="aten::add"](%0)
     block0() {
-      %8 : Tensor = aten::mul(%0, %4)
-      -> (%0, %8)
+      %14 : Tensor = prim::SumToSize(%0, %8)
+      %15 : Tensor = aten::mul(%0, %10)
+      %16 : Tensor = prim::SumToSize(%15, %9)
+      -> (%14, %16)
     }
-  %9 : Tensor, %10 : Tensor = prim::GradOf[name="aten::mul"](%6)
+  %17 : Tensor, %18 : Tensor = prim::GradOf[name="aten::mul"](%12)
     block0() {
-      %11 : Tensor = aten::mul(%6, %2)
-      %12 : Tensor = aten::mul(%6, %3)
-      -> (%11, %12)
+      %19 : Tensor = aten::mul(%12, %2)
+      %20 : Tensor = prim::SumToSize(%19, %6)
+      %21 : Tensor = aten::mul(%12, %5)
+      %22 : Tensor = prim::SumToSize(%21, %7)
+      -> (%20, %22)
     }
-  %13 : Tensor = prim::AutogradAdd(%1, %9)
-  %14 : Tensor, %15 : Tensor = prim::GradOf[name="aten::add"](%13)
+  %23 : Tensor = prim::AutogradAdd(%1, %17)
+  %24 : Tensor, %25 : Tensor = prim::GradOf[name="aten::add"](%23)
     block0() {
-      %16 : Tensor = aten::mul(%13, %4)
-      -> (%13, %16)
+      %26 : Tensor = prim::SumToSize(%23, %3)
+      %27 : Tensor = aten::mul(%23, %10)
+      %28 : Tensor = prim::SumToSize(%27, %4)
+      -> (%26, %28)
     }
-  %17 : Tensor = prim::AutogradAdd(%10, %15)
-  return (%17);
+  %29 : Tensor = prim::AutogradAdd(%18, %25)
+  return (%29);
 }
 
index eabf945..9ee5162 100644 (file)
@@ -12,71 +12,99 @@ graph(%0 : Float(*, *)
       %11 : Float(*, *)
       %12 : Float(*, *)
       %13 : Float(*, *)
+      %14 : int[]
+      %15 : int[]
+      %16 : int[]
+      %17 : int[]
+      %18 : int[]
+      %19 : int[]
       %ingate : Float(*, *)
       %forgetgate : Float(*, *)
       %cellgate : Float(*, *)
       %outgate : Float(*, *)
-      %18 : Float(*, *)) {
-  %19 : Float(*, *), %20 : Float(*, *) = prim::FusionGroup_0(%forgetgate, %ingate, %cellgate, %outgate, %9, %1, %18, %0)
-  %21 : Float(*, *) = aten::t(%13)
-  %22 : Float(*, *) = aten::mm(%20, %21)
-  %23 : Float(*, *) = aten::t(%10)
-  %24 : Float(*, *) = aten::mm(%23, %20)
-  %25 : Float(*, *) = aten::t(%24)
-  %26 : Float(*, *) = aten::t(%12)
-  %27 : Float(*, *) = aten::mm(%20, %26)
-  %28 : Float(*, *) = aten::t(%11)
-  %29 : Float(*, *) = aten::mm(%28, %20)
-  %30 : Float(*, *) = aten::t(%29)
-  return (%19, %20, %20, %22, %25, %27, %30);
+      %24 : int[]
+      %25 : int[]
+      %26 : int[]
+      %27 : int[]
+      %28 : int[]
+      %29 : int[]
+      %30 : Float(*, *)
+      %31 : int[]
+      %32 : int[]) {
+  %33 : Float(*, *) = aten::mul(%0, %30)
+  %34 : Tensor = prim::SumToSize(%33, %31)
+  %35 : Float(*, *) = aten::mul(%0, %outgate)
+  %36 : Tensor = prim::SumToSize(%35, %32)
+  %37 : Tensor = prim::FusionGroup_0(%1, %36, %30)
+  %38 : Tensor = prim::SumToSize(%37, %28)
+  %39 : Tensor = prim::SumToSize(%37, %29)
+  %40 : Tensor = aten::mul(%39, %cellgate)
+  %41 : Tensor = prim::SumToSize(%40, %26)
+  %42 : Tensor = aten::mul(%39, %ingate)
+  %43 : Tensor = prim::SumToSize(%42, %27)
+  %44 : Tensor = aten::mul(%38, %9)
+  %45 : Tensor = prim::SumToSize(%44, %24)
+  %46 : Tensor = aten::mul(%38, %forgetgate)
+  %47 : Tensor = prim::SumToSize(%46, %25)
+  %48 : Tensor = prim::FusionGroup_1(%41, %ingate, %45, %forgetgate, %43, %cellgate, %34, %outgate)
+  %49 : Tensor = prim::SumToSize(%48, %18)
+  %50 : Tensor = prim::SumToSize(%48, %19)
+  %51 : Tensor = prim::SumToSize(%49, %16)
+  %52 : Tensor = prim::SumToSize(%49, %17)
+  %53 : Tensor = prim::SumToSize(%51, %14)
+  %54 : Tensor = prim::SumToSize(%51, %15)
+  %55 : Float(*, *) = aten::t(%13)
+  %56 : Float(*, *) = aten::mm(%54, %55)
+  %57 : Float(*, *) = aten::t(%10)
+  %58 : Float(*, *) = aten::mm(%57, %54)
+  %59 : Float(*, *) = aten::t(%58)
+  %60 : Float(*, *) = aten::t(%12)
+  %61 : Float(*, *) = aten::mm(%53, %60)
+  %62 : Float(*, *) = aten::t(%11)
+  %63 : Float(*, *) = aten::mm(%62, %53)
+  %64 : Float(*, *) = aten::t(%63)
+  return (%47, %50, %52, %56, %59, %61, %64);
 }
 with prim::FusionGroup_0 = graph(%0 : Float(*, *)
+      %1 : Tensor
+      %2 : Float(*, *)) {
+  %3 : Float(*, *) = aten::mul(%2, %2)
+  %4 : Float(*, *) = aten::neg(%3)
+  %5 : int = prim::Constant[value=1]()
+  %6 : Float(*, *) = aten::add(%4, %5, %5)
+  %7 : Tensor = aten::mul(%1, %6)
+  %8 : int = prim::Constant[value=1]()
+  %9 : Tensor = aten::add(%0, %7, %8)
+  return (%9);
+}
+with prim::FusionGroup_1 = graph(%0 : Tensor
       %1 : Float(*, *)
-      %2 : Float(*, *)
+      %2 : Tensor
       %3 : Float(*, *)
-      %4 : Float(*, *)
+      %4 : Tensor
       %5 : Float(*, *)
-      %6 : Float(*, *)
+      %6 : Tensor
       %7 : Float(*, *)) {
-  %8 : Float(*, *) = aten::mul(%7, %6)
-  %9 : Float(*, *) = aten::mul(%7, %3)
-  %10 : Float(*, *) = aten::mul(%6, %6)
-  %11 : Float(*, *) = aten::neg(%10)
-  %12 : int = prim::Constant[value=1]()
-  %13 : int = prim::Constant[value=1]()
-  %14 : Float(*, *) = aten::add(%11, %13, %13)
-  %15 : Float(*, *) = aten::mul(%9, %14)
-  %16 : int = prim::Constant[value=1]()
-  %17 : int = prim::Constant[value=1]()
-  %18 : Float(*, *) = aten::add(%5, %15, %17)
-  %19 : Float(*, *) = aten::mul(%18, %2)
-  %20 : Float(*, *) = aten::mul(%18, %1)
-  %21 : Float(*, *) = aten::mul(%18, %4)
-  %22 : Float(*, *) = aten::neg(%3)
-  %23 : int = prim::Constant[value=1]()
+  %8 : Float(*, *) = aten::neg(%7)
+  %9 : int = prim::Constant[value=1]()
+  %10 : Float(*, *) = aten::add(%8, %9, %9)
+  %11 : Float(*, *) = aten::mul(%10, %7)
+  %12 : Tensor = aten::mul(%11, %6)
+  %13 : Float(*, *) = aten::mul(%5, %5)
+  %14 : Float(*, *) = aten::neg(%13)
+  %15 : int = prim::Constant[value=1]()
+  %16 : Float(*, *) = aten::add(%14, %15, %15)
+  %17 : Tensor = aten::mul(%4, %16)
+  %18 : Float(*, *) = aten::neg(%3)
+  %19 : int = prim::Constant[value=1]()
+  %20 : Float(*, *) = aten::add(%18, %19, %19)
+  %21 : Float(*, *) = aten::mul(%20, %3)
+  %22 : Tensor = aten::mul(%21, %2)
+  %23 : Float(*, *) = aten::neg(%1)
   %24 : int = prim::Constant[value=1]()
-  %25 : Float(*, *) = aten::add(%22, %24, %24)
-  %26 : Float(*, *) = aten::mul(%25, %3)
-  %27 : Float(*, *) = aten::mul(%26, %8)
-  %28 : Float(*, *) = aten::mul(%2, %2)
-  %29 : Float(*, *) = aten::neg(%28)
-  %30 : int = prim::Constant[value=1]()
-  %31 : int = prim::Constant[value=1]()
-  %32 : Float(*, *) = aten::add(%29, %31, %31)
-  %33 : Float(*, *) = aten::mul(%20, %32)
-  %34 : Float(*, *) = aten::neg(%0)
-  %35 : int = prim::Constant[value=1]()
-  %36 : int = prim::Constant[value=1]()
-  %37 : Float(*, *) = aten::add(%34, %36, %36)
-  %38 : Float(*, *) = aten::mul(%37, %0)
-  %39 : Float(*, *) = aten::mul(%38, %21)
-  %40 : Float(*, *) = aten::neg(%1)
-  %41 : int = prim::Constant[value=1]()
-  %42 : int = prim::Constant[value=1]()
-  %43 : Float(*, *) = aten::add(%40, %42, %42)
-  %44 : Float(*, *) = aten::mul(%43, %1)
-  %45 : Float(*, *) = aten::mul(%44, %19)
-  %46 : Float(*, *) = prim::FusedConcat[dim=1](%45, %39, %33, %27)
-  %47 : Float(*, *) = aten::mul(%18, %0)
-  return (%47, %46);
+  %25 : Float(*, *) = aten::add(%23, %24, %24)
+  %26 : Float(*, *) = aten::mul(%25, %1)
+  %27 : Tensor = aten::mul(%26, %0)
+  %28 : Tensor = prim::FusedConcat[dim=1](%27, %22, %17, %12)
+  return (%28);
 }
index 3772edc..a10d0c3 100644 (file)
@@ -20,54 +20,58 @@ with prim::DifferentiableGraph_0 = graph(%0 : Float(*, *)
   %9 : Float(*, *) = aten::t(%4)
   %10 : Float(*, *) = aten::mm(%3, %9)
   %11 : int = prim::Constant[value=1]()
-  %12 : Tensor[] = prim::ListConstruct(%1, %2, %8, %10)
-  %13 : Tensor[] = aten::broadcast_tensors(%12)
-  %14 : Tensor, %15 : Tensor, %16 : Tensor, %17 : Tensor = prim::ListUnpack(%13)
-  %hy : Float(*, *), %19 : Float(*, *), %cy : Float(*, *), %outgate.1 : Float(*, *), %cellgate.1 : Float(*, *), %forgetgate.1 : Float(*, *), %ingate.1 : Float(*, *) = prim::FusionGroup_0(%0, %17, %16, %15, %14)
-  return (%hy, %cy, %7, %9, %ingate.1, %forgetgate.1, %cellgate.1, %outgate.1, %19);
+  %12 : int[] = aten::size(%8)
+  %13 : int[] = aten::size(%10)
+  %14 : int[] = aten::size(%2)
+  %15 : Float(*, *), %16 : Float(*, *) = prim::FusionGroup_0(%2, %8, %10)
+  %17 : int[] = aten::size(%16)
+  %18 : int[] = aten::size(%15)
+  %19 : int[] = aten::size(%1)
+  %20 : Tensor[] = prim::ListConstruct(%15, %1)
+  %21 : Tensor[] = aten::broadcast_tensors(%20)
+  %22 : Tensor, %23 : Tensor = prim::ListUnpack(%21)
+  %24 : int[] = aten::size(%0)
+  %hy : Float(*, *), %26 : Float(*, *), %cy : Float(*, *), %28 : Float(*, *), %29 : Float(*, *), %outgate.1 : Float(*, *), %cellgate.1 : Float(*, *), %forgetgate.1 : Float(*, *), %ingate.1 : Float(*, *) = prim::FusionGroup_1(%0, %23, %22)
+  %34 : int[] = aten::size(%ingate.1)
+  %35 : int[] = aten::size(%forgetgate.1)
+  %36 : int[] = aten::size(%cellgate.1)
+  %37 : int[] = aten::size(%outgate.1)
+  %38 : int[] = aten::size(%29)
+  %39 : int[] = aten::size(%28)
+  %40 : int[] = aten::size(%26)
+  return (%hy, %cy, %7, %9, %12, %13, %17, %14, %18, %19, %ingate.1, %forgetgate.1, %cellgate.1, %outgate.1, %35, %24, %34, %36, %38, %39, %26, %37, %40);
 }
-with prim::FusionGroup_0 = graph(%0 : Float(*, *)
+with prim::FusionGroup_0 = graph(%0 : Float(*)
+      %1 : Float(*, *)
+      %2 : Float(*, *)) {
+  %3 : int = prim::Constant[value=1]()
+  %4 : Float(*, *) = aten::add(%1, %2, %3)
+  %5 : int = prim::Constant[value=1]()
+  %6 : Float(*, *) = aten::add(%4, %0, %5)
+  return (%6, %4);
+}
+with prim::FusionGroup_1 = graph(%0 : Float(*, *)
       %1 : Tensor
-      %2 : Tensor
-      %3 : Tensor
-      %4 : Tensor) {
-  %5 : Float(*, *), %6 : Float(*, *), %7 : Float(*, *), %8 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%4)
-  %9 : Float(*, *), %10 : Float(*, *), %11 : Float(*, *), %12 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%3)
-  %13 : Float(*, *), %14 : Float(*, *), %15 : Float(*, *), %16 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%2)
-  %17 : Float(*, *), %18 : Float(*, *), %19 : Float(*, *), %20 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%1)
-  %21 : int = prim::Constant[value=1]()
-  %22 : Float(*, *) = aten::add(%13, %17, %21)
-  %23 : int = prim::Constant[value=1]()
-  %24 : Float(*, *) = aten::add(%14, %18, %23)
+      %2 : Tensor) {
+  %3 : Float(*, *), %4 : Float(*, *), %5 : Float(*, *), %6 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%2)
+  %7 : Float(*, *), %8 : Float(*, *), %9 : Float(*, *), %10 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%1)
+  %11 : int = prim::Constant[value=1]()
+  %12 : Float(*, *) = aten::add(%3, %7, %11)
+  %13 : int = prim::Constant[value=1]()
+  %14 : Float(*, *) = aten::add(%4, %8, %13)
+  %15 : int = prim::Constant[value=1]()
+  %16 : Float(*, *) = aten::add(%5, %9, %15)
+  %17 : int = prim::Constant[value=1]()
+  %18 : Float(*, *) = aten::add(%6, %10, %17)
+  %ingate.1 : Float(*, *) = aten::sigmoid(%12)
+  %forgetgate.1 : Float(*, *) = aten::sigmoid(%14)
+  %cellgate.1 : Float(*, *) = aten::tanh(%16)
+  %outgate.1 : Float(*, *) = aten::sigmoid(%18)
+  %23 : Float(*, *) = aten::mul(%forgetgate.1, %0)
+  %24 : Float(*, *) = aten::mul(%ingate.1, %cellgate.1)
   %25 : int = prim::Constant[value=1]()
-  %26 : Float(*, *) = aten::add(%15, %19, %25)
-  %27 : int = prim::Constant[value=1]()
-  %28 : Float(*, *) = aten::add(%16, %20, %27)
-  %29 : int = prim::Constant[value=1]()
-  %30 : Float(*, *) = aten::add(%22, %9, %29)
-  %31 : int = prim::Constant[value=1]()
-  %32 : Float(*, *) = aten::add(%24, %10, %31)
-  %33 : int = prim::Constant[value=1]()
-  %34 : Float(*, *) = aten::add(%26, %11, %33)
-  %35 : int = prim::Constant[value=1]()
-  %36 : Float(*, *) = aten::add(%28, %12, %35)
-  %37 : int = prim::Constant[value=1]()
-  %38 : Float(*, *) = aten::add(%30, %5, %37)
-  %39 : int = prim::Constant[value=1]()
-  %40 : Float(*, *) = aten::add(%32, %6, %39)
-  %41 : int = prim::Constant[value=1]()
-  %42 : Float(*, *) = aten::add(%34, %7, %41)
-  %43 : int = prim::Constant[value=1]()
-  %44 : Float(*, *) = aten::add(%36, %8, %43)
-  %ingate.1 : Float(*, *) = aten::sigmoid(%38)
-  %forgetgate.1 : Float(*, *) = aten::sigmoid(%40)
-  %cellgate.1 : Float(*, *) = aten::tanh(%42)
-  %outgate.1 : Float(*, *) = aten::sigmoid(%44)
-  %49 : Float(*, *) = aten::mul(%forgetgate.1, %0)
-  %50 : Float(*, *) = aten::mul(%ingate.1, %cellgate.1)
-  %51 : int = prim::Constant[value=1]()
-  %cy : Float(*, *) = aten::add(%49, %50, %51)
-  %53 : Float(*, *) = aten::tanh(%cy)
-  %hy : Float(*, *) = aten::mul(%outgate.1, %53)
-  return (%hy, %53, %cy, %outgate.1, %cellgate.1, %forgetgate.1, %ingate.1);
+  %cy : Float(*, *) = aten::add(%23, %24, %25)
+  %27 : Float(*, *) = aten::tanh(%cy)
+  %hy : Float(*, *) = aten::mul(%outgate.1, %27)
+  return (%hy, %27, %cy, %24, %23, %outgate.1, %cellgate.1, %forgetgate.1, %ingate.1);
 }
index c2c0bca..c53cb88 100644 (file)
@@ -20,87 +20,121 @@ graph(%0 : Float(*, *)
       %Wx : Float(*, *)
       %20 : Float(*, *)
       %Uz : Float(*, *)
-      %22 : Float(*, *)
+      %22 : int[]
+      %23 : int[]
+      %24 : Float(*, *)
+      %25 : int[]
+      %26 : int[]
+      %27 : int[]
+      %28 : int[]
+      %29 : int[]
+      %30 : int[]
+      %31 : int[]
+      %32 : int[]
+      %33 : int[]
+      %34 : int[]
+      %35 : int[]
+      %36 : int[]
       %ingate : Float(*, *)
       %forgetgate : Float(*, *)
       %cellgate : Float(*, *)
       %outgate : Float(*, *)
-      %27 : Float(*, *)) {
-  %28 : Float(*, *) = prim::FusionGroup_0(%ingate, %forgetgate, %cellgate, %outgate, %12, %1, %27, %0)
-  %29 : Float(*, *) = aten::mul(%28, %Wx)
-  %30 : Float(*, *) = prim::FusionGroup_1(%28, %22, %13)
-  %31 : Float(*, *), %32 : Float(*, *), %33 : Float(*, *) = prim::FusionGroup_2(%Wx, %15, %28, %14, %Uz)
-  %34 : Float(*, *) = aten::t(%16)
-  %35 : Float(*, *) = aten::mm(%34, %30)
-  %36 : Float(*, *) = aten::t(%35)
-  %37 : Float(*, *) = aten::t(%17)
-  %38 : Float(*, *) = aten::mm(%37, %32)
-  %39 : Float(*, *) = aten::t(%38)
-  return (%28, %33, %29, %31, %36, %39);
+      %41 : int[]
+      %42 : int[]
+      %43 : int[]
+      %44 : int[]
+      %45 : int[]
+      %46 : int[]
+      %47 : Float(*, *)
+      %48 : int[]
+      %49 : int[]) {
+  %50 : int = prim::Constant[value=1]()
+  %51 : Float(*, *) = aten::mul(%0, %47)
+  %52 : Tensor = prim::SumToSize(%51, %48)
+  %53 : Float(*, *) = aten::mul(%0, %outgate)
+  %54 : Tensor = prim::SumToSize(%53, %49)
+  %55 : Tensor = prim::FusionGroup_0(%1, %54, %47)
+  %56 : Tensor = prim::SumToSize(%55, %45)
+  %57 : Tensor = prim::SumToSize(%55, %46)
+  %58 : Tensor = aten::mul(%57, %cellgate)
+  %59 : Tensor = prim::SumToSize(%58, %43)
+  %60 : Tensor = aten::mul(%57, %ingate)
+  %61 : Tensor = prim::SumToSize(%60, %44)
+  %62 : Tensor = aten::mul(%56, %12)
+  %63 : Tensor = prim::SumToSize(%62, %41)
+  %64 : Tensor = prim::FusionGroup_1(%59, %ingate, %63, %forgetgate, %61, %cellgate, %52, %outgate)
+  %65 : Tensor = prim::SumToSize(%64, %35)
+  %66 : Tensor = prim::SumToSize(%64, %36)
+  %67 : Tensor = prim::SumToSize(%65, %33)
+  %68 : Tensor = prim::SumToSize(%65, %34)
+  %69 : Tensor = aten::mul(%68, %Uz)
+  %70 : Tensor = prim::SumToSize(%69, %31)
+  %71 : Tensor = aten::mul(%68, %13)
+  %72 : Tensor = prim::SumToSize(%71, %32)
+  %73 : Tensor = prim::SumToSize(%67, %29)
+  %74 : Tensor = prim::SumToSize(%67, %30)
+  %75 : Tensor = aten::mul(%74, %Wx)
+  %76 : Tensor = prim::SumToSize(%75, %27)
+  %77 : Tensor = aten::mul(%74, %14)
+  %78 : Tensor = prim::SumToSize(%77, %28)
+  %79 : Tensor = aten::mul(%73, %Uz)
+  %80 : Tensor = prim::SumToSize(%79, %25)
+  %81 : Tensor = aten::mul(%73, %24)
+  %82 : Tensor = prim::SumToSize(%81, %26)
+  %83 : Tensor = aten::add(%72, %82, %50)
+  %84 : Tensor = aten::mul(%80, %Wx)
+  %85 : Tensor = prim::SumToSize(%84, %22)
+  %86 : Tensor = aten::mul(%80, %15)
+  %87 : Tensor = prim::SumToSize(%86, %23)
+  %88 : Tensor = aten::add(%78, %87, %50)
+  %89 : Float(*, *) = aten::t(%16)
+  %90 : Float(*, *) = aten::mm(%89, %83)
+  %91 : Float(*, *) = aten::t(%90)
+  %92 : Float(*, *) = aten::t(%17)
+  %93 : Float(*, *) = aten::mm(%92, %88)
+  %94 : Float(*, *) = aten::t(%93)
+  return (%66, %70, %76, %85, %91, %94);
 }
 with prim::FusionGroup_0 = graph(%0 : Float(*, *)
+      %1 : Tensor
+      %2 : Float(*, *)) {
+  %3 : Float(*, *) = aten::mul(%2, %2)
+  %4 : Float(*, *) = aten::neg(%3)
+  %5 : int = prim::Constant[value=1]()
+  %6 : Float(*, *) = aten::add(%4, %5, %5)
+  %7 : Tensor = aten::mul(%1, %6)
+  %8 : int = prim::Constant[value=1]()
+  %9 : Tensor = aten::add(%0, %7, %8)
+  return (%9);
+}
+with prim::FusionGroup_1 = graph(%0 : Tensor
       %1 : Float(*, *)
-      %2 : Float(*, *)
+      %2 : Tensor
       %3 : Float(*, *)
-      %4 : Float(*, *)
+      %4 : Tensor
       %5 : Float(*, *)
-      %6 : Float(*, *)
+      %6 : Tensor
       %7 : Float(*, *)) {
-  %8 : Float(*, *) = aten::mul(%7, %6)
-  %9 : Float(*, *) = aten::mul(%7, %3)
-  %10 : Float(*, *) = aten::mul(%6, %6)
-  %11 : Float(*, *) = aten::neg(%10)
-  %12 : int = prim::Constant[value=1]()
-  %13 : Float(*, *) = aten::add(%11, %12, %12)
-  %14 : Float(*, *) = aten::mul(%9, %13)
-  %15 : int = prim::Constant[value=1]()
-  %16 : Float(*, *) = aten::add(%5, %14, %15)
-  %17 : Float(*, *) = aten::mul(%16, %2)
-  %18 : Float(*, *) = aten::mul(%16, %0)
-  %19 : Float(*, *) = aten::mul(%16, %4)
-  %20 : Float(*, *) = aten::neg(%3)
-  %21 : int = prim::Constant[value=1]()
-  %22 : Float(*, *) = aten::add(%20, %21, %21)
-  %23 : Float(*, *) = aten::mul(%22, %3)
-  %24 : Float(*, *) = aten::mul(%23, %8)
-  %25 : Float(*, *) = aten::mul(%2, %2)
-  %26 : Float(*, *) = aten::neg(%25)
-  %27 : int = prim::Constant[value=1]()
-  %28 : Float(*, *) = aten::add(%26, %27, %27)
-  %29 : Float(*, *) = aten::mul(%18, %28)
-  %30 : Float(*, *) = aten::neg(%1)
-  %31 : int = prim::Constant[value=1]()
-  %32 : Float(*, *) = aten::add(%30, %31, %31)
-  %33 : Float(*, *) = aten::mul(%32, %1)
-  %34 : Float(*, *) = aten::mul(%33, %19)
-  %35 : Float(*, *) = aten::neg(%0)
-  %36 : int = prim::Constant[value=1]()
-  %37 : Float(*, *) = aten::add(%35, %36, %36)
-  %38 : Float(*, *) = aten::mul(%37, %0)
-  %39 : Float(*, *) = aten::mul(%38, %17)
-  %40 : Float(*, *) = prim::FusedConcat[dim=1](%39, %34, %29, %24)
-  return (%40);
-}
-with prim::FusionGroup_1 = graph(%0 : Float(*, *)
-      %1 : Float(*, *)
-      %2 : Float(*)) {
-  %3 : Float(*, *) = aten::mul(%0, %2)
-  %4 : Float(*, *) = aten::mul(%0, %1)
-  %5 : int = prim::Constant[value=1]()
-  %6 : Float(*, *) = aten::add(%3, %4, %5)
-  return (%6);
-}
-with prim::FusionGroup_2 = graph(%0 : Float(*, *)
-      %1 : Float(*)
-      %2 : Float(*, *)
-      %3 : Float(*)
-      %4 : Float(*, *)) {
-  %5 : Float(*, *) = aten::mul(%2, %4)
-  %6 : Float(*, *) = aten::mul(%2, %3)
-  %7 : Float(*, *) = aten::mul(%5, %1)
-  %8 : int = prim::Constant[value=1]()
+  %8 : Float(*, *) = aten::neg(%7)
   %9 : int = prim::Constant[value=1]()
-  %10 : Float(*, *) = aten::add(%6, %7, %9)
-  %11 : Float(*, *) = aten::mul(%5, %0)
-  return (%11, %10, %5);
+  %10 : Float(*, *) = aten::add(%8, %9, %9)
+  %11 : Float(*, *) = aten::mul(%10, %7)
+  %12 : Tensor = aten::mul(%11, %6)
+  %13 : Float(*, *) = aten::mul(%5, %5)
+  %14 : Float(*, *) = aten::neg(%13)
+  %15 : int = prim::Constant[value=1]()
+  %16 : Float(*, *) = aten::add(%14, %15, %15)
+  %17 : Tensor = aten::mul(%4, %16)
+  %18 : Float(*, *) = aten::neg(%3)
+  %19 : int = prim::Constant[value=1]()
+  %20 : Float(*, *) = aten::add(%18, %19, %19)
+  %21 : Float(*, *) = aten::mul(%20, %3)
+  %22 : Tensor = aten::mul(%21, %2)
+  %23 : Float(*, *) = aten::neg(%1)
+  %24 : int = prim::Constant[value=1]()
+  %25 : Float(*, *) = aten::add(%23, %24, %24)
+  %26 : Float(*, *) = aten::mul(%25, %1)
+  %27 : Tensor = aten::mul(%26, %0)
+  %28 : Tensor = prim::FusedConcat[dim=1](%27, %22, %17, %12)
+  return (%28);
 }
index 44cb7da..2e2726c 100644 (file)
@@ -23,72 +23,71 @@ with prim::DifferentiableGraph_0 = graph(%0 : Float(*, *)
   %Wx.1 : Float(*, *) = aten::mm(%7, %9)
   %11 : Float(*, *) = aten::t(%6)
   %Uz.1 : Float(*, *) = aten::mm(%5, %11)
-  %13 : Float(*, *) = aten::mul(%4, %Wx.1)
-  %14 : int = prim::Constant[value=1]()
-  %15 : Tensor[] = prim::ListConstruct(%1, %2, %Uz.1, %13, %3, %Wx.1)
-  %16 : Tensor[] = aten::broadcast_tensors(%15)
-  %17 : Tensor, %18 : Tensor, %19 : Tensor, %20 : Tensor, %21 : Tensor, %22 : Tensor = prim::ListUnpack(%16)
-  %hy : Float(*, *), %24 : Float(*, *), %cy : Float(*, *), %outgate.1 : Float(*, *), %cellgate.1 : Float(*, *), %forgetgate.1 : Float(*, *), %ingate.1 : Float(*, *) = prim::FusionGroup_0(%0, %22, %21, %20, %19, %18, %17)
-  return (%hy, %cy, %9, %Wx.1, %11, %Uz.1, %13, %ingate.1, %forgetgate.1, %cellgate.1, %outgate.1, %24);
+  %13 : int[] = aten::size(%4)
+  %14 : int[] = aten::size(%Wx.1)
+  %15 : int[] = aten::size(%Uz.1)
+  %16 : int[] = aten::size(%3)
+  %17 : int = prim::Constant[value=1]()
+  %18 : int[] = aten::size(%2)
+  %19 : Float(*, *), %20 : Float(*, *), %21 : Float(*, *), %22 : Float(*, *), %23 : Float(*, *), %24 : Float(*, *) = prim::FusionGroup_0(%2, %Uz.1, %3, %Wx.1, %4)
+  %25 : int[] = aten::size(%24)
+  %26 : int[] = aten::size(%23)
+  %27 : int[] = aten::size(%22)
+  %28 : int[] = aten::size(%21)
+  %29 : int[] = aten::size(%20)
+  %30 : int[] = aten::size(%19)
+  %31 : int[] = aten::size(%1)
+  %32 : Tensor[] = prim::ListConstruct(%19, %1)
+  %33 : Tensor[] = aten::broadcast_tensors(%32)
+  %34 : Tensor, %35 : Tensor = prim::ListUnpack(%33)
+  %36 : int[] = aten::size(%0)
+  %hy : Float(*, *), %38 : Float(*, *), %cy : Float(*, *), %40 : Float(*, *), %41 : Float(*, *), %outgate.1 : Float(*, *), %cellgate.1 : Float(*, *), %forgetgate.1 : Float(*, *), %ingate.1 : Float(*, *) = prim::FusionGroup_1(%0, %35, %34)
+  %46 : int[] = aten::size(%ingate.1)
+  %47 : int[] = aten::size(%forgetgate.1)
+  %48 : int[] = aten::size(%cellgate.1)
+  %49 : int[] = aten::size(%outgate.1)
+  %50 : int[] = aten::size(%41)
+  %51 : int[] = aten::size(%40)
+  %52 : int[] = aten::size(%38)
+  return (%hy, %cy, %9, %Wx.1, %11, %Uz.1, %13, %14, %24, %25, %15, %16, %14, %26, %27, %18, %15, %28, %29, %30, %31, %ingate.1, %forgetgate.1, %cellgate.1, %outgate.1, %47, %36, %46, %48, %50, %51, %38, %49, %52);
 }
-with prim::FusionGroup_0 = graph(%0 : Float(*, *)
+with prim::FusionGroup_0 = graph(%0 : Float(*)
+      %1 : Float(*, *)
+      %2 : Float(*)
+      %3 : Float(*, *)
+      %4 : Float(*)) {
+  %5 : Float(*, *) = aten::mul(%4, %3)
+  %6 : Float(*, *) = aten::mul(%5, %1)
+  %7 : Float(*, *) = aten::mul(%2, %3)
+  %8 : int = prim::Constant[value=1]()
+  %9 : Float(*, *) = aten::add(%6, %7, %8)
+  %10 : Float(*, *) = aten::mul(%0, %1)
+  %11 : int = prim::Constant[value=1]()
+  %12 : Float(*, *) = aten::add(%9, %10, %11)
+  return (%12, %10, %9, %7, %6, %5);
+}
+with prim::FusionGroup_1 = graph(%0 : Float(*, *)
       %1 : Tensor
-      %2 : Tensor
-      %3 : Tensor
-      %4 : Tensor
-      %5 : Tensor
-      %6 : Tensor) {
-  %7 : Float(*, *), %8 : Float(*, *), %9 : Float(*, *), %10 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%6)
-  %11 : Float(*, *), %12 : Float(*, *), %13 : Float(*, *), %14 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%5)
-  %15 : Float(*, *), %16 : Float(*, *), %17 : Float(*, *), %18 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%4)
-  %19 : Float(*, *), %20 : Float(*, *), %21 : Float(*, *), %22 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%3)
-  %23 : Float(*, *), %24 : Float(*, *), %25 : Float(*, *), %26 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%2)
-  %27 : Float(*, *), %28 : Float(*, *), %29 : Float(*, *), %30 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%1)
-  %31 : Float(*, *) = aten::mul(%23, %27)
-  %32 : Float(*, *) = aten::mul(%24, %28)
-  %33 : Float(*, *) = aten::mul(%25, %29)
-  %34 : Float(*, *) = aten::mul(%26, %30)
-  %35 : Float(*, *) = aten::mul(%19, %15)
-  %36 : Float(*, *) = aten::mul(%20, %16)
-  %37 : Float(*, *) = aten::mul(%21, %17)
-  %38 : Float(*, *) = aten::mul(%22, %18)
-  %39 : Float(*, *) = aten::mul(%11, %15)
-  %40 : Float(*, *) = aten::mul(%12, %16)
-  %41 : Float(*, *) = aten::mul(%13, %17)
-  %42 : Float(*, *) = aten::mul(%14, %18)
-  %43 : int = prim::Constant[value=1]()
-  %44 : Float(*, *) = aten::add(%35, %31, %43)
-  %45 : int = prim::Constant[value=1]()
-  %46 : Float(*, *) = aten::add(%36, %32, %45)
-  %47 : int = prim::Constant[value=1]()
-  %48 : Float(*, *) = aten::add(%37, %33, %47)
-  %49 : int = prim::Constant[value=1]()
-  %50 : Float(*, *) = aten::add(%38, %34, %49)
-  %51 : int = prim::Constant[value=1]()
-  %52 : Float(*, *) = aten::add(%44, %39, %51)
-  %53 : int = prim::Constant[value=1]()
-  %54 : Float(*, *) = aten::add(%46, %40, %53)
-  %55 : int = prim::Constant[value=1]()
-  %56 : Float(*, *) = aten::add(%48, %41, %55)
-  %57 : int = prim::Constant[value=1]()
-  %58 : Float(*, *) = aten::add(%50, %42, %57)
-  %59 : int = prim::Constant[value=1]()
-  %60 : Float(*, *) = aten::add(%52, %7, %59)
-  %61 : int = prim::Constant[value=1]()
-  %62 : Float(*, *) = aten::add(%54, %8, %61)
-  %63 : int = prim::Constant[value=1]()
-  %64 : Float(*, *) = aten::add(%56, %9, %63)
-  %65 : int = prim::Constant[value=1]()
-  %66 : Float(*, *) = aten::add(%58, %10, %65)
-  %ingate.1 : Float(*, *) = aten::sigmoid(%60)
-  %forgetgate.1 : Float(*, *) = aten::sigmoid(%62)
-  %cellgate.1 : Float(*, *) = aten::tanh(%64)
-  %outgate.1 : Float(*, *) = aten::sigmoid(%66)
-  %71 : Float(*, *) = aten::mul(%forgetgate.1, %0)
-  %72 : Float(*, *) = aten::mul(%ingate.1, %cellgate.1)
-  %73 : int = prim::Constant[value=1]()
-  %cy : Float(*, *) = aten::add(%71, %72, %73)
-  %75 : Float(*, *) = aten::tanh(%cy)
-  %hy : Float(*, *) = aten::mul(%outgate.1, %75)
-  return (%hy, %75, %cy, %outgate.1, %cellgate.1, %forgetgate.1, %ingate.1);
+      %2 : Tensor) {
+  %3 : Float(*, *), %4 : Float(*, *), %5 : Float(*, *), %6 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%2)
+  %7 : Float(*, *), %8 : Float(*, *), %9 : Float(*, *), %10 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%1)
+  %11 : int = prim::Constant[value=1]()
+  %12 : Float(*, *) = aten::add(%3, %7, %11)
+  %13 : int = prim::Constant[value=1]()
+  %14 : Float(*, *) = aten::add(%4, %8, %13)
+  %15 : int = prim::Constant[value=1]()
+  %16 : Float(*, *) = aten::add(%5, %9, %15)
+  %17 : int = prim::Constant[value=1]()
+  %18 : Float(*, *) = aten::add(%6, %10, %17)
+  %ingate.1 : Float(*, *) = aten::sigmoid(%12)
+  %forgetgate.1 : Float(*, *) = aten::sigmoid(%14)
+  %cellgate.1 : Float(*, *) = aten::tanh(%16)
+  %outgate.1 : Float(*, *) = aten::sigmoid(%18)
+  %23 : Float(*, *) = aten::mul(%forgetgate.1, %0)
+  %24 : Float(*, *) = aten::mul(%ingate.1, %cellgate.1)
+  %25 : int = prim::Constant[value=1]()
+  %cy : Float(*, *) = aten::add(%23, %24, %25)
+  %27 : Float(*, *) = aten::tanh(%cy)
+  %hy : Float(*, *) = aten::mul(%outgate.1, %27)
+  return (%hy, %27, %cy, %24, %23, %outgate.1, %cellgate.1, %forgetgate.1, %ingate.1);
 }
index 6b3f1c8..653750d 100644 (file)
@@ -358,7 +358,7 @@ class JitTestCase(TestCase):
                 vs = vs[:-drop]
             # we don't want all the grad for all the outputs to be the same
             # so we multiply each by a constant
-            return sum([(i + 1) * v.sum() for i, v in enumerate(vs) if v is not None])
+            return sum([math.log(i + 2) * v.sum() for i, v in enumerate(vs) if v is not None])
         if input_tensors is None:
             input_tensors = reference_tensors
 
@@ -433,10 +433,11 @@ class JitTestCase(TestCase):
 
         return ge
 
-    def assertAllFused(self, graph):
+    def assertAllFused(self, graph, except_for=()):
         if [n.kind() for n in graph.nodes()] == ['prim::DifferentiableGraph']:
             graph = next(graph.nodes()).g('Subgraph')
-        self.assertTrue(all(node.kind() in {'prim::Constant', 'prim::FusionGroup'} for node in graph.nodes()),
+        allowed_nodes = {'prim::Constant', 'prim::FusionGroup'} | set(except_for)
+        self.assertTrue(all(node.kind() in allowed_nodes for node in graph.nodes()),
                         'got {}'.format(graph))
         self.assertTrue([node.kind() for node in graph.nodes()].count('prim::FusionGroup') == 1)
 
@@ -3225,12 +3226,12 @@ a")
         funcs = (func2, funcInf)
         for f in funcs:
             s = self.checkScript(f, (a, b))
-            self.assertAllFused(s.graph_for(a, b))
+            self.assertAllFused(s.graph_for(a, b), except_for={'aten::size'})
 
             c = s(a, b)
             c.sum().backward()
             graph = backward_graph(s)
-            self.assertAllFused(graph)
+            self.assertAllFused(graph, except_for={'prim::SumToSize'})
 
     def test_mul(self):
         def func(a, b):
@@ -4212,8 +4213,9 @@ a")
         inputs = get_lstm_inputs('cuda', training=True)
         module = self.checkScript(LSTMCellS, inputs)
         forward_graph = module.graph_for(*inputs)
-        self.assertGraphContainsExactly(
-            forward_graph, 'prim::FusionGroup', 1, consider_subgraphs=True)
+        with self.assertRaises(AssertionError):
+            self.assertGraphContainsExactly(
+                forward_graph, 'prim::FusionGroup', 1, consider_subgraphs=True)
         self.assertExpectedGraph(forward_graph, subname='forward')
 
         hy, cy = module(*inputs)
@@ -4227,8 +4229,9 @@ a")
         inputs = get_milstm_inputs('cuda', training=True)
         module = self.checkScript(MiLSTMCell, inputs)
         forward_graph = module.graph_for(*inputs)
-        self.assertGraphContainsExactly(
-            forward_graph, 'prim::FusionGroup', 1, consider_subgraphs=True)
+        with self.assertRaises(AssertionError):
+            self.assertGraphContainsExactly(
+                forward_graph, 'prim::FusionGroup', 1, consider_subgraphs=True)
         self.assertExpectedGraph(forward_graph, subname='forward')
 
         hy, cy = module(*inputs)
index 1bc1586..01e24be 100644 (file)
@@ -340,7 +340,7 @@ static void validate_outputs(const edge_list& edges, variable_list& grads, const
         ss << metadata.shape();
         AT_ERROR(format_error(ss.str()));
       }
-      grads[i] = at::sum_to(grads[i], metadata.shape());
+      grads[i] = at::sum_to(std::move(grads[i]), metadata.shape());
     }
     if (!is_compatible_type(metadata.type(), grads[i].type())) {
       std::stringstream ss;
index 4aba7b6..731753f 100644 (file)
@@ -124,6 +124,9 @@ struct ArgumentSpec {
   size_t size() const {
     return args.size();
   }
+  const ArgumentInfo& at(size_t i) const {
+    return args[i];
+  }
   size_t hashCode() const {
     return hash_code;
   }
@@ -147,6 +150,7 @@ private:
         return fillType(subtype, offset);
       }));
     } else {
+      offset++;
       return original;
     }
   }
@@ -314,6 +318,28 @@ private:
   const int i;
 };
 
+inline std::ostream & operator<<(std::ostream & out, const ArgumentInfo & info) {
+  if(!info.defined()) {
+    return out << "<undefined>";
+  }
+  out << "Tensor(device=" << info.device()
+    << ", type=" << toString(info.type())
+    << ", requires_grad=" << info.requires_grad()
+    << ", dims=" << info.dim() << ")";
+  return out;
+}
+
+inline std::ostream& operator<<(std::ostream & out, const ArgumentSpec & spec) {
+  out << "{";
+  for(size_t i = 0; i < spec.size(); ++i) {
+    if (i > 0)
+      out << ", ";
+    out << spec.at(i);
+  }
+  out << "}";
+  return out;
+}
+
 inline std::ostream & operator<<(std::ostream & out, const CompleteArgumentInfo & info) {
   if(!info.defined()) {
     return out << "<undefined>";
index 8d95b0b..eb5bb21 100644 (file)
@@ -78,6 +78,7 @@ bool isDifferentiable(Node * n) {
     "aten::sinh(Tensor self) -> Tensor",
     "aten::tan(Tensor self) -> Tensor",
     "aten::trunc(Tensor self) -> Tensor",
+    "prim::SumToSize(Tensor(a) self, int[] size) -> Tensor(a)",
     "aten::log_softmax(Tensor self, int dim) -> Tensor",
     "aten::avg_pool2d(Tensor self, int[] kernel_size, int[] stride, int[] padding, bool ceil_mode, bool count_include_pad) -> Tensor",
     "aten::max_pool2d_with_indices(Tensor self, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode) -> (Tensor, Tensor)",
@@ -137,48 +138,64 @@ static std::vector<Value*> gradientForNode(Node* node, ArrayRef<Value*> grad_val
     "aten::eq(Tensor self, Tensor other) -> Tensor",
     "aten::ne(Tensor self, Tensor other) -> Tensor"
   };
-  const auto build_sym_grad = [node](const std::vector<SymbolicVariable>& grads) -> std::vector<SymbolicVariable> {
+  const auto sumToSizeOf = [node](SymbolicVariable v, Symbol input_name) -> SymbolicVariable {
+    Value * size;
+    {
+      WithInsertPoint insert_guard {node};
+      size = SymbolicVariable(node->namedInput(input_name)).size();
+    }
+    return v.sumToSize(size);
+  };
+  const auto build_sym_grad = [node, &sumToSizeOf](const std::vector<SymbolicVariable>& grads) -> std::vector<SymbolicVariable> {
     auto inputs = fmap<SymbolicVariable>(node->inputs());
     auto outputs = fmap<SymbolicVariable>(node->outputs());
 
     if (node->matches("aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor")) {
-      return {grads.at(0), grads.at(0) * node->namedInput(attr::alpha), nullptr};
+      return {sumToSizeOf(grads.at(0), attr::self),
+              sumToSizeOf(grads.at(0) * node->namedInput(attr::alpha), attr::other),
+              nullptr};
 
     } else if (node->matches("aten::add(Tensor self, Scalar other, Scalar alpha) -> Tensor")) {
       return {grads.at(0), nullptr, nullptr};
 
     } else if (node->kind() == prim::AutogradAdd) {
+      // NB: AutogradAdds don't broadcast
       return {grads.at(0), grads.at(0)};
 
     } else if (node->matches("aten::sub(Tensor self, Tensor other, *, Scalar alpha) -> Tensor")) {
-      return {grads.at(0), -grads.at(0) * node->namedInput(attr::alpha), nullptr};
+      return {sumToSizeOf(grads.at(0), attr::self),
+              sumToSizeOf(-grads.at(0) * node->namedInput(attr::alpha), attr::other),
+              nullptr};
 
     } else if (node->matches("aten::sub(Tensor self, Scalar other, Scalar alpha) -> Tensor")) {
       return {grads.at(0), nullptr, nullptr};
 
     } else if (node->matches("aten::mul(Tensor self, Tensor other) -> Tensor")) {
-      return {grads.at(0) * inputs.at(1), grads.at(0) * inputs.at(0)};
+      return {sumToSizeOf(grads.at(0) * inputs.at(1), attr::self),
+              sumToSizeOf(grads.at(0) * inputs.at(0), attr::other)};
 
     } else if (node->matches("aten::mul(Tensor self, Scalar other) -> Tensor")) {
       return {grads.at(0) * inputs.at(1), nullptr};
 
     } else if (node->matches("aten::div(Tensor self, Tensor other) -> Tensor")) {
-      return {grads.at(0) / inputs.at(1), -grads.at(0) * inputs.at(0) / (inputs.at(1) * inputs.at(1))};
+      return {sumToSizeOf(grads.at(0) / inputs.at(1), attr::self),
+              sumToSizeOf(-grads.at(0) * inputs.at(0) / (inputs.at(1) * inputs.at(1)), attr::other)};
 
     } else if (node->matches("aten::div(Tensor self, Scalar other) -> Tensor")) {
       return {grads.at(0) / inputs.at(1), nullptr};
 
     } else if (node->matches("aten::max(Tensor self, Tensor other) -> Tensor")) {
-      return {grads.at(0) * (inputs.at(0) > inputs.at(1)).type_as(grads.at(0)),
-              grads.at(0) * (inputs.at(1) > inputs.at(0)).type_as(grads.at(0))};
+      return {sumToSizeOf(grads.at(0) * (inputs.at(0) > inputs.at(1)).type_as(grads.at(0)), attr::self),
+              sumToSizeOf(grads.at(0) * (inputs.at(1) > inputs.at(0)).type_as(grads.at(0)), attr::other)};
 
     } else if (node->matches("aten::min(Tensor self, Tensor other) -> Tensor")) {
-      return {grads.at(0) * (inputs.at(0) < inputs.at(1)).type_as(grads.at(0)),
-              grads.at(0) * (inputs.at(1) < inputs.at(0)).type_as(grads.at(0))};
+      return {sumToSizeOf(grads.at(0) * (inputs.at(0) < inputs.at(1)).type_as(grads.at(0)), attr::self),
+              sumToSizeOf(grads.at(0) * (inputs.at(1) < inputs.at(0)).type_as(grads.at(0)), attr::other)};
 
     } else if (node->matches("aten::where(Tensor condition, Tensor self, Tensor other) -> Tensor")) {
-      return {nullptr, grads.at(0) * inputs.at(0).type_as(grads.at(0)),
-                       grads.at(0) * (1 - inputs.at(0)).type_as(grads.at(0))};
+      return {nullptr,
+              sumToSizeOf(grads.at(0) * inputs.at(0).type_as(grads.at(0)), attr::self),
+              sumToSizeOf(grads.at(0) * (1 - inputs.at(0)).type_as(grads.at(0)), attr::other)};
 
     } else if (node->matches("aten::sigmoid(Tensor self) -> Tensor")) {
       // TODO: The order of operations matter in this case. This
@@ -239,6 +256,14 @@ static std::vector<Value*> gradientForNode(Node* node, ArrayRef<Value*> grad_val
     } else if (node->matches("aten::atan(Tensor self) -> Tensor")) {
       return {grads.at(0) / (inputs.at(0) * inputs.at(0) + at::Scalar(1))};
 
+    } else if (node->matches("prim::SumToSize(Tensor(a) self, int[] size) -> Tensor(a)")) {
+      Value * self_size;
+      {
+        WithInsertPoint insert_guard { node };
+        self_size = inputs.at(0).size();
+      }
+      return {grads.at(0).expand(self_size), nullptr};
+
     } else if (node->matches("aten::ceil(Tensor self) -> Tensor")) {
       return {SymbolicVariable::zeros_like(grads.at(0))};
 
@@ -315,7 +340,7 @@ static std::vector<Value*> gradientForNode(Node* node, ArrayRef<Value*> grad_val
       return {grads.at(0).squeeze(node->namedInput(attr::dim)), nullptr};
 
     } else if (node->matches("aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta, Scalar alpha) -> Tensor")) {
-      return {grads.at(0) * node->namedInput(attr::beta),
+      return {sumToSizeOf(grads.at(0) * node->namedInput(attr::beta), attr::self),
               grads.at(0).mm(inputs.at(2).t()) * node->namedInput(attr::alpha),
               inputs.at(1).t().mm(grads.at(0)) * node->namedInput(attr::alpha),
               nullptr, nullptr};
@@ -421,15 +446,15 @@ static std::vector<Value*> gradientForNode(Node* node, ArrayRef<Value*> grad_val
     } else if (node->matches("aten::thnn_conv2d_forward(Tensor self, Tensor weight, int[] kernel_size, Tensor? bias, int[] stride, int[] padding) -> (Tensor, Tensor, Tensor)")) {
       auto graph = node->owningGraph();
       auto backward_value = graph->insert(aten::thnn_conv2d_backward, {
-       grads.at(0).value(),
-       inputs.at(0).value(),
-       inputs.at(1).value(),
-       node->namedInput(attr::kernel_size),
-       node->namedInput(attr::stride),
-       node->namedInput(attr::padding),
-       outputs.at(1).value(),
-       outputs.at(2).value(),
-       graph->insertConstant(std::vector<bool>{true, true, true})
+        grads.at(0).value(),
+        inputs.at(0).value(),
+        inputs.at(1).value(),
+        node->namedInput(attr::kernel_size),
+        node->namedInput(attr::stride),
+        node->namedInput(attr::padding),
+        outputs.at(1).value(),
+        outputs.at(2).value(),
+        graph->insertConstant(std::vector<bool>{true, true, true})
       });
       // graph->insert returns a tuple automatically if multiple outputs are returned. So unpack them again.
       Node* tuple_unpack_node = graph->insertNode(graph->createTupleUnpack(backward_value));
@@ -440,16 +465,16 @@ static std::vector<Value*> gradientForNode(Node* node, ArrayRef<Value*> grad_val
     } else if (node->matches("aten::native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)")) {
       auto graph = node->owningGraph();
       auto backward_value = graph->insert(aten::native_batch_norm_backward, {
-       grads.at(0).value(),
-       inputs.at(0).value(),
-       inputs.at(1).value(),
-       inputs.at(3).value(),
-       inputs.at(4).value(),
-       outputs.at(1).value(),
-       outputs.at(2).value(),
-       inputs.at(5).value(),
-       inputs.at(7).value(),
-       graph->insertConstant(std::vector<bool>{true, true, true})
+        grads.at(0).value(),
+        inputs.at(0).value(),
+        inputs.at(1).value(),
+        inputs.at(3).value(),
+        inputs.at(4).value(),
+        outputs.at(1).value(),
+        outputs.at(2).value(),
+        inputs.at(5).value(),
+        inputs.at(7).value(),
+        graph->insertConstant(std::vector<bool>{true, true, true})
       });
       // graph->insert returns a tuple automatically if multiple outputs are returned. So unpack them again.
       Node* tuple_unpack_node = graph->insertNode(graph->createTupleUnpack(backward_value));
@@ -775,16 +800,16 @@ Gradient differentiate(std::shared_ptr<Graph>& graph) {
   auto rev_info = addReverseInline(grad_desc);
   // Lift constants captured for the reverse graph into it
   liftConstants(grad_desc, rev_info);
-  // addReverseInline has to call gradientForNode if *any* of the outputs
-  // require grad, but it will emit vjps for *all* outputs. Use DCE to remove
-  // unnecessary nodes.
-  EliminateDeadCode(rev_info.reverse_block);
   // Fills in f, df, f_real_outputs, df_input_captures,
   // modifies df_input_vjps (new vjps are added for temporaries)
   lambdaLiftReverse(grad_desc, rev_info);
   // It's possible the we've cloned the same constants many times, so
   // de-duplicate them
   ConstantPooling(grad_desc.df);
+  // addReverseInline has to call gradientForNode if *any* of the inputs
+  // require grad, but it will emit vjps for *all* inputs. Use DCE to remove
+  // unnecessary nodes.
+  EliminateDeadCode(grad_desc.df);
   return grad_desc;
 }
 
index 3e802a9..5c85722 100644 (file)
@@ -9,9 +9,11 @@
 #include "torch/csrc/jit/operator.h"
 #include "torch/csrc/jit/custom_operator.h"
 #include "torch/csrc/jit/script/jit_exception.h"
-
 #include "torch/csrc/variable_tensor_functions.h"
 
+#include <ATen/ExpandUtils.h>
+#include <c10/util/SmallVector.h>
+
 #include <exception>
 #include <iostream>
 #include <limits>
@@ -361,6 +363,17 @@ RegisterOperators reg({
           };
         }),
     Operator(
+        "prim::SumToSize(Tensor(a) self, int[] size) -> Tensor(a)",
+        [](const Node* node) {
+          return [=](Stack& stack) {
+            at::Tensor self;
+            Shared<IntList> desired_sizes;
+            pop(stack, self, desired_sizes);
+            push(stack, at::sum_to(std::move(self), desired_sizes->elements()));
+            return 0;
+          };
+        }),
+    Operator(
         prim::TupleUnpack,
         [](const Node* node) {
           size_t num_elems = node->outputs().size();
index 9a34502..d1eccbc 100644 (file)
@@ -33,8 +33,8 @@ RegisterOperators reg({
         "aten::size(Tensor self) -> int[]",
         [](Stack& stack) {
           autograd::profiler::RecordFunction record("sizes");
-          auto result = (std::move(pop(stack))).toTensor().sizes();
-          pack(stack, std::move(result));
+          auto t = std::move(pop(stack)).toTensor();
+          pack(stack, t.sizes().vec());
           return 0;
         }),
     Operator(
index a625f60..1d4b82f 100644 (file)
@@ -123,6 +123,15 @@ struct SymbolicVariable {
   SymbolicVariable operator%(at::Scalar rhs) const {
     return create(aten::remainder, {*this, insertConstant(rhs)})[0].typeLike(*this);
   }
+  Value* size() const {
+    return v->owningGraph()->insert(aten::size, {v});
+  }
+  SymbolicVariable sumToSize(Value * size) const {
+    return create(prim::SumToSize, {*this, size})[0];
+  }
+  SymbolicVariable expand(Value * size) const {
+    return v->owningGraph()->insert(aten::expand, {v, size});
+  }
   SymbolicVariable isnan() const {
     return create(aten::ne, {*this, *this})[0].typeLikeWithScalarType(*this, at::kByte);
   }