From 195cb4efa8b204ee1c9cb908308d516b812790e2 Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Wed, 8 Sep 2021 09:52:53 -0700 Subject: [PATCH] update scatter formula (#64546) Summary: Fixes https://github.com/pytorch/pytorch/issues/63430 Already tested OpInfo gradient tests https://github.com/pytorch/pytorch/blob/544c8e6a5d26efdf1cf679b313893fe119825930/torch/testing/_internal/common_methods_invocations.py#L8575-L8577 Pull Request resolved: https://github.com/pytorch/pytorch/pull/64546 Reviewed By: saketh-are Differential Revision: D30768759 Pulled By: albanD fbshipit-source-id: 27d144971c51a956a232fc7d02df5c9d2706d565 --- tools/autograd/derivatives.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index 4bdb565..505130f 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -1186,12 +1186,12 @@ result: auto_element_wise - name: scatter.src(Tensor self, int dim, Tensor index, Tensor src) -> Tensor - self: grad.clone().scatter_(dim, index, 0) + self: grad.scatter(dim, index, 0) index: non_differentiable src: grad.gather(dim, index) - name: scatter.value(Tensor self, int dim, Tensor index, Scalar value) -> Tensor - self: grad.clone().scatter_(dim, index, 0) + self: grad.scatter(dim, index, 0) index: non_differentiable - name: scatter_add(Tensor self, int dim, Tensor index, Tensor src) -> Tensor -- 2.7.4