from tensorflow.python.eager import context
from tensorflow.python.framework import ops
from tensorflow.python.training import adam
+from tensorflow.python.training import distribute as distribute_lib
from tensorflow.python.training import gradient_descent
from tensorflow.python.util import tf_inspect
return self._required_tpu
+default_strategy = NamedDistribution(
+ "Default",
+ distribute_lib._default_distribution_strategy, # pylint: disable=protected-access
+ required_gpus=None)
one_device_strategy = NamedDistribution(
"OneDeviceCPU", one_device_strategy.OneDeviceStrategy("/cpu:0"),
- None)
+ required_gpus=None)
tpu_strategy_single_iteration = NamedDistribution(
"TPUSingleIteration",
tpu_strategy.TPUStrategy(iterations_per_step=1),
required_tpu=True)
tpu_strategy = NamedDistribution(
"TPU", tpu_strategy.TPUStrategy(), required_tpu=True)
+# Note that we disable prefetching for testing since prefetching makes
+# the input non-deterministic.
mirrored_strategy_with_gpu_and_cpu = NamedDistribution(
"MirroredCPUAndGPU",
- mirrored_strategy.MirroredStrategy(["/gpu:0", "/cpu:0"]), 1)
-mirrored_strategy_without_prefetch = NamedDistribution(
- "MirroredCPUAndGPUNoPrefetch",
mirrored_strategy.MirroredStrategy(
- ["/gpu:0", "/cpu:0"], prefetch_on_device=False), 1)
+ ["/gpu:0", "/cpu:0"], prefetch_on_device=False),
+ required_gpus=1)
mirrored_strategy_with_two_gpus = NamedDistribution(
"Mirrored2GPUs",
- mirrored_strategy.MirroredStrategy(["/gpu:0", "/gpu:1"]), 2)
+ mirrored_strategy.MirroredStrategy(
+ ["/gpu:0", "/gpu:1"], prefetch_on_device=False),
+ required_gpus=2)
adam_optimizer_v1_fn = NamedObject(
"AdamV1", lambda: adam.AdamOptimizer(0.2, epsilon=1))
renorm=renorm,
update_ops_in_tower_mode=not update_ops_in_cross_tower_mode)
- # Disable prefetching since that makes the specific input on each device
- # to be non deterministic, and this test relies on specific input being
- # on each device.
+ # Make sure prefetching is disabled since that makes the
+ # specific input on each device to be non deterministic, and
+ # this test relies on specific input being on each device.
if isinstance(distribution, mirrored_strategy.MirroredStrategy):
- distribution._prefetch_on_device = False
+ self.assertFalse(distribution._prefetch_on_device)
iterator = distribution.distribute_dataset(
dataset_fn).make_one_shot_iterator()