handle the case in acc_ops.sum when dim == 0, differentiating it from the case when...
authorEmad El-Haraty <elharaty@fb.com>
Mon, 13 Sep 2021 21:22:53 +0000 (14:22 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Mon, 13 Sep 2021 21:24:16 +0000 (14:24 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64869

handle the case in acc_ops.sum when dim == 0, differentiating it from the case when dim is None

Reviewed By: 842974287

Differential Revision: D30872739

fbshipit-source-id: 2755d3230804a16ef1c9289f804138c6dd7766b3

torch/fx/experimental/fx_acc/acc_ops.py

index 22de9ee..2ffc993 100644 (file)
@@ -567,7 +567,7 @@ def add_sum_mapper(node: torch.fx.Node, mod: torch.fx.GraphModule) -> torch.fx.N
 
 @register_acc_op
 def sum(*, input, dim=None, keepdim=False, dtype=None):
-    if dim:
+    if dim is not None:
         return torch.sum(**locals())
     else:
         return input.sum(dtype=dtype)