[DataPipe] fixing tests related fork() to remove warnings (#64827)
authorKevin Tse <ktse@fb.com>
Fri, 10 Sep 2021 18:00:01 +0000 (11:00 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Fri, 10 Sep 2021 18:01:42 +0000 (11:01 -0700)
Summary:
There are two warnings produced by `test_fork_datapipe`. This PR addresses the issues raised by those warnings without impacting the test cases.

cc VitalyFedyunin ejguan

Pull Request resolved: https://github.com/pytorch/pytorch/pull/64827

Reviewed By: ejguan

Differential Revision: D30870528

Pulled By: NivekT

fbshipit-source-id: 580a001c6fa3ff6f8b04a7e5183e58861938204b

test/test_datapipe.py

index 15cb059..b77d0a1 100644 (file)
@@ -606,8 +606,7 @@ class TestFunctionalIterDataPipe(TestCase):
 
         # Test Case: making sure all child DataPipe shares the same reference
         dp1, dp2, dp3 = input_dp.fork(num_instances=3)
-        self.assertTrue(all(n1 is n2 for n1, n2 in zip(dp1, dp2)))
-        self.assertTrue(all(n1 is n3 for n1, n3 in zip(dp1, dp3)))
+        self.assertTrue(all(n1 is n2 and n1 is n3 for n1, n2, n3 in zip(dp1, dp2, dp3)))
 
         # Test Case: one child DataPipe yields all value at a time
         output1, output2, output3 = list(dp1), list(dp2), list(dp3)
@@ -680,7 +679,6 @@ class TestFunctionalIterDataPipe(TestCase):
         output1, output2 = list(dp1), list(dp2)
         self.assertEqual(list(range(10)), output1)
         self.assertEqual(list(range(10)), output2)
-        output1, output2 = list(dp1), list(dp2)
         with warnings.catch_warnings(record=True) as wa:
             self.assertEqual(list(range(10)), list(dp1))  # Resets even though dp3 has not been read
             self.assertEqual(len(wa), 1)