Fix allow_inf in assertEqual (#16959)
authorSsnL <tongzhou.wang.1994@gmail.com>
Tue, 12 Feb 2019 15:49:48 +0000 (07:49 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Tue, 12 Feb 2019 15:56:34 +0000 (07:56 -0800)
Summary:
gchanan pointed out in https://github.com/pytorch/pytorch/pull/16389 that `allow_inf` is treating `-inf` and `inf` as equal. This fixes it.

Also fixing #16448 since it's near and 2.1 has released.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16959

Differential Revision: D14025297

Pulled By: gchanan

fbshipit-source-id: 95348309492e7ab65aa4d7aabb5a1800de66c5d6

test/common_utils.py

index 5dce3d8..d7a7569 100644 (file)
@@ -410,7 +410,8 @@ class TestCase(expecttest.TestCase):
                         # inf check if allow_inf=True
                         if allow_inf:
                             inf_mask = torch.isinf(a)
-                            self.assertTrue(torch.equal(inf_mask, torch.isinf(b)), message)
+                            inf_sign = inf_mask.sign()
+                            self.assertTrue(torch.equal(inf_sign, torch.isinf(b).sign()), message)
                             diff[inf_mask] = 0
                     # TODO: implement abs on CharTensor (int8)
                     if diff.is_signed() and diff.dtype != torch.int8: