Improvements for current AD (#17187)
authorAiling Zhang <ailzhang@fb.com>
Fri, 22 Feb 2019 22:19:04 +0000 (14:19 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 22 Feb 2019 22:34:14 +0000 (14:34 -0800)
Summary:
This PR removes a few size of `self` that passed from forward pass to backward pass when `self` is already required in backward pass. This could be reason that cause the potential slow down in #16689 . I will attach a few perf numbers (still a bit volatile among runs tho) I got in the comment.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17187

Differential Revision: D14179512

Pulled By: ailzhang

fbshipit-source-id: 5f3b1f6f26a3fef6dec15623b940380cc13656fa

16 files changed:
aten/src/ATen/native/ReduceOps.cpp
aten/src/ATen/native/TensorCompare.cpp
aten/src/ATen/native/TensorShape.cpp
aten/src/ATen/native/native_functions.yaml
test/cpp/jit/test_misc.h
test/expect/TestFuser.test_lstm_cuda-backward.expect
test/expect/TestFuser.test_lstm_cuda-forward.expect
test/expect/TestFuser.test_milstm_cuda-backward.expect
test/expect/TestFuser.test_milstm_cuda-forward.expect
test/expect/TestJit.test_cpp_cuda.expect
test/test_jit.py
tools/autograd/derivatives.yaml
tools/autograd/templates/Functions.cpp
torch/csrc/jit/autodiff.cpp
torch/csrc/jit/register_prim_ops.cpp
torch/csrc/jit/symbolic_script.cpp

index 88c78a8..27c2566 100644 (file)
@@ -335,42 +335,6 @@ Tensor& sum_out(Tensor& result, const Tensor& self, IntArrayRef dim, ScalarType
   return at::native::sum_out(result, self, dim, false, dtype);
 }
 
-int64_t _safe_size(IntArrayRef sizes, IntArrayRef dim) {
-  int64_t size = 1;
-  if (sizes.size() == 0) {
-    return 1;
-  }
-  for (auto d : dim) {
-    d = at::maybe_wrap_dim(d, sizes.size());
-    size *= sizes[d];
-  }
-  return size;
-}
-
-Tensor unsqueeze_multiple(const Tensor & t, IntArrayRef dim, size_t n_dims) {
-    auto dims_to_unsqueeze = at::dim_list_to_bitset(dim, n_dims);
-    Tensor res = t;
-    for (size_t i = 0; i < n_dims; i++){
-      if (dims_to_unsqueeze[i]) {
-        res = res.unsqueeze(i);
-      }
-    }
-    return res;
-}
-
-Tensor sum_backward(const Tensor & grad, IntArrayRef sizes, IntArrayRef dims, bool keepdim) {
-  if (!keepdim && sizes.size() > 0) {
-    if (dims.size()==1) {
-      return grad.unsqueeze(dims[0]).expand(sizes);
-    } else {
-      Tensor res = unsqueeze_multiple(grad, dims, sizes.size());
-      return res.expand(sizes);
-    }
-  } else {
-    return grad.expand(sizes);
-  }
-}
-
 Tensor& prod_out(Tensor& result, const Tensor& self, int64_t dim, bool keepdim, ScalarType dtype) {
   return at::native::prod_out(
       result, self, dim, keepdim, c10::optional<ScalarType>(dtype));
@@ -452,16 +416,6 @@ Tensor logsumexp(const Tensor &self, IntArrayRef dims, bool keepdim) {
   return at::native::logsumexp_out(result, self, dims, keepdim);
 }
 
-Tensor logsumexp_backward(const Tensor& grad, const Tensor & self, const Tensor& res, IntArrayRef dim, bool keepdim) {
-  Tensor grad_input = grad;
-  Tensor fwd_res = res;
-  if (!keepdim && self.dim() != 0) {
-    grad_input = unsqueeze_multiple(grad, dim, self.sizes().size());
-    fwd_res = unsqueeze_multiple(res, dim, self.sizes().size());
-  }
-  return grad_input * (self - fwd_res).exp();
-}
-
 static Tensor& norm_out(Tensor &result, const Tensor &self, optional<Scalar> opt_p,
                                IntArrayRef dim, bool keepdim, optional<ScalarType> opt_dtype) {
   auto p = opt_p.value_or(2.0);
@@ -674,21 +628,6 @@ Tensor &var_out(Tensor &result, const Tensor &self, IntArrayRef dim, bool unbias
   return std_var_out(result, self, dim, unbiased, keepdim, false);
 }
 
-Tensor var_backward(const Tensor & grad, const Tensor & self, bool unbiased) {
-  return (2.0 / (self.numel() - unbiased)) * grad * (self - self.mean());
-}
-
-Tensor var_backward(const Tensor & grad, const Tensor & self, IntArrayRef dim, bool unbiased, bool keepdim) {
-  if (self.dim() == 0) {
-    return at::var_backward(grad, self, unbiased);
-  }
-  Tensor unsqueezed_grad = grad;
-  if (!keepdim && self.dim() > 1) {
-    unsqueezed_grad = unsqueeze_multiple(grad, dim, self.sizes().size());
-  }
-  return (2.0 / (at::_safe_size(self.sizes(), dim) - unbiased)) * unsqueezed_grad * (self - self.mean(dim, true));
-}
-
 Tensor std(const Tensor& self, bool unbiased) {
   AT_CHECK(self.type().backend() == Backend::CPU || self.type().backend() == Backend::CUDA,
            "std only supports CPU AND CUDA backend, got: ", toString(self.type().backend()));
index 04ee3dd..0f58f4f 100644 (file)
@@ -34,14 +34,6 @@ namespace at { namespace native {
 DEFINE_DISPATCH(max_kernel);
 DEFINE_DISPATCH(min_kernel);
 
-Tensor index_select_backward(const Tensor& grad, int64_t dim, const Tensor& indices, IntArrayRef sizes, bool keepdim) {
-  Tensor res = at::zeros(sizes, grad.options());
-  if (!keepdim && sizes.size() > 0) {
-    return res.scatter_(dim, indices.unsqueeze(dim), grad.unsqueeze(dim));
-  }
-  return res.scatter_(dim, indices, grad);
-}
-
 bool allclose(const Tensor& self, const Tensor& other, double rtol, double atol, bool equal_nan) {
   return at::isclose(self, other, rtol, atol, equal_nan).all().item<uint8_t>();
 }
index bddf40a..2e7b5e6 100644 (file)
@@ -384,16 +384,6 @@ Tensor permute(const Tensor& self, IntArrayRef dims) {
   return self.as_strided(newSizes, newStrides);
 }
 
-Tensor permute_backwards(const Tensor & grad, IntArrayRef fwd_dims) {
-  // invert the permutation
-  auto ndims = fwd_dims.size();
-  std::vector<int64_t> dims(ndims);
-  for (size_t i = 0; i < ndims; i++) {
-    dims[at::maybe_wrap_dim(fwd_dims[i], ndims)] = i;
-  }
-  return grad.permute(dims);
-}
-
 Tensor repeat(const Tensor& self, IntArrayRef repeats) {
   AT_CHECK(repeats.size() >= (size_t)self.dim(),
            "Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor");
@@ -461,12 +451,6 @@ Tensor select(const Tensor& self, int64_t dim, int64_t index) {
   return self.as_strided(sizes, strides, storage_offset);
 }
 
-Tensor select_backward(const Tensor& grad, IntArrayRef input_sizes, int64_t dim, int64_t index) {
-  auto grad_input = at::zeros(input_sizes, grad.options());
-  grad_input.select(dim, index).copy_(grad);
-  return grad_input;
-}
-
 Tensor slice(const Tensor& self, int64_t dim, int64_t start, int64_t end, int64_t step) {
   int64_t ndim = self.dim();
   if (ndim == 0) {
@@ -500,12 +484,6 @@ Tensor slice(const Tensor& self, int64_t dim, int64_t start, int64_t end, int64_
   return self.as_strided(sizes, strides, storage_offset);
 }
 
-Tensor slice_backward(const Tensor& grad, IntArrayRef input_sizes, int64_t dim, int64_t start, int64_t end, int64_t step) {
-  auto grad_input = at::zeros(input_sizes, grad.options());
-  grad_input.slice(dim, start, end, step).copy_(grad);
-  return grad_input;
-}
-
 std::vector<Tensor> split(const Tensor& self, int64_t split_size, int64_t dim) {
   AT_CHECK(self.dim() != 0, "split expects at least a 1-dimensional tensor");
   AT_CHECK(split_size >= 0,  "split expects split_size be non-negative, but got split_size=", split_size);
@@ -712,28 +690,6 @@ Tensor squeeze(const Tensor& self) {
   return self.as_strided(std::get<0>(g), std::get<1>(g));
 }
 
-Tensor unsqueeze_to(const Tensor & self, IntArrayRef sizes) {
-  auto result = self;
-
-  int64_t nDims = sizes.size();
-  for (int64_t dim = 0; dim < nDims; dim++) {
-    if (sizes[dim] == 1) {
-      result = result.unsqueeze(dim);
-    }
-  }
-  return result;
-}
-
-Tensor unsqueeze_to(const Tensor & self, int64_t dim, IntArrayRef sizes) {
-  dim = at::maybe_wrap_dim(dim, sizes.size());
-  // in NumPy it's not an error to unsqueeze a scalar, but we still need to avoided
-  // unsqueezing in the backward.
-  if (sizes.size() > 0 && sizes[dim] == 1) {
-    return self.unsqueeze(dim);
-  }
-  return self;
-}
-
 Tensor squeeze(const Tensor& self, int64_t dim) {
   int64_t dims = self.dim();
   dim = maybe_wrap_dim(dim, dims);
index 71e7a7b..399c654 100644 (file)
   dispatch:
     CUDA: _cudnn_init_dropout_state
 
-- func: index_select_backward(Tensor grad, int64_t dim, Tensor indices, int[] sizes, bool keepdim) -> Tensor
-
-- func: select_backward(Tensor grad, int[] input_sizes, int64_t dim, int64_t index) -> Tensor
-
 - func: _fused_dropout(Tensor self, float p, Generator? generator=None) -> (Tensor, Tensor)
   matches_jit_signature: True
   variants: function
 - func: logsumexp(Tensor self, int[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
   matches_jit_signature: True
 
-- func: logsumexp_backward(Tensor grad, Tensor self, Tensor res, int[1] dim, bool keepdim) -> Tensor
-  matches_jit_signature: True
-
 - func: margin_ranking_loss(Tensor input1, Tensor input2, Tensor target, float margin=0.0, int reduction=Mean) -> Tensor
   matches_jit_signature: True
 
 - func: mean(Tensor self, int[1] dim, *, ScalarType dtype, Tensor(a!) out) -> Tensor(a!)
   matches_jit_signature: True
 
-- func: sum_backward(Tensor grad, int[] sizes, int[] dims, bool keepdim) -> Tensor
-  matches_jit_signature: True
-
 - func: median(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)
   matches_jit_signature: True
   variants: function, method
   matches_jit_signature: True
   variants: method  # This is method-only to match the previous tensor API. In the future we could make this a function too.
 
-- func: permute_backwards(Tensor grad, int[] fwd_dims) -> Tensor
-  matches_jit_signature: True
-
 - func: pixel_shuffle(Tensor self, int upscale_factor) -> Tensor
   matches_jit_signature: True
 
   variants: function, method
   device_guard: False
 
-- func: _safe_size(int[] sizes, int[] dim) -> int64_t
-
 - func: slice(Tensor(a) self, int dim=0, int start=0, int end=9223372036854775807, int step=1) -> Tensor(a)
   matches_jit_signature: True
   variants: function, method
   device_guard: False
 
-- func: slice_backward(Tensor grad, int[] input_sizes, int64_t dim, int64_t start, int64_t end, int64_t step) -> Tensor
-
 - func: slogdet(Tensor self) -> (Tensor, Tensor)
   matches_jit_signature: True
   variants: function, method
   variants: function, method
   device_guard: False
 
-- func: unsqueeze_to(Tensor self, int[] sizes) -> Tensor
-  matches_jit_signature: True
-
-- func: unsqueeze_to(Tensor self, int64_t dim, int[] sizes) -> Tensor
-
 - func: squeeze_(Tensor(a!) self) -> Tensor(a!)
   matches_jit_signature: True
   variants: method
 - func: var(Tensor self, int[1] dim, bool unbiased=True, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
   matches_jit_signature: True
 
-- func: var_backward(Tensor grad, Tensor self, bool unbiased) -> Tensor
-  matches_jit_signature: True
-
-- func: var_backward(Tensor grad, Tensor self, int[] dim, bool unbiased, bool keepdim) -> Tensor
-  matches_jit_signature: True
-
 - func: view_as(Tensor self, Tensor other) -> Tensor
   matches_jit_signature: True
   variants: method
index 5dd7a65..6a064b9 100644 (file)
@@ -848,24 +848,7 @@ void testDifferentiate(std::ostream& out = std::cout) {
 
   auto grad_spec = differentiate(graph);
   std::vector<size_t> expected_captured_inputs = {0, 1};
-  // With add/mul implemented using torchscript, we passes sizes of
-  // self & other instead passing the tensors themselve.
-  // The forward graph is now
-  //graph(%0 : Float(2, 3, 4)
-  //      %1 : Float(2, 3, 4)) {
-  //  %2 : Float(2, 3, 4) = aten::mul(%0, %1)
-  //  %self_size.4 : int[] = aten::size(%0)
-  //  %other_size.4 : int[] = aten::size(%1)
-  //  %3 : Float(2, 3, 4) = aten::mul(%2, %0)
-  //  %self_size.2 : int[] = aten::size(%2)
-  //  %4 : int = prim::Constant[value=1]()
-  //  %7 : int[] = aten::size(%3)
-  //  %5 : Float(2, 3, 4) = aten::add(%3, %1, %4)
-  //  return (%5, %2, %self_size.4, %other_size.4, %self_size.2, %7);
-  //}
-  // Thus all the sizes info added in forward outputs are saved
-  // in grad_spec.df_input_caputered_outputs.
-  std::vector<size_t> expected_captured_outputs = {1, 2, 3, 4, 5};
+  std::vector<size_t> expected_captured_outputs = {1, 2};
   std::vector<size_t> expected_input_vjps = {0, 1};
   std::vector<size_t> expected_output_vjps = {0, 1};
   ASSERT_EQ(grad_spec.f_real_outputs, 1);
@@ -897,29 +880,12 @@ void testDifferentiateWithRequiresGrad(std::ostream& out = std::cout) {
   PropagateInputShapes(graph);
   PropagateRequiresGrad(graph);
 
-  // With add/mul implemented using torchscript, we passes sizes of
-  // self & other instead passing the tensors themselve.
-  // The forward graph is now
-  // graph(%0 : Float(*)
-  //       %1 : Float(*)) {
-  //   %2 : Float(*) = aten::mul(%1, %1)
-  //   %3 : int = prim::Constant[value=1]()
-  //   %4 : Float(*) = aten::add(%2, %1, %3)
-  //   %39 : int[] = aten::size(%0)
-  //   %6 : Float(*) = aten::add(%4, %0, %3)
-  //   %7 : Float(*) = aten::mul(%6, %0)
-  //   %self_size.2 : int[] = aten::size(%6)
-  //   %11 : int[] = aten::size(%7)
-  //   %9 : Float(*) = aten::add(%7, %1, %3)
-  //   return (%4, %9, %39, %6, %self_size.2, %11);
-  // }
-
   auto grad_spec = differentiate(graph);
-  std::vector<size_t> expected_input_vjps = {1, 3}; // for e and %6 = (d + a)
+  std::vector<size_t> expected_input_vjps = {1, 2}; // for e and %4 = (d + a)
   std::vector<size_t> expected_output_vjps = {0}; // only a requires grad
   ASSERT_EQ(grad_spec.f_real_outputs, 2);
   ASSERT_EQ(grad_spec.df_input_captured_inputs, std::vector<size_t>({0}));
-  ASSERT_EQ(grad_spec.df_input_captured_outputs, std::vector<size_t>({2, 3, 4, 5}));
+  ASSERT_EQ(grad_spec.df_input_captured_outputs, std::vector<size_t>({2, 3}));
   ASSERT_EQ(grad_spec.df_input_vjps, expected_input_vjps);
   ASSERT_EQ(grad_spec.df_output_vjps, expected_output_vjps);
   out << "testDifferentiateWithRequiresGrad\n";
index dda0c4f..77f8d4c 100644 (file)
@@ -22,35 +22,35 @@ graph(%0 : Float(*, *),
       %forgetgate : Float(*, *),
       %cellgate : Float(*, *),
       %outgate : Float(*, *),
-      %self_size.5 : int[],
-      %other_size.5 : int[],
-      %self_size.3 : int[],
-      %other_size.3 : int[],
-      %28 : int[],
-      %29 : int[],
-      %30 : Float(*, *),
-      %self_size.1 : int[],
-      %other_size.1 : int[]):
-  %33 : int = prim::Constant[value=1]()
-  %34 : Tensor = prim::FusionGroup_0(%outgate, %0, %30, %self_size.1)
-  %grad_other.5 : Tensor, %36 : Tensor, %37 : Tensor, %38 : Tensor = prim::FusionGroup_1(%forgetgate, %9, %ingate, %cellgate, %1, %30, %0, %outgate, %other_size.5, %self_size.5, %28, %other_size.3, %self_size.3, %29, %other_size.1)
+      %24 : int[],
+      %25 : int[],
+      %26 : Float(*, *)):
+  %27 : int = prim::Constant[value=1]()
+  %28 : int[] = aten::size(%outgate)
+  %29 : int[] = aten::size(%26)
+  %30 : int[] = aten::size(%ingate)
+  %31 : int[] = aten::size(%cellgate)
+  %32 : int[] = aten::size(%forgetgate)
+  %33 : int[] = aten::size(%9)
+  %34 : Tensor = prim::FusionGroup_0(%outgate, %0, %26, %28)
+  %grad_other.5 : Tensor, %36 : Tensor, %37 : Tensor, %38 : Tensor = prim::FusionGroup_1(%forgetgate, %9, %ingate, %cellgate, %1, %26, %0, %outgate, %33, %32, %24, %31, %30, %25, %29)
   %39 : Tensor[] = prim::ListConstruct(%38, %36, %37, %34)
-  %40 : Tensor = aten::cat(%39, %33)
+  %40 : Tensor = aten::cat(%39, %27)
   %41 : Tensor = aten::_grad_sum_to_size(%40, %19)
   %42 : Tensor = aten::_grad_sum_to_size(%40, %17)
   %43 : Tensor = aten::_grad_sum_to_size(%40, %14)
   %44 : Tensor = aten::_grad_sum_to_size(%40, %15)
   %45 : Float(*, *) = aten::t(%13)
-  %46 : Float(*, *) = aten::mm(%44, %45)
+  %grad_self.7 : Float(*, *) = aten::mm(%44, %45)
   %47 : Float(*, *) = aten::t(%10)
-  %48 : Float(*, *) = aten::mm(%47, %44)
-  %grad_self.7 : Float(*, *) = aten::t(%48)
+  %grad_mat2.1 : Float(*, *) = aten::mm(%47, %44)
+  %grad_self.9 : Float(*, *) = aten::t(%grad_mat2.1)
   %50 : Float(*, *) = aten::t(%12)
-  %51 : Float(*, *) = aten::mm(%43, %50)
+  %grad_self.11 : Float(*, *) = aten::mm(%43, %50)
   %52 : Float(*, *) = aten::t(%11)
-  %53 : Float(*, *) = aten::mm(%52, %43)
-  %grad_self.9 : Float(*, *) = aten::t(%53)
-  return (%grad_other.5, %41, %42, %46, %grad_self.7, %51, %grad_self.9)
+  %grad_mat2.3 : Float(*, *) = aten::mm(%52, %43)
+  %grad_self.13 : Float(*, *) = aten::t(%grad_mat2.3)
+  return (%grad_other.5, %41, %42, %grad_self.7, %grad_self.9, %grad_self.11, %grad_self.13)
 with prim::FusionGroup_0 = graph(%0 : Float(*, *),
       %1 : Float(*, *),
       %2 : Float(*, *),
index 414a041..e53deb8 100644 (file)
@@ -28,16 +28,14 @@ with prim::DifferentiableGraph_0 = graph(%0 : Float(*, *),
   %17 : Tensor, %18 : Tensor, %19 : Tensor, %20 : Tensor = prim::ListUnpack(%16)
   %21 : int[] = prim::BroadcastSizes(%11, %12)
   %22 : int[] = prim::BroadcastSizes(%21, %13)
-  %other_size.6 : int[] = aten::size(%0)
-  %hy : Float(*, *), %25 : Float(*, *), %cy : Float(*, *), %outgate.1 : Float(*, *), %cellgate.1 : Float(*, *), %forgetgate.1 : Float(*, *), %ingate.1 : Float(*, *) = prim::FusionGroup_0(%0, %20, %19, %18, %17)
-  %31 : int[] = aten::size(%25)
-  %32 : int[] = aten::size(%outgate.1)
-  %33 : int[] = aten::size(%cellgate.1)
-  %34 : int[] = aten::size(%forgetgate.1)
-  %35 : int[] = aten::size(%ingate.1)
-  %36 : int[] = prim::BroadcastSizes(%34, %other_size.6)
-  %37 : int[] = prim::BroadcastSizes(%35, %33)
-  return (%hy, %cy, %7, %9, %11, %12, %21, %13, %22, %14, %ingate.1, %forgetgate.1, %cellgate.1, %outgate.1, %34, %other_size.6, %35, %33, %36, %37, %25, %32, %31)
+  %hy : Float(*, *), %24 : Float(*, *), %cy : Float(*, *), %outgate.1 : Float(*, *), %cellgate.1 : Float(*, *), %forgetgate.1 : Float(*, *), %ingate.1 : Float(*, *) = prim::FusionGroup_0(%0, %20, %19, %18, %17)
+  %30 : int[] = aten::size(%0)
+  %31 : int[] = aten::size(%cellgate.1)
+  %32 : int[] = aten::size(%forgetgate.1)
+  %33 : int[] = aten::size(%ingate.1)
+  %34 : int[] = prim::BroadcastSizes(%32, %30)
+  %35 : int[] = prim::BroadcastSizes(%33, %31)
+  return (%hy, %cy, %7, %9, %11, %12, %21, %13, %22, %14, %ingate.1, %forgetgate.1, %cellgate.1, %outgate.1, %34, %35, %24)
 with prim::FusionGroup_0 = graph(%0 : Float(*, *),
       %1 : Tensor,
       %2 : Tensor,
index 590d427..711fb9d 100644 (file)
@@ -17,50 +17,50 @@ graph(%0 : Float(*, *),
       %Wx : Float(*, *),
       %Uz : Float(*, *),
       %18 : Float(*, *),
-      %self_size.13 : int[],
-      %other_size.13 : int[],
-      %self_size.11 : int[],
-      %other_size.11 : int[],
-      %self_size.9 : int[],
+      %19 : int[],
+      %20 : int[],
+      %21 : int[],
+      %22 : int[],
+      %23 : int[],
       %24 : int[],
-      %25 : int[],
-      %self_size.7 : int[],
-      %27 : int[],
-      %28 : int[],
-      %29 : int[],
-      %30 : int[],
       %ingate : Float(*, *),
       %forgetgate : Float(*, *),
       %cellgate : Float(*, *),
       %outgate : Float(*, *),
-      %self_size.5 : int[],
-      %self_size.3 : int[],
-      %other_size.3 : int[],
-      %38 : int[],
-      %39 : int[],
-      %40 : Float(*, *),
-      %self_size.1 : int[],
-      %other_size.1 : int[]):
-  %43 : int = prim::Constant[value=1]()
-  %44 : Tensor = prim::FusionGroup_0(%outgate, %0, %40, %self_size.1)
-  %45 : Tensor, %46 : Tensor, %47 : Tensor = prim::FusionGroup_1(%10, %ingate, %cellgate, %1, %40, %0, %outgate, %forgetgate, %self_size.5, %38, %other_size.3, %self_size.3, %39, %other_size.1)
-  %48 : Tensor[] = prim::ListConstruct(%47, %45, %46, %44)
-  %49 : Tensor = aten::cat(%48, %43)
-  %50 : Tensor = aten::_grad_sum_to_size(%49, %30)
-  %51 : Tensor = aten::_grad_sum_to_size(%49, %28)
-  %grad_self.7 : Tensor = prim::FusionGroup_2(%51, %Uz, %self_size.7)
-  %53 : Tensor = aten::_grad_sum_to_size(%49, %24)
-  %54 : Tensor = aten::_grad_sum_to_size(%49, %25)
-  %grad_self.9 : Tensor = prim::FusionGroup_3(%54, %Wx, %self_size.9)
-  %56 : Tensor = prim::FusionGroup_4(%53, %18, %51, %11, %other_size.11)
-  %grad_self.13 : Tensor, %58 : Tensor = prim::FusionGroup_5(%Wx, %13, %53, %Uz, %54, %12, %self_size.13, %other_size.13, %self_size.11)
+      %29 : int[],
+      %30 : int[],
+      %31 : Float(*, *)):
+  %32 : int = prim::Constant[value=1]()
+  %33 : int[] = aten::size(%outgate)
+  %34 : int[] = aten::size(%31)
+  %35 : int[] = aten::size(%ingate)
+  %36 : int[] = aten::size(%cellgate)
+  %37 : int[] = aten::size(%forgetgate)
+  %38 : Tensor = prim::FusionGroup_0(%outgate, %0, %31, %33)
+  %39 : Tensor, %40 : Tensor, %41 : Tensor = prim::FusionGroup_1(%10, %ingate, %cellgate, %1, %31, %0, %outgate, %forgetgate, %37, %29, %36, %35, %30, %34)
+  %42 : Tensor[] = prim::ListConstruct(%41, %39, %40, %38)
+  %43 : Tensor = aten::cat(%42, %32)
+  %44 : Tensor = aten::_grad_sum_to_size(%43, %24)
+  %45 : Tensor = aten::_grad_sum_to_size(%43, %22)
+  %46 : int[] = aten::size(%11)
+  %grad_self.7 : Tensor = prim::FusionGroup_2(%45, %Uz, %46)
+  %48 : int[] = aten::size(%Uz)
+  %49 : Tensor = aten::_grad_sum_to_size(%43, %19)
+  %50 : Tensor = aten::_grad_sum_to_size(%43, %20)
+  %51 : int[] = aten::size(%12)
+  %grad_self.9 : Tensor = prim::FusionGroup_3(%50, %Wx, %51)
+  %53 : int[] = aten::size(%Wx)
+  %54 : int[] = aten::size(%18)
+  %55 : Tensor = prim::FusionGroup_4(%49, %18, %45, %11, %48)
+  %56 : int[] = aten::size(%13)
+  %grad_self.13 : Tensor, %58 : Tensor = prim::FusionGroup_5(%Wx, %13, %49, %Uz, %50, %12, %56, %53, %54)
   %59 : Float(*, *) = aten::t(%14)
-  %60 : Float(*, *) = aten::mm(%59, %56)
-  %grad_self.15 : Float(*, *) = aten::t(%60)
+  %grad_mat2.1 : Float(*, *) = aten::mm(%59, %55)
+  %grad_self.17 : Float(*, *) = aten::t(%grad_mat2.1)
   %62 : Float(*, *) = aten::t(%15)
-  %63 : Float(*, *) = aten::mm(%62, %58)
-  %grad_self.17 : Float(*, *) = aten::t(%63)
-  return (%50, %grad_self.7, %grad_self.9, %grad_self.13, %grad_self.15, %grad_self.17)
+  %grad_mat2.3 : Float(*, *) = aten::mm(%62, %58)
+  %grad_self.21 : Float(*, *) = aten::t(%grad_mat2.3)
+  return (%44, %grad_self.7, %grad_self.9, %grad_self.13, %grad_self.17, %grad_self.21)
 with prim::FusionGroup_0 = graph(%0 : Float(*, *),
       %1 : Float(*, *),
       %2 : Float(*, *),
index 019db7d..a8b3b77 100644 (file)
@@ -24,31 +24,28 @@ with prim::DifferentiableGraph_0 = graph(%0 : Float(*, *),
   %11 : Float(*, *) = aten::t(%6)
   %Uz.1 : Float(*, *) = aten::mm(%5, %11)
   %13 : Float(*, *) = aten::mul(%4, %Wx.1)
-  %self_size.14 : int[] = aten::size(%4)
-  %other_size.14 : int[] = aten::size(%Wx.1)
-  %self_size.12 : int[] = aten::size(%13)
-  %other_size.12 : int[] = aten::size(%Uz.1)
-  %self_size.10 : int[] = aten::size(%3)
-  %self_size.8 : int[] = aten::size(%2)
-  %20 : int[] = aten::size(%1)
-  %21 : Tensor[] = prim::ListConstruct(%1, %2, %Uz.1, %13, %3, %Wx.1)
-  %22 : Tensor[] = aten::broadcast_tensors(%21)
-  %23 : Tensor, %24 : Tensor, %25 : Tensor, %26 : Tensor, %27 : Tensor, %28 : Tensor = prim::ListUnpack(%22)
-  %29 : int[] = prim::BroadcastSizes(%self_size.10, %other_size.14)
-  %30 : int[] = prim::BroadcastSizes(%self_size.12, %other_size.12)
-  %31 : int[] = prim::BroadcastSizes(%self_size.8, %other_size.12)
-  %32 : int[] = prim::BroadcastSizes(%30, %29)
-  %33 : int[] = prim::BroadcastSizes(%32, %31)
-  %hy : Float(*, *), %35 : Float(*, *), %cy : Float(*, *), %outgate.1 : Float(*, *), %cellgate.1 : Float(*, *), %forgetgate.1 : Float(*, *), %ingate.1 : Float(*, *) = prim::FusionGroup_0(%0, %28, %27, %26, %25, %24, %23)
-  %41 : int[] = aten::size(%0)
-  %42 : int[] = aten::size(%35)
-  %43 : int[] = aten::size(%outgate.1)
-  %44 : int[] = aten::size(%cellgate.1)
-  %45 : int[] = aten::size(%forgetgate.1)
-  %46 : int[] = aten::size(%ingate.1)
-  %47 : int[] = prim::BroadcastSizes(%45, %41)
-  %48 : int[] = prim::BroadcastSizes(%46, %44)
-  return (%hy, %cy, %Wx.1, %Uz.1, %13, %self_size.14, %other_size.14, %self_size.12, %other_size.12, %self_size.10, %30, %29, %self_size.8, %32, %31, %33, %20, %ingate.1, %forgetgate.1, %cellgate.1, %outgate.1, %45, %46, %44, %47, %48, %35, %43, %42)
+  %14 : int[] = aten::size(%1)
+  %15 : Tensor[] = prim::ListConstruct(%1, %2, %Uz.1, %13, %3, %Wx.1)
+  %16 : Tensor[] = aten::broadcast_tensors(%15)
+  %17 : Tensor, %18 : Tensor, %19 : Tensor, %20 : Tensor, %21 : Tensor, %22 : Tensor = prim::ListUnpack(%16)
+  %23 : int[] = aten::size(%3)
+  %24 : int[] = aten::size(%Wx.1)
+  %25 : int[] = prim::BroadcastSizes(%23, %24)
+  %26 : int[] = aten::size(%13)
+  %27 : int[] = aten::size(%Uz.1)
+  %28 : int[] = prim::BroadcastSizes(%26, %27)
+  %29 : int[] = aten::size(%2)
+  %30 : int[] = prim::BroadcastSizes(%29, %27)
+  %31 : int[] = prim::BroadcastSizes(%28, %25)
+  %32 : int[] = prim::BroadcastSizes(%31, %30)
+  %hy : Float(*, *), %34 : Float(*, *), %cy : Float(*, *), %outgate.1 : Float(*, *), %cellgate.1 : Float(*, *), %forgetgate.1 : Float(*, *), %ingate.1 : Float(*, *) = prim::FusionGroup_0(%0, %22, %21, %20, %19, %18, %17)
+  %40 : int[] = aten::size(%0)
+  %41 : int[] = aten::size(%cellgate.1)
+  %42 : int[] = aten::size(%forgetgate.1)
+  %43 : int[] = aten::size(%ingate.1)
+  %44 : int[] = prim::BroadcastSizes(%42, %40)
+  %45 : int[] = prim::BroadcastSizes(%43, %41)
+  return (%hy, %cy, %Wx.1, %Uz.1, %13, %28, %25, %31, %30, %32, %14, %ingate.1, %forgetgate.1, %cellgate.1, %outgate.1, %44, %45, %34)
 with prim::FusionGroup_0 = graph(%0 : Float(*, *),
       %1 : Tensor,
       %2 : Tensor,
index 5e465c3..e15ab3c 100644 (file)
@@ -85,48 +85,47 @@ testDifferentiate
 graph(%0 : Float(2, 3, 4),
       %1 : Float(2, 3, 4)):
   %2 : Float(2, 3, 4) = aten::mul(%0, %1)
-  %self_size.4 : int[] = aten::size(%0)
-  %other_size.4 : int[] = aten::size(%1)
   %3 : Float(2, 3, 4) = aten::mul(%2, %0)
-  %self_size.2 : int[] = aten::size(%2)
   %4 : int = prim::Constant[value=1]()
   %7 : int[] = aten::size(%3)
   %5 : Float(2, 3, 4) = aten::add(%3, %1, %4)
-  return (%5, %2, %self_size.4, %other_size.4, %self_size.2, %7)
+  return (%5, %2, %7)
 graph(%0 : Float(2, 3, 4),
       %1 : Float(2, 3, 4),
       %2 : Float(2, 3, 4),
       %3 : Float(2, 3, 4),
       %4 : Float(2, 3, 4),
-      %self_size.3 : int[],
-      %other_size.3 : int[],
-      %self_size.1 : int[],
-      %8 : int[]):
-  %9 : int = prim::Constant[value=1]()
-  %10 : Tensor, %11 : Tensor = prim::GradOf[name="aten::add"](%0)
+      %5 : int[]):
+  %7 : int = prim::Constant[value=1]()
+  %6 : int[] = aten::size(%3)
+  %8 : Tensor, %9 : Tensor = prim::GradOf[name="aten::add"](%0)
     block0():
-      %12 : Tensor = aten::_grad_sum_to_size(%0, %8)
-      %13 : Float(2, 3, 4) = aten::mul(%0, %9)
-      %14 : Tensor = aten::_grad_sum_to_size(%13, %other_size.3)
-      -> (%12, %14)
-  %grad_self.2 : Tensor, %grad_other.2 : Tensor = prim::GradOf[name="aten::mul"](%10)
+      %10 : Tensor = aten::_grad_sum_to_size(%0, %5)
+      %11 : Float(2, 3, 4) = aten::mul(%0, %7)
+      %12 : Tensor = aten::_grad_sum_to_size(%11, %6)
+      -> (%10, %12)
+  %grad_self.2 : Tensor, %grad_other.2 : Tensor = prim::GradOf[name="aten::mul"](%8)
     block0():
-      %17 : Tensor = aten::mul(%10, %2)
-      %grad_self.1 : Tensor = aten::_grad_sum_to_size(%17, %self_size.1)
-      %19 : Tensor = aten::mul(%10, %4)
-      %grad_other.1 : Tensor = aten::_grad_sum_to_size(%19, %self_size.3)
+      %15 : Tensor = aten::mul(%8, %2)
+      %16 : int[] = aten::size(%4)
+      %grad_self.1 : Tensor = aten::_grad_sum_to_size(%15, %16)
+      %18 : Tensor = aten::mul(%8, %4)
+      %19 : int[] = aten::size(%2)
+      %grad_other.1 : Tensor = aten::_grad_sum_to_size(%18, %19)
       -> (%grad_self.1, %grad_other.1)
   %21 : Tensor = prim::AutogradAdd(%1, %grad_self.2)
   %grad_self : Tensor, %grad_other : Tensor = prim::GradOf[name="aten::mul"](%21)
     block0():
       %24 : Tensor = aten::mul(%21, %3)
-      %grad_self.3 : Tensor = aten::_grad_sum_to_size(%24, %self_size.3)
-      %26 : Tensor = aten::mul(%21, %2)
-      %grad_other.3 : Tensor = aten::_grad_sum_to_size(%26, %other_size.3)
+      %25 : int[] = aten::size(%2)
+      %grad_self.3 : Tensor = aten::_grad_sum_to_size(%24, %25)
+      %27 : Tensor = aten::mul(%21, %2)
+      %28 : int[] = aten::size(%3)
+      %grad_other.3 : Tensor = aten::_grad_sum_to_size(%27, %28)
       -> (%grad_self.3, %grad_other.3)
-  %28 : Tensor = prim::AutogradAdd(%grad_other.2, %grad_self)
-  %29 : Tensor = prim::AutogradAdd(%11, %grad_other)
-  return (%28, %29)
+  %30 : Tensor = prim::AutogradAdd(%grad_other.2, %grad_self)
+  %31 : Tensor = prim::AutogradAdd(%9, %grad_other)
+  return (%30, %31)
 
 testDifferentiateWithRequiresGrad
 graph(%0 : Float(*),
@@ -134,38 +133,37 @@ graph(%0 : Float(*),
   %2 : Float(*) = aten::mul(%1, %1)
   %3 : int = prim::Constant[value=1]()
   %4 : Float(*) = aten::add(%2, %1, %3)
-  %39 : int[] = aten::size(%0)
   %6 : Float(*) = aten::add(%4, %0, %3)
   %7 : Float(*) = aten::mul(%6, %0)
-  %self_size.2 : int[] = aten::size(%6)
   %11 : int[] = aten::size(%7)
   %9 : Float(*) = aten::add(%7, %1, %3)
-  return (%4, %9, %39, %6, %self_size.2, %11)
+  return (%4, %9, %6, %11)
 graph(%0 : Float(*),
       %1 : Float(*),
       %2 : Float(*),
-      %3 : int[],
-      %4 : Float(*),
-      %self_size.1 : int[],
-      %6 : int[]):
-  %7 : int = prim::Constant[value=1]()
-  %8 : Tensor = prim::GradOf[name="aten::add"](%0)
+      %3 : Float(*),
+      %4 : int[]):
+  %6 : int = prim::Constant[value=1]()
+  %5 : int[] = aten::size(%2)
+  %7 : Tensor = prim::GradOf[name="aten::add"](%0)
     block0():
-      %9 : Tensor = aten::_grad_sum_to_size(%0, %6)
-      -> (%9)
-  %grad_self : Tensor, %grad_other : Tensor = prim::GradOf[name="aten::mul"](%8)
+      %8 : Tensor = aten::_grad_sum_to_size(%0, %4)
+      -> (%8)
+  %grad_self : Tensor, %grad_other : Tensor = prim::GradOf[name="aten::mul"](%7)
     block0():
-      %12 : Tensor = aten::mul(%8, %2)
-      %grad_self.1 : Tensor = aten::_grad_sum_to_size(%12, %self_size.1)
-      %14 : Tensor = aten::mul(%8, %4)
-      %grad_other.1 : Tensor = aten::_grad_sum_to_size(%14, %3)
+      %11 : Tensor = aten::mul(%7, %2)
+      %12 : int[] = aten::size(%3)
+      %grad_self.1 : Tensor = aten::_grad_sum_to_size(%11, %12)
+      %14 : Tensor = aten::mul(%7, %3)
+      %15 : int[] = aten::size(%2)
+      %grad_other.1 : Tensor = aten::_grad_sum_to_size(%14, %15)
       -> (%grad_self.1, %grad_other.1)
-  %16 : Tensor = prim::AutogradAdd(%1, %grad_self)
-  %17 : Tensor = prim::GradOf[name="aten::add"](%16)
+  %17 : Tensor = prim::AutogradAdd(%1, %grad_self)
+  %18 : Tensor = prim::GradOf[name="aten::add"](%17)
     block0():
-      %18 : Tensor = aten::mul(%16, %7)
-      %19 : Tensor = aten::_grad_sum_to_size(%18, %3)
-      -> (%19)
-  %20 : Tensor = prim::AutogradAdd(%grad_other, %17)
-  return (%20)
+      %19 : Tensor = aten::mul(%17, %6)
+      %20 : Tensor = aten::_grad_sum_to_size(%19, %5)
+      -> (%20)
+  %21 : Tensor = prim::AutogradAdd(%grad_other, %18)
+  return (%21)
 
index c5fdf84..222b295 100644 (file)
@@ -11124,12 +11124,19 @@ EXCLUDE_SCRIPT_MODULES = {
 
 DISABLE_AUTODIFF_SUBGRAPH_INLINING = {
     'test_nn_avg_pool2d',
+    'test_nn_adaptive_avg_pool1d',
     'test_nn_adaptive_avg_pool2d',
+    'test_nn_adaptive_avg_pool3d',
     'test_nn_batch_norm',
     'test_nn_embedding',
     'test_nn_log_softmax',
+    'test_nn_softmax',
+    'test_nn_softmax_with_all_args',
     'test_nn_threshold',
     'test_nn_nll_loss',
+    # Should have added all test_nn_interpolate_* here,
+    # but it's using autodiff since its subgraph is over
+    # 2 nodes.
 }
 
 
@@ -12531,7 +12538,6 @@ nn_functional_tests = [
     ('multilabel_soft_margin_loss', (3, S), (non_differentiable(torch.rand(3, S)),),),
     ('cosine_embedding_loss', (S, S), ((S, S), non_differentiable(torch.rand(S,))),),
     ('pixel_shuffle', (1, 9, 4, 4), (3,),),
-    ('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2,),),
     ('affine_grid', (S, 2, 3), (torch.Size([S, 1, 7, 7]),),),
     ('pad', (3, 3, 4, 2), ([1, 1],),),
     ('pairwise_distance', (S, S), ((S, S),),),
@@ -12557,8 +12563,35 @@ nn_functional_tests = [
       torch.randint(1, S, (S,), dtype=torch.long))),
     ('upsample', torch.randn(S, S, M, M), (None, 2), 'with_scale'),
     ('upsample', torch.randn(S, S, M, M), (4,), 'with_size'),
-    ('interpolate', torch.randn(S, S, M, M), (None, 2.), 'with_scale'),
-    ('interpolate', torch.randn(S, S, M, M), (4,), 'with_size'),
+    ('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2,), 'nearest_4d'),
+    ('interpolate', torch.randn(S, S, M, M), (None, 2.), 'nearest_4d_with_scale'),
+    ('interpolate', torch.randn(S, S, M, M), (4,), 'nearest_4d_with_size'),
+    ('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2,), 'area_4d'),
+    ('interpolate', torch.randn(S, S, M, M), (None, 2.), 'area_4d_with_scale'),
+    ('interpolate', torch.randn(S, S, M, M), (4,), 'area_4d_with_size'),
+    ('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2,), 'bilinear_4d'),
+    ('interpolate', torch.randn(S, S, M, M), (None, 2.), 'bilinear_4d_with_scale'),
+    ('interpolate', torch.randn(S, S, M, M), (4,), 'bilinear_4d_with_size'),
+    ('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2,), 'bicubic_4d'),
+    ('interpolate', torch.randn(S, S, M, M), (None, 2.), 'bicubic_4d_with_scale'),
+    ('interpolate', torch.randn(S, S, M, M), (4,), 'bicubic_4d_with_size'),
+    ('interpolate', torch.zeros(3, 3).view(1, 3, 3), (2,), 'nearest_3d'),
+    ('interpolate', torch.randn(S, M, M), (None, 2.), 'nearest_3d_with_scale'),
+    ('interpolate', torch.randn(S, M, M), (4,), 'nearest_3d_with_size'),
+    ('interpolate', torch.zeros(3, 3).view(1, 3, 3), (2,), 'area_3d'),
+    ('interpolate', torch.randn(S, M, M), (None, 2.), 'area_3d_with_scale'),
+    ('interpolate', torch.randn(S, M, M), (4,), 'area_3d_with_size'),
+    ('interpolate', torch.zeros(3, 3).view(1, 3, 3), (2,), 'linear_3d'),
+    ('interpolate', torch.randn(S, M, M), (None, 2.), 'linear_3d_with_scale'),
+    ('interpolate', torch.randn(S, M, M), (4,), 'linear_3d_with_size'),
+    ('interpolate', torch.randn(S, M, M, M, M), (None, 2.), 'nearest_5d_with_scale'),
+    ('interpolate', torch.randn(S, M, M, M, M), (4,), 'nearest_5d_with_size'),
+    ('interpolate', torch.zeros(3, 3, 3).view(1, 1, 3, 3, 3), (2,), 'area_5d'),
+    ('interpolate', torch.randn(S, M, M, M, M), (None, 2.), 'area_5d_with_scale'),
+    ('interpolate', torch.randn(S, M, M, M, M), (4,), 'area_5d_with_size'),
+    ('interpolate', torch.zeros(3, 3, 3).view(1, 1, 3, 3, 3), (2,), 'trilinear_5d'),
+    ('interpolate', torch.randn(S, M, M, M, M), (None, 2.), 'trilinear_5d_with_scale'),
+    ('interpolate', torch.randn(S, M, M, M, M), (4,), 'trilinear_5d_with_size'),
 ]
 
 
index 6dac188..476414a 100644 (file)
   self: zeros_like(grad)
 
 - name: logsumexp(Tensor self, IntArrayRef dim, bool keepdim)
-  self: at::logsumexp_backward(grad, self, result, dim, keepdim)
+  self: logsumexp_backward(grad, self, result, dim, keepdim)
 
 - name: lt_(Tensor self, Scalar other)
   self: zeros_like(self)
   self: grad.expand(self.sizes()).to(self.type().scalarType()) / self.numel()
 
 - name: mean(Tensor self, IntArrayRef dim, bool keepdim)
-  self: sum_backward(grad, self.sizes(), dim, keepdim) / at::_safe_size(self.sizes(), dim)
+  self: sum_backward(grad, self.sizes(), dim, keepdim) / _safe_size(self.sizes(), dim)
 
 - name: mean(Tensor self, IntArrayRef dim, ScalarType dtype)
-  self: sum_backward(grad, self.sizes(), dim, false).to(self.type().scalarType()) / at::_safe_size(self.sizes(), dim)
+  self: sum_backward(grad, self.sizes(), dim, false).to(self.type().scalarType()) / _safe_size(self.sizes(), dim)
 
 - name: mean(Tensor self, IntArrayRef dim, bool keepdim, ScalarType dtype)
-  self: sum_backward(grad, self.sizes(), dim, keepdim).to(self.type().scalarType()) / at::_safe_size(self.sizes(), dim)
+  self: sum_backward(grad, self.sizes(), dim, keepdim).to(self.type().scalarType()) / _safe_size(self.sizes(), dim)
 
 - name: median(Tensor self)
   self: select_equals_backward(grad, self, result)
index 63b4416..d44ec5a 100644 (file)
@@ -76,6 +76,18 @@ Tensor maybe_multiply(const Tensor & t, const Scalar & s) {
   }
 }
 
+int64_t _safe_size(IntArrayRef sizes, IntArrayRef dim) {
+  int64_t size = 1;
+  if (sizes.size() == 0) {
+    return 1;
+  }
+  for (auto d : dim) {
+    d = at::maybe_wrap_dim(d, sizes.size());
+    size *= sizes[d];
+  }
+  return size;
+}
+
 Tensor norm_backward(const Tensor & grad, const Tensor & self, const optional<Scalar> & p_, const Tensor & norm) {
   double p = p_.value_or(2.0).toDouble();
   Tensor self_scaled;
@@ -148,6 +160,40 @@ Tensor mvlgamma_backward(Tensor grad, const Tensor & self, int64_t p) {
   return grad * args.digamma_().sum(-1);
 }
 
+Tensor permute_backwards(const Tensor & grad, IntArrayRef fwd_dims) {
+  // invert the permutation
+  auto ndims = fwd_dims.size();
+  std::vector<int64_t> dims(ndims);
+  for (size_t i = 0; i < ndims; i++) {
+    dims[at::maybe_wrap_dim(fwd_dims[i], ndims)] = i;
+  }
+  return grad.permute(dims);
+}
+
+Tensor unsqueeze_multiple(const Tensor & t, IntArrayRef dim, size_t n_dims) {
+    auto dims_to_unsqueeze = at::dim_list_to_bitset(dim, n_dims);
+    Tensor res = t;
+    for (size_t i = 0; i < n_dims; i++){
+      if (dims_to_unsqueeze[i]) {
+        res = res.unsqueeze(i);
+      }
+    }
+    return res;
+}
+
+Tensor sum_backward(const Tensor & grad, IntArrayRef sizes, IntArrayRef dims, bool keepdim) {
+  if (!keepdim && sizes.size() > 0) {
+    if (dims.size()==1) {
+      return grad.unsqueeze(dims[0]).expand(sizes);
+    } else {
+      Tensor res = unsqueeze_multiple(grad, dims, sizes.size());
+      return res.expand(sizes);
+    }
+  } else {
+    return grad.expand(sizes);
+  }
+}
+
 std::vector<int64_t> reverse_list(const IntArrayRef list) {
   auto result = std::vector<int64_t>();
   result.reserve(list.size());
@@ -383,13 +429,13 @@ Tensor cumsum_backward(const Tensor &x, int64_t dim, ScalarType input_dtype) {
   return cumsum_backward(x.to(input_dtype), dim);
 }
 
-//Tensor logsumexp_backward(Tensor grad, const Tensor & self, Tensor result, IntArrayRef dim, bool keepdim) {
-//  if (!keepdim && self.dim() != 0) {
-//    grad = unsqueeze_multiple(grad, dim, self.sizes().size());
-//    result = unsqueeze_multiple(result, dim, self.sizes().size());
-//  }
-//  return grad * (self - result).exp();
-//}
+Tensor logsumexp_backward(Tensor grad, const Tensor & self, Tensor result, IntArrayRef dim, bool keepdim) {
+  if (!keepdim && self.dim() != 0) {
+    grad = unsqueeze_multiple(grad, dim, self.sizes().size());
+    result = unsqueeze_multiple(result, dim, self.sizes().size());
+  }
+  return grad * (self - result).exp();
+}
 
 Tensor unbind_backward(const variable_list& grads, int64_t dim) {
   IntArrayRef sizes;
@@ -408,6 +454,28 @@ Tensor unbind_backward(const variable_list& grads, int64_t dim) {
   return at::stack(grads_tensors, dim);
 }
 
+Tensor unsqueeze_to(const Tensor & self, IntArrayRef sizes) {
+  auto result = self;
+
+  int64_t nDims = sizes.size();
+  for (int64_t dim = 0; dim < nDims; dim++) {
+    if (sizes[dim] == 1) {
+      result = result.unsqueeze(dim);
+    }
+  }
+  return result;
+}
+
+Tensor unsqueeze_to(const Tensor & self, int64_t dim, IntArrayRef sizes) {
+  dim = at::maybe_wrap_dim(dim, sizes.size());
+  // in NumPy it's not an error to unsqueeze a scalar, but we still need to avoided
+  // unsqueezing in the backward.
+  if (sizes.size() > 0 && sizes[dim] == 1) {
+    return self.unsqueeze(dim);
+  }
+  return self;
+}
+
 std::vector<Tensor> cat_tensors_backward(const Tensor & grad, const std::vector<std::vector<int64_t>> &sizes, int64_t dim) {
   dim = at::legacy_cat_wrap_dim(dim, sizes);
   std::vector<Tensor> grad_inputs(sizes.size());
@@ -538,6 +606,26 @@ Tensor select_equals_backward(Tensor grad, const Tensor & input, const Tensor &
   return grad_input;
 }
 
+Tensor index_select_backward(Tensor grad, int64_t dim, Tensor indices, IntArrayRef sizes, bool keepdim) {
+  if (!keepdim && sizes.size() > 0) {
+    grad = grad.unsqueeze(dim);
+    indices = indices.unsqueeze(dim);
+  }
+  return at::zeros(sizes, grad.options()).scatter_(dim, indices, grad);
+}
+
+Tensor slice_backward(Tensor grad, IntArrayRef input_sizes, int64_t dim, int64_t start, int64_t end, int64_t step) {
+  auto grad_input = at::zeros(input_sizes, grad.options());
+  grad_input.slice(dim, start, end, step).copy_(grad);
+  return grad_input;
+}
+
+Tensor select_backward(Tensor grad, IntArrayRef input_sizes, int64_t dim, int64_t index) {
+  auto grad_input = at::zeros(input_sizes, grad.options());
+  grad_input.select(dim, index).copy_(grad);
+  return grad_input;
+}
+
 Tensor trace_backward(const Tensor & grad, IntArrayRef sizes) {
   if (sizes.size() != 2) {
     throw std::runtime_error("expected matrix input");
@@ -563,6 +651,19 @@ Tensor unfold_backward(const Tensor & grad, IntArrayRef input_sizes, int64_t dim
   return grad_input.view(input_sizes);
 }
 
+Tensor var_backward(const Tensor & grad, const Tensor & self, bool unbiased) {
+  return (2.0 / (self.numel() - unbiased)) * grad * (self - self.mean());
+}
+
+Tensor var_backward(Tensor grad, const Tensor & self, IntArrayRef dim, bool unbiased, bool keepdim) {
+  if (self.dim() == 0) {
+    return var_backward(grad, self, unbiased);
+  }
+  if (!keepdim && self.dim() > 1) {
+    grad = unsqueeze_multiple(grad, dim, self.sizes().size());
+  }
+  return (2.0 / (_safe_size(self.sizes(), dim) - unbiased)) * grad * (self - self.mean(dim, true));
+}
 
 Tensor masked_scatter_backward(const Tensor & grad, const Tensor & mask, IntArrayRef sizes) {
   int64_t numel = 1;
index 5a3e8fa..2d4f41a 100644 (file)
@@ -27,12 +27,16 @@ void wrapDim(int64_t& dim, const std::vector<int64_t>& sizes) {
   }
 }
 
+// need_trim_grad_ops contains functions that return multiple outputs in
+// forward, but only the first one requires grad.
+// Example:
 // kthvalue returns (kthvalue, index of kthvalue), currently autodiff only
 // supports at most one output that requires grad. Thus we need to remove
 // the grad for index that doesn't require grad.
 bool needTrimGrad(Node* n) {
   static OperatorSet need_trim_grad_ops = {
       "aten::kthvalue(Tensor self, int k, int dim, bool keepdim) -> (Tensor, Tensor)",
+      "aten::topk(Tensor self, int k, int dim, bool largest, bool sorted) -> (Tensor, Tensor)",
   };
   if (need_trim_grad_ops.find(n)) {
     return true;
index 0089f99..4b091b2 100644 (file)
@@ -1177,6 +1177,46 @@ int listAdd(Stack& stack) {
   return 0;
 }
 
+template <class TList, class TElement>
+int listMulIntLeft(Stack& stack) {
+  TList list;
+  int64_t n;
+  pop(stack, list, n);
+
+  std::vector<TElement> ret;
+  const auto size = list->elements().size() * n;
+  ret.reserve(size);
+
+  for (auto i = 0; i < n; i++) {
+    for (const auto& e : list->elements()) {
+      ret.push_back(e);
+    }
+  }
+
+  push(stack, ret);
+  return 0;
+}
+
+template <class TList, class TElement>
+int listMulIntRight(Stack& stack) {
+  TList list;
+  int64_t n;
+  pop(stack, n, list);
+
+  std::vector<TElement> ret;
+  const auto size = list->elements().size() * n;
+  ret.reserve(size);
+
+  for (auto i = 0; i < n; i++) {
+    for (const auto& e : list->elements()) {
+      ret.push_back(e);
+    }
+  }
+
+  push(stack, ret);
+  return 0;
+}
+
 template <typename TList, typename TElement>
 int listSlice(Stack& stack) {
   TList list;
@@ -1415,10 +1455,17 @@ RegisterOperators reg2({
           "[] l, int start, int end=9223372036854775807, int step=1) -> " decl_type \
           "[]",                                                                     \
           listSlice<Shared<c_type>, c_type::ElemType>),                             \
-      Operator("aten::list(" decl_type "[] l) -> " decl_type "[]", listList)
+      Operator("aten::list(" decl_type "[] l) -> " decl_type "[]", listList),       \
+      Operator(                                                                     \
+          "aten::mul(" decl_type "[] l, int n) -> " decl_type "[]",                 \
+          listMulIntLeft<Shared<c_type>, c_type::ElemType>),                        \
+      Operator(                                                                     \
+          "aten::mul(int n, " decl_type "[] l) -> " decl_type "[]",                 \
+          listMulIntRight<Shared<c_type>, c_type::ElemType>)
 
     CREATE_LIST_OPS("int", IntList),
     CREATE_LIST_OPS("float", DoubleList),
+    CREATE_LIST_OPS("bool", BoolList),
     CREATE_LIST_OPS("Tensor", TensorList),
     CREATE_LIST_OPS("t", GenericList),
 #undef CREATE_LIST_OPS
index 394af9b..cff83d6 100644 (file)
@@ -7,6 +7,331 @@ std::mutex lock;
 const std::vector<std::string> functions = {
     R"(
 
+        ####     HELPER FUNCTIONS           ###
+        ####     PREFIX: AD_                ###
+        ####     SCHEMA NOT SAVED IN CACHE  ###
+
+        def AD_unsqueeze_multiple(t,
+                                  dims: List[int],
+                                  n_dims: int):
+            seen = [False] * n_dims
+            for i in range(len(dims)):
+                seen[dims[i]] = True
+
+            for d in range(n_dims):
+                if seen[d]:
+                    t = t.unsqueeze(d)
+            return t
+
+        def AD_sum_backward(grad,
+                            sizes: List[int],
+                            dims: List[int],
+                            keepdim: bool):
+            if not keepdim and len(sizes) > 0:
+                if len(dims) == 1:
+                    return grad.unsqueeze(dims[0]).expand(sizes)
+                else:
+                    res = AD_unsqueeze_multiple(grad, dims, len(sizes))
+                    return res.expand(sizes)
+            else:
+                return grad.expand(sizes)
+
+        def AD_logsumexp_backward(grad, self, result,
+                                  dim: List[int],
+                                  keepdim: bool):
+            if not keepdim and self.dim() != 0:
+                n_dims = len(self.size())
+                grad = AD_unsqueeze_multiple(grad, dim, n_dims)
+                result = AD_unsqueeze_multiple(result, dim, n_dims)
+            return grad * (self - result).exp()
+
+        def mean_0(self):
+            self_size = self.size()
+            self_numel = self.numel()
+            def backward(grad_output):
+                grad_self = grad_output.expand(self_size) / self_numel
+                return grad_self
+
+            return torch.mean(self), backward
+
+        def mean_1(self,
+                   dim: List[int],
+                   keepdim: bool):
+            self_size = self.size()
+            def backward(grad_output):
+                grad_self = AD_sum_backward(grad_output, self_size, dim, keepdim) / AD_safe_size(self_size, dim)
+                return grad_self, None, None
+
+            return torch.mean(self, dim, keepdim), backward
+
+        def logsumexp(self,
+                      dim: List[int],
+                      keepdim: bool):
+            result = torch.logsumexp(self, dim, keepdim)
+            self_dim = self.dim()
+            def backward(grad_output):
+                grad_self = AD_logsumexp_backward(grad_output, self, result, dim, keepdim)
+                return grad_self, None, None
+
+            return result, backward
+
+        def AD_bool_to_int(b: bool):
+            # FIXME: torchscript: int - bool
+            if b:
+                i = 1
+            else:
+                i = 0
+            return i
+
+        def AD_var_backward_0(grad, self, unbiased: bool):
+            b = AD_bool_to_int(unbiased)
+
+            # FIXME: torchscript: div(float, float)
+            return  grad * (self - self.mean()) * 2.0 / (self.numel() - b)
+
+        def AD_safe_size(sizes: List[int],
+                         dims: List[int]):
+            if len(sizes) == 0:
+                return 1
+
+            size = 1
+            for i in range(len(dims)):
+                d = dims[i]
+                size *= sizes[d]
+
+            return size
+
+        def AD_var_backward_1(grad,
+                              self,
+                              dim: List[int],
+                              unbiased: bool,
+                              keepdim: bool):
+            if self.dim() == 0:
+                return AD_var_backward_0(grad, self, unbiased)
+            self_size = self.size()
+            b = AD_bool_to_int(unbiased)
+            if not keepdim and self.dim() > 1:
+                grad = AD_unsqueeze_multiple(grad, dim, len(self_size))
+
+            # FIXME: torchscript: div(float, float)
+            return grad * (self - self.mean(dim, True)) * 2.0 / (AD_safe_size(self_size, dim) - b)
+
+        def std_0(self,
+                  unbiased: bool=True):
+            std_out = torch.std(self, unbiased)
+            def backward(grad_output):
+                grad_self = AD_var_backward_0(grad_output / (std_out * 2), self, unbiased)
+                return grad_self, None
+
+            return std_out, backward
+
+        def std_1(self,
+                  dim: List[int],
+                  unbiased: bool,
+                  keepdim: bool):
+            std_out = torch.std(self, dim, unbiased, keepdim)
+            def backward(grad_output):
+                grad_self = AD_var_backward_1(grad_output / (std_out * 2), self, dim, unbiased, keepdim)
+                return grad_self, None, None, None
+
+            return std_out, backward
+
+        def var_0(self,
+                  unbiased: bool=True):
+            def backward(grad_output):
+                grad_self = AD_var_backward_0(grad_output, self, unbiased)
+                return grad_self, None
+
+            return torch.var(self, unbiased), backward
+
+        def var_1(self,
+                  dim: List[int],
+                  unbiased: bool,
+                  keepdim: bool):
+            def backward(grad_output):
+                grad_self = AD_var_backward_1(grad_output, self, dim, unbiased, keepdim)
+                return grad_self, None, None, None
+
+            return torch.var(self, dim, unbiased, keepdim), backward
+
+        def AD_index_select_backward(grad,
+                                     dim: int,
+                                     indices,
+                                     sizes: List[int],
+                                     keepdim: bool):
+            if not keepdim and len(sizes) > 0:
+                grad = grad.unsqueeze(dim)
+                indices = indices.unsqueeze(dim)
+
+            # FIXME: torchscript: torch.zeros(sizes, grad.options())
+            return torch.zeros(sizes).to(grad).scatter_(dim, indices, grad)
+
+        def topk(self,
+                 k: int,
+                 dim: int = -1,
+                 largest: bool = True,
+                 sorted: bool = True):
+            result0, result1 = torch.topk(self, k, dim, largest, sorted)
+            self_size = self.size()
+            def backward(grad_output):
+                grad_self = AD_index_select_backward(grad_output, dim, result1, self_size, True)
+                return grad_self, None, None, None, None
+
+            return result0, result1, backward
+
+        def kthvalue(self,
+                     k: int,
+                     dim: int,
+                     keepdim: bool):
+            result0, result1 = torch.kthvalue(self, k, dim, keepdim)
+            self_size = self.size()
+            def backward(grad_output):
+                grad_self = AD_index_select_backward(grad_output, dim, result1, self_size, keepdim)
+                return grad_self, None, None, None
+
+            return result0, result1, backward
+
+        def AD_mm_backward_self(grad, mat2):
+            return grad.mm(mat2.t())
+
+        def AD_mm_backward_mat2(grad, self):
+            return self.t().mm(grad)
+
+        def mm(self, mat2):
+            def backward(grad_output):
+                grad_self = AD_mm_backward_self(grad_output, mat2)
+                grad_mat2 = AD_mm_backward_mat2(grad_output, self)
+                return grad_self, grad_mat2
+
+            return torch.mm(self, mat2), backward
+
+        def AD_permute_backward(grad,
+                                fwd_dims: List[int]):
+            ndims = len(fwd_dims)
+            dims = [0] * ndims
+
+            for i in range(ndims):
+                dims[fwd_dims[i]] = i
+
+            return grad.permute(dims)
+
+        def permute(self,
+                    dims: List[int]):
+            def backward(grad_output):
+                grad_self = AD_permute_backward(grad_output, dims)
+                return grad_self, None
+
+            return torch.permute(self, dims), backward
+
+        def AD_select_backward(grad,
+                               input_sizes: List[int],
+                               dim: int,
+                               index: int):
+            # FIXME: torchscript: torch.zeros(sizes, grad.options())
+            grad_input = torch.zeros(input_sizes).to(grad)
+            grad_input.select(dim, index).copy_(grad)
+            return grad_input
+
+        def select(self,
+                   dim: int,
+                   index: int):
+            self_size = self.size()
+            def backward(grad_output):
+                grad_self = AD_select_backward(grad_output, self_size, dim, index)
+                return grad_self, None, None
+
+            return torch.select(self, dim, index), backward
+
+        def AD_slice_backward(grad,
+                              input_sizes: List[int],
+                              dim: int,
+                              start: int,
+                              end: int,
+                              step: int):
+            # FIXME: torchscript: torch.zeros(sizes, grad.options())
+            grad_input = torch.zeros(input_sizes).to(grad)
+            grad_input.slice(dim, start, end, step).copy_(grad)
+            return grad_input
+
+        # DON'T enable slice unless we can correctly handle view ops in graph executor.
+        # It triggers failure of TestJit.test_sample in test_distributions.py.
+        # def slice(self,
+        #           dim: int=0,
+        #           start: int=0,
+        #           end: int=9223372036854775807,
+        #           step: int=1):
+        #     def backward(grad_output):
+        #         grad_self = AD_slice_backward(grad_output, self.size(), dim, start, end, step)
+        #         return grad_self, None, None, None, None
+
+        #     return torch.slice(self, dim, start, end, step), backward
+
+        def AD_unsqueeze_to_0(self,
+                              sizes: List[int]):
+            ndims = len(sizes)
+            for i in range(ndims):
+                if sizes[i] == 1:
+                    self = self.unsqueeze(i)
+
+            return self
+
+        def AD_unsqueeze_to_1(self,
+                              dim: int,
+                              sizes: List[int]):
+            if len(sizes) > 0 and sizes[dim] == 1:
+                return self.unsqueeze(dim)
+            return self
+
+        def squeeze_0(self):
+            self_size = self.size()
+            def backward(grad_output):
+                grad_self = AD_unsqueeze_to_0(grad_output, self_size)
+                return grad_self
+
+            return torch.squeeze(self), backward
+
+        def squeeze_1(self,
+                      dim: int):
+            self_size = self.size()
+            def backward(grad_output):
+                grad_self = AD_unsqueeze_to_1(grad_output, dim, self_size)
+                return grad_self, None
+
+            return torch.squeeze(self, dim), backward
+
+        def AD_infer_size(a: List[int],
+                          b: List[int]):
+            dimsA = len(a)
+            dimsB = len(b)
+
+            ndim = dimsA if dimsA > dimsB else dimsB
+            expand_sizes = [0] * ndim
+
+            for i in range(ndim):
+                idx = - i + ndim - 1
+                sizeA = a[i] if dimsA + i >= 0 else 1
+                sizeB = b[i] if dimsB + i >= 0 else 1
+
+                # Assert sizeA == sizeB or sizeA == 1 or sizeB == 1
+                expand_sizes[i] = sizeB if sizeA == 1 else sizeA
+
+            return expand_sizes
+
+        def AD_bmm_backward_self(grad, mat2):
+            return grad.bmm(mat2.transpose(1, 2))
+
+        def AD_bmm_backward_mat2(grad, self):
+            return self.transpose(1, 2).bmm(grad)
+
+        def bmm(self, mat2):
+            def backward(grad_output):
+                grad_self = AD_bmm_backward_self(grad_output, mat2)
+                grad_mat2 = AD_bmm_backward_mat2(grad_output, self)
+                return grad_self, grad_mat2
+            return torch.bmm(self, mat2), backward
+
+    )",
+    R"(
         def _dim_arange(like,
                         dim: int):
             def backward(grad_output):
@@ -20,11 +345,19 @@ const std::vector<std::string> functions = {
 
             return self.contiguous(), backward
 
+        def dot(self, tensor):
+            def backward(grad_output):
+                grad_self = grad_output * tensor
+                grad_tensor = grad_output * self
+                return grad_self, grad_tensor
+
+            return torch.dot(self, tensor), backward
+
         def erf(self):
             def backward(grad_output):
                 # Precomputed constant C = 2.0 / math.sqrt(math.pi)
                 C = 1.1283791670955126
-                grad_self =  C * torch.exp(- self.pow(2)) * grad_output
+                grad_self =  C * torch.exp(- self * self) * grad_output
                 return grad_self
 
             return torch.erf(self), backward
@@ -55,58 +388,24 @@ const std::vector<std::string> functions = {
 
             return torch.full_like(self, fill_value), backward
 
-        def kthvalue(self,
-                     k: int,
-                     dim: int,
-                     keepdim: bool):
-            result0, result1 = torch.kthvalue(self, k, dim, keepdim)
-            self_size = self.size()
-            def backward(grad_output):
-                grad_self = torch.index_select_backward(grad_output, dim, result1, self_size, keepdim)
-                return grad_self, None, None, None
-
-            return result0, result1, backward
-
-        def logsumexp(self,
-                      dim: List[int],
-                      keepdim: bool):
-            result = torch.logsumexp(self, dim, keepdim)
-            self_dim = self.dim()
-            def backward(grad_output):
-                grad_self = torch.logsumexp_backward(grad_output, self, result, dim, keepdim)
-                return grad_self, None, None
-
-            return result, backward
-
-        def mean_0(self):
-            self_size = self.size()
-            self_numel = self.numel()
-            def backward(grad_output):
-                grad_self = grad_output.expand(self_size) / self_numel
-                return grad_self
-
-            return torch.mean(self), backward
-
-        def mean_1(self,
-                   dim: List[int],
-                   keepdim: bool):
-            self_size = self.size()
-            def backward(grad_output):
-                grad_self = torch.sum_backward(grad_output, self_size, dim, keepdim) / torch._safe_size(self_size, dim)
-                return grad_self, None, None
-
-            return torch.mean(self, dim, keepdim), backward
-
         def mul(self, other):
-            self_size = self.size()
-            other_size = other.size()
             def backward(grad_output):
-                grad_self = (grad_output * other)._grad_sum_to_size(self_size)
-                grad_other = (grad_output * self)._grad_sum_to_size(other_size)
+                # self & other are used in backward. No need to pass in their size
+                # from forward pass
+                grad_self = (grad_output * other)._grad_sum_to_size(self.size())
+                grad_other = (grad_output * self)._grad_sum_to_size(other.size())
                 return grad_self, grad_other
 
             return self * other, backward
 
+        def mv(self, vec):
+            def backward(grad_output):
+                grad_self = grad_output.ger(vec)
+                grad_vec = self.t().mv(grad_output)
+                return grad_self, grad_vec
+
+            return torch.mv(self, vec), backward
+
         def nonzero(self):
             def backward(grad_output):
                 return None
@@ -119,14 +418,6 @@ const std::vector<std::string> functions = {
 
             return torch.ones_like(self), backward
 
-        def permute(self,
-                    dims: List[int]):
-            def backward(grad_output):
-                grad_self = torch.permute_backwards(grad_output, dims)
-                return grad_self, None
-
-            return torch.permute(self, dims), backward
-
         def pow_0(self,
                   exponent: float):
             def backward(grad_output):
@@ -136,11 +427,10 @@ const std::vector<std::string> functions = {
             return torch.pow(self, exponent), backward
 
         def pow_1(self, exponent):
-            self_size = self.size()
-            exponent_size = exponent.size()
             def backward(grad_output):
-                grad_self = torch.where(exponent == 0.0, torch.zeros_like(self), grad_output * exponent * torch.pow(self, exponent - 1))._grad_sum_to_size(self_size)
-                grad_exponent = (grad_output * torch.pow(self, exponent) * torch.log(self))._grad_sum_to_size(exponent_size)
+                # self & exponent are used in backward, no need to pass in its size explicitly
+                grad_self = torch.where(exponent == 0.0, torch.zeros_like(self), grad_output * exponent * torch.pow(self, exponent - 1))._grad_sum_to_size(self.size())
+                grad_exponent = (grad_output * torch.pow(self, exponent) * torch.log(self))._grad_sum_to_size(exponent.size())
                 return grad_self, grad_exponent
 
             return torch.pow(self, exponent), backward
@@ -174,16 +464,6 @@ const std::vector<std::string> functions = {
 
             return torch.rsub(self, other, alpha), backward
 
-        def select(self,
-                   dim: int,
-                   index: int):
-            self_size = self.size()
-            def backward(grad_output):
-                grad_self = torch.select_backward(grad_output, self_size, dim, index)
-                return grad_self, None, None
-
-            return torch.select(self, dim, index), backward
-
         def sqrt(self):
             result = torch.sqrt(self)
             def backward(grad_output):
@@ -192,23 +472,6 @@ const std::vector<std::string> functions = {
 
             return result, backward
 
-        def squeeze_0(self):
-            self_size = self.size()
-            def backward(grad_output):
-                grad_self = torch.unsqueeze_to(grad_output, self_size)
-                return grad_self
-
-            return torch.squeeze(self), backward
-
-        def squeeze_1(self,
-                      dim: int):
-            self_size = self.size()
-            def backward(grad_output):
-                grad_self = torch.unsqueeze_to(grad_output, dim, self_size)
-                return grad_self, None
-
-            return torch.squeeze(self, dim), backward
-
         def t(self):
             def backward(grad_output):
                 grad_self = torch.t(grad_output)
@@ -255,19 +518,6 @@ const std::vector<std::string> functions = {
 
             return self.to(other, non_blocking=non_blocking, copy=copy), backward
 
-        def topk(self,
-                 k,
-                 dim: int = -1,
-                 largest: bool = True,
-                 sorted: bool = True):
-            result0, result1 = torch.topk(self, k, dim, largest, sorted)
-            self_size = self.size()
-            def backward(grad_output):
-                grad_self = torch.index_select_backward(grad_output, dim, result1, self_size, True)
-                return grad_self, None, None, None, None
-
-            return result0, result1, backward
-
         def transpose(self,
                       dim0: int,
                       dim1: int):
@@ -277,63 +527,57 @@ const std::vector<std::string> functions = {
 
             return torch.transpose(self, dim0, dim1), backward
 
-        def var_0(self,
-                  unbiased: bool=True):
+        def view(self,
+                 size: List[int]):
+            self_size = self.size()
             def backward(grad_output):
-                grad_self = torch.var_backward(grad_output, self, unbiased)
+                grad_self = grad_output.reshape(self_size)
                 return grad_self, None
 
-            return torch.var(self, unbiased), backward
+            return torch.view(self, size), backward
+    )",
+    R"(
+        def AD_adaptive_avg_pool2d_backward(grad,
+                                            self,
+                                            output_size: List[int]):
+            if output_size[0] == 1 and output_size[1] == 1:
+                self_size = self.size()
+                grad_self = grad.expand(self.size()) / (self_size[-1] * self_size[-2])
+            else:
+                grad_self = torch._adaptive_avg_pool2d_backward(grad, self)
 
-        def var_1(self,
-                  dim: List[int],
-                  unbiased: bool,
-                  keepdim: bool):
-            def backward(grad_output):
-                grad_self = torch.var_backward(grad_output, self, dim, unbiased, keepdim)
-                return grad_self, None, None, None
+            return grad_self
 
-            return torch.var(self, dim, unbiased, keepdim), backward
+        def AD_adaptive_avg_pool1d_backward(grad,
+                                            input,
+                                            output_size: List[int]):
+            output_size_2d = [1, output_size[0]]
+            grad_input = AD_adaptive_avg_pool2d_backward(grad.unsqueeze(2), input.unsqueeze(2), output_size_2d).squeeze(2)
+            return grad_input
 
-        def std_0(self,
-                  unbiased: bool=True):
-            std_out = torch.std(self, unbiased)
+        def adaptive_avg_pool1d(self,
+                                output_size: List[int]):
             def backward(grad_output):
-                grad_self = torch.var_backward(grad_output / (std_out * 2), self, unbiased)
+                grad_self = AD_adaptive_avg_pool1d_backward(grad_output, self, output_size)
                 return grad_self, None
 
-            return std_out, backward
+            return torch.adaptive_avg_pool1d(self, output_size), backward
 
-        def std_1(self,
-                  dim: List[int],
-                  unbiased: bool,
-                  keepdim: bool):
-            std_out = torch.std(self, dim, unbiased, keepdim)
-            def backward(grad_output):
-                grad_self = torch.var_backward(grad_output / (std_out * 2), self, dim, unbiased, keepdim)
-                return grad_self, None, None, None
-
-            return std_out, backward
-
-        def view(self,
-                 size: List[int]):
-            self_size = self.size()
+        def adaptive_avg_pool2d(self,
+                                output_size: List[int]):
             def backward(grad_output):
-                grad_self = grad_output.reshape(self_size)
+                # self is used in backward, no need to pass in its size explicitly
+                grad_self = AD_adaptive_avg_pool2d_backward(grad_output, self, output_size)
                 return grad_self, None
+            return torch.adaptive_avg_pool2d(self, output_size), backward
 
-            return torch.view(self, size), backward
-
-        def adaptive_avg_pool2d(self,
+        def adaptive_avg_pool3d(self,
                                 output_size: List[int]):
-            self_size = self.size()
             def backward(grad_output):
-                if output_size[0] == 1 and output_size[1] == 1:
-                    grad_self = grad_output.expand(self_size) / (self_size[-1] * self_size[-2])
-                else:
-                    grad_self = torch._adaptive_avg_pool2d_backward(grad_output, self)
+                grad_self = torch.adaptive_avg_pool3d_backward(grad_output, self)
                 return grad_self, None
-            return torch.adaptive_avg_pool2d(self, output_size), backward
+
+            return torch.adaptive_avg_pool3d(self, output_size), backward
 
         def batch_norm(input : Tensor,
                        weight : Optional[Tensor],
@@ -376,6 +620,109 @@ const std::vector<std::string> functions = {
             def backward(grad):
                 return torch.nll_loss_backward(grad, self, target, weight, reduction, ignore_index, total_weight), None, None, None, None
             return result, backward
+
+        def softmax_0(self, dim: int):
+            result = torch.softmax(self, dim)
+            def backward(grad_output):
+                grad_self = torch._softmax_backward_data(grad_output, result, dim, self)
+                return grad_self, None
+
+            return result, backward
+
+        def softmax_1(self, dim: int, dtype: int):
+            result = torch.softmax(self, dim, dtype)
+            def backward(grad_output):
+                grad_self = torch._softmax_backward_data(grad_output, result, dim, self)
+                return grad_self, None, None
+
+            return torch.softmax(self, dim, dtype), backward
+
+        def AD_interpolate_backward(grad,
+                                    input,
+                                    mode: str,
+                                    align_corners: bool):
+            output_size = grad.size()[2:]
+            input_size = input.size()
+            input_dim = len(input_size)
+            if input_dim == 3 and mode == 'nearest':
+                grad_input = torch.upsample_nearest1d_backward(grad, output_size, input_size)
+            elif input_dim == 4 and mode == 'nearest':
+                grad_input = torch.upsample_nearest2d_backward(grad, output_size, input_size)
+            elif input_dim == 5 and mode == 'nearest':
+                grad_input = torch.upsample_nearest3d_backward(grad, output_size, input_size)
+            elif input_dim == 3 and mode == 'linear':
+                grad_input = torch.upsample_linear1d_backward(grad, output_size, input_size, align_corners)
+            elif input_dim == 4 and mode == 'bilinear':
+                grad_input = torch.upsample_bilinear2d_backward(grad, output_size, input_size, align_corners)
+            elif input_dim == 5 and mode == 'trilinear':
+                grad_input = torch.upsample_trilinear3d_backward(grad, output_size, input_size, align_corners)
+            elif input_dim == 4 and mode == 'bicubic':
+                grad_input = torch.upsample_bicubic2d_backward(grad, output_size, input_size, align_corners)
+            elif input_dim == 3 and mode == 'area':
+                grad_input = AD_adaptive_avg_pool1d_backward(grad, input, output_size)
+            elif input_dim == 4 and mode == 'area':
+                grad_input = AD_adaptive_avg_pool2d_backward(grad, input, output_size)
+            elif input_dim == 5 and mode == 'area':
+                grad_input = torch.adaptive_avg_pool3d_backward(grad, input)
+            else:
+                # NEVER REACH HERE
+                grad_input = torch.zeros_like(input)
+                raise RuntimeError('Input Error: Only 3D, 4D and 5D input Tensors supported')
+
+            return grad_input
+
+        def __interpolate_0(input,
+                            size: Optional[int],
+                            scale_factor: Optional[List[float]],
+                            mode: str='nearest',
+                            align_corners: Optional[bool]):
+            def backward(grad_output):
+                if align_corners is None:
+                    align_corners = False
+                grad_self = AD_interpolate_backward(grad_output, input, mode, align_corners)
+                return grad_self, None, None, None, None
+
+            return torch.__interpolate(input, size, scale_factor, mode, align_corners), backward
+
+        def __interpolate_1(input,
+                            size: Optional[List[int]],
+                            scale_factor: Optional[List[float]],
+                            mode: str='nearest',
+                            align_corners: Optional[bool]):
+            def backward(grad_output):
+                if align_corners is None:
+                    align_corners = False
+                grad_self = AD_interpolate_backward(grad_output, input, mode, align_corners)
+                return grad_self, None, None, None, None
+
+            return torch.__interpolate(input, size, scale_factor, mode, align_corners), backward
+
+        def __interpolate_2(input,
+                            size: Optional[int],
+                            scale_factor: Optional[float],
+                            mode: str='nearest',
+                            align_corners: Optional[bool]):
+            def backward(grad_output):
+                if align_corners is None:
+                    align_corners = False
+                grad_self = AD_interpolate_backward(grad_output, input, mode, align_corners)
+                return grad_self, None, None, None, None
+
+            return torch.__interpolate(input, size, scale_factor, mode, align_corners), backward
+
+        def __interpolate_3(input,
+                            size: Optional[List[int]],
+                            scale_factor: Optional[float],
+                            mode: str='nearest',
+                            align_corners: Optional[bool]):
+            def backward(grad_output):
+                if align_corners is None:
+                    align_corners = False
+                grad_self = AD_interpolate_backward(grad_output, input, mode, align_corners)
+                return grad_self, None, None, None, None
+
+            return torch.__interpolate(input, size, scale_factor, mode, align_corners), backward
+
       )"};
 std::unordered_map<std::string, GradientPair> schema_to_graphs;
 
@@ -416,17 +763,26 @@ std::string overloadedSchemaString(const FunctionSchema& schema) {
   auto pos = schema_name.find_last_of('_');
   auto schema_name_suffix = schema_name.substr(pos + 1);
   std::string schema_string = canonicalSchemaString(schema);
-  if (!schema_name_suffix.empty()
-      && schema_name_suffix.find_first_not_of("0123456789") == std::string::npos) {
-    schema_string.replace(schema_string.find(schema_name),
-                          schema_name.length(),
-                          schema_name.substr(0, pos));
+  if (!schema_name_suffix.empty() &&
+      schema_name_suffix.find_first_not_of("0123456789") == std::string::npos) {
+    schema_string.replace(
+        schema_string.find(schema_name),
+        schema_name.length(),
+        schema_name.substr(0, pos));
   }
   return schema_string;
 }
 
+bool isHelperFunction(const std::string& method_name) {
+  std::string helper_prefix = "AD_";
+  return method_name.compare(0, helper_prefix.length(), helper_prefix) == 0;
+}
+
 void loadModule(const std::shared_ptr<script::Module>& module) {
   for (const auto& method_ : module->get_methods()) {
+    if (isHelperFunction(method_.key()))
+      continue;
+
     const auto& method = method_.value();
     GradientPair pair;
     pair.forward = method->graph();