From: Emad El-Haraty Date: Thu, 9 Sep 2021 17:32:22 +0000 (-0700) Subject: fix acc topk's handling of the case when dim=0, fix tests as well (#64727) X-Git-Tag: accepted/tizen/8.0/unified/20231005.095509~330 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=46c886e8a6d3d9f502fa8c0985784436ab7f9543;p=platform%2Fupstream%2Fpytorch.git fix acc topk's handling of the case when dim=0, fix tests as well (#64727) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/64727 the acc ops convertor for topk has a subtle bug (i found this while trying to introduce max/min) the code does not differentiate between dim == None and dim ==0, but these are both different computations Reviewed By: jfix71, 842974287 Differential Revision: D30833621 fbshipit-source-id: 6cd84e6ca4e95bb1a6d6465e61830b76808a9c78 --- diff --git a/torch/fx/experimental/fx2trt/converters/acc_ops_converters.py b/torch/fx/experimental/fx2trt/converters/acc_ops_converters.py index e101b6b..e946a92 100644 --- a/torch/fx/experimental/fx2trt/converters/acc_ops_converters.py +++ b/torch/fx/experimental/fx2trt/converters/acc_ops_converters.py @@ -813,7 +813,7 @@ def acc_ops_topk(network, target, args, kwargs, name): num_dims = len(input_val.shape) + (1 if network.has_implicit_batch_dimension else 0) k = kwargs["k"] - dim = (kwargs["dim"] if kwargs["dim"] else -1) % num_dims + dim = (kwargs["dim"] if kwargs["dim"] is not None else -1) % num_dims operation = trt.TopKOperation.MAX if kwargs["largest"] else trt.TopKOperation.MIN layer = network.add_topk( input_val, operation, k, get_axes_for_reduce_op(dim, network.has_implicit_batch_dimension)