update scatter formula (#64546)
authorkshitij12345 <kshitijkalambarkar@gmail.com>
Wed, 8 Sep 2021 16:52:53 +0000 (09:52 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Wed, 8 Sep 2021 17:02:35 +0000 (10:02 -0700)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/63430

Already tested OpInfo gradient tests
https://github.com/pytorch/pytorch/blob/544c8e6a5d26efdf1cf679b313893fe119825930/torch/testing/_internal/common_methods_invocations.py#L8575-L8577

Pull Request resolved: https://github.com/pytorch/pytorch/pull/64546

Reviewed By: saketh-are

Differential Revision: D30768759

Pulled By: albanD

fbshipit-source-id: 27d144971c51a956a232fc7d02df5c9d2706d565

tools/autograd/derivatives.yaml

index 4bdb565..505130f 100644 (file)
   result: auto_element_wise
 
 - name: scatter.src(Tensor self, int dim, Tensor index, Tensor src) -> Tensor
-  self: grad.clone().scatter_(dim, index, 0)
+  self: grad.scatter(dim, index, 0)
   index: non_differentiable
   src: grad.gather(dim, index)
 
 - name: scatter.value(Tensor self, int dim, Tensor index, Scalar value) -> Tensor
-  self: grad.clone().scatter_(dim, index, 0)
+  self: grad.scatter(dim, index, 0)
   index: non_differentiable
 
 - name: scatter_add(Tensor self, int dim, Tensor index, Tensor src) -> Tensor