caffe2 - set up correct inheritance structure for remaining operator test classes...
authorDuc Ngo <duc@fb.com>
Mon, 1 Apr 2019 22:49:56 +0000 (15:49 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Mon, 1 Apr 2019 22:53:22 +0000 (15:53 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18622

Set up correct inheritance structure for remaining operator test classes

Reviewed By: ezyang

Differential Revision: D14685941

fbshipit-source-id: a6b1b3be325935b7fec7515be13a4994b3016bf0

caffe2/python/operator_test/blobs_queue_db_test.py
caffe2/python/operator_test/checkpoint_test.py
caffe2/python/operator_test/copy_ops_test.py
caffe2/python/operator_test/heatmap_max_keypoint_op_test.py
caffe2/python/operator_test/recurrent_net_executor_test.py

index 6112635..6e4c25c 100644 (file)
@@ -7,10 +7,10 @@ import unittest
 import numpy as np
 
 import caffe2.proto.caffe2_pb2 as caffe2_pb2
-from caffe2.python import core, workspace, timeout_guard
+from caffe2.python import core, workspace, timeout_guard, test_util
 
 
-class BlobsQueueDBTest(unittest.TestCase):
+class BlobsQueueDBTest(test_util.TestCase):
     def test_create_blobs_queue_db_string(self):
         def add_blobs(queue, num_samples):
             blob = core.BlobReference("blob")
index dc42d47..831b30c 100644 (file)
@@ -3,14 +3,14 @@ from __future__ import division
 from __future__ import print_function
 from __future__ import unicode_literals
 
-from caffe2.python import core, workspace
+from caffe2.python import core, workspace, test_util
 import os
 import shutil
 import tempfile
 import unittest
 
 
-class CheckpointTest(unittest.TestCase):
+class CheckpointTest(test_util.TestCase):
     """A simple test case to make sure that the checkpoint behavior is correct.
     """
 
@@ -23,8 +23,8 @@ class CheckpointTest(unittest.TestCase):
         net.Iter([], "iter")
         net.ConstantFill([], "value", shape=[1, 2, 3])
         net.Checkpoint(["iter", "value"], [],
-                     db=os.path.join(temp_root, "test_checkpoint_at_%05d"),
-                     db_type="leveldb", every=10, absolute_path=True)
+                       db=os.path.join(temp_root, "test_checkpoint_at_%05d"),
+                       db_type="leveldb", every=10, absolute_path=True)
         self.assertTrue(workspace.CreateNet(net))
         for i in range(100):
             self.assertTrue(workspace.RunNet("test_checkpoint"))
@@ -40,5 +40,4 @@ class CheckpointTest(unittest.TestCase):
 
 
 if __name__ == "__main__":
-    import unittest
     unittest.main()
index 04e9358..4efec57 100644 (file)
@@ -7,10 +7,10 @@ import numpy as np
 
 import unittest
 from caffe2.proto import caffe2_pb2
-from caffe2.python import workspace, core, model_helper, brew
+from caffe2.python import workspace, core, model_helper, brew, test_util
 
 
-class CopyOpsTest(unittest.TestCase):
+class CopyOpsTest(test_util.TestCase):
 
     def tearDown(self):
         # Reset workspace after each test
index a1aa0aa..8cff1dc 100644 (file)
@@ -29,6 +29,7 @@ def heatmap_approx_keypoint_ref(maps, rois):
 
 class TestHeatmapMaxKeypointOp(hu.HypothesisTestCase):
     def setUp(self):
+        super(TestHeatmapMaxKeypointOp, self).setUp()
         np.random.seed(0)
 
         # initial coordinates and interpolate HEATMAP_SIZE from it
index f36c22a..8ee846a 100644 (file)
@@ -4,7 +4,7 @@ from __future__ import print_function
 from __future__ import unicode_literals
 
 from caffe2.proto import caffe2_pb2
-from caffe2.python import model_helper, workspace, core, rnn_cell
+from caffe2.python import model_helper, workspace, core, rnn_cell, test_util
 from caffe2.python.attention import AttentionType
 
 import numpy as np
@@ -15,9 +15,10 @@ import hypothesis.strategies as st
 from hypothesis import given
 
 
-class TestRNNExecutor(unittest.TestCase):
+class TestRNNExecutor(test_util.TestCase):
 
     def setUp(self):
+        super(TestRNNExecutor, self).setUp()
         self.batch_size = 8
         self.input_dim = 20
         self.hidden_dim = 30
@@ -295,7 +296,6 @@ class TestRNNExecutor(unittest.TestCase):
         self.assertEqual(1 if forward_only else 2, num_found)
 
 if __name__ == "__main__":
-    import unittest
     import random
     random.seed(2603)
     workspace.GlobalInit([