from collections import OrderedDict
import sys
+import types
+import unittest
from absl.testing import parameterized
+import six
from tensorflow.contrib.distribute.python import mirrored_strategy
from tensorflow.contrib.distribute.python import one_device_strategy
combinations: a list of dictionaries created using combine() and times().
Restrictions:
- -- there should always be a "mode" argument. Accepted values are "eager"
- and "graph".
+ -- the "mode" argument can be either "eager" or "graph". It's "graph" by
+ default.
-- arguments of the test method must match by name to get the corresponding
value of the combination. Tests must accept all arguments except the
"mode", "required_tpu" and "required_gpus".
test will be skipped if the specified number of GPUs aren't available.
Returns:
- a decorator that will cause the test method to be run under the specified
- conditions.
+ a decorator that will cause the test method or the test class to be run
+ under the specified conditions.
Raises:
- ValueError - if "mode" argument wasn't either "eager" or "graph".
+ ValueError - if "mode" argument wasn't either "eager" or "graph" or if other
+ arguments were not accepted by the test method.
"""
- def decorator(test_function):
+ def decorator(test_method_or_class):
"""The decorator to be returned."""
# Generate good test names that can be used with --test_filter.
list(combination.items()) + [("testcase_name",
"_test{}".format(name))]))
- @parameterized.named_parameters(*named_combinations)
- def decorated(self, **kwargs):
- """A wrapped test method that sets up `test_function`."""
- assert "mode" in kwargs
- mode = kwargs["mode"]
-
- distribution = kwargs.pop("distribution", None)
- required_tpu = kwargs.pop("required_tpu", False)
- required_gpus = kwargs.pop("required_gpus", None)
-
- if distribution:
- assert required_gpus is None, (
- "Do not use `required_gpus` and `distribution` together.")
- assert required_tpu is False, (
- "Do not use `required_tpu` and `distribution` together.")
- kwargs["distribution"] = distribution.strategy
- required_gpus = distribution.required_gpus
- required_tpu = distribution.required_tpu
-
- if required_tpu and not TPU_TEST:
- self.skipTest("Test requires a TPU, but it's not available.")
- if not required_tpu and TPU_TEST:
- self.skipTest("Test that doesn't require a TPU.")
-
- if not required_gpus:
- if GPU_TEST:
- self.skipTest("Test that doesn't require GPUs.")
- elif context.num_gpus() < required_gpus:
- self.skipTest(
- "{} GPUs are not available for this test. {} GPUs are available".
- format(required_gpus, context.num_gpus()))
-
- # At this point, `kwargs` doesn't have `required_gpus` or `required_tpu`
- # that the user might have specified. `kwargs` still has `mode`, which
- # the test is allowed to accept or ignore.
- requested_arguments = tf_inspect.getfullargspec(test_function).args
- missing_arguments = set(list(kwargs.keys()) + ["self"]).difference(
- set(requested_arguments + ["mode"]))
- if missing_arguments:
- raise ValueError("The test is missing arguments {} .".format(
- missing_arguments))
-
- kwargs_to_pass = {}
- for arg in requested_arguments:
- if arg == "self":
- kwargs_to_pass[arg] = self
- else:
- kwargs_to_pass[arg] = kwargs[arg]
-
- if mode == "eager":
- with context.eager_mode(), ops.Graph().as_default():
- test_function(**kwargs_to_pass)
- elif mode == "graph":
- with context.graph_mode(), ops.Graph().as_default():
- test_function(**kwargs_to_pass)
- else:
- raise ValueError(
- "'mode' has to be either 'eager' or 'graph' and not {}".format(
- mode))
+ if isinstance(test_method_or_class, type):
+ class_object = test_method_or_class
+ class_object._test_method_ids = test_method_ids = {}
+ for name, test_method in six.iteritems(class_object.__dict__.copy()):
+ if (name.startswith(unittest.TestLoader.testMethodPrefix) and
+ isinstance(test_method, types.FunctionType)):
+ delattr(class_object, name)
+ methods = {}
+ parameterized._update_class_dict_for_param_test_case(
+ class_object.__name__, methods, test_method_ids, name,
+ parameterized._ParameterizedTestIter(
+ _augment_with_special_arguments(test_method),
+ named_combinations, parameterized._NAMED, name))
+ for method_name, method in six.iteritems(methods):
+ setattr(class_object, method_name, method)
+
+ return class_object
+ else:
+ test_method = _augment_with_special_arguments(test_method_or_class)
+ return parameterized.named_parameters(*named_combinations)(test_method)
- return decorated
return decorator
+def _augment_with_special_arguments(test_method):
+ def decorated(self, **kwargs):
+ """A wrapped test method that treats some arguments in a special way."""
+ mode = kwargs.pop("mode", "graph")
+
+ distribution = kwargs.pop("distribution", None)
+ required_tpu = kwargs.pop("required_tpu", False)
+ required_gpus = kwargs.pop("required_gpus", None)
+
+ if distribution:
+ assert required_gpus is None, (
+ "Do not use `required_gpus` and `distribution` together.")
+ assert required_tpu is False, (
+ "Do not use `required_tpu` and `distribution` together.")
+ kwargs["distribution"] = distribution.strategy
+ required_gpus = distribution.required_gpus
+ required_tpu = distribution.required_tpu
+
+ if required_tpu and not TPU_TEST:
+ self.skipTest("Test requires a TPU, but it's not available.")
+ if not required_tpu and TPU_TEST:
+ self.skipTest("Test that doesn't require a TPU.")
+
+ if not required_gpus:
+ if GPU_TEST:
+ self.skipTest("Test that doesn't require GPUs.")
+ elif context.num_gpus() < required_gpus:
+ self.skipTest(
+ "{} GPUs are not available for this test. {} GPUs are available".
+ format(required_gpus, context.num_gpus()))
+
+ # At this point, `kwargs` doesn't have `required_gpus` or `required_tpu`
+ # that the user might have specified. `kwargs` still has `mode`, which
+ # the test is allowed to accept or ignore.
+ requested_arguments = tf_inspect.getfullargspec(test_method).args
+ missing_arguments = set(list(kwargs.keys()) + ["self"]).difference(
+ set(requested_arguments + ["mode"]))
+ if missing_arguments:
+ raise ValueError("The test is missing arguments {} .".format(
+ missing_arguments))
+
+ kwargs_to_pass = {}
+ for arg in requested_arguments:
+ if arg == "self":
+ kwargs_to_pass[arg] = self
+ else:
+ kwargs_to_pass[arg] = kwargs[arg]
+
+ if mode == "eager":
+ with ops.Graph().as_default(), context.eager_mode():
+ test_method(**kwargs_to_pass)
+ elif mode == "graph":
+ with ops.Graph().as_default(), context.graph_mode():
+ test_method(**kwargs_to_pass)
+ else:
+ raise ValueError(
+ "'mode' has to be either 'eager' or 'graph' and not {}".format(
+ mode))
+ return decorated
+
+
def combine(**kwargs):
"""Generate combinations based on its keyword arguments.