Allow combinations to be used on the class level. Make "mode" optional.
authorIgor Saprykin <isaprykin@google.com>
Thu, 24 May 2018 22:28:03 +0000 (15:28 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 24 May 2018 22:32:37 +0000 (15:32 -0700)
Applying a generator to a class is the same as applying that generator to every member of that class.  It is meant to allow avoiding repetition in some cases.

The implementation relies on some internals of parameterized tests and how it works with a class level declaration:  https://github.com/abseil/abseil-py/blob/master/absl/testing/parameterized.py#L319.

The "mode" argument is required before this change.  To accommodate cases where execution mode isn't the point of the test, "mode" became optional with "graph" mode being default.  Another idea I had was to pick a random mode by default.

PiperOrigin-RevId: 197964501

tensorflow/contrib/distribute/python/combinations.py
tensorflow/contrib/distribute/python/combinations_test.py

index 1593581..e400fa5 100644 (file)
@@ -41,7 +41,10 @@ from __future__ import print_function
 
 from collections import OrderedDict
 import sys
+import types
+import unittest
 from absl.testing import parameterized
+import six
 
 from tensorflow.contrib.distribute.python import mirrored_strategy
 from tensorflow.contrib.distribute.python import one_device_strategy
@@ -67,8 +70,8 @@ def generate(combinations):
     combinations: a list of dictionaries created using combine() and times().
 
   Restrictions:
-   -- there should always be a "mode" argument.  Accepted values are "eager"
-      and "graph".
+   -- the "mode" argument can be either "eager" or "graph".  It's "graph" by
+      default.
    -- arguments of the test method must match by name to get the corresponding
       value of the combination.  Tests must accept all arguments except the
       "mode", "required_tpu" and "required_gpus".
@@ -83,14 +86,15 @@ def generate(combinations):
       test will be skipped if the specified number of GPUs aren't available.
 
   Returns:
-    a decorator that will cause the test method to be run under the specified
-    conditions.
+    a decorator that will cause the test method or the test class to be run
+    under the specified conditions.
 
   Raises:
-    ValueError - if "mode" argument wasn't either "eager" or "graph".
+    ValueError - if "mode" argument wasn't either "eager" or "graph" or if other
+      arguments were not accepted by the test method.
   """
 
-  def decorator(test_function):
+  def decorator(test_method_or_class):
     """The decorator to be returned."""
 
     # Generate good test names that can be used with --test_filter.
@@ -110,70 +114,91 @@ def generate(combinations):
               list(combination.items()) + [("testcase_name",
                                             "_test{}".format(name))]))
 
-    @parameterized.named_parameters(*named_combinations)
-    def decorated(self, **kwargs):
-      """A wrapped test method that sets up `test_function`."""
-      assert "mode" in kwargs
-      mode = kwargs["mode"]
-
-      distribution = kwargs.pop("distribution", None)
-      required_tpu = kwargs.pop("required_tpu", False)
-      required_gpus = kwargs.pop("required_gpus", None)
-
-      if distribution:
-        assert required_gpus is None, (
-            "Do not use `required_gpus` and `distribution` together.")
-        assert required_tpu is False, (
-            "Do not use `required_tpu` and `distribution` together.")
-        kwargs["distribution"] = distribution.strategy
-        required_gpus = distribution.required_gpus
-        required_tpu = distribution.required_tpu
-
-      if required_tpu and not TPU_TEST:
-        self.skipTest("Test requires a TPU, but it's not available.")
-      if not required_tpu and TPU_TEST:
-        self.skipTest("Test that doesn't require a TPU.")
-
-      if not required_gpus:
-        if GPU_TEST:
-          self.skipTest("Test that doesn't require GPUs.")
-      elif context.num_gpus() < required_gpus:
-        self.skipTest(
-            "{} GPUs are not available for this test. {} GPUs are available".
-            format(required_gpus, context.num_gpus()))
-
-      # At this point, `kwargs` doesn't have `required_gpus` or `required_tpu`
-      # that the user might have specified.  `kwargs` still has `mode`, which
-      # the test is allowed to accept or ignore.
-      requested_arguments = tf_inspect.getfullargspec(test_function).args
-      missing_arguments = set(list(kwargs.keys()) + ["self"]).difference(
-          set(requested_arguments + ["mode"]))
-      if missing_arguments:
-        raise ValueError("The test is missing arguments {} .".format(
-            missing_arguments))
-
-      kwargs_to_pass = {}
-      for arg in requested_arguments:
-        if arg == "self":
-          kwargs_to_pass[arg] = self
-        else:
-          kwargs_to_pass[arg] = kwargs[arg]
-
-      if mode == "eager":
-        with context.eager_mode(), ops.Graph().as_default():
-          test_function(**kwargs_to_pass)
-      elif mode == "graph":
-        with context.graph_mode(), ops.Graph().as_default():
-          test_function(**kwargs_to_pass)
-      else:
-        raise ValueError(
-            "'mode' has to be either 'eager' or 'graph' and not {}".format(
-                mode))
+    if isinstance(test_method_or_class, type):
+      class_object = test_method_or_class
+      class_object._test_method_ids = test_method_ids = {}
+      for name, test_method in six.iteritems(class_object.__dict__.copy()):
+        if (name.startswith(unittest.TestLoader.testMethodPrefix) and
+            isinstance(test_method, types.FunctionType)):
+          delattr(class_object, name)
+          methods = {}
+          parameterized._update_class_dict_for_param_test_case(
+              class_object.__name__, methods, test_method_ids, name,
+              parameterized._ParameterizedTestIter(
+                  _augment_with_special_arguments(test_method),
+                  named_combinations, parameterized._NAMED, name))
+          for method_name, method in six.iteritems(methods):
+            setattr(class_object, method_name, method)
+
+      return class_object
+    else:
+      test_method = _augment_with_special_arguments(test_method_or_class)
+      return parameterized.named_parameters(*named_combinations)(test_method)
 
-    return decorated
   return decorator
 
 
+def _augment_with_special_arguments(test_method):
+  def decorated(self, **kwargs):
+    """A wrapped test method that treats some arguments in a special way."""
+    mode = kwargs.pop("mode", "graph")
+
+    distribution = kwargs.pop("distribution", None)
+    required_tpu = kwargs.pop("required_tpu", False)
+    required_gpus = kwargs.pop("required_gpus", None)
+
+    if distribution:
+      assert required_gpus is None, (
+          "Do not use `required_gpus` and `distribution` together.")
+      assert required_tpu is False, (
+          "Do not use `required_tpu` and `distribution` together.")
+      kwargs["distribution"] = distribution.strategy
+      required_gpus = distribution.required_gpus
+      required_tpu = distribution.required_tpu
+
+    if required_tpu and not TPU_TEST:
+      self.skipTest("Test requires a TPU, but it's not available.")
+    if not required_tpu and TPU_TEST:
+      self.skipTest("Test that doesn't require a TPU.")
+
+    if not required_gpus:
+      if GPU_TEST:
+        self.skipTest("Test that doesn't require GPUs.")
+    elif context.num_gpus() < required_gpus:
+      self.skipTest(
+          "{} GPUs are not available for this test. {} GPUs are available".
+          format(required_gpus, context.num_gpus()))
+
+    # At this point, `kwargs` doesn't have `required_gpus` or `required_tpu`
+    # that the user might have specified.  `kwargs` still has `mode`, which
+    # the test is allowed to accept or ignore.
+    requested_arguments = tf_inspect.getfullargspec(test_method).args
+    missing_arguments = set(list(kwargs.keys()) + ["self"]).difference(
+        set(requested_arguments + ["mode"]))
+    if missing_arguments:
+      raise ValueError("The test is missing arguments {} .".format(
+          missing_arguments))
+
+    kwargs_to_pass = {}
+    for arg in requested_arguments:
+      if arg == "self":
+        kwargs_to_pass[arg] = self
+      else:
+        kwargs_to_pass[arg] = kwargs[arg]
+
+    if mode == "eager":
+      with ops.Graph().as_default(), context.eager_mode():
+        test_method(**kwargs_to_pass)
+    elif mode == "graph":
+      with ops.Graph().as_default(), context.graph_mode():
+        test_method(**kwargs_to_pass)
+    else:
+      raise ValueError(
+          "'mode' has to be either 'eager' or 'graph' and not {}".format(
+              mode))
+  return decorated
+
+
 def combine(**kwargs):
   """Generate combinations based on its keyword arguments.
 
index 184bcf2..86aa48c 100644 (file)
@@ -19,6 +19,7 @@ from __future__ import division
 from __future__ import print_function
 
 from collections import OrderedDict
+from absl.testing import parameterized
 
 from tensorflow.contrib.distribute.python import combinations
 from tensorflow.python.eager import test
@@ -120,5 +121,28 @@ class TestingCombinationsTest(test.TestCase):
       _ = combinations.times(c1, c2)
 
 
+@combinations.generate(combinations.combine(a=[1, 0], b=[2, 3], c=[1]))
+class CombineTheTestSuite(parameterized.TestCase):
+
+  def test_add_things(self, a, b, c):
+    self.assertLessEqual(3, a + b + c)
+    self.assertLessEqual(a + b + c, 5)
+
+  def test_add_things_one_more(self, a, b, c):
+    self.assertLessEqual(3, a + b + c)
+    self.assertLessEqual(a + b + c, 5)
+
+  def not_a_test(self, a=0, b=0, c=0):
+    del a, b, c
+    self.fail()
+
+  def _test_but_private(self, a=0, b=0, c=0):
+    del a, b, c
+    self.fail()
+
+  # Check that nothing funny happens to a non-callable that starts with "_test".
+  test_member = 0
+
+
 if __name__ == "__main__":
   test.main()