fix acc topk's handling of the case when dim=0, fix tests as well (#64727)
authorEmad El-Haraty <elharaty@fb.com>
Thu, 9 Sep 2021 17:32:22 +0000 (10:32 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Thu, 9 Sep 2021 17:35:23 +0000 (10:35 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64727

the acc ops convertor for topk has a subtle bug (i found this while trying to introduce max/min)
the code does not differentiate between dim == None and dim ==0, but these are both different computations

Reviewed By: jfix71, 842974287

Differential Revision: D30833621

fbshipit-source-id: 6cd84e6ca4e95bb1a6d6465e61830b76808a9c78

torch/fx/experimental/fx2trt/converters/acc_ops_converters.py

index e101b6b..e946a92 100644 (file)
@@ -813,7 +813,7 @@ def acc_ops_topk(network, target, args, kwargs, name):
 
     num_dims = len(input_val.shape) + (1 if network.has_implicit_batch_dimension else 0)
     k = kwargs["k"]
-    dim = (kwargs["dim"] if kwargs["dim"] else -1) % num_dims
+    dim = (kwargs["dim"] if kwargs["dim"] is not None else -1) % num_dims
     operation = trt.TopKOperation.MAX if kwargs["largest"] else trt.TopKOperation.MIN
     layer = network.add_topk(
         input_val, operation, k, get_axes_for_reduce_op(dim, network.has_implicit_batch_dimension)