use datatype dependent tolerance in data parallel tests
authorNatalia Gimelshein <ngimelshein@nvidia.com>
Tue, 11 Dec 2018 06:48:16 +0000 (22:48 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Tue, 11 Dec 2018 06:50:27 +0000 (22:50 -0800)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/14856

Differential Revision: D13413560

Pulled By: soumith

fbshipit-source-id: b3a0cfe93477ed332e6eaa2e39ef5f4cc8b36481

test/test_nn.py

index 0d71e43..9469486 100644 (file)
@@ -3360,7 +3360,7 @@ class TestNN(NNTestCase):
         net = nn.DataParallel(l)
         out = net(i)
         self.assertEqual(out.get_device(), 0)
-        self.assertEqual(out.data, expected_out)
+        self.assertEqual(out.data, expected_out, dtype2prec[dtype])
 
     @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
     @repeat_test_for_types(ALL_TENSORTYPES)
@@ -3379,7 +3379,7 @@ class TestNN(NNTestCase):
         n = nn.DataParallel(Net())
         out = n(input=i)
         self.assertEqual(out.get_device(), 0)
-        self.assertEqual(out.data, expected_out)
+        self.assertEqual(out.data, expected_out, dtype2prec[dtype])
 
     @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
     @repeat_test_for_types(ALL_TENSORTYPES)
@@ -3398,7 +3398,7 @@ class TestNN(NNTestCase):
         n = nn.DataParallel(Net())
         out = n(input={'data': i, 'unused': []})
         self.assertEqual(out.get_device(), 0)
-        self.assertEqual(out.data, expected_out)
+        self.assertEqual(out.data, expected_out, dtype2prec[dtype])
 
     @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
     @repeat_test_for_types(ALL_TENSORTYPES)
@@ -3417,7 +3417,7 @@ class TestNN(NNTestCase):
         n = nn.DataParallel(Net())
         out = n(input={'data': i, 'unused': {}})
         self.assertEqual(out.get_device(), 0)
-        self.assertEqual(out.data, expected_out)
+        self.assertEqual(out.data, expected_out, dtype2prec[dtype])
 
     @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
     @repeat_test_for_types(ALL_TENSORTYPES)
@@ -3436,7 +3436,7 @@ class TestNN(NNTestCase):
         n = nn.DataParallel(Net())
         out = n(input={'data': i, 'unused': ()})
         self.assertEqual(out.get_device(), 0)
-        self.assertEqual(out.data, expected_out)
+        self.assertEqual(out.data, expected_out, dtype2prec[dtype])
 
     @unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
     def test_data_parallel_device_args(self):