return at::native::sum_out(result, self, dim, false, dtype);
}
+int64_t _safe_size(IntArrayRef sizes, IntArrayRef dim) {
+ int64_t size = 1;
+ if (sizes.size() == 0) {
+ return 1;
+ }
+ for (auto d : dim) {
+ d = at::maybe_wrap_dim(d, sizes.size());
+ size *= sizes[d];
+ }
+ return size;
+}
+
+Tensor unsqueeze_multiple(const Tensor & t, IntArrayRef dim, size_t n_dims) {
+ auto dims_to_unsqueeze = at::dim_list_to_bitset(dim, n_dims);
+ Tensor res = t;
+ for (size_t i = 0; i < n_dims; i++){
+ if (dims_to_unsqueeze[i]) {
+ res = res.unsqueeze(i);
+ }
+ }
+ return res;
+}
+
+Tensor sum_backward(const Tensor & grad, IntArrayRef sizes, IntArrayRef dims, bool keepdim) {
+ if (!keepdim && sizes.size() > 0) {
+ if (dims.size()==1) {
+ return grad.unsqueeze(dims[0]).expand(sizes);
+ } else {
+ Tensor res = unsqueeze_multiple(grad, dims, sizes.size());
+ return res.expand(sizes);
+ }
+ } else {
+ return grad.expand(sizes);
+ }
+}
+
Tensor& prod_out(Tensor& result, const Tensor& self, int64_t dim, bool keepdim, ScalarType dtype) {
return at::native::prod_out(
result, self, dim, keepdim, c10::optional<ScalarType>(dtype));
return at::native::logsumexp_out(result, self, dims, keepdim);
}
+Tensor logsumexp_backward(const Tensor& grad, const Tensor & self, const Tensor& res, IntArrayRef dim, bool keepdim) {
+ Tensor grad_input = grad;
+ Tensor fwd_res = res;
+ if (!keepdim && self.dim() != 0) {
+ grad_input = unsqueeze_multiple(grad, dim, self.sizes().size());
+ fwd_res = unsqueeze_multiple(res, dim, self.sizes().size());
+ }
+ return grad_input * (self - fwd_res).exp();
+}
+
static Tensor& norm_out(Tensor &result, const Tensor &self, optional<Scalar> opt_p,
IntArrayRef dim, bool keepdim, optional<ScalarType> opt_dtype) {
auto p = opt_p.value_or(2.0);
return std_var_out(result, self, dim, unbiased, keepdim, false);
}
+Tensor var_backward(const Tensor & grad, const Tensor & self, bool unbiased) {
+ return (2.0 / (self.numel() - unbiased)) * grad * (self - self.mean());
+}
+
+Tensor var_backward(const Tensor & grad, const Tensor & self, IntArrayRef dim, bool unbiased, bool keepdim) {
+ if (self.dim() == 0) {
+ return at::var_backward(grad, self, unbiased);
+ }
+ Tensor unsqueezed_grad = grad;
+ if (!keepdim && self.dim() > 1) {
+ unsqueezed_grad = unsqueeze_multiple(grad, dim, self.sizes().size());
+ }
+ return (2.0 / (at::_safe_size(self.sizes(), dim) - unbiased)) * unsqueezed_grad * (self - self.mean(dim, true));
+}
+
Tensor std(const Tensor& self, bool unbiased) {
AT_CHECK(self.type().backend() == Backend::CPU || self.type().backend() == Backend::CUDA,
"std only supports CPU AND CUDA backend, got: ", toString(self.type().backend()));
DEFINE_DISPATCH(max_kernel);
DEFINE_DISPATCH(min_kernel);
+Tensor index_select_backward(const Tensor& grad, int64_t dim, const Tensor& indices, IntArrayRef sizes, bool keepdim) {
+ Tensor res = at::zeros(sizes, grad.options());
+ if (!keepdim && sizes.size() > 0) {
+ return res.scatter_(dim, indices.unsqueeze(dim), grad.unsqueeze(dim));
+ }
+ return res.scatter_(dim, indices, grad);
+}
+
bool allclose(const Tensor& self, const Tensor& other, double rtol, double atol, bool equal_nan) {
return at::isclose(self, other, rtol, atol, equal_nan).all().item<uint8_t>();
}
return self.as_strided(newSizes, newStrides);
}
+Tensor permute_backwards(const Tensor & grad, IntArrayRef fwd_dims) {
+ // invert the permutation
+ auto ndims = fwd_dims.size();
+ std::vector<int64_t> dims(ndims);
+ for (size_t i = 0; i < ndims; i++) {
+ dims[at::maybe_wrap_dim(fwd_dims[i], ndims)] = i;
+ }
+ return grad.permute(dims);
+}
+
Tensor repeat(const Tensor& self, IntArrayRef repeats) {
AT_CHECK(repeats.size() >= (size_t)self.dim(),
"Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor");
return self.as_strided(sizes, strides, storage_offset);
}
+Tensor select_backward(const Tensor& grad, IntArrayRef input_sizes, int64_t dim, int64_t index) {
+ auto grad_input = at::zeros(input_sizes, grad.options());
+ grad_input.select(dim, index).copy_(grad);
+ return grad_input;
+}
+
Tensor slice(const Tensor& self, int64_t dim, int64_t start, int64_t end, int64_t step) {
int64_t ndim = self.dim();
if (ndim == 0) {
return self.as_strided(sizes, strides, storage_offset);
}
+Tensor slice_backward(const Tensor& grad, IntArrayRef input_sizes, int64_t dim, int64_t start, int64_t end, int64_t step) {
+ auto grad_input = at::zeros(input_sizes, grad.options());
+ grad_input.slice(dim, start, end, step).copy_(grad);
+ return grad_input;
+}
+
std::vector<Tensor> split(const Tensor& self, int64_t split_size, int64_t dim) {
AT_CHECK(self.dim() != 0, "split expects at least a 1-dimensional tensor");
AT_CHECK(split_size >= 0, "split expects split_size be non-negative, but got split_size=", split_size);
return self.as_strided(std::get<0>(g), std::get<1>(g));
}
+Tensor unsqueeze_to(const Tensor & self, IntArrayRef sizes) {
+ auto result = self;
+
+ int64_t nDims = sizes.size();
+ for (int64_t dim = 0; dim < nDims; dim++) {
+ if (sizes[dim] == 1) {
+ result = result.unsqueeze(dim);
+ }
+ }
+ return result;
+}
+
+Tensor unsqueeze_to(const Tensor & self, int64_t dim, IntArrayRef sizes) {
+ dim = at::maybe_wrap_dim(dim, sizes.size());
+ // in NumPy it's not an error to unsqueeze a scalar, but we still need to avoided
+ // unsqueezing in the backward.
+ if (sizes.size() > 0 && sizes[dim] == 1) {
+ return self.unsqueeze(dim);
+ }
+ return self;
+}
+
Tensor squeeze(const Tensor& self, int64_t dim) {
int64_t dims = self.dim();
dim = maybe_wrap_dim(dim, dims);
dispatch:
CUDA: _cudnn_init_dropout_state
+- func: index_select_backward(Tensor grad, int64_t dim, Tensor indices, int[] sizes, bool keepdim) -> Tensor
+
+- func: select_backward(Tensor grad, int[] input_sizes, int64_t dim, int64_t index) -> Tensor
+
- func: _fused_dropout(Tensor self, float p, Generator? generator=None) -> (Tensor, Tensor)
matches_jit_signature: True
variants: function
- func: logsumexp(Tensor self, int[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
matches_jit_signature: True
+- func: logsumexp_backward(Tensor grad, Tensor self, Tensor res, int[1] dim, bool keepdim) -> Tensor
+
- func: margin_ranking_loss(Tensor input1, Tensor input2, Tensor target, float margin=0.0, int reduction=Mean) -> Tensor
matches_jit_signature: True
- func: mean(Tensor self, int[1] dim, *, ScalarType dtype, Tensor(a!) out) -> Tensor(a!)
matches_jit_signature: True
+- func: sum_backward(Tensor grad, int[] sizes, int[] dims, bool keepdim) -> Tensor
+
- func: median(Tensor self, int dim, bool keepdim=False) -> (Tensor, Tensor)
matches_jit_signature: True
variants: function, method
matches_jit_signature: True
variants: method # This is method-only to match the previous tensor API. In the future we could make this a function too.
+- func: permute_backwards(Tensor grad, int[] fwd_dims) -> Tensor
+
- func: pixel_shuffle(Tensor self, int upscale_factor) -> Tensor
matches_jit_signature: True
variants: function, method
device_guard: False
+- func: _safe_size(int[] sizes, int[] dim) -> int64_t
+
- func: slice(Tensor(a) self, int dim=0, int start=0, int end=9223372036854775807, int step=1) -> Tensor(a)
matches_jit_signature: True
variants: function, method
device_guard: False
+- func: slice_backward(Tensor grad, int[] input_sizes, int64_t dim, int64_t start, int64_t end, int64_t step) -> Tensor
+
- func: slogdet(Tensor self) -> (Tensor, Tensor)
matches_jit_signature: True
variants: function, method
variants: function, method
device_guard: False
+- func: unsqueeze_to(Tensor self, int[] sizes) -> Tensor
+
+- func: unsqueeze_to(Tensor self, int64_t dim, int[] sizes) -> Tensor
+
- func: squeeze_(Tensor(a!) self) -> Tensor(a!)
matches_jit_signature: True
variants: method
- func: var(Tensor self, int[1] dim, bool unbiased=True, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
matches_jit_signature: True
+- func: var_backward(Tensor grad, Tensor self, bool unbiased) -> Tensor
+
+- func: var_backward(Tensor grad, Tensor self, int[] dim, bool unbiased, bool keepdim) -> Tensor
+
- func: view_as(Tensor self, Tensor other) -> Tensor
matches_jit_signature: True
variants: method
('expand', (1, S), (1, 1, S), 'new_dim_front_old_front_1'),
('expand', (), (dont_convert(()),), 'scalar_to_scalar'),
('expand', (), (1, 3, 2), 'scalar_to_dims'),
+ ('expand_as', (S, 1, 1), (torch.rand(S, S, S),)),
('exp', (S, S, S), NO_ARGS),
('exp', (), NO_ARGS, 'scalar'),
('expm1', (S, S, S), NO_ARGS),
'test_det_dim2_null',
'test_det_rank1',
'test_det_rank2',
+ # `other` expand_as(self, other) is not used in autograd.
+ 'test_expand_as',
'test_logdet',
'test_logdet_1x1',
'test_logdet_symmetric',
auto grad_spec = differentiate(graph);
std::vector<size_t> expected_captured_inputs = {0, 1};
- std::vector<size_t> expected_captured_outputs = {1, 2};
+ // With add/mul implemented using torchscript, we passes sizes of
+ // self & other instead passing the tensors themselve.
+ // The forward graph is now
+ //graph(%0 : Float(2, 3, 4)
+ // %1 : Float(2, 3, 4)) {
+ // %2 : Float(2, 3, 4) = aten::mul(%0, %1)
+ // %self_size.4 : int[] = aten::size(%0)
+ // %other_size.4 : int[] = aten::size(%1)
+ // %3 : Float(2, 3, 4) = aten::mul(%2, %0)
+ // %self_size.2 : int[] = aten::size(%2)
+ // %4 : int = prim::Constant[value=1]()
+ // %7 : int[] = aten::size(%3)
+ // %5 : Float(2, 3, 4) = aten::add(%3, %1, %4)
+ // return (%5, %2, %self_size.4, %other_size.4, %self_size.2, %7);
+ //}
+ // Thus all the sizes info added in forward outputs are saved
+ // in grad_spec.df_input_caputered_outputs.
+ std::vector<size_t> expected_captured_outputs = {1, 2, 3, 4, 5};
std::vector<size_t> expected_input_vjps = {0, 1};
std::vector<size_t> expected_output_vjps = {0, 1};
ASSERT_EQ(grad_spec.f_real_outputs, 1);
PropagateInputShapes(graph);
PropagateRequiresGrad(graph);
+ // With add/mul implemented using torchscript, we passes sizes of
+ // self & other instead passing the tensors themselve.
+ // The forward graph is now
+ // graph(%0 : Float(*)
+ // %1 : Float(*)) {
+ // %2 : Float(*) = aten::mul(%1, %1)
+ // %3 : int = prim::Constant[value=1]()
+ // %4 : Float(*) = aten::add(%2, %1, %3)
+ // %39 : int[] = aten::size(%0)
+ // %6 : Float(*) = aten::add(%4, %0, %3)
+ // %7 : Float(*) = aten::mul(%6, %0)
+ // %self_size.2 : int[] = aten::size(%6)
+ // %11 : int[] = aten::size(%7)
+ // %9 : Float(*) = aten::add(%7, %1, %3)
+ // return (%4, %9, %39, %6, %self_size.2, %11);
+ // }
+
auto grad_spec = differentiate(graph);
- std::vector<size_t> expected_input_vjps = {1, 2}; // for e and %4 = (d + a)
+ std::vector<size_t> expected_input_vjps = {1, 3}; // for e and %6 = (d + a)
std::vector<size_t> expected_output_vjps = {0}; // only a requires grad
ASSERT_EQ(grad_spec.f_real_outputs, 2);
ASSERT_EQ(grad_spec.df_input_captured_inputs, std::vector<size_t>({0}));
- ASSERT_EQ(grad_spec.df_input_captured_outputs, std::vector<size_t>({2, 3}));
+ ASSERT_EQ(grad_spec.df_input_captured_outputs, std::vector<size_t>({2, 3, 4, 5}));
ASSERT_EQ(grad_spec.df_input_vjps, expected_input_vjps);
ASSERT_EQ(grad_spec.df_output_vjps, expected_output_vjps);
out << "testDifferentiateWithRequiresGrad\n";
%forgetgate : Float(*, *),
%cellgate : Float(*, *),
%outgate : Float(*, *),
- %24 : int[],
- %25 : int[],
- %26 : Float(*, *)):
- %27 : int = prim::Constant[value=1]()
- %28 : int[] = aten::size(%outgate)
- %29 : int[] = aten::size(%26)
- %30 : int[] = aten::size(%ingate)
- %31 : int[] = aten::size(%cellgate)
- %32 : int[] = aten::size(%forgetgate)
- %33 : int[] = aten::size(%9)
- %34 : Tensor = prim::FusionGroup_0(%outgate, %0, %26, %28)
- %grad_other.5 : Tensor, %36 : Tensor, %37 : Tensor, %38 : Tensor = prim::FusionGroup_1(%forgetgate, %9, %ingate, %cellgate, %1, %26, %0, %outgate, %33, %32, %24, %31, %30, %25, %29)
+ %self_size.5 : int[],
+ %other_size.5 : int[],
+ %self_size.3 : int[],
+ %other_size.3 : int[],
+ %28 : int[],
+ %29 : int[],
+ %30 : Float(*, *),
+ %self_size.1 : int[],
+ %other_size.1 : int[]):
+ %33 : int = prim::Constant[value=1]()
+ %34 : Tensor = prim::FusionGroup_0(%outgate, %0, %30, %self_size.1)
+ %grad_other.5 : Tensor, %36 : Tensor, %37 : Tensor, %38 : Tensor = prim::FusionGroup_1(%forgetgate, %9, %ingate, %cellgate, %1, %30, %0, %outgate, %other_size.5, %self_size.5, %28, %other_size.3, %self_size.3, %29, %other_size.1)
%39 : Tensor[] = prim::ListConstruct(%38, %36, %37, %34)
- %40 : Tensor = aten::cat(%39, %27)
+ %40 : Tensor = aten::cat(%39, %33)
%41 : Tensor = aten::_grad_sum_to_size(%40, %19)
%42 : Tensor = aten::_grad_sum_to_size(%40, %17)
%43 : Tensor = aten::_grad_sum_to_size(%40, %14)
%46 : Float(*, *) = aten::mm(%44, %45)
%47 : Float(*, *) = aten::t(%10)
%48 : Float(*, *) = aten::mm(%47, %44)
- %49 : Float(*, *) = aten::t(%48)
+ %grad_self.7 : Float(*, *) = aten::t(%48)
%50 : Float(*, *) = aten::t(%12)
%51 : Float(*, *) = aten::mm(%43, %50)
%52 : Float(*, *) = aten::t(%11)
%53 : Float(*, *) = aten::mm(%52, %43)
- %54 : Float(*, *) = aten::t(%53)
- return (%grad_other.5, %41, %42, %46, %49, %51, %54)
+ %grad_self.9 : Float(*, *) = aten::t(%53)
+ return (%grad_other.5, %41, %42, %46, %grad_self.7, %51, %grad_self.9)
with prim::FusionGroup_0 = graph(%0 : Float(*, *),
%1 : Float(*, *),
%2 : Float(*, *),
%17 : Tensor, %18 : Tensor, %19 : Tensor, %20 : Tensor = prim::ListUnpack(%16)
%21 : int[] = prim::BroadcastSizes(%11, %12)
%22 : int[] = prim::BroadcastSizes(%21, %13)
- %hy : Float(*, *), %24 : Float(*, *), %cy : Float(*, *), %outgate.1 : Float(*, *), %cellgate.1 : Float(*, *), %forgetgate.1 : Float(*, *), %ingate.1 : Float(*, *) = prim::FusionGroup_0(%0, %20, %19, %18, %17)
- %30 : int[] = aten::size(%0)
- %31 : int[] = aten::size(%cellgate.1)
- %32 : int[] = aten::size(%forgetgate.1)
- %33 : int[] = aten::size(%ingate.1)
- %34 : int[] = prim::BroadcastSizes(%32, %30)
- %35 : int[] = prim::BroadcastSizes(%33, %31)
- return (%hy, %cy, %7, %9, %11, %12, %21, %13, %22, %14, %ingate.1, %forgetgate.1, %cellgate.1, %outgate.1, %34, %35, %24)
+ %other_size.6 : int[] = aten::size(%0)
+ %hy : Float(*, *), %25 : Float(*, *), %cy : Float(*, *), %outgate.1 : Float(*, *), %cellgate.1 : Float(*, *), %forgetgate.1 : Float(*, *), %ingate.1 : Float(*, *) = prim::FusionGroup_0(%0, %20, %19, %18, %17)
+ %31 : int[] = aten::size(%25)
+ %32 : int[] = aten::size(%outgate.1)
+ %33 : int[] = aten::size(%cellgate.1)
+ %34 : int[] = aten::size(%forgetgate.1)
+ %35 : int[] = aten::size(%ingate.1)
+ %36 : int[] = prim::BroadcastSizes(%34, %other_size.6)
+ %37 : int[] = prim::BroadcastSizes(%35, %33)
+ return (%hy, %cy, %7, %9, %11, %12, %21, %13, %22, %14, %ingate.1, %forgetgate.1, %cellgate.1, %outgate.1, %34, %other_size.6, %35, %33, %36, %37, %25, %32, %31)
with prim::FusionGroup_0 = graph(%0 : Float(*, *),
%1 : Tensor,
%2 : Tensor,
%Wx : Float(*, *),
%Uz : Float(*, *),
%18 : Float(*, *),
- %19 : int[],
- %20 : int[],
- %21 : int[],
- %22 : int[],
- %23 : int[],
+ %self_size.13 : int[],
+ %other_size.13 : int[],
+ %self_size.11 : int[],
+ %other_size.11 : int[],
+ %self_size.9 : int[],
%24 : int[],
+ %25 : int[],
+ %self_size.7 : int[],
+ %27 : int[],
+ %28 : int[],
+ %29 : int[],
+ %30 : int[],
%ingate : Float(*, *),
%forgetgate : Float(*, *),
%cellgate : Float(*, *),
%outgate : Float(*, *),
- %29 : int[],
- %30 : int[],
- %31 : Float(*, *)):
- %32 : int = prim::Constant[value=1]()
- %33 : int[] = aten::size(%outgate)
- %34 : int[] = aten::size(%31)
- %35 : int[] = aten::size(%ingate)
- %36 : int[] = aten::size(%cellgate)
- %37 : int[] = aten::size(%forgetgate)
- %38 : Tensor = prim::FusionGroup_0(%outgate, %0, %31, %33)
- %39 : Tensor, %40 : Tensor, %41 : Tensor = prim::FusionGroup_1(%10, %ingate, %cellgate, %1, %31, %0, %outgate, %forgetgate, %37, %29, %36, %35, %30, %34)
- %42 : Tensor[] = prim::ListConstruct(%41, %39, %40, %38)
- %43 : Tensor = aten::cat(%42, %32)
- %44 : Tensor = aten::_grad_sum_to_size(%43, %24)
- %45 : Tensor = aten::_grad_sum_to_size(%43, %22)
- %46 : int[] = aten::size(%11)
- %grad_self.7 : Tensor = prim::FusionGroup_2(%45, %Uz, %46)
- %48 : int[] = aten::size(%Uz)
- %49 : Tensor = aten::_grad_sum_to_size(%43, %19)
- %50 : Tensor = aten::_grad_sum_to_size(%43, %20)
- %51 : int[] = aten::size(%12)
- %grad_self.9 : Tensor = prim::FusionGroup_3(%50, %Wx, %51)
- %53 : int[] = aten::size(%Wx)
- %54 : int[] = aten::size(%18)
- %55 : Tensor = prim::FusionGroup_4(%49, %18, %45, %11, %48)
- %56 : int[] = aten::size(%13)
- %grad_self.13 : Tensor, %58 : Tensor = prim::FusionGroup_5(%Wx, %13, %49, %Uz, %50, %12, %56, %53, %54)
+ %self_size.5 : int[],
+ %self_size.3 : int[],
+ %other_size.3 : int[],
+ %38 : int[],
+ %39 : int[],
+ %40 : Float(*, *),
+ %self_size.1 : int[],
+ %other_size.1 : int[]):
+ %43 : int = prim::Constant[value=1]()
+ %44 : Tensor = prim::FusionGroup_0(%outgate, %0, %40, %self_size.1)
+ %45 : Tensor, %46 : Tensor, %47 : Tensor = prim::FusionGroup_1(%10, %ingate, %cellgate, %1, %40, %0, %outgate, %forgetgate, %self_size.5, %38, %other_size.3, %self_size.3, %39, %other_size.1)
+ %48 : Tensor[] = prim::ListConstruct(%47, %45, %46, %44)
+ %49 : Tensor = aten::cat(%48, %43)
+ %50 : Tensor = aten::_grad_sum_to_size(%49, %30)
+ %51 : Tensor = aten::_grad_sum_to_size(%49, %28)
+ %grad_self.7 : Tensor = prim::FusionGroup_2(%51, %Uz, %self_size.7)
+ %53 : Tensor = aten::_grad_sum_to_size(%49, %24)
+ %54 : Tensor = aten::_grad_sum_to_size(%49, %25)
+ %grad_self.9 : Tensor = prim::FusionGroup_3(%54, %Wx, %self_size.9)
+ %56 : Tensor = prim::FusionGroup_4(%53, %18, %51, %11, %other_size.11)
+ %grad_self.13 : Tensor, %58 : Tensor = prim::FusionGroup_5(%Wx, %13, %53, %Uz, %54, %12, %self_size.13, %other_size.13, %self_size.11)
%59 : Float(*, *) = aten::t(%14)
- %60 : Float(*, *) = aten::mm(%59, %55)
- %61 : Float(*, *) = aten::t(%60)
+ %60 : Float(*, *) = aten::mm(%59, %56)
+ %grad_self.15 : Float(*, *) = aten::t(%60)
%62 : Float(*, *) = aten::t(%15)
%63 : Float(*, *) = aten::mm(%62, %58)
- %64 : Float(*, *) = aten::t(%63)
- return (%44, %grad_self.7, %grad_self.9, %grad_self.13, %61, %64)
+ %grad_self.17 : Float(*, *) = aten::t(%63)
+ return (%50, %grad_self.7, %grad_self.9, %grad_self.13, %grad_self.15, %grad_self.17)
with prim::FusionGroup_0 = graph(%0 : Float(*, *),
%1 : Float(*, *),
%2 : Float(*, *),
%11 : Float(*, *) = aten::t(%6)
%Uz.1 : Float(*, *) = aten::mm(%5, %11)
%13 : Float(*, *) = aten::mul(%4, %Wx.1)
- %14 : int[] = aten::size(%1)
- %15 : Tensor[] = prim::ListConstruct(%1, %2, %Uz.1, %13, %3, %Wx.1)
- %16 : Tensor[] = aten::broadcast_tensors(%15)
- %17 : Tensor, %18 : Tensor, %19 : Tensor, %20 : Tensor, %21 : Tensor, %22 : Tensor = prim::ListUnpack(%16)
- %23 : int[] = aten::size(%3)
- %24 : int[] = aten::size(%Wx.1)
- %25 : int[] = prim::BroadcastSizes(%23, %24)
- %26 : int[] = aten::size(%13)
- %27 : int[] = aten::size(%Uz.1)
- %28 : int[] = prim::BroadcastSizes(%26, %27)
- %29 : int[] = aten::size(%2)
- %30 : int[] = prim::BroadcastSizes(%29, %27)
- %31 : int[] = prim::BroadcastSizes(%28, %25)
- %32 : int[] = prim::BroadcastSizes(%31, %30)
- %hy : Float(*, *), %34 : Float(*, *), %cy : Float(*, *), %outgate.1 : Float(*, *), %cellgate.1 : Float(*, *), %forgetgate.1 : Float(*, *), %ingate.1 : Float(*, *) = prim::FusionGroup_0(%0, %22, %21, %20, %19, %18, %17)
- %40 : int[] = aten::size(%0)
- %41 : int[] = aten::size(%cellgate.1)
- %42 : int[] = aten::size(%forgetgate.1)
- %43 : int[] = aten::size(%ingate.1)
- %44 : int[] = prim::BroadcastSizes(%42, %40)
- %45 : int[] = prim::BroadcastSizes(%43, %41)
- return (%hy, %cy, %Wx.1, %Uz.1, %13, %28, %25, %31, %30, %32, %14, %ingate.1, %forgetgate.1, %cellgate.1, %outgate.1, %44, %45, %34)
+ %self_size.14 : int[] = aten::size(%4)
+ %other_size.14 : int[] = aten::size(%Wx.1)
+ %self_size.12 : int[] = aten::size(%13)
+ %other_size.12 : int[] = aten::size(%Uz.1)
+ %self_size.10 : int[] = aten::size(%3)
+ %self_size.8 : int[] = aten::size(%2)
+ %20 : int[] = aten::size(%1)
+ %21 : Tensor[] = prim::ListConstruct(%1, %2, %Uz.1, %13, %3, %Wx.1)
+ %22 : Tensor[] = aten::broadcast_tensors(%21)
+ %23 : Tensor, %24 : Tensor, %25 : Tensor, %26 : Tensor, %27 : Tensor, %28 : Tensor = prim::ListUnpack(%22)
+ %29 : int[] = prim::BroadcastSizes(%self_size.10, %other_size.14)
+ %30 : int[] = prim::BroadcastSizes(%self_size.12, %other_size.12)
+ %31 : int[] = prim::BroadcastSizes(%self_size.8, %other_size.12)
+ %32 : int[] = prim::BroadcastSizes(%30, %29)
+ %33 : int[] = prim::BroadcastSizes(%32, %31)
+ %hy : Float(*, *), %35 : Float(*, *), %cy : Float(*, *), %outgate.1 : Float(*, *), %cellgate.1 : Float(*, *), %forgetgate.1 : Float(*, *), %ingate.1 : Float(*, *) = prim::FusionGroup_0(%0, %28, %27, %26, %25, %24, %23)
+ %41 : int[] = aten::size(%0)
+ %42 : int[] = aten::size(%35)
+ %43 : int[] = aten::size(%outgate.1)
+ %44 : int[] = aten::size(%cellgate.1)
+ %45 : int[] = aten::size(%forgetgate.1)
+ %46 : int[] = aten::size(%ingate.1)
+ %47 : int[] = prim::BroadcastSizes(%45, %41)
+ %48 : int[] = prim::BroadcastSizes(%46, %44)
+ return (%hy, %cy, %Wx.1, %Uz.1, %13, %self_size.14, %other_size.14, %self_size.12, %other_size.12, %self_size.10, %30, %29, %self_size.8, %32, %31, %33, %20, %ingate.1, %forgetgate.1, %cellgate.1, %outgate.1, %45, %46, %44, %47, %48, %35, %43, %42)
with prim::FusionGroup_0 = graph(%0 : Float(*, *),
%1 : Tensor,
%2 : Tensor,
graph(%0 : Float(2, 3, 4),
%1 : Float(2, 3, 4)):
%2 : Float(2, 3, 4) = aten::mul(%0, %1)
+ %self_size.4 : int[] = aten::size(%0)
+ %other_size.4 : int[] = aten::size(%1)
%3 : Float(2, 3, 4) = aten::mul(%2, %0)
+ %self_size.2 : int[] = aten::size(%2)
%4 : int = prim::Constant[value=1]()
%7 : int[] = aten::size(%3)
%5 : Float(2, 3, 4) = aten::add(%3, %1, %4)
- return (%5, %2, %7)
+ return (%5, %2, %self_size.4, %other_size.4, %self_size.2, %7)
graph(%0 : Float(2, 3, 4),
%1 : Float(2, 3, 4),
%2 : Float(2, 3, 4),
%3 : Float(2, 3, 4),
%4 : Float(2, 3, 4),
- %5 : int[]):
- %7 : int = prim::Constant[value=1]()
- %6 : int[] = aten::size(%3)
- %8 : Tensor, %9 : Tensor = prim::GradOf[name="aten::add"](%0)
+ %self_size.3 : int[],
+ %other_size.3 : int[],
+ %self_size.1 : int[],
+ %8 : int[]):
+ %9 : int = prim::Constant[value=1]()
+ %10 : Tensor, %11 : Tensor = prim::GradOf[name="aten::add"](%0)
block0():
- %10 : Tensor = aten::_grad_sum_to_size(%0, %5)
- %11 : Float(2, 3, 4) = aten::mul(%0, %7)
- %12 : Tensor = aten::_grad_sum_to_size(%11, %6)
- -> (%10, %12)
- %grad_self.2 : Tensor, %grad_other.2 : Tensor = prim::GradOf[name="aten::mul"](%8)
+ %12 : Tensor = aten::_grad_sum_to_size(%0, %8)
+ %13 : Float(2, 3, 4) = aten::mul(%0, %9)
+ %14 : Tensor = aten::_grad_sum_to_size(%13, %other_size.3)
+ -> (%12, %14)
+ %grad_self.2 : Tensor, %grad_other.2 : Tensor = prim::GradOf[name="aten::mul"](%10)
block0():
- %15 : Tensor = aten::mul(%8, %2)
- %16 : int[] = aten::size(%4)
- %grad_self.1 : Tensor = aten::_grad_sum_to_size(%15, %16)
- %18 : Tensor = aten::mul(%8, %4)
- %19 : int[] = aten::size(%2)
- %grad_other.1 : Tensor = aten::_grad_sum_to_size(%18, %19)
+ %17 : Tensor = aten::mul(%10, %2)
+ %grad_self.1 : Tensor = aten::_grad_sum_to_size(%17, %self_size.1)
+ %19 : Tensor = aten::mul(%10, %4)
+ %grad_other.1 : Tensor = aten::_grad_sum_to_size(%19, %self_size.3)
-> (%grad_self.1, %grad_other.1)
%21 : Tensor = prim::AutogradAdd(%1, %grad_self.2)
%grad_self : Tensor, %grad_other : Tensor = prim::GradOf[name="aten::mul"](%21)
block0():
%24 : Tensor = aten::mul(%21, %3)
- %25 : int[] = aten::size(%2)
- %grad_self.3 : Tensor = aten::_grad_sum_to_size(%24, %25)
- %27 : Tensor = aten::mul(%21, %2)
- %28 : int[] = aten::size(%3)
- %grad_other.3 : Tensor = aten::_grad_sum_to_size(%27, %28)
+ %grad_self.3 : Tensor = aten::_grad_sum_to_size(%24, %self_size.3)
+ %26 : Tensor = aten::mul(%21, %2)
+ %grad_other.3 : Tensor = aten::_grad_sum_to_size(%26, %other_size.3)
-> (%grad_self.3, %grad_other.3)
- %30 : Tensor = prim::AutogradAdd(%grad_other.2, %grad_self)
- %31 : Tensor = prim::AutogradAdd(%9, %grad_other)
- return (%30, %31)
+ %28 : Tensor = prim::AutogradAdd(%grad_other.2, %grad_self)
+ %29 : Tensor = prim::AutogradAdd(%11, %grad_other)
+ return (%28, %29)
testDifferentiateWithRequiresGrad
graph(%0 : Float(*),
%2 : Float(*) = aten::mul(%1, %1)
%3 : int = prim::Constant[value=1]()
%4 : Float(*) = aten::add(%2, %1, %3)
+ %39 : int[] = aten::size(%0)
%6 : Float(*) = aten::add(%4, %0, %3)
%7 : Float(*) = aten::mul(%6, %0)
+ %self_size.2 : int[] = aten::size(%6)
%11 : int[] = aten::size(%7)
%9 : Float(*) = aten::add(%7, %1, %3)
- return (%4, %9, %6, %11)
+ return (%4, %9, %39, %6, %self_size.2, %11)
graph(%0 : Float(*),
%1 : Float(*),
%2 : Float(*),
- %3 : Float(*),
- %4 : int[]):
- %6 : int = prim::Constant[value=1]()
- %5 : int[] = aten::size(%2)
- %7 : Tensor = prim::GradOf[name="aten::add"](%0)
+ %3 : int[],
+ %4 : Float(*),
+ %self_size.1 : int[],
+ %6 : int[]):
+ %7 : int = prim::Constant[value=1]()
+ %8 : Tensor = prim::GradOf[name="aten::add"](%0)
block0():
- %8 : Tensor = aten::_grad_sum_to_size(%0, %4)
- -> (%8)
- %grad_self : Tensor, %grad_other : Tensor = prim::GradOf[name="aten::mul"](%7)
+ %9 : Tensor = aten::_grad_sum_to_size(%0, %6)
+ -> (%9)
+ %grad_self : Tensor, %grad_other : Tensor = prim::GradOf[name="aten::mul"](%8)
block0():
- %11 : Tensor = aten::mul(%7, %2)
- %12 : int[] = aten::size(%3)
- %grad_self.1 : Tensor = aten::_grad_sum_to_size(%11, %12)
- %14 : Tensor = aten::mul(%7, %3)
- %15 : int[] = aten::size(%2)
- %grad_other.1 : Tensor = aten::_grad_sum_to_size(%14, %15)
+ %12 : Tensor = aten::mul(%8, %2)
+ %grad_self.1 : Tensor = aten::_grad_sum_to_size(%12, %self_size.1)
+ %14 : Tensor = aten::mul(%8, %4)
+ %grad_other.1 : Tensor = aten::_grad_sum_to_size(%14, %3)
-> (%grad_self.1, %grad_other.1)
- %17 : Tensor = prim::AutogradAdd(%1, %grad_self)
- %18 : Tensor = prim::GradOf[name="aten::add"](%17)
+ %16 : Tensor = prim::AutogradAdd(%1, %grad_self)
+ %17 : Tensor = prim::GradOf[name="aten::add"](%16)
block0():
- %19 : Tensor = aten::mul(%17, %6)
- %20 : Tensor = aten::_grad_sum_to_size(%19, %5)
- -> (%20)
- %21 : Tensor = prim::AutogradAdd(%grad_other, %18)
- return (%21)
+ %18 : Tensor = aten::mul(%16, %7)
+ %19 : Tensor = aten::_grad_sum_to_size(%18, %3)
+ -> (%19)
+ %20 : Tensor = prim::AutogradAdd(%grad_other, %17)
+ return (%20)
self.assertIs(torch.int32, s(b, "t.to(dtype=torch.int32)").dtype)
self.assertEqual(b.device, s(b, "t.to(dtype=torch.int32)").device)
+ # Test AD: aten::to(Tensor self, int dtype, bool non_blocking, bool copy) -> Tensor
+ t = torch.tensor(5).float().requires_grad_()
+ out_ref = t.to(torch.float32)
+ out = s(t, "t.to(torch.float32)")
+ self.assertEqual(out_ref, out)
+
+ grad_ref = torch.autograd.grad(out_ref.sum(), t)
+ grad = torch.autograd.grad(out.sum(), t)
+ self.assertEqual(grad_ref, grad)
+
+ # Test AD: aten::to(Tensor self, Device? device, int? dtype, bool non_blocking, bool copy) -> Tensor
+ out_ref = t.to('cpu')
+ out = s(t, "t.to('cpu')")
+ self.assertEqual(out_ref, out)
+
+ grad_ref = torch.autograd.grad(out_ref.sum(), t)
+ grad = torch.autograd.grad(out.sum(), t)
+ self.assertEqual(grad_ref, grad)
+
+ # Test AD: aten::to(Tensor self, Tensor other, bool non_blocking, bool copy) -> Tensor
+ @torch.jit.script
+ def func2(t, t_ref):
+ return t.to(t_ref)
+
+ func2.debug_disable_autodiff_subgraph_inlining()
+
+ t_ref = torch.tensor(4).double()
+ out_ref = t.to(t_ref)
+ out = func2(t, t_ref)
+ grad_ref = torch.autograd.grad(out_ref.sum(), t)
+ grad = torch.autograd.grad(out.sum(), t)
+ self.assertEqual(grad_ref, grad)
+
@unittest.skipIf(not RUN_CUDA, "No CUDA")
def test_tensor_number_math_cuda(self):
self._test_tensor_number_math(device='cuda')
DISABLE_AUTODIFF_SUBGRAPH_INLINING = {
'test_nn_avg_pool2d',
'test_nn_adaptive_avg_pool2d',
+ 'test_nn_embedding',
'test_nn_log_softmax',
'test_nn_threshold',
'test_nn_nll_loss',
-Subproject commit 822d8df0a2a32233c6022f50a158817a0f19bdc7
+Subproject commit 15c33c945851907411619f599900c3852108e7e3
self: zeros_like(grad)
- name: logsumexp(Tensor self, IntArrayRef dim, bool keepdim)
- self: logsumexp_backward(grad, self, result, dim, keepdim)
+ self: at::logsumexp_backward(grad, self, result, dim, keepdim)
- name: lt_(Tensor self, Scalar other)
self: zeros_like(self)
self: grad.expand(self.sizes()).to(self.type().scalarType()) / self.numel()
- name: mean(Tensor self, IntArrayRef dim, bool keepdim)
- self: sum_backward(grad, self.sizes(), dim, keepdim) / _safe_size(self.sizes(), dim)
+ self: sum_backward(grad, self.sizes(), dim, keepdim) / at::_safe_size(self.sizes(), dim)
- name: mean(Tensor self, IntArrayRef dim, ScalarType dtype)
- self: sum_backward(grad, self.sizes(), dim, false).to(self.type().scalarType()) / _safe_size(self.sizes(), dim)
+ self: sum_backward(grad, self.sizes(), dim, false).to(self.type().scalarType()) / at::_safe_size(self.sizes(), dim)
- name: mean(Tensor self, IntArrayRef dim, bool keepdim, ScalarType dtype)
- self: sum_backward(grad, self.sizes(), dim, keepdim).to(self.type().scalarType()) / _safe_size(self.sizes(), dim)
+ self: sum_backward(grad, self.sizes(), dim, keepdim).to(self.type().scalarType()) / at::_safe_size(self.sizes(), dim)
- name: median(Tensor self)
self: select_equals_backward(grad, self, result)
}
}
-int64_t _safe_size(IntArrayRef sizes, IntArrayRef dim) {
- int64_t size = 1;
- if (sizes.size() == 0) {
- return 1;
- }
- for (auto d : dim) {
- d = at::maybe_wrap_dim(d, sizes.size());
- size *= sizes[d];
- }
- return size;
-}
-
Tensor norm_backward(const Tensor & grad, const Tensor & self, const optional<Scalar> & p_, const Tensor & norm) {
double p = p_.value_or(2.0).toDouble();
Tensor self_scaled;
return grad * args.digamma_().sum(-1);
}
-Tensor permute_backwards(const Tensor & grad, IntArrayRef fwd_dims) {
- // invert the permutation
- auto ndims = fwd_dims.size();
- std::vector<int64_t> dims(ndims);
- for (size_t i = 0; i < ndims; i++) {
- dims[at::maybe_wrap_dim(fwd_dims[i], ndims)] = i;
- }
- return grad.permute(dims);
-}
-
-Tensor unsqueeze_multiple(const Tensor & t, IntArrayRef dim, size_t n_dims) {
- auto dims_to_unsqueeze = at::dim_list_to_bitset(dim, n_dims);
- Tensor res = t;
- for (size_t i = 0; i < n_dims; i++){
- if (dims_to_unsqueeze[i]) {
- res = res.unsqueeze(i);
- }
- }
- return res;
-}
-
-Tensor sum_backward(const Tensor & grad, IntArrayRef sizes, IntArrayRef dims, bool keepdim) {
- if (!keepdim && sizes.size() > 0) {
- if (dims.size()==1) {
- return grad.unsqueeze(dims[0]).expand(sizes);
- } else {
- Tensor res = unsqueeze_multiple(grad, dims, sizes.size());
- return res.expand(sizes);
- }
- } else {
- return grad.expand(sizes);
- }
-}
-
std::vector<int64_t> reverse_list(const IntArrayRef list) {
auto result = std::vector<int64_t>();
result.reserve(list.size());
return cumsum_backward(x.to(input_dtype), dim);
}
-Tensor logsumexp_backward(Tensor grad, const Tensor & self, Tensor result, IntArrayRef dim, bool keepdim) {
- if (!keepdim && self.dim() != 0) {
- grad = unsqueeze_multiple(grad, dim, self.sizes().size());
- result = unsqueeze_multiple(result, dim, self.sizes().size());
- }
- return grad * (self - result).exp();
-}
+//Tensor logsumexp_backward(Tensor grad, const Tensor & self, Tensor result, IntArrayRef dim, bool keepdim) {
+// if (!keepdim && self.dim() != 0) {
+// grad = unsqueeze_multiple(grad, dim, self.sizes().size());
+// result = unsqueeze_multiple(result, dim, self.sizes().size());
+// }
+// return grad * (self - result).exp();
+//}
Tensor unbind_backward(const variable_list& grads, int64_t dim) {
IntArrayRef sizes;
return at::stack(grads_tensors, dim);
}
-Tensor unsqueeze_to(const Tensor & self, IntArrayRef sizes) {
- auto result = self;
-
- int64_t nDims = sizes.size();
- for (int64_t dim = 0; dim < nDims; dim++) {
- if (sizes[dim] == 1) {
- result = result.unsqueeze(dim);
- }
- }
- return result;
-}
-
-Tensor unsqueeze_to(const Tensor & self, int64_t dim, IntArrayRef sizes) {
- dim = at::maybe_wrap_dim(dim, sizes.size());
- // in NumPy it's not an error to unsqueeze a scalar, but we still need to avoided
- // unsqueezing in the backward.
- if (sizes.size() > 0 && sizes[dim] == 1) {
- return self.unsqueeze(dim);
- }
- return self;
-}
-
std::vector<Tensor> cat_tensors_backward(const Tensor & grad, const std::vector<std::vector<int64_t>> &sizes, int64_t dim) {
dim = at::legacy_cat_wrap_dim(dim, sizes);
std::vector<Tensor> grad_inputs(sizes.size());
return grad_input;
}
-Tensor index_select_backward(Tensor grad, int64_t dim, Tensor indices, IntArrayRef sizes, bool keepdim) {
- if (!keepdim && sizes.size() > 0) {
- grad = grad.unsqueeze(dim);
- indices = indices.unsqueeze(dim);
- }
- return at::zeros(sizes, grad.options()).scatter_(dim, indices, grad);
-}
-
-Tensor slice_backward(Tensor grad, IntArrayRef input_sizes, int64_t dim, int64_t start, int64_t end, int64_t step) {
- auto grad_input = at::zeros(input_sizes, grad.options());
- grad_input.slice(dim, start, end, step).copy_(grad);
- return grad_input;
-}
-
-Tensor select_backward(Tensor grad, IntArrayRef input_sizes, int64_t dim, int64_t index) {
- auto grad_input = at::zeros(input_sizes, grad.options());
- grad_input.select(dim, index).copy_(grad);
- return grad_input;
-}
-
Tensor trace_backward(const Tensor & grad, IntArrayRef sizes) {
if (sizes.size() != 2) {
throw std::runtime_error("expected matrix input");
return grad_input.view(input_sizes);
}
-Tensor var_backward(const Tensor & grad, const Tensor & self, bool unbiased) {
- return (2.0 / (self.numel() - unbiased)) * grad * (self - self.mean());
-}
-
-Tensor var_backward(Tensor grad, const Tensor & self, IntArrayRef dim, bool unbiased, bool keepdim) {
- if (self.dim() == 0) {
- return var_backward(grad, self, unbiased);
- }
- if (!keepdim && self.dim() > 1) {
- grad = unsqueeze_multiple(grad, dim, self.sizes().size());
- }
- return (2.0 / (_safe_size(self.sizes(), dim) - unbiased)) * grad * (self - self.mean(dim, true));
-}
Tensor masked_scatter_backward(const Tensor & grad, const Tensor & mask, IntArrayRef sizes) {
int64_t numel = 1;
}
}
+// kthvalue returns (kthvalue, index of kthvalue), currently autodiff only
+// supports at most one output that requires grad. Thus we need to remove
+// the grad for index that doesn't require grad.
+bool needTrimGrad(Node* n) {
+ static OperatorSet need_trim_grad_ops = {
+ "aten::kthvalue(Tensor self, int k, int dim, bool keepdim) -> (Tensor, Tensor)",
+ };
+ if (need_trim_grad_ops.find(n)) {
+ return true;
+ }
+ return false;
+}
+
bool isDifferentiable(Node* n) {
// TODO: scalar-tensor ops should be canonicalized
static OperatorSet differentiable_ops = {
auto fw_graph = compiled_graphs->forward;
new_outputs = inlineCallTo(
*graph, *fw_graph, node->inputs(), /*unpack_outputs=*/true);
- for (size_t i = 0; i < node->outputs().size(); ++i) {
- new_outputs.at(i)->setType(node->outputs()[i]->type());
- new_outputs.at(i)->replaceAllUsesWith(node->outputs()[i]);
+ auto outputs = node->outputs();
+ AT_ASSERT(new_outputs.size() == outputs.size() + 1);
+ for (size_t i = 0; i < outputs.size(); ++i) {
+ new_outputs.at(i)->setType(outputs[i]->type());
+ outputs[i]->replaceAllUsesWith(new_outputs.at(i));
}
}
// Use backward graph to construct reverse_block
auto bw_graph = compiled_graphs->backward;
auto grad_vec = grads.vec();
+ if (needTrimGrad(node)) {
+ grad_vec.erase(grad_vec.begin()+1, grad_vec.end());
+ }
auto it = grad_vec.begin();
grad_vec.insert(it, new_outputs.back());
ArrayRef<Value*> grad(grad_vec);
return {sizes.at(dim) > 1 ? grads.at(0) : grads.at(0).unsqueeze(dim),
nullptr};
- } else if (node->matches(
- "aten::cat(Tensor[] tensors, int dim) -> Tensor",
- /*const_inputs=*/attr::dim)) {
- int dim = *node->get<int64_t>(attr::dim);
- auto tensor_inputs = inputs;
- tensor_inputs.pop_back();
- const auto& first_sizes = tensor_inputs.at(0).sizes();
- const auto has_first_sizes = [&first_sizes](SymbolicVariable var) {
- return var.sizes() == first_sizes;
- };
-
- // NB: this is a specialization for the common case where all inputs are
- // of equal sizes. We can use a single split operation to handle that.
- if (std::all_of(
- tensor_inputs.begin(), tensor_inputs.end(), has_first_sizes)) {
- auto tensor_grads = grads.at(0).chunk(tensor_inputs.size(), dim);
- tensor_grads.emplace_back(nullptr); // for attr::dim
- return tensor_grads;
- } else {
- size_t offset = 0;
- auto grad = grads.at(0);
- std::vector<SymbolicVariable> tensor_grads;
- for (auto input : tensor_inputs) {
- tensor_grads.push_back(grad.narrow(dim, offset, input.sizes()[dim]));
- offset += input.sizes()[dim];
- }
- tensor_grads.emplace_back(nullptr); // for attr::dim
- return tensor_grads;
- }
} else if (comparison_ops.find(node)) {
return {nullptr, nullptr};
Node* node,
ArrayRef<Value*> grad_values) {
auto& graph = *node->owningGraph();
+
+ // FIXME: In case forward has multi outputs, we only support one requires grad
+ if (needTrimGrad(node)) {
+ grad_values = grad_values.at(0);
+ }
auto linear = graph.insertNode(graph.create(prim::GradOf, {grad_values}, 0));
// to make reading gradient graphs easier, remember the name of the forward op
linear->s_(attr::name, node->kind().toDisplayString());
variable_list outputs;
outputs.reserve(num_outputs());
for (size_t i = 0; i < num_outputs(); ++i) {
- if (should_compute_output(i)) {
+ // Input grad can also be None even if it requires grad
+ // Example: `other` in expand_as(self, other)
+ if (should_compute_output(i) && !stack[i].isNone()) {
auto output = std::move(stack[i]).toTensor();
const auto& edge = next_edge(i);
if (output.defined()) {
"aten::div(Tensor self, Scalar other) -> Tensor",
"aten::neg(Tensor self) -> Tensor",
"aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor",
+ "aten::where(Tensor condition, Tensor self, Tensor other) -> Tensor",
// add this used to be prim::AutogradAdd
}};
std::mutex lock;
const std::vector<std::string> functions = {
R"(
+
+ def _dim_arange(like,
+ dim: int):
+ def backward(grad_output):
+ return None, None
+
+ return torch._dim_arange(like, dim), backward
+
+ def contiguous(self):
+ def backward(grad_output):
+ return None
+
+ return self.contiguous(), backward
+
+ def erf(self):
+ def backward(grad_output):
+ # Precomputed constant C = 2.0 / math.sqrt(math.pi)
+ C = 1.1283791670955126
+ grad_self = C * torch.exp(- self.pow(2)) * grad_output
+ return grad_self
+
+ return torch.erf(self), backward
+
+ def expand(self,
+ size: List[int],
+ implicit: bool=False):
+ self_size = self.size()
+ def backward(grad_output):
+ grad_self = torch._grad_sum_to_size(grad_output, self_size)
+ return grad_self, None, None
+
+ return torch.expand(self, size, implicit=implicit), backward
+
+ def expand_as(self, other):
+ self_size = self.size()
+ def backward(grad_output):
+ grad_self = grad_output._grad_sum_to_size(self_size)
+ return grad_self, None
+
+ return torch.expand_as(self, other), backward
+
+ def full_like(self,
+ fill_value: float):
+ def backward(grad_output):
+ return None, None
+
+ return torch.full_like(self, fill_value), backward
+
+ def kthvalue(self,
+ k: int,
+ dim: int,
+ keepdim: bool):
+ result0, result1 = torch.kthvalue(self, k, dim, keepdim)
+ self_size = self.size()
+ def backward(grad_output):
+ grad_self = torch.index_select_backward(grad_output, dim, result1, self_size, keepdim)
+ return grad_self, None, None, None
+
+ return result0, result1, backward
+
+ def logsumexp(self,
+ dim: List[int],
+ keepdim: bool):
+ result = torch.logsumexp(self, dim, keepdim)
+ self_dim = self.dim()
+ def backward(grad_output):
+ grad_self = torch.logsumexp_backward(grad_output, self, result, dim, keepdim)
+ return grad_self, None, None
+
+ return result, backward
+
+ def mean_0(self):
+ self_size = self.size()
+ self_numel = self.numel()
+ def backward(grad_output):
+ grad_self = grad_output.expand(self_size) / self_numel
+ return grad_self
+
+ return torch.mean(self), backward
+
+ def mean_1(self,
+ dim: List[int],
+ keepdim: bool):
+ self_size = self.size()
+ def backward(grad_output):
+ grad_self = torch.sum_backward(grad_output, self_size, dim, keepdim) / torch._safe_size(self_size, dim)
+ return grad_self, None, None
+
+ return torch.mean(self, dim, keepdim), backward
+
def mul(self, other):
+ self_size = self.size()
+ other_size = other.size()
def backward(grad_output):
- grad_self = (grad_output * other)._grad_sum_to_size(self.size())
- grad_other = (grad_output * self)._grad_sum_to_size(other.size())
+ grad_self = (grad_output * other)._grad_sum_to_size(self_size)
+ grad_other = (grad_output * self)._grad_sum_to_size(other_size)
return grad_self, grad_other
+
return self * other, backward
+ def nonzero(self):
+ def backward(grad_output):
+ return None
+
+ return torch.nonzero(self), backward
+
+ def ones_like(self):
+ def backward(grad_output):
+ return None
+
+ return torch.ones_like(self), backward
+
+ def permute(self,
+ dims: List[int]):
+ def backward(grad_output):
+ grad_self = torch.permute_backwards(grad_output, dims)
+ return grad_self, None
+
+ return torch.permute(self, dims), backward
+
+ def pow_0(self,
+ exponent: float):
+ def backward(grad_output):
+ grad_self = torch.where(torch.tensor(exponent == 0.0), torch.zeros_like(self), grad_output * exponent * torch.pow(self, exponent - 1))
+ return grad_self, None
+
+ return torch.pow(self, exponent), backward
+
+ def pow_1(self, exponent):
+ self_size = self.size()
+ exponent_size = exponent.size()
+ def backward(grad_output):
+ grad_self = torch.where(exponent == 0.0, torch.zeros_like(self), grad_output * exponent * torch.pow(self, exponent - 1))._grad_sum_to_size(self_size)
+ grad_exponent = (grad_output * torch.pow(self, exponent) * torch.log(self))._grad_sum_to_size(exponent_size)
+ return grad_self, grad_exponent
+
+ return torch.pow(self, exponent), backward
+
+ def pow_2(self: float,
+ exponent):
+ def backward(grad_output):
+ grad_exponent = grad_output * torch.pow(self, exponent) * torch.log(torch.tensor(self))
+ return None, grad_exponent
+
+ return torch.pow(self, exponent), backward
+
+ def rsub_0(self, other,
+ alpha: float = 1.0):
+ self_size = self.size()
+ other_size = other.size()
+ def backward(grad_output):
+ grad_self = (- grad_output * alpha)._grad_sum_to_size(self_size)
+ grad_other = (grad_output)._grad_sum_to_size(other_size)
+ return grad_self, grad_other, None
+
+ return torch.rsub(self, other, alpha), backward
+
+ def rsub_1(self,
+ other: float,
+ alpha: float = 1.0):
+ self_size = self.size()
+ def backward(grad_output):
+ grad_self = (- grad_output * alpha)._grad_sum_to_size(self_size)
+ return grad_self, None, None
+
+ return torch.rsub(self, other, alpha), backward
+
+ def select(self,
+ dim: int,
+ index: int):
+ self_size = self.size()
+ def backward(grad_output):
+ grad_self = torch.select_backward(grad_output, self_size, dim, index)
+ return grad_self, None, None
+
+ return torch.select(self, dim, index), backward
+
+ def sqrt(self):
+ result = torch.sqrt(self)
+ def backward(grad_output):
+ grad_self = grad_output / (2 * result)
+ return grad_self
+
+ return result, backward
+
+ def squeeze_0(self):
+ self_size = self.size()
+ def backward(grad_output):
+ grad_self = torch.unsqueeze_to(grad_output, self_size)
+ return grad_self
+
+ return torch.squeeze(self), backward
+
+ def squeeze_1(self,
+ dim: int):
+ self_size = self.size()
+ def backward(grad_output):
+ grad_self = torch.unsqueeze_to(grad_output, dim, self_size)
+ return grad_self, None
+
+ return torch.squeeze(self, dim), backward
+
+ def t(self):
+ def backward(grad_output):
+ grad_self = torch.t(grad_output)
+ return grad_self
+
+ return torch.t(self), backward
+
+ def to_0(self,
+ device: Optional[Device],
+ dtype: Optional[int],
+ non_blocking: bool=False,
+ copy: bool=False):
+ self_device = self.device
+ self_dtype = self.dtype
+ if device is not None:
+ result = self.to(device, dtype=dtype, non_blocking=non_blocking, copy=copy)
+ else:
+ result = self.to(dtype, non_blocking=non_blocking, copy=copy)
+ def backward(grad_output):
+ grad_self = grad_output.to(self_device, dtype=self_dtype, non_blocking=non_blocking, copy=copy)
+ return grad_self, None, None, None, None
+
+ return result, backward
+
+
+ def to_1(self,
+ dtype: int,
+ non_blocking: bool=False,
+ copy: bool=False):
+ self_dtype = self.dtype
+ def backward(grad_output):
+ grad_self = grad_output.to(self_dtype, non_blocking, copy)
+ return grad_self, None, None, None
+
+ return self.to(dtype=dtype, non_blocking=non_blocking, copy=copy), backward
+
+ def to_2(self,
+ other,
+ non_blocking: bool=False,
+ copy: bool=False):
+ def backward(grad_output):
+ grad_self = grad_output.to(self, non_blocking, copy)
+ return grad_self, None, None, None
+
+ return self.to(other, non_blocking=non_blocking, copy=copy), backward
+
+ def topk(self,
+ k,
+ dim: int = -1,
+ largest: bool = True,
+ sorted: bool = True):
+ result0, result1 = torch.topk(self, k, dim, largest, sorted)
+ self_size = self.size()
+ def backward(grad_output):
+ grad_self = torch.index_select_backward(grad_output, dim, result1, self_size, True)
+ return grad_self, None, None, None, None
+
+ return result0, result1, backward
+
+ def transpose(self,
+ dim0: int,
+ dim1: int):
+ def backward(grad_output):
+ grad_self = torch.transpose(grad_output, dim0, dim1)
+ return grad_self, None, None
+
+ return torch.transpose(self, dim0, dim1), backward
+
+ def var_0(self,
+ unbiased: bool=True):
+ def backward(grad_output):
+ grad_self = torch.var_backward(grad_output, self, unbiased)
+ return grad_self, None
+
+ return torch.var(self, unbiased), backward
+
+ def var_1(self,
+ dim: List[int],
+ unbiased: bool,
+ keepdim: bool):
+ def backward(grad_output):
+ grad_self = torch.var_backward(grad_output, self, dim, unbiased, keepdim)
+ return grad_self, None, None, None
+
+ return torch.var(self, dim, unbiased, keepdim), backward
+
+ def view(self,
+ size: List[int]):
+ self_size = self.size()
+ def backward(grad_output):
+ grad_self = grad_output.reshape(self_size)
+ return grad_self, None
+
+ return torch.view(self, size), backward
+
def adaptive_avg_pool2d(self,
output_size: List[int]):
def backward(grad_output):
return grad_self, None
return torch.adaptive_avg_pool2d(self, output_size), backward
+
+ def embedding(weight,
+ indices,
+ padding_idx: int,
+ scale_grad_by_freq: bool,
+ sparse: bool):
+ weight_size_0 = weight.size()[0]
+ def backward(grad_output):
+ grad_weight = torch.embedding_backward(grad_output, indices, weight_size_0, padding_idx, scale_grad_by_freq, sparse)
+ return grad_weight, None, None, None, None
+
+ return torch.embedding(weight, indices, padding_idx, scale_grad_by_freq, sparse), backward
+
)"};
std::unordered_map<std::string, GradientPair> schema_to_graphs;
return Argument("", TupleType::create(std::move(types)));
}
+// In torchscript AD formulas, we define {func_0, func_1, ...} as
+// overloaded functions of `func`.
+// Remove the suffix before adding the schema string to map
+// schema_to_graphs.
+std::string overloadedSchemaString(const FunctionSchema& schema) {
+ const auto& schema_name = schema.name();
+ auto pos = schema_name.find_last_of('_');
+ auto schema_name_suffix = schema_name.substr(pos + 1);
+ std::string schema_string = canonicalSchemaString(schema);
+ if (!schema_name_suffix.empty()
+ && schema_name_suffix.find_first_not_of("0123456789") == string::npos) {
+ schema_string.replace(schema_string.find(schema_name),
+ schema_name.length(),
+ schema_name.substr(0, pos));
+ }
+ return schema_string;
+}
+
void loadModule(const std::shared_ptr<script::Module>& module) {
for (const auto& method_ : module->get_methods()) {
const auto& method = method_.value();
Symbol::aten(loaded_schema.name()),
loaded_schema.arguments(),
{originalReturnType(new_tuple->type()->expect<TupleType>())});
- std::string key = canonicalSchemaString(actual_schema);
- schema_to_graphs[key] = std::move(pair);
+
+ // modify canonical string for function overloading
+ // prefer not to modify the schema name
+ auto schema_string = overloadedSchemaString(actual_schema);
+
+ schema_to_graphs[schema_string] = std::move(pair);
}
}
return cache_it->second;
} else {
auto schema_str = canonicalSchemaString(schema);
+ // JIT doesn't support keyword only arguments.
+ // Remove ' *,' in schema before looking up
+ // TODO: #16921 properly support keyword only arguments in JIT.
+ auto n = schema_str.find("*, ");
+ if (n != std::string::npos) {
+ schema_str = schema_str.erase(n, 3);
+ }
+
auto sym_script_it = schema_to_graphs.find(schema_str);
if (sym_script_it != schema_to_graphs.end()) {
cached_gradient_pairs.emplace_hint(
operation won't be differentiable.
"""
if out is None:
- denom = input.norm(p, dim, True).clamp(min=eps).expand_as(input)
+ denom = input.norm(p, dim, True).clamp_min(eps).expand_as(input)
ret = input / denom
else:
- denom = input.norm(p, dim, True).clamp_(min=eps).expand_as(input)
+ denom = input.norm(p, dim, True).clamp_min(eps).expand_as(input)
ret = torch.div(input, denom, out=out)
return ret