From c99277e177cf16736262251c7e92ea5e9ba2c5c2 Mon Sep 17 00:00:00 2001 From: Emad El-Haraty Date: Mon, 13 Sep 2021 14:22:53 -0700 Subject: [PATCH] handle the case in acc_ops.sum when dim == 0, differentiating it from the case when dim is None (#64869) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/64869 handle the case in acc_ops.sum when dim == 0, differentiating it from the case when dim is None Reviewed By: 842974287 Differential Revision: D30872739 fbshipit-source-id: 2755d3230804a16ef1c9289f804138c6dd7766b3 --- torch/fx/experimental/fx_acc/acc_ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/fx/experimental/fx_acc/acc_ops.py b/torch/fx/experimental/fx_acc/acc_ops.py index 22de9ee..2ffc993 100644 --- a/torch/fx/experimental/fx_acc/acc_ops.py +++ b/torch/fx/experimental/fx_acc/acc_ops.py @@ -567,7 +567,7 @@ def add_sum_mapper(node: torch.fx.Node, mod: torch.fx.GraphModule) -> torch.fx.N @register_acc_op def sum(*, input, dim=None, keepdim=False, dtype=None): - if dim: + if dim is not None: return torch.sum(**locals()) else: return input.sum(dtype=dtype) -- 2.7.4