maskrcnn & bert AD coverage part 1 (#16689)
authorAiling Zhang <ailzhang@fb.com>
Thu, 14 Feb 2019 22:55:44 +0000 (14:55 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 14 Feb 2019 23:36:39 +0000 (15:36 -0800)
Summary:
- Moved a few functions from `autograd` namespace to `aten` namespace to be visible from JIT nativeResolver.
- Added a hack to loop up keyword only argument. Will add proper support for kw only later
- Simulate function overload in aten using `_<number>` as function name suffix.
- Even `forward` returns multiple outputs like in `kthvalue`, there's at most one requires grad that we currently support.
- Removed the `TensorList` related ops here since partial `TensorList` support is prone to bugs. Our symbolic diff for `cat` was never tested with autodiff, and it seems broken. Need to find another proper way to support these ops(either by properly supporting `TensorList` or sth like `prim::ConstantChunk`  and leave them for next PR.

Ops supported in this PR:
```
erf
expand_as
index
kthvalue
mean
permute
pow
rsub
select
sqrt
squeeze
t
to
topk
transpose
view
var
embedding
logsumexp
// grad is None
_dim_arange
contiguous
nonzero
ones_like
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16689

Differential Revision: D14020806

Pulled By: ailzhang

fbshipit-source-id: a5e2c144a7be5a0d39d7ac5f93cb402ec12503a5

20 files changed:
aten/src/ATen/native/ReduceOps.cpp
aten/src/ATen/native/TensorCompare.cpp
aten/src/ATen/native/TensorShape.cpp
aten/src/ATen/native/native_functions.yaml
test/common_methods_invocations.py
test/cpp/jit/test_misc.h
test/expect/TestFuser.test_lstm_cuda-backward.expect
test/expect/TestFuser.test_lstm_cuda-forward.expect
test/expect/TestFuser.test_milstm_cuda-backward.expect
test/expect/TestFuser.test_milstm_cuda-forward.expect
test/expect/TestJit.test_cpp_cuda.expect
test/test_jit.py
third_party/onnx
tools/autograd/derivatives.yaml
tools/autograd/templates/Functions.cpp
torch/csrc/jit/autodiff.cpp
torch/csrc/jit/graph_executor.cpp
torch/csrc/jit/passes/graph_fuser.cpp
torch/csrc/jit/symbolic_script.cpp
torch/nn/functional.py

index 27c2566..88c78a8 100644 (file)
@@ -335,6 +335,42 @@ Tensor& sum_out(Tensor& result, const Tensor& self, IntArrayRef dim, ScalarType
   return at::native::sum_out(result, self, dim, false, dtype);
 }
 
+int64_t _safe_size(IntArrayRef sizes, IntArrayRef dim) {
+  int64_t size = 1;
+  if (sizes.size() == 0) {
+    return 1;
+  }
+  for (auto d : dim) {
+    d = at::maybe_wrap_dim(d, sizes.size());
+    size *= sizes[d];
+  }
+  return size;
+}
+
+Tensor unsqueeze_multiple(const Tensor & t, IntArrayRef dim, size_t n_dims) {
+    auto dims_to_unsqueeze = at::dim_list_to_bitset(dim, n_dims);
+    Tensor res = t;
+    for (size_t i = 0; i < n_dims; i++){
+      if (dims_to_unsqueeze[i]) {
+        res = res.unsqueeze(i);
+      }
+    }
+    return res;
+}
+
+Tensor sum_backward(const Tensor & grad, IntArrayRef sizes, IntArrayRef dims, bool keepdim) {
+  if (!keepdim && sizes.size() > 0) {
+    if (dims.size()==1) {
+      return grad.unsqueeze(dims[0]).expand(sizes);
+    } else {
+      Tensor res = unsqueeze_multiple(grad, dims, sizes.size());
+      return res.expand(sizes);
+    }
+  } else {
+    return grad.expand(sizes);
+  }
+}
+
 Tensor& prod_out(Tensor& result, const Tensor& self, int64_t dim, bool keepdim, ScalarType dtype) {
   return at::native::prod_out(
       result, self, dim, keepdim, c10::optional<ScalarType>(dtype));
@@ -416,6 +452,16 @@ Tensor logsumexp(const Tensor &self, IntArrayRef dims, bool keepdim) {
   return at::native::logsumexp_out(result, self, dims, keepdim);
 }
 
+Tensor logsumexp_backward(const Tensor& grad, const Tensor & self, const Tensor& res, IntArrayRef dim, bool keepdim) {
+  Tensor grad_input = grad;
+  Tensor fwd_res = res;
+  if (!keepdim && self.dim() != 0) {
+    grad_input = unsqueeze_multiple(grad, dim, self.sizes().size());
+    fwd_res = unsqueeze_multiple(res, dim, self.sizes().size());
+  }
+  return grad_input * (self - fwd_res).exp();
+}
+
 static Tensor& norm_out(Tensor &result, const Tensor &self, optional<Scalar> opt_p,
                                IntArrayRef dim, bool keepdim, optional<ScalarType> opt_dtype) {
   auto p = opt_p.value_or(2.0);
@@ -628,6 +674,21 @@ Tensor &var_out(Tensor &result, const Tensor &self, IntArrayRef dim, bool unbias
   return std_var_out(result, self, dim, unbiased, keepdim, false);
 }
 
+Tensor var_backward(const Tensor & grad, const Tensor & self, bool unbiased) {
+  return (2.0 / (self.numel() - unbiased)) * grad * (self - self.mean());
+}
+
+Tensor var_backward(const Tensor & grad, const Tensor & self, IntArrayRef dim, bool unbiased, bool keepdim) {
+  if (self.dim() == 0) {
+    return at::var_backward(grad, self, unbiased);
+  }
+  Tensor unsqueezed_grad = grad;
+  if (!keepdim && self.dim() > 1) {
+    unsqueezed_grad = unsqueeze_multiple(grad, dim, self.sizes().size());
+  }
+  return (2.0 / (at::_safe_size(self.sizes(), dim) - unbiased)) * unsqueezed_grad * (self - self.mean(dim, true));
+}
+
 Tensor std(const Tensor& self, bool unbiased) {
   AT_CHECK(self.type().backend() == Backend::CPU || self.type().backend() == Backend::CUDA,
            "std only supports CPU AND CUDA backend, got: ", toString(self.type().backend()));
index 0f58f4f..04ee3dd 100644 (file)
@@ -34,6 +34,14 @@ namespace at { namespace native {
 DEFINE_DISPATCH(max_kernel);
 DEFINE_DISPATCH(min_kernel);
 
+Tensor index_select_backward(const Tensor& grad, int64_t dim, const Tensor& indices, IntArrayRef sizes, bool keepdim) {
+  Tensor res = at::zeros(sizes, grad.options());
+  if (!keepdim && sizes.size() > 0) {
+    return res.scatter_(dim, indices.unsqueeze(dim), grad.unsqueeze(dim));
+  }
+  return res.scatter_(dim, indices, grad);
+}
+
 bool allclose(const Tensor& self, const Tensor& other, double rtol, double atol, bool equal_nan) {
   return at::isclose(self, other, rtol, atol, equal_nan).all().item<uint8_t>();
 }
index 2e7b5e6..bddf40a 100644 (file)
@@ -384,6 +384,16 @@ Tensor permute(const Tensor& self, IntArrayRef dims) {
   return self.as_strided(newSizes, newStrides);
 }
 
+Tensor permute_backwards(const Tensor & grad, IntArrayRef fwd_dims) {
+  // invert the permutation
+  auto ndims = fwd_dims.size();
+  std::vector<int64_t> dims(ndims);
+  for (size_t i = 0; i < ndims; i++) {
+    dims[at::maybe_wrap_dim(fwd_dims[i], ndims)] = i;
+  }
+  return grad.permute(dims);
+}
+
 Tensor repeat(const Tensor& self, IntArrayRef repeats) {
   AT_CHECK(repeats.size() >= (size_t)self.dim(),
            "Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor");
@@ -451,6 +461,12 @@ Tensor select(const Tensor& self, int64_t dim, int64_t index) {
   return self.as_strided(sizes, strides, storage_offset);
 }
 
+Tensor select_backward(const Tensor& grad, IntArrayRef input_sizes, int64_t dim, int64_t index) {
+  auto grad_input = at::zeros(input_sizes, grad.options());
+  grad_input.select(dim, index).copy_(grad);
+  return grad_input;
+}
+
 Tensor slice(const Tensor& self, int64_t dim, int64_t start, int64_t end, int64_t step) {
   int64_t ndim = self.dim();
   if (ndim == 0) {
@@ -484,6 +500,12 @@ Tensor slice(const Tensor& self, int64_t dim, int64_t start, int64_t end, int64_
   return self.as_strided(sizes, strides, storage_offset);
 }
 
+Tensor slice_backward(const Tensor& grad, IntArrayRef input_sizes, int64_t dim, int64_t start, int64_t end, int64_t step) {
+  auto grad_input = at::zeros(input_sizes, grad.options());
+  grad_input.slice(dim, start, end, step).copy_(grad);
+  return grad_input;
+}
+
 std::vector<Tensor> split(const Tensor& self, int64_t split_size, int64_t dim) {
   AT_CHECK(self.dim() != 0, "split expects at least a 1-dimensional tensor");
   AT_CHECK(split_size >= 0,  "split expects split_size be non-negative, but got split_size=", split_size);
@@ -690,6 +712,28 @@ Tensor squeeze(const Tensor& self) {
   return self.as_strided(std::get<0>(g), std::get<1>(g));
 }
 
+Tensor unsqueeze_to(const Tensor & self, IntArrayRef sizes) {
+  auto result = self;
+
+  int64_t nDims = sizes.size();
+  for (int64_t dim = 0; dim < nDims; dim++) {
+    if (sizes[dim] == 1) {
+      result = result.unsqueeze(dim);
+    }
+  }
+  return result;
+}
+
+Tensor unsqueeze_to(const Tensor & self, int64_t dim, IntArrayRef sizes) {
+  dim = at::maybe_wrap_dim(dim, sizes.size());
+  // in NumPy it's not an error to unsqueeze a scalar, but we still need to avoided
+  // unsqueezing in the backward.
+  if (sizes.size() > 0 && sizes[dim] == 1) {
+    return self.unsqueeze(dim);
+  }
+  return self;
+}
+
 Tensor squeeze(const Tensor& self, int64_t dim) {
   int64_t dims = self.dim();
   dim = maybe_wrap_dim(dim, dims);
index 7f2d1a2..386b608 100644 (file)
   dispatch:
     CUDA: _cudnn_init_dropout_state
 
+- func: index_select_backward(Tensor grad, int64_t dim, Tensor indices, int[] sizes, bool keepdim) -> Tensor
+
+- func: select_backward(Tensor grad, int[] input_sizes, int64_t dim, int64_t index) -> Tensor
+
 - func: _fused_dropout(Tensor self, float p, Generator? generator=None) -> (Tensor, Tensor)
   matches_jit_signature: True
   variants: function
 - func: logsumexp(Tensor self, int[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
   matches_jit_signature: True
 
+- func: logsumexp_backward(Tensor grad, Tensor self, Tensor res, int[1] dim, bool keepdim) -> Tensor
+
 - func: margin_ranking_loss(Tensor input1, Tensor input2, Tensor target, float margin=0.0, int reduction=Mean) -> Tensor
   matches_jit_signature: True
 
 - func: mean(Tensor self, int[1] dim, *, ScalarType dtype, Tensor(a!) out) -> Tensor(a!)
   matches_jit_signature: True
 
+- func: sum_backward(Tensor grad, int[] sizes, int[] dims, bool keepdim) -> Tensor
+
 - func: median(Tensor self, int dim, bool keepdim=False) -> (Tensor, Tensor)
   matches_jit_signature: True
   variants: function, method
   matches_jit_signature: True
   variants: method  # This is method-only to match the previous tensor API. In the future we could make this a function too.
 
+- func: permute_backwards(Tensor grad, int[] fwd_dims) -> Tensor
+
 - func: pixel_shuffle(Tensor self, int upscale_factor) -> Tensor
   matches_jit_signature: True
 
   variants: function, method
   device_guard: False
 
+- func: _safe_size(int[] sizes, int[] dim) -> int64_t
+
 - func: slice(Tensor(a) self, int dim=0, int start=0, int end=9223372036854775807, int step=1) -> Tensor(a)
   matches_jit_signature: True
   variants: function, method
   device_guard: False
 
+- func: slice_backward(Tensor grad, int[] input_sizes, int64_t dim, int64_t start, int64_t end, int64_t step) -> Tensor
+
 - func: slogdet(Tensor self) -> (Tensor, Tensor)
   matches_jit_signature: True
   variants: function, method
   variants: function, method
   device_guard: False
 
+- func: unsqueeze_to(Tensor self, int[] sizes) -> Tensor
+
+- func: unsqueeze_to(Tensor self, int64_t dim, int[] sizes) -> Tensor
+
 - func: squeeze_(Tensor(a!) self) -> Tensor(a!)
   matches_jit_signature: True
   variants: method
 - func: var(Tensor self, int[1] dim, bool unbiased=True, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
   matches_jit_signature: True
 
+- func: var_backward(Tensor grad, Tensor self, bool unbiased) -> Tensor
+
+- func: var_backward(Tensor grad, Tensor self, int[] dim, bool unbiased, bool keepdim) -> Tensor
+
 - func: view_as(Tensor self, Tensor other) -> Tensor
   matches_jit_signature: True
   variants: method
index 6f15114..0c8be4c 100644 (file)
@@ -209,6 +209,7 @@ def method_tests():
         ('expand', (1, S), (1, 1, S), 'new_dim_front_old_front_1'),
         ('expand', (), (dont_convert(()),), 'scalar_to_scalar'),
         ('expand', (), (1, 3, 2), 'scalar_to_dims'),
+        ('expand_as', (S, 1, 1), (torch.rand(S, S, S),)),
         ('exp', (S, S, S), NO_ARGS),
         ('exp', (), NO_ARGS, 'scalar'),
         ('expm1', (S, S, S), NO_ARGS),
@@ -1020,6 +1021,8 @@ EXCLUDE_GRADGRADCHECK_BY_TEST_NAME = {
     'test_det_dim2_null',
     'test_det_rank1',
     'test_det_rank2',
+    # `other` expand_as(self, other) is not used in autograd.
+    'test_expand_as',
     'test_logdet',
     'test_logdet_1x1',
     'test_logdet_symmetric',
index 9039dc7..6f50469 100644 (file)
@@ -848,7 +848,24 @@ void testDifferentiate(std::ostream& out = std::cout) {
 
   auto grad_spec = differentiate(graph);
   std::vector<size_t> expected_captured_inputs = {0, 1};
-  std::vector<size_t> expected_captured_outputs = {1, 2};
+  // With add/mul implemented using torchscript, we passes sizes of
+  // self & other instead passing the tensors themselve.
+  // The forward graph is now
+  //graph(%0 : Float(2, 3, 4)
+  //      %1 : Float(2, 3, 4)) {
+  //  %2 : Float(2, 3, 4) = aten::mul(%0, %1)
+  //  %self_size.4 : int[] = aten::size(%0)
+  //  %other_size.4 : int[] = aten::size(%1)
+  //  %3 : Float(2, 3, 4) = aten::mul(%2, %0)
+  //  %self_size.2 : int[] = aten::size(%2)
+  //  %4 : int = prim::Constant[value=1]()
+  //  %7 : int[] = aten::size(%3)
+  //  %5 : Float(2, 3, 4) = aten::add(%3, %1, %4)
+  //  return (%5, %2, %self_size.4, %other_size.4, %self_size.2, %7);
+  //}
+  // Thus all the sizes info added in forward outputs are saved
+  // in grad_spec.df_input_caputered_outputs.
+  std::vector<size_t> expected_captured_outputs = {1, 2, 3, 4, 5};
   std::vector<size_t> expected_input_vjps = {0, 1};
   std::vector<size_t> expected_output_vjps = {0, 1};
   ASSERT_EQ(grad_spec.f_real_outputs, 1);
@@ -880,12 +897,29 @@ void testDifferentiateWithRequiresGrad(std::ostream& out = std::cout) {
   PropagateInputShapes(graph);
   PropagateRequiresGrad(graph);
 
+  // With add/mul implemented using torchscript, we passes sizes of
+  // self & other instead passing the tensors themselve.
+  // The forward graph is now
+  // graph(%0 : Float(*)
+  //       %1 : Float(*)) {
+  //   %2 : Float(*) = aten::mul(%1, %1)
+  //   %3 : int = prim::Constant[value=1]()
+  //   %4 : Float(*) = aten::add(%2, %1, %3)
+  //   %39 : int[] = aten::size(%0)
+  //   %6 : Float(*) = aten::add(%4, %0, %3)
+  //   %7 : Float(*) = aten::mul(%6, %0)
+  //   %self_size.2 : int[] = aten::size(%6)
+  //   %11 : int[] = aten::size(%7)
+  //   %9 : Float(*) = aten::add(%7, %1, %3)
+  //   return (%4, %9, %39, %6, %self_size.2, %11);
+  // }
+
   auto grad_spec = differentiate(graph);
-  std::vector<size_t> expected_input_vjps = {1, 2}; // for e and %4 = (d + a)
+  std::vector<size_t> expected_input_vjps = {1, 3}; // for e and %6 = (d + a)
   std::vector<size_t> expected_output_vjps = {0}; // only a requires grad
   ASSERT_EQ(grad_spec.f_real_outputs, 2);
   ASSERT_EQ(grad_spec.df_input_captured_inputs, std::vector<size_t>({0}));
-  ASSERT_EQ(grad_spec.df_input_captured_outputs, std::vector<size_t>({2, 3}));
+  ASSERT_EQ(grad_spec.df_input_captured_outputs, std::vector<size_t>({2, 3, 4, 5}));
   ASSERT_EQ(grad_spec.df_input_vjps, expected_input_vjps);
   ASSERT_EQ(grad_spec.df_output_vjps, expected_output_vjps);
   out << "testDifferentiateWithRequiresGrad\n";
index 8c43a7e..dda0c4f 100644 (file)
@@ -22,20 +22,20 @@ graph(%0 : Float(*, *),
       %forgetgate : Float(*, *),
       %cellgate : Float(*, *),
       %outgate : Float(*, *),
-      %24 : int[],
-      %25 : int[],
-      %26 : Float(*, *)):
-  %27 : int = prim::Constant[value=1]()
-  %28 : int[] = aten::size(%outgate)
-  %29 : int[] = aten::size(%26)
-  %30 : int[] = aten::size(%ingate)
-  %31 : int[] = aten::size(%cellgate)
-  %32 : int[] = aten::size(%forgetgate)
-  %33 : int[] = aten::size(%9)
-  %34 : Tensor = prim::FusionGroup_0(%outgate, %0, %26, %28)
-  %grad_other.5 : Tensor, %36 : Tensor, %37 : Tensor, %38 : Tensor = prim::FusionGroup_1(%forgetgate, %9, %ingate, %cellgate, %1, %26, %0, %outgate, %33, %32, %24, %31, %30, %25, %29)
+      %self_size.5 : int[],
+      %other_size.5 : int[],
+      %self_size.3 : int[],
+      %other_size.3 : int[],
+      %28 : int[],
+      %29 : int[],
+      %30 : Float(*, *),
+      %self_size.1 : int[],
+      %other_size.1 : int[]):
+  %33 : int = prim::Constant[value=1]()
+  %34 : Tensor = prim::FusionGroup_0(%outgate, %0, %30, %self_size.1)
+  %grad_other.5 : Tensor, %36 : Tensor, %37 : Tensor, %38 : Tensor = prim::FusionGroup_1(%forgetgate, %9, %ingate, %cellgate, %1, %30, %0, %outgate, %other_size.5, %self_size.5, %28, %other_size.3, %self_size.3, %29, %other_size.1)
   %39 : Tensor[] = prim::ListConstruct(%38, %36, %37, %34)
-  %40 : Tensor = aten::cat(%39, %27)
+  %40 : Tensor = aten::cat(%39, %33)
   %41 : Tensor = aten::_grad_sum_to_size(%40, %19)
   %42 : Tensor = aten::_grad_sum_to_size(%40, %17)
   %43 : Tensor = aten::_grad_sum_to_size(%40, %14)
@@ -44,13 +44,13 @@ graph(%0 : Float(*, *),
   %46 : Float(*, *) = aten::mm(%44, %45)
   %47 : Float(*, *) = aten::t(%10)
   %48 : Float(*, *) = aten::mm(%47, %44)
-  %49 : Float(*, *) = aten::t(%48)
+  %grad_self.7 : Float(*, *) = aten::t(%48)
   %50 : Float(*, *) = aten::t(%12)
   %51 : Float(*, *) = aten::mm(%43, %50)
   %52 : Float(*, *) = aten::t(%11)
   %53 : Float(*, *) = aten::mm(%52, %43)
-  %54 : Float(*, *) = aten::t(%53)
-  return (%grad_other.5, %41, %42, %46, %49, %51, %54)
+  %grad_self.9 : Float(*, *) = aten::t(%53)
+  return (%grad_other.5, %41, %42, %46, %grad_self.7, %51, %grad_self.9)
 with prim::FusionGroup_0 = graph(%0 : Float(*, *),
       %1 : Float(*, *),
       %2 : Float(*, *),
index e53deb8..414a041 100644 (file)
@@ -28,14 +28,16 @@ with prim::DifferentiableGraph_0 = graph(%0 : Float(*, *),
   %17 : Tensor, %18 : Tensor, %19 : Tensor, %20 : Tensor = prim::ListUnpack(%16)
   %21 : int[] = prim::BroadcastSizes(%11, %12)
   %22 : int[] = prim::BroadcastSizes(%21, %13)
-  %hy : Float(*, *), %24 : Float(*, *), %cy : Float(*, *), %outgate.1 : Float(*, *), %cellgate.1 : Float(*, *), %forgetgate.1 : Float(*, *), %ingate.1 : Float(*, *) = prim::FusionGroup_0(%0, %20, %19, %18, %17)
-  %30 : int[] = aten::size(%0)
-  %31 : int[] = aten::size(%cellgate.1)
-  %32 : int[] = aten::size(%forgetgate.1)
-  %33 : int[] = aten::size(%ingate.1)
-  %34 : int[] = prim::BroadcastSizes(%32, %30)
-  %35 : int[] = prim::BroadcastSizes(%33, %31)
-  return (%hy, %cy, %7, %9, %11, %12, %21, %13, %22, %14, %ingate.1, %forgetgate.1, %cellgate.1, %outgate.1, %34, %35, %24)
+  %other_size.6 : int[] = aten::size(%0)
+  %hy : Float(*, *), %25 : Float(*, *), %cy : Float(*, *), %outgate.1 : Float(*, *), %cellgate.1 : Float(*, *), %forgetgate.1 : Float(*, *), %ingate.1 : Float(*, *) = prim::FusionGroup_0(%0, %20, %19, %18, %17)
+  %31 : int[] = aten::size(%25)
+  %32 : int[] = aten::size(%outgate.1)
+  %33 : int[] = aten::size(%cellgate.1)
+  %34 : int[] = aten::size(%forgetgate.1)
+  %35 : int[] = aten::size(%ingate.1)
+  %36 : int[] = prim::BroadcastSizes(%34, %other_size.6)
+  %37 : int[] = prim::BroadcastSizes(%35, %33)
+  return (%hy, %cy, %7, %9, %11, %12, %21, %13, %22, %14, %ingate.1, %forgetgate.1, %cellgate.1, %outgate.1, %34, %other_size.6, %35, %33, %36, %37, %25, %32, %31)
 with prim::FusionGroup_0 = graph(%0 : Float(*, *),
       %1 : Tensor,
       %2 : Tensor,
index eb8e612..590d427 100644 (file)
@@ -17,50 +17,50 @@ graph(%0 : Float(*, *),
       %Wx : Float(*, *),
       %Uz : Float(*, *),
       %18 : Float(*, *),
-      %19 : int[],
-      %20 : int[],
-      %21 : int[],
-      %22 : int[],
-      %23 : int[],
+      %self_size.13 : int[],
+      %other_size.13 : int[],
+      %self_size.11 : int[],
+      %other_size.11 : int[],
+      %self_size.9 : int[],
       %24 : int[],
+      %25 : int[],
+      %self_size.7 : int[],
+      %27 : int[],
+      %28 : int[],
+      %29 : int[],
+      %30 : int[],
       %ingate : Float(*, *),
       %forgetgate : Float(*, *),
       %cellgate : Float(*, *),
       %outgate : Float(*, *),
-      %29 : int[],
-      %30 : int[],
-      %31 : Float(*, *)):
-  %32 : int = prim::Constant[value=1]()
-  %33 : int[] = aten::size(%outgate)
-  %34 : int[] = aten::size(%31)
-  %35 : int[] = aten::size(%ingate)
-  %36 : int[] = aten::size(%cellgate)
-  %37 : int[] = aten::size(%forgetgate)
-  %38 : Tensor = prim::FusionGroup_0(%outgate, %0, %31, %33)
-  %39 : Tensor, %40 : Tensor, %41 : Tensor = prim::FusionGroup_1(%10, %ingate, %cellgate, %1, %31, %0, %outgate, %forgetgate, %37, %29, %36, %35, %30, %34)
-  %42 : Tensor[] = prim::ListConstruct(%41, %39, %40, %38)
-  %43 : Tensor = aten::cat(%42, %32)
-  %44 : Tensor = aten::_grad_sum_to_size(%43, %24)
-  %45 : Tensor = aten::_grad_sum_to_size(%43, %22)
-  %46 : int[] = aten::size(%11)
-  %grad_self.7 : Tensor = prim::FusionGroup_2(%45, %Uz, %46)
-  %48 : int[] = aten::size(%Uz)
-  %49 : Tensor = aten::_grad_sum_to_size(%43, %19)
-  %50 : Tensor = aten::_grad_sum_to_size(%43, %20)
-  %51 : int[] = aten::size(%12)
-  %grad_self.9 : Tensor = prim::FusionGroup_3(%50, %Wx, %51)
-  %53 : int[] = aten::size(%Wx)
-  %54 : int[] = aten::size(%18)
-  %55 : Tensor = prim::FusionGroup_4(%49, %18, %45, %11, %48)
-  %56 : int[] = aten::size(%13)
-  %grad_self.13 : Tensor, %58 : Tensor = prim::FusionGroup_5(%Wx, %13, %49, %Uz, %50, %12, %56, %53, %54)
+      %self_size.5 : int[],
+      %self_size.3 : int[],
+      %other_size.3 : int[],
+      %38 : int[],
+      %39 : int[],
+      %40 : Float(*, *),
+      %self_size.1 : int[],
+      %other_size.1 : int[]):
+  %43 : int = prim::Constant[value=1]()
+  %44 : Tensor = prim::FusionGroup_0(%outgate, %0, %40, %self_size.1)
+  %45 : Tensor, %46 : Tensor, %47 : Tensor = prim::FusionGroup_1(%10, %ingate, %cellgate, %1, %40, %0, %outgate, %forgetgate, %self_size.5, %38, %other_size.3, %self_size.3, %39, %other_size.1)
+  %48 : Tensor[] = prim::ListConstruct(%47, %45, %46, %44)
+  %49 : Tensor = aten::cat(%48, %43)
+  %50 : Tensor = aten::_grad_sum_to_size(%49, %30)
+  %51 : Tensor = aten::_grad_sum_to_size(%49, %28)
+  %grad_self.7 : Tensor = prim::FusionGroup_2(%51, %Uz, %self_size.7)
+  %53 : Tensor = aten::_grad_sum_to_size(%49, %24)
+  %54 : Tensor = aten::_grad_sum_to_size(%49, %25)
+  %grad_self.9 : Tensor = prim::FusionGroup_3(%54, %Wx, %self_size.9)
+  %56 : Tensor = prim::FusionGroup_4(%53, %18, %51, %11, %other_size.11)
+  %grad_self.13 : Tensor, %58 : Tensor = prim::FusionGroup_5(%Wx, %13, %53, %Uz, %54, %12, %self_size.13, %other_size.13, %self_size.11)
   %59 : Float(*, *) = aten::t(%14)
-  %60 : Float(*, *) = aten::mm(%59, %55)
-  %61 : Float(*, *) = aten::t(%60)
+  %60 : Float(*, *) = aten::mm(%59, %56)
+  %grad_self.15 : Float(*, *) = aten::t(%60)
   %62 : Float(*, *) = aten::t(%15)
   %63 : Float(*, *) = aten::mm(%62, %58)
-  %64 : Float(*, *) = aten::t(%63)
-  return (%44, %grad_self.7, %grad_self.9, %grad_self.13, %61, %64)
+  %grad_self.17 : Float(*, *) = aten::t(%63)
+  return (%50, %grad_self.7, %grad_self.9, %grad_self.13, %grad_self.15, %grad_self.17)
 with prim::FusionGroup_0 = graph(%0 : Float(*, *),
       %1 : Float(*, *),
       %2 : Float(*, *),
index a8b3b77..019db7d 100644 (file)
@@ -24,28 +24,31 @@ with prim::DifferentiableGraph_0 = graph(%0 : Float(*, *),
   %11 : Float(*, *) = aten::t(%6)
   %Uz.1 : Float(*, *) = aten::mm(%5, %11)
   %13 : Float(*, *) = aten::mul(%4, %Wx.1)
-  %14 : int[] = aten::size(%1)
-  %15 : Tensor[] = prim::ListConstruct(%1, %2, %Uz.1, %13, %3, %Wx.1)
-  %16 : Tensor[] = aten::broadcast_tensors(%15)
-  %17 : Tensor, %18 : Tensor, %19 : Tensor, %20 : Tensor, %21 : Tensor, %22 : Tensor = prim::ListUnpack(%16)
-  %23 : int[] = aten::size(%3)
-  %24 : int[] = aten::size(%Wx.1)
-  %25 : int[] = prim::BroadcastSizes(%23, %24)
-  %26 : int[] = aten::size(%13)
-  %27 : int[] = aten::size(%Uz.1)
-  %28 : int[] = prim::BroadcastSizes(%26, %27)
-  %29 : int[] = aten::size(%2)
-  %30 : int[] = prim::BroadcastSizes(%29, %27)
-  %31 : int[] = prim::BroadcastSizes(%28, %25)
-  %32 : int[] = prim::BroadcastSizes(%31, %30)
-  %hy : Float(*, *), %34 : Float(*, *), %cy : Float(*, *), %outgate.1 : Float(*, *), %cellgate.1 : Float(*, *), %forgetgate.1 : Float(*, *), %ingate.1 : Float(*, *) = prim::FusionGroup_0(%0, %22, %21, %20, %19, %18, %17)
-  %40 : int[] = aten::size(%0)
-  %41 : int[] = aten::size(%cellgate.1)
-  %42 : int[] = aten::size(%forgetgate.1)
-  %43 : int[] = aten::size(%ingate.1)
-  %44 : int[] = prim::BroadcastSizes(%42, %40)
-  %45 : int[] = prim::BroadcastSizes(%43, %41)
-  return (%hy, %cy, %Wx.1, %Uz.1, %13, %28, %25, %31, %30, %32, %14, %ingate.1, %forgetgate.1, %cellgate.1, %outgate.1, %44, %45, %34)
+  %self_size.14 : int[] = aten::size(%4)
+  %other_size.14 : int[] = aten::size(%Wx.1)
+  %self_size.12 : int[] = aten::size(%13)
+  %other_size.12 : int[] = aten::size(%Uz.1)
+  %self_size.10 : int[] = aten::size(%3)
+  %self_size.8 : int[] = aten::size(%2)
+  %20 : int[] = aten::size(%1)
+  %21 : Tensor[] = prim::ListConstruct(%1, %2, %Uz.1, %13, %3, %Wx.1)
+  %22 : Tensor[] = aten::broadcast_tensors(%21)
+  %23 : Tensor, %24 : Tensor, %25 : Tensor, %26 : Tensor, %27 : Tensor, %28 : Tensor = prim::ListUnpack(%22)
+  %29 : int[] = prim::BroadcastSizes(%self_size.10, %other_size.14)
+  %30 : int[] = prim::BroadcastSizes(%self_size.12, %other_size.12)
+  %31 : int[] = prim::BroadcastSizes(%self_size.8, %other_size.12)
+  %32 : int[] = prim::BroadcastSizes(%30, %29)
+  %33 : int[] = prim::BroadcastSizes(%32, %31)
+  %hy : Float(*, *), %35 : Float(*, *), %cy : Float(*, *), %outgate.1 : Float(*, *), %cellgate.1 : Float(*, *), %forgetgate.1 : Float(*, *), %ingate.1 : Float(*, *) = prim::FusionGroup_0(%0, %28, %27, %26, %25, %24, %23)
+  %41 : int[] = aten::size(%0)
+  %42 : int[] = aten::size(%35)
+  %43 : int[] = aten::size(%outgate.1)
+  %44 : int[] = aten::size(%cellgate.1)
+  %45 : int[] = aten::size(%forgetgate.1)
+  %46 : int[] = aten::size(%ingate.1)
+  %47 : int[] = prim::BroadcastSizes(%45, %41)
+  %48 : int[] = prim::BroadcastSizes(%46, %44)
+  return (%hy, %cy, %Wx.1, %Uz.1, %13, %self_size.14, %other_size.14, %self_size.12, %other_size.12, %self_size.10, %30, %29, %self_size.8, %32, %31, %33, %20, %ingate.1, %forgetgate.1, %cellgate.1, %outgate.1, %45, %46, %44, %47, %48, %35, %43, %42)
 with prim::FusionGroup_0 = graph(%0 : Float(*, *),
       %1 : Tensor,
       %2 : Tensor,
index e15ab3c..5e465c3 100644 (file)
@@ -85,47 +85,48 @@ testDifferentiate
 graph(%0 : Float(2, 3, 4),
       %1 : Float(2, 3, 4)):
   %2 : Float(2, 3, 4) = aten::mul(%0, %1)
+  %self_size.4 : int[] = aten::size(%0)
+  %other_size.4 : int[] = aten::size(%1)
   %3 : Float(2, 3, 4) = aten::mul(%2, %0)
+  %self_size.2 : int[] = aten::size(%2)
   %4 : int = prim::Constant[value=1]()
   %7 : int[] = aten::size(%3)
   %5 : Float(2, 3, 4) = aten::add(%3, %1, %4)
-  return (%5, %2, %7)
+  return (%5, %2, %self_size.4, %other_size.4, %self_size.2, %7)
 graph(%0 : Float(2, 3, 4),
       %1 : Float(2, 3, 4),
       %2 : Float(2, 3, 4),
       %3 : Float(2, 3, 4),
       %4 : Float(2, 3, 4),
-      %5 : int[]):
-  %7 : int = prim::Constant[value=1]()
-  %6 : int[] = aten::size(%3)
-  %8 : Tensor, %9 : Tensor = prim::GradOf[name="aten::add"](%0)
+      %self_size.3 : int[],
+      %other_size.3 : int[],
+      %self_size.1 : int[],
+      %8 : int[]):
+  %9 : int = prim::Constant[value=1]()
+  %10 : Tensor, %11 : Tensor = prim::GradOf[name="aten::add"](%0)
     block0():
-      %10 : Tensor = aten::_grad_sum_to_size(%0, %5)
-      %11 : Float(2, 3, 4) = aten::mul(%0, %7)
-      %12 : Tensor = aten::_grad_sum_to_size(%11, %6)
-      -> (%10, %12)
-  %grad_self.2 : Tensor, %grad_other.2 : Tensor = prim::GradOf[name="aten::mul"](%8)
+      %12 : Tensor = aten::_grad_sum_to_size(%0, %8)
+      %13 : Float(2, 3, 4) = aten::mul(%0, %9)
+      %14 : Tensor = aten::_grad_sum_to_size(%13, %other_size.3)
+      -> (%12, %14)
+  %grad_self.2 : Tensor, %grad_other.2 : Tensor = prim::GradOf[name="aten::mul"](%10)
     block0():
-      %15 : Tensor = aten::mul(%8, %2)
-      %16 : int[] = aten::size(%4)
-      %grad_self.1 : Tensor = aten::_grad_sum_to_size(%15, %16)
-      %18 : Tensor = aten::mul(%8, %4)
-      %19 : int[] = aten::size(%2)
-      %grad_other.1 : Tensor = aten::_grad_sum_to_size(%18, %19)
+      %17 : Tensor = aten::mul(%10, %2)
+      %grad_self.1 : Tensor = aten::_grad_sum_to_size(%17, %self_size.1)
+      %19 : Tensor = aten::mul(%10, %4)
+      %grad_other.1 : Tensor = aten::_grad_sum_to_size(%19, %self_size.3)
       -> (%grad_self.1, %grad_other.1)
   %21 : Tensor = prim::AutogradAdd(%1, %grad_self.2)
   %grad_self : Tensor, %grad_other : Tensor = prim::GradOf[name="aten::mul"](%21)
     block0():
       %24 : Tensor = aten::mul(%21, %3)
-      %25 : int[] = aten::size(%2)
-      %grad_self.3 : Tensor = aten::_grad_sum_to_size(%24, %25)
-      %27 : Tensor = aten::mul(%21, %2)
-      %28 : int[] = aten::size(%3)
-      %grad_other.3 : Tensor = aten::_grad_sum_to_size(%27, %28)
+      %grad_self.3 : Tensor = aten::_grad_sum_to_size(%24, %self_size.3)
+      %26 : Tensor = aten::mul(%21, %2)
+      %grad_other.3 : Tensor = aten::_grad_sum_to_size(%26, %other_size.3)
       -> (%grad_self.3, %grad_other.3)
-  %30 : Tensor = prim::AutogradAdd(%grad_other.2, %grad_self)
-  %31 : Tensor = prim::AutogradAdd(%9, %grad_other)
-  return (%30, %31)
+  %28 : Tensor = prim::AutogradAdd(%grad_other.2, %grad_self)
+  %29 : Tensor = prim::AutogradAdd(%11, %grad_other)
+  return (%28, %29)
 
 testDifferentiateWithRequiresGrad
 graph(%0 : Float(*),
@@ -133,37 +134,38 @@ graph(%0 : Float(*),
   %2 : Float(*) = aten::mul(%1, %1)
   %3 : int = prim::Constant[value=1]()
   %4 : Float(*) = aten::add(%2, %1, %3)
+  %39 : int[] = aten::size(%0)
   %6 : Float(*) = aten::add(%4, %0, %3)
   %7 : Float(*) = aten::mul(%6, %0)
+  %self_size.2 : int[] = aten::size(%6)
   %11 : int[] = aten::size(%7)
   %9 : Float(*) = aten::add(%7, %1, %3)
-  return (%4, %9, %6, %11)
+  return (%4, %9, %39, %6, %self_size.2, %11)
 graph(%0 : Float(*),
       %1 : Float(*),
       %2 : Float(*),
-      %3 : Float(*),
-      %4 : int[]):
-  %6 : int = prim::Constant[value=1]()
-  %5 : int[] = aten::size(%2)
-  %7 : Tensor = prim::GradOf[name="aten::add"](%0)
+      %3 : int[],
+      %4 : Float(*),
+      %self_size.1 : int[],
+      %6 : int[]):
+  %7 : int = prim::Constant[value=1]()
+  %8 : Tensor = prim::GradOf[name="aten::add"](%0)
     block0():
-      %8 : Tensor = aten::_grad_sum_to_size(%0, %4)
-      -> (%8)
-  %grad_self : Tensor, %grad_other : Tensor = prim::GradOf[name="aten::mul"](%7)
+      %9 : Tensor = aten::_grad_sum_to_size(%0, %6)
+      -> (%9)
+  %grad_self : Tensor, %grad_other : Tensor = prim::GradOf[name="aten::mul"](%8)
     block0():
-      %11 : Tensor = aten::mul(%7, %2)
-      %12 : int[] = aten::size(%3)
-      %grad_self.1 : Tensor = aten::_grad_sum_to_size(%11, %12)
-      %14 : Tensor = aten::mul(%7, %3)
-      %15 : int[] = aten::size(%2)
-      %grad_other.1 : Tensor = aten::_grad_sum_to_size(%14, %15)
+      %12 : Tensor = aten::mul(%8, %2)
+      %grad_self.1 : Tensor = aten::_grad_sum_to_size(%12, %self_size.1)
+      %14 : Tensor = aten::mul(%8, %4)
+      %grad_other.1 : Tensor = aten::_grad_sum_to_size(%14, %3)
       -> (%grad_self.1, %grad_other.1)
-  %17 : Tensor = prim::AutogradAdd(%1, %grad_self)
-  %18 : Tensor = prim::GradOf[name="aten::add"](%17)
+  %16 : Tensor = prim::AutogradAdd(%1, %grad_self)
+  %17 : Tensor = prim::GradOf[name="aten::add"](%16)
     block0():
-      %19 : Tensor = aten::mul(%17, %6)
-      %20 : Tensor = aten::_grad_sum_to_size(%19, %5)
-      -> (%20)
-  %21 : Tensor = prim::AutogradAdd(%grad_other, %18)
-  return (%21)
+      %18 : Tensor = aten::mul(%16, %7)
+      %19 : Tensor = aten::_grad_sum_to_size(%18, %3)
+      -> (%19)
+  %20 : Tensor = prim::AutogradAdd(%grad_other, %17)
+  return (%20)
 
index c2b2238..926501f 100644 (file)
@@ -4998,6 +4998,39 @@ a")
                     self.assertIs(torch.int32, s(b, "t.to(dtype=torch.int32)").dtype)
                     self.assertEqual(b.device, s(b, "t.to(dtype=torch.int32)").device)
 
+        # Test AD: aten::to(Tensor self, int dtype, bool non_blocking, bool copy) -> Tensor
+        t = torch.tensor(5).float().requires_grad_()
+        out_ref = t.to(torch.float32)
+        out = s(t, "t.to(torch.float32)")
+        self.assertEqual(out_ref, out)
+
+        grad_ref = torch.autograd.grad(out_ref.sum(), t)
+        grad = torch.autograd.grad(out.sum(), t)
+        self.assertEqual(grad_ref, grad)
+
+        # Test AD: aten::to(Tensor self, Device? device, int? dtype, bool non_blocking, bool copy) -> Tensor
+        out_ref = t.to('cpu')
+        out = s(t, "t.to('cpu')")
+        self.assertEqual(out_ref, out)
+
+        grad_ref = torch.autograd.grad(out_ref.sum(), t)
+        grad = torch.autograd.grad(out.sum(), t)
+        self.assertEqual(grad_ref, grad)
+
+        # Test AD: aten::to(Tensor self, Tensor other, bool non_blocking, bool copy) -> Tensor
+        @torch.jit.script
+        def func2(t, t_ref):
+            return t.to(t_ref)
+
+        func2.debug_disable_autodiff_subgraph_inlining()
+
+        t_ref = torch.tensor(4).double()
+        out_ref = t.to(t_ref)
+        out = func2(t, t_ref)
+        grad_ref = torch.autograd.grad(out_ref.sum(), t)
+        grad = torch.autograd.grad(out.sum(), t)
+        self.assertEqual(grad_ref, grad)
+
     @unittest.skipIf(not RUN_CUDA, "No CUDA")
     def test_tensor_number_math_cuda(self):
         self._test_tensor_number_math(device='cuda')
@@ -10637,6 +10670,7 @@ EXCLUDE_SCRIPT_MODULES = {
 DISABLE_AUTODIFF_SUBGRAPH_INLINING = {
     'test_nn_avg_pool2d',
     'test_nn_adaptive_avg_pool2d',
+    'test_nn_embedding',
     'test_nn_log_softmax',
     'test_nn_threshold',
     'test_nn_nll_loss',
index 822d8df..15c33c9 160000 (submodule)
@@ -1 +1 @@
-Subproject commit 822d8df0a2a32233c6022f50a158817a0f19bdc7
+Subproject commit 15c33c945851907411619f599900c3852108e7e3
index 641b55d..be8ac7b 100644 (file)
   self: zeros_like(grad)
 
 - name: logsumexp(Tensor self, IntArrayRef dim, bool keepdim)
-  self: logsumexp_backward(grad, self, result, dim, keepdim)
+  self: at::logsumexp_backward(grad, self, result, dim, keepdim)
 
 - name: lt_(Tensor self, Scalar other)
   self: zeros_like(self)
   self: grad.expand(self.sizes()).to(self.type().scalarType()) / self.numel()
 
 - name: mean(Tensor self, IntArrayRef dim, bool keepdim)
-  self: sum_backward(grad, self.sizes(), dim, keepdim) / _safe_size(self.sizes(), dim)
+  self: sum_backward(grad, self.sizes(), dim, keepdim) / at::_safe_size(self.sizes(), dim)
 
 - name: mean(Tensor self, IntArrayRef dim, ScalarType dtype)
-  self: sum_backward(grad, self.sizes(), dim, false).to(self.type().scalarType()) / _safe_size(self.sizes(), dim)
+  self: sum_backward(grad, self.sizes(), dim, false).to(self.type().scalarType()) / at::_safe_size(self.sizes(), dim)
 
 - name: mean(Tensor self, IntArrayRef dim, bool keepdim, ScalarType dtype)
-  self: sum_backward(grad, self.sizes(), dim, keepdim).to(self.type().scalarType()) / _safe_size(self.sizes(), dim)
+  self: sum_backward(grad, self.sizes(), dim, keepdim).to(self.type().scalarType()) / at::_safe_size(self.sizes(), dim)
 
 - name: median(Tensor self)
   self: select_equals_backward(grad, self, result)
index d44ec5a..63b4416 100644 (file)
@@ -76,18 +76,6 @@ Tensor maybe_multiply(const Tensor & t, const Scalar & s) {
   }
 }
 
-int64_t _safe_size(IntArrayRef sizes, IntArrayRef dim) {
-  int64_t size = 1;
-  if (sizes.size() == 0) {
-    return 1;
-  }
-  for (auto d : dim) {
-    d = at::maybe_wrap_dim(d, sizes.size());
-    size *= sizes[d];
-  }
-  return size;
-}
-
 Tensor norm_backward(const Tensor & grad, const Tensor & self, const optional<Scalar> & p_, const Tensor & norm) {
   double p = p_.value_or(2.0).toDouble();
   Tensor self_scaled;
@@ -160,40 +148,6 @@ Tensor mvlgamma_backward(Tensor grad, const Tensor & self, int64_t p) {
   return grad * args.digamma_().sum(-1);
 }
 
-Tensor permute_backwards(const Tensor & grad, IntArrayRef fwd_dims) {
-  // invert the permutation
-  auto ndims = fwd_dims.size();
-  std::vector<int64_t> dims(ndims);
-  for (size_t i = 0; i < ndims; i++) {
-    dims[at::maybe_wrap_dim(fwd_dims[i], ndims)] = i;
-  }
-  return grad.permute(dims);
-}
-
-Tensor unsqueeze_multiple(const Tensor & t, IntArrayRef dim, size_t n_dims) {
-    auto dims_to_unsqueeze = at::dim_list_to_bitset(dim, n_dims);
-    Tensor res = t;
-    for (size_t i = 0; i < n_dims; i++){
-      if (dims_to_unsqueeze[i]) {
-        res = res.unsqueeze(i);
-      }
-    }
-    return res;
-}
-
-Tensor sum_backward(const Tensor & grad, IntArrayRef sizes, IntArrayRef dims, bool keepdim) {
-  if (!keepdim && sizes.size() > 0) {
-    if (dims.size()==1) {
-      return grad.unsqueeze(dims[0]).expand(sizes);
-    } else {
-      Tensor res = unsqueeze_multiple(grad, dims, sizes.size());
-      return res.expand(sizes);
-    }
-  } else {
-    return grad.expand(sizes);
-  }
-}
-
 std::vector<int64_t> reverse_list(const IntArrayRef list) {
   auto result = std::vector<int64_t>();
   result.reserve(list.size());
@@ -429,13 +383,13 @@ Tensor cumsum_backward(const Tensor &x, int64_t dim, ScalarType input_dtype) {
   return cumsum_backward(x.to(input_dtype), dim);
 }
 
-Tensor logsumexp_backward(Tensor grad, const Tensor & self, Tensor result, IntArrayRef dim, bool keepdim) {
-  if (!keepdim && self.dim() != 0) {
-    grad = unsqueeze_multiple(grad, dim, self.sizes().size());
-    result = unsqueeze_multiple(result, dim, self.sizes().size());
-  }
-  return grad * (self - result).exp();
-}
+//Tensor logsumexp_backward(Tensor grad, const Tensor & self, Tensor result, IntArrayRef dim, bool keepdim) {
+//  if (!keepdim && self.dim() != 0) {
+//    grad = unsqueeze_multiple(grad, dim, self.sizes().size());
+//    result = unsqueeze_multiple(result, dim, self.sizes().size());
+//  }
+//  return grad * (self - result).exp();
+//}
 
 Tensor unbind_backward(const variable_list& grads, int64_t dim) {
   IntArrayRef sizes;
@@ -454,28 +408,6 @@ Tensor unbind_backward(const variable_list& grads, int64_t dim) {
   return at::stack(grads_tensors, dim);
 }
 
-Tensor unsqueeze_to(const Tensor & self, IntArrayRef sizes) {
-  auto result = self;
-
-  int64_t nDims = sizes.size();
-  for (int64_t dim = 0; dim < nDims; dim++) {
-    if (sizes[dim] == 1) {
-      result = result.unsqueeze(dim);
-    }
-  }
-  return result;
-}
-
-Tensor unsqueeze_to(const Tensor & self, int64_t dim, IntArrayRef sizes) {
-  dim = at::maybe_wrap_dim(dim, sizes.size());
-  // in NumPy it's not an error to unsqueeze a scalar, but we still need to avoided
-  // unsqueezing in the backward.
-  if (sizes.size() > 0 && sizes[dim] == 1) {
-    return self.unsqueeze(dim);
-  }
-  return self;
-}
-
 std::vector<Tensor> cat_tensors_backward(const Tensor & grad, const std::vector<std::vector<int64_t>> &sizes, int64_t dim) {
   dim = at::legacy_cat_wrap_dim(dim, sizes);
   std::vector<Tensor> grad_inputs(sizes.size());
@@ -606,26 +538,6 @@ Tensor select_equals_backward(Tensor grad, const Tensor & input, const Tensor &
   return grad_input;
 }
 
-Tensor index_select_backward(Tensor grad, int64_t dim, Tensor indices, IntArrayRef sizes, bool keepdim) {
-  if (!keepdim && sizes.size() > 0) {
-    grad = grad.unsqueeze(dim);
-    indices = indices.unsqueeze(dim);
-  }
-  return at::zeros(sizes, grad.options()).scatter_(dim, indices, grad);
-}
-
-Tensor slice_backward(Tensor grad, IntArrayRef input_sizes, int64_t dim, int64_t start, int64_t end, int64_t step) {
-  auto grad_input = at::zeros(input_sizes, grad.options());
-  grad_input.slice(dim, start, end, step).copy_(grad);
-  return grad_input;
-}
-
-Tensor select_backward(Tensor grad, IntArrayRef input_sizes, int64_t dim, int64_t index) {
-  auto grad_input = at::zeros(input_sizes, grad.options());
-  grad_input.select(dim, index).copy_(grad);
-  return grad_input;
-}
-
 Tensor trace_backward(const Tensor & grad, IntArrayRef sizes) {
   if (sizes.size() != 2) {
     throw std::runtime_error("expected matrix input");
@@ -651,19 +563,6 @@ Tensor unfold_backward(const Tensor & grad, IntArrayRef input_sizes, int64_t dim
   return grad_input.view(input_sizes);
 }
 
-Tensor var_backward(const Tensor & grad, const Tensor & self, bool unbiased) {
-  return (2.0 / (self.numel() - unbiased)) * grad * (self - self.mean());
-}
-
-Tensor var_backward(Tensor grad, const Tensor & self, IntArrayRef dim, bool unbiased, bool keepdim) {
-  if (self.dim() == 0) {
-    return var_backward(grad, self, unbiased);
-  }
-  if (!keepdim && self.dim() > 1) {
-    grad = unsqueeze_multiple(grad, dim, self.sizes().size());
-  }
-  return (2.0 / (_safe_size(self.sizes(), dim) - unbiased)) * grad * (self - self.mean(dim, true));
-}
 
 Tensor masked_scatter_backward(const Tensor & grad, const Tensor & mask, IntArrayRef sizes) {
   int64_t numel = 1;
index c43669d..b63f069 100644 (file)
@@ -27,6 +27,19 @@ void wrapDim(int64_t& dim, const std::vector<int64_t>& sizes) {
   }
 }
 
+// kthvalue returns (kthvalue, index of kthvalue), currently autodiff only
+// supports at most one output that requires grad. Thus we need to remove
+// the grad for index that doesn't require grad.
+bool needTrimGrad(Node* n) {
+  static OperatorSet need_trim_grad_ops = {
+      "aten::kthvalue(Tensor self, int k, int dim, bool keepdim) -> (Tensor, Tensor)",
+  };
+  if (need_trim_grad_ops.find(n)) {
+    return true;
+  }
+  return false;
+}
+
 bool isDifferentiable(Node* n) {
   // TODO: scalar-tensor ops should be canonicalized
   static OperatorSet differentiable_ops = {
@@ -194,15 +207,20 @@ static c10::optional<std::vector<Value*>> build_script_grad(
     auto fw_graph = compiled_graphs->forward;
     new_outputs = inlineCallTo(
         *graph, *fw_graph, node->inputs(), /*unpack_outputs=*/true);
-    for (size_t i = 0; i < node->outputs().size(); ++i) {
-      new_outputs.at(i)->setType(node->outputs()[i]->type());
-      new_outputs.at(i)->replaceAllUsesWith(node->outputs()[i]);
+    auto outputs = node->outputs();
+    AT_ASSERT(new_outputs.size() == outputs.size() + 1);
+    for (size_t i = 0; i < outputs.size(); ++i) {
+      new_outputs.at(i)->setType(outputs[i]->type());
+      outputs[i]->replaceAllUsesWith(new_outputs.at(i));
     }
   }
 
   // Use backward graph to construct reverse_block
   auto bw_graph = compiled_graphs->backward;
   auto grad_vec = grads.vec();
+  if (needTrimGrad(node)) {
+    grad_vec.erase(grad_vec.begin()+1, grad_vec.end());
+  }
   auto it = grad_vec.begin();
   grad_vec.insert(it, new_outputs.back());
   ArrayRef<Value*> grad(grad_vec);
@@ -578,35 +596,6 @@ class GradientHelper {
       return {sizes.at(dim) > 1 ? grads.at(0) : grads.at(0).unsqueeze(dim),
               nullptr};
 
-    } else if (node->matches(
-                   "aten::cat(Tensor[] tensors, int dim) -> Tensor",
-                   /*const_inputs=*/attr::dim)) {
-      int dim = *node->get<int64_t>(attr::dim);
-      auto tensor_inputs = inputs;
-      tensor_inputs.pop_back();
-      const auto& first_sizes = tensor_inputs.at(0).sizes();
-      const auto has_first_sizes = [&first_sizes](SymbolicVariable var) {
-        return var.sizes() == first_sizes;
-      };
-
-      // NB: this is a specialization for the common case where all inputs are
-      // of equal sizes. We can use a single split operation to handle that.
-      if (std::all_of(
-              tensor_inputs.begin(), tensor_inputs.end(), has_first_sizes)) {
-        auto tensor_grads = grads.at(0).chunk(tensor_inputs.size(), dim);
-        tensor_grads.emplace_back(nullptr); // for attr::dim
-        return tensor_grads;
-      } else {
-        size_t offset = 0;
-        auto grad = grads.at(0);
-        std::vector<SymbolicVariable> tensor_grads;
-        for (auto input : tensor_inputs) {
-          tensor_grads.push_back(grad.narrow(dim, offset, input.sizes()[dim]));
-          offset += input.sizes()[dim];
-        }
-        tensor_grads.emplace_back(nullptr); // for attr::dim
-        return tensor_grads;
-      }
     } else if (comparison_ops.find(node)) {
       return {nullptr, nullptr};
 
@@ -775,6 +764,11 @@ static std::vector<Value*> linearGradientForNode(
     Node* node,
     ArrayRef<Value*> grad_values) {
   auto& graph = *node->owningGraph();
+
+  // FIXME: In case forward has multi outputs, we only support one requires grad
+  if (needTrimGrad(node)) {
+    grad_values = grad_values.at(0);
+  }
   auto linear = graph.insertNode(graph.create(prim::GradOf, {grad_values}, 0));
   // to make reading gradient graphs easier, remember the name of the forward op
   linear->s_(attr::name, node->kind().toDisplayString());
index 2beafd4..2332dd1 100644 (file)
@@ -108,7 +108,9 @@ struct DifferentiableGraphBackward : public autograd::Function {
     variable_list outputs;
     outputs.reserve(num_outputs());
     for (size_t i = 0; i < num_outputs(); ++i) {
-      if (should_compute_output(i)) {
+      // Input grad can also be None even if it requires grad
+      // Example: `other` in expand_as(self, other)
+      if (should_compute_output(i) && !stack[i].isNone()) {
         auto output = std::move(stack[i]).toTensor();
         const auto& edge = next_edge(i);
         if (output.defined()) {
index 4f43440..5ebb169 100644 (file)
@@ -1272,6 +1272,7 @@ bool trackSingleGradSumToSizeToOutputs(
       "aten::div(Tensor self, Scalar other) -> Tensor",
       "aten::neg(Tensor self) -> Tensor",
       "aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor",
+      "aten::where(Tensor condition, Tensor self, Tensor other) -> Tensor",
       // add this used to be prim::AutogradAdd
   }};
 
index 327626d..bfbf4ae 100644 (file)
@@ -6,13 +6,303 @@ namespace {
 std::mutex lock;
 const std::vector<std::string> functions = {
     R"(
+
+        def _dim_arange(like,
+                        dim: int):
+            def backward(grad_output):
+                return None, None
+
+            return torch._dim_arange(like, dim), backward
+
+        def contiguous(self):
+            def backward(grad_output):
+                return None
+
+            return self.contiguous(), backward
+
+        def erf(self):
+            def backward(grad_output):
+                # Precomputed constant C = 2.0 / math.sqrt(math.pi)
+                C = 1.1283791670955126
+                grad_self =  C * torch.exp(- self.pow(2)) * grad_output
+                return grad_self
+
+            return torch.erf(self), backward
+
+        def expand(self,
+                   size: List[int],
+                   implicit: bool=False):
+            self_size = self.size()
+            def backward(grad_output):
+                grad_self = torch._grad_sum_to_size(grad_output, self_size)
+                return grad_self, None, None
+
+            return torch.expand(self, size, implicit=implicit), backward
+
+        def expand_as(self, other):
+            self_size = self.size()
+            def backward(grad_output):
+                grad_self = grad_output._grad_sum_to_size(self_size)
+                return grad_self, None
+
+            return torch.expand_as(self, other), backward
+
+        def full_like(self,
+                      fill_value: float):
+            def backward(grad_output):
+                return None, None
+
+            return torch.full_like(self, fill_value), backward
+
+        def kthvalue(self,
+                     k: int,
+                     dim: int,
+                     keepdim: bool):
+            result0, result1 = torch.kthvalue(self, k, dim, keepdim)
+            self_size = self.size()
+            def backward(grad_output):
+                grad_self = torch.index_select_backward(grad_output, dim, result1, self_size, keepdim)
+                return grad_self, None, None, None
+
+            return result0, result1, backward
+
+        def logsumexp(self,
+                      dim: List[int],
+                      keepdim: bool):
+            result = torch.logsumexp(self, dim, keepdim)
+            self_dim = self.dim()
+            def backward(grad_output):
+                grad_self = torch.logsumexp_backward(grad_output, self, result, dim, keepdim)
+                return grad_self, None, None
+
+            return result, backward
+
+        def mean_0(self):
+            self_size = self.size()
+            self_numel = self.numel()
+            def backward(grad_output):
+                grad_self = grad_output.expand(self_size) / self_numel
+                return grad_self
+
+            return torch.mean(self), backward
+
+        def mean_1(self,
+                   dim: List[int],
+                   keepdim: bool):
+            self_size = self.size()
+            def backward(grad_output):
+                grad_self = torch.sum_backward(grad_output, self_size, dim, keepdim) / torch._safe_size(self_size, dim)
+                return grad_self, None, None
+
+            return torch.mean(self, dim, keepdim), backward
+
         def mul(self, other):
+            self_size = self.size()
+            other_size = other.size()
             def backward(grad_output):
-                grad_self = (grad_output * other)._grad_sum_to_size(self.size())
-                grad_other = (grad_output * self)._grad_sum_to_size(other.size())
+                grad_self = (grad_output * other)._grad_sum_to_size(self_size)
+                grad_other = (grad_output * self)._grad_sum_to_size(other_size)
                 return grad_self, grad_other
+
             return self * other, backward
 
+        def nonzero(self):
+            def backward(grad_output):
+                return None
+
+            return torch.nonzero(self), backward
+
+        def ones_like(self):
+            def backward(grad_output):
+                return None
+
+            return torch.ones_like(self), backward
+
+        def permute(self,
+                    dims: List[int]):
+            def backward(grad_output):
+                grad_self = torch.permute_backwards(grad_output, dims)
+                return grad_self, None
+
+            return torch.permute(self, dims), backward
+
+        def pow_0(self,
+                  exponent: float):
+            def backward(grad_output):
+                grad_self = torch.where(torch.tensor(exponent == 0.0), torch.zeros_like(self), grad_output * exponent * torch.pow(self, exponent - 1))
+                return grad_self, None
+
+            return torch.pow(self, exponent), backward
+
+        def pow_1(self, exponent):
+            self_size = self.size()
+            exponent_size = exponent.size()
+            def backward(grad_output):
+                grad_self = torch.where(exponent == 0.0, torch.zeros_like(self), grad_output * exponent * torch.pow(self, exponent - 1))._grad_sum_to_size(self_size)
+                grad_exponent = (grad_output * torch.pow(self, exponent) * torch.log(self))._grad_sum_to_size(exponent_size)
+                return grad_self, grad_exponent
+
+            return torch.pow(self, exponent), backward
+
+        def pow_2(self: float,
+                  exponent):
+            def backward(grad_output):
+                grad_exponent = grad_output * torch.pow(self, exponent) * torch.log(torch.tensor(self))
+                return None, grad_exponent
+
+            return torch.pow(self, exponent), backward
+
+        def rsub_0(self, other,
+                   alpha: float = 1.0):
+            self_size = self.size()
+            other_size = other.size()
+            def backward(grad_output):
+                grad_self = (- grad_output * alpha)._grad_sum_to_size(self_size)
+                grad_other = (grad_output)._grad_sum_to_size(other_size)
+                return grad_self, grad_other, None
+
+            return torch.rsub(self, other, alpha), backward
+
+        def rsub_1(self,
+                   other: float,
+                   alpha: float = 1.0):
+            self_size = self.size()
+            def backward(grad_output):
+                grad_self = (- grad_output * alpha)._grad_sum_to_size(self_size)
+                return grad_self, None, None
+
+            return torch.rsub(self, other, alpha), backward
+
+        def select(self,
+                   dim: int,
+                   index: int):
+            self_size = self.size()
+            def backward(grad_output):
+                grad_self = torch.select_backward(grad_output, self_size, dim, index)
+                return grad_self, None, None
+
+            return torch.select(self, dim, index), backward
+
+        def sqrt(self):
+            result = torch.sqrt(self)
+            def backward(grad_output):
+                grad_self = grad_output / (2 * result)
+                return grad_self
+
+            return result, backward
+
+        def squeeze_0(self):
+            self_size = self.size()
+            def backward(grad_output):
+                grad_self = torch.unsqueeze_to(grad_output, self_size)
+                return grad_self
+
+            return torch.squeeze(self), backward
+
+        def squeeze_1(self,
+                      dim: int):
+            self_size = self.size()
+            def backward(grad_output):
+                grad_self = torch.unsqueeze_to(grad_output, dim, self_size)
+                return grad_self, None
+
+            return torch.squeeze(self, dim), backward
+
+        def t(self):
+            def backward(grad_output):
+                grad_self = torch.t(grad_output)
+                return grad_self
+
+            return torch.t(self), backward
+
+        def to_0(self,
+                 device: Optional[Device],
+                 dtype: Optional[int],
+                 non_blocking: bool=False,
+                 copy: bool=False):
+            self_device = self.device
+            self_dtype = self.dtype
+            if device is not None:
+                result = self.to(device, dtype=dtype, non_blocking=non_blocking, copy=copy)
+            else:
+                result = self.to(dtype, non_blocking=non_blocking, copy=copy)
+            def backward(grad_output):
+                grad_self = grad_output.to(self_device, dtype=self_dtype, non_blocking=non_blocking, copy=copy)
+                return grad_self, None, None, None, None
+
+            return result, backward
+
+
+        def to_1(self,
+                 dtype: int,
+                 non_blocking: bool=False,
+                 copy: bool=False):
+            self_dtype = self.dtype
+            def backward(grad_output):
+                grad_self = grad_output.to(self_dtype, non_blocking, copy)
+                return grad_self, None, None, None
+
+            return self.to(dtype=dtype, non_blocking=non_blocking, copy=copy), backward
+
+        def to_2(self,
+                 other,
+                 non_blocking: bool=False,
+                 copy: bool=False):
+            def backward(grad_output):
+                grad_self = grad_output.to(self, non_blocking, copy)
+                return grad_self, None, None, None
+
+            return self.to(other, non_blocking=non_blocking, copy=copy), backward
+
+        def topk(self,
+                 k,
+                 dim: int = -1,
+                 largest: bool = True,
+                 sorted: bool = True):
+            result0, result1 = torch.topk(self, k, dim, largest, sorted)
+            self_size = self.size()
+            def backward(grad_output):
+                grad_self = torch.index_select_backward(grad_output, dim, result1, self_size, True)
+                return grad_self, None, None, None, None
+
+            return result0, result1, backward
+
+        def transpose(self,
+                      dim0: int,
+                      dim1: int):
+            def backward(grad_output):
+                grad_self = torch.transpose(grad_output, dim0, dim1)
+                return grad_self, None, None
+
+            return torch.transpose(self, dim0, dim1), backward
+
+        def var_0(self,
+                  unbiased: bool=True):
+            def backward(grad_output):
+                grad_self = torch.var_backward(grad_output, self, unbiased)
+                return grad_self, None
+
+            return torch.var(self, unbiased), backward
+
+        def var_1(self,
+                  dim: List[int],
+                  unbiased: bool,
+                  keepdim: bool):
+            def backward(grad_output):
+                grad_self = torch.var_backward(grad_output, self, dim, unbiased, keepdim)
+                return grad_self, None, None, None
+
+            return torch.var(self, dim, unbiased, keepdim), backward
+
+        def view(self,
+                 size: List[int]):
+            self_size = self.size()
+            def backward(grad_output):
+                grad_self = grad_output.reshape(self_size)
+                return grad_self, None
+
+            return torch.view(self, size), backward
+
         def adaptive_avg_pool2d(self,
                                 output_size: List[int]):
             def backward(grad_output):
@@ -20,6 +310,19 @@ const std::vector<std::string> functions = {
                 return grad_self, None
 
             return torch.adaptive_avg_pool2d(self, output_size), backward
+
+        def embedding(weight,
+                      indices,
+                      padding_idx: int,
+                      scale_grad_by_freq: bool,
+                      sparse: bool):
+            weight_size_0 = weight.size()[0]
+            def backward(grad_output):
+                grad_weight = torch.embedding_backward(grad_output, indices, weight_size_0, padding_idx, scale_grad_by_freq, sparse)
+                return grad_weight, None, None, None, None
+
+            return torch.embedding(weight, indices, padding_idx, scale_grad_by_freq, sparse), backward
+
       )"};
 std::unordered_map<std::string, GradientPair> schema_to_graphs;
 
@@ -51,6 +354,24 @@ Argument originalReturnType(const TupleTypePtr& tup) {
   return Argument("", TupleType::create(std::move(types)));
 }
 
+// In torchscript AD formulas, we define {func_0, func_1, ...} as
+// overloaded functions of `func`.
+// Remove the suffix before adding the schema string to map
+// schema_to_graphs.
+std::string overloadedSchemaString(const FunctionSchema& schema) {
+  const auto& schema_name = schema.name();
+  auto pos = schema_name.find_last_of('_');
+  auto schema_name_suffix = schema_name.substr(pos + 1);
+  std::string schema_string = canonicalSchemaString(schema);
+  if (!schema_name_suffix.empty()
+      && schema_name_suffix.find_first_not_of("0123456789") == string::npos) {
+    schema_string.replace(schema_string.find(schema_name),
+                          schema_name.length(),
+                          schema_name.substr(0, pos));
+  }
+  return schema_string;
+}
+
 void loadModule(const std::shared_ptr<script::Module>& module) {
   for (const auto& method_ : module->get_methods()) {
     const auto& method = method_.value();
@@ -90,8 +411,12 @@ void loadModule(const std::shared_ptr<script::Module>& module) {
         Symbol::aten(loaded_schema.name()),
         loaded_schema.arguments(),
         {originalReturnType(new_tuple->type()->expect<TupleType>())});
-    std::string key = canonicalSchemaString(actual_schema);
-    schema_to_graphs[key] = std::move(pair);
+
+    // modify canonical string for function overloading
+    // prefer not to modify the schema name
+    auto schema_string = overloadedSchemaString(actual_schema);
+
+    schema_to_graphs[schema_string] = std::move(pair);
   }
 }
 
@@ -114,6 +439,14 @@ c10::optional<GradientPair> gradientInfoForSchema(
     return cache_it->second;
   } else {
     auto schema_str = canonicalSchemaString(schema);
+    // JIT doesn't support keyword only arguments.
+    // Remove ' *,' in schema before looking up
+    // TODO: #16921 properly support keyword only arguments in JIT.
+    auto n = schema_str.find("*, ");
+    if (n != std::string::npos) {
+      schema_str = schema_str.erase(n, 3);
+    }
+
     auto sym_script_it = schema_to_graphs.find(schema_str);
     if (sym_script_it != schema_to_graphs.end()) {
       cached_gradient_pairs.emplace_hint(
index 9594591..3cd11d2 100644 (file)
@@ -2920,10 +2920,10 @@ def normalize(input, p=2, dim=1, eps=1e-12, out=None):
                                 operation won't be differentiable.
     """
     if out is None:
-        denom = input.norm(p, dim, True).clamp(min=eps).expand_as(input)
+        denom = input.norm(p, dim, True).clamp_min(eps).expand_as(input)
         ret = input / denom
     else:
-        denom = input.norm(p, dim, True).clamp_(min=eps).expand_as(input)
+        denom = input.norm(p, dim, True).clamp_min(eps).expand_as(input)
         ret = torch.div(input, denom, out=out)
     return ret