# See the License for the specific language governing permissions and
# limitations under the License.
# ===================================================================
-
"""TPUEstimator class."""
from __future__ import absolute_import
import copy
import threading
import time
+import traceback
import six
from six.moves import queue as Queue # pylint: disable=redefined-builtin
from tensorflow.python.training import training
from tensorflow.python.training import training_util
-
_INITIAL_LOSS = 1e7
_ZERO_LOSS = 0.
_TPU_ESTIMATOR = 'tpu_estimator'
initializer=init_ops.zeros_initializer(),
trainable=False,
use_resource=True,
- collections=[ops.GraphKeys.GLOBAL_VARIABLES,
- ops.GraphKeys.GLOBAL_STEP])
+ collections=[ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.GLOBAL_STEP])
def _create_or_get_iterations_per_loop():
raise RuntimeError('Multiple iterations_per_loop_var in collection.')
with ops.colocate_with(training_util.get_global_step()):
- with variable_scope.variable_scope(_TPU_ESTIMATOR,
- reuse=variable_scope.AUTO_REUSE):
+ with variable_scope.variable_scope(
+ _TPU_ESTIMATOR, reuse=variable_scope.AUTO_REUSE):
return variable_scope.get_variable(
_ITERATIONS_PER_LOOP_VAR,
initializer=init_ops.zeros_initializer(),
return self._eval_batch_size
return None
- global_batch_size = (self._train_batch_size if
- mode == model_fn_lib.ModeKeys.TRAIN
- else self._eval_batch_size)
+ global_batch_size = (
+ self._train_batch_size
+ if mode == model_fn_lib.ModeKeys.TRAIN else self._eval_batch_size)
# On TPU
if self.is_input_sharded_per_core():
return global_batch_size // self.num_cores
# The tpu job is determined by the run_config. Right now, this method is
# required as tpu_config is not part of the RunConfig.
mode = self._assert_mode()
- master = (run_config.evaluation_master if mode == model_fn_lib.ModeKeys.EVAL
- else run_config.master)
+ master = (
+ run_config.evaluation_master
+ if mode == model_fn_lib.ModeKeys.EVAL else run_config.master)
if master in _LOCAL_MASTERS:
return None
def tpu_host_placement_function(self):
"""Returns the TPU host place function."""
master = self.master_job
+
def _placement_function(_sentinal=None, core_id=None, host_id=None): # pylint: disable=invalid-name
assert _sentinal is None
if core_id is not None and host_id is not None:
if core_id is not None:
host_id = core_id / 8
return '/job:%s/task:%d/device:CPU:0' % (master, host_id)
+
return _placement_function
@property
def tpu_device_placement_function(self):
master = self.master_job
job_device = '' if master is None else ('/job:%s' % master)
+
def _placement_function(i):
return '%s/task:%d/device:TPU:%d' % (job_device, i / 8, i % 8)
+
return _placement_function
@property
def tpu_ordinal_function(self):
"""Returns the TPU ordinal fn."""
+
def _tpu_ordinal_function(index):
"""Return the TPU ordinal associated with a shard.
The ordinal of the TPU device the shard's infeed should be placed on.
"""
return index % 8
+
return _tpu_ordinal_function
STOP = -2
-class TPUEstimatorSpec(collections.namedtuple('TPUEstimatorSpec', [
- 'mode',
- 'predictions',
- 'loss',
- 'train_op',
- 'eval_metrics',
- 'export_outputs',
- 'scaffold_fn'])):
+class TPUEstimatorSpec(
+ collections.namedtuple('TPUEstimatorSpec', [
+ 'mode',
+ 'predictions',
+ 'loss',
+ 'train_op',
+ 'eval_metrics',
+ 'export_outputs',
+ 'scaffold_fn'
+ ])):
"""Ops and objects returned from a `model_fn` and passed to `TPUEstimator`.
See `EstimatorSpec` for `mode`, 'predictions, 'loss', 'train_op', and
"""Creates a validated `TPUEstimatorSpec` instance."""
if eval_metrics is not None:
_EvalMetrics.validate(eval_metrics)
- return super(TPUEstimatorSpec, cls).__new__(cls,
- mode=mode,
- predictions=predictions,
- loss=loss,
- train_op=train_op,
- eval_metrics=eval_metrics,
- export_outputs=export_outputs,
- scaffold_fn=scaffold_fn)
+ return super(TPUEstimatorSpec, cls).__new__(
+ cls,
+ mode=mode,
+ predictions=predictions,
+ loss=loss,
+ train_op=train_op,
+ eval_metrics=eval_metrics,
+ export_outputs=export_outputs,
+ scaffold_fn=scaffold_fn)
def as_estimator_spec(self):
"""Creates an equivalent `EstimatorSpec` used by CPU train/eval."""
eval_metric_ops = _EvalMetrics.to_metric_metric_ops_for_cpu(
self.eval_metrics)
scaffold = self.scaffold_fn() if self.scaffold_fn else None
- return model_fn_lib.EstimatorSpec(mode=self.mode,
- predictions=self.predictions,
- loss=self.loss,
- train_op=self.train_op,
- eval_metric_ops=eval_metric_ops,
- export_outputs=self.export_outputs,
- scaffold=scaffold)
+ return model_fn_lib.EstimatorSpec(
+ mode=self.mode,
+ predictions=self.predictions,
+ loss=self.loss,
+ train_op=self.train_op,
+ eval_metric_ops=eval_metric_ops,
+ export_outputs=self.export_outputs,
+ scaffold=scaffold)
+
+
+class _OpQueueContext(object):
+ """Manages work queue and thread for a infeed/outfeed thread."""
+
+ def __init__(self, name, target, args):
+ self._name = name
+ self._queue = Queue.Queue()
+ args = (self,) + args
+ self._thread = threading.Thread(name=name, target=target, args=args)
+ self._thread.daemon = True
+ self._thread.start()
+
+ def stop(self):
+ self._queue.put(_SIGNAL.STOP)
+
+ def send_next_batch_signal(self, iterations):
+ self._queue.put(iterations)
+
+ def read_iteration_counts(self):
+ while True:
+ signal = self._queue.get(block=True)
+ logging.debug('%s read signal %s', self._name, signal)
+ if signal == _SIGNAL.STOP:
+ logging.info('%s received signal, stopping.', self._name)
+ return
+ yield signal
+ def join(self):
+ logging.info('Shutting down %s thread.' % self._name)
+ self.stop()
+ self._thread.join()
-class _InfeedOutfeedThreadBaseController(object):
- """This wraps the infeed/outfeed thread and stops when Estimator finishes."""
- def __init__(self, thd):
- self._signal_queue = Queue.Queue()
- thd.daemon = True
- thd.start()
- self._thd = thd
+class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook):
+ """A Session hook setting up the TPU initialization, infeed, and outfeed.
- def block_and_get_signal(self):
- return self._signal_queue.get()
+ This hook does two major things:
+ 1. initialize and shutdown TPU system.
+ 2. launch and join the threads for infeed enqueue and (optional) outfeed
+ dequeue.
+ """
- def send_next_batch_signal(self, signal=_SIGNAL.NEXT_BATCH):
- self._signal_queue.put(signal)
+ def __init__(self, ctx, enqueue_ops, dequeue_ops=None):
+ self._master_job = ctx.master_job
+ self._enqueue_ops = enqueue_ops
+ self._dequeue_ops = dequeue_ops
+ self._initial_infeed_sleep_secs = (
+ ctx.config.tpu_config.initial_infeed_sleep_secs)
+ self._session_cancel_timer = None
- def join(self):
- self._signal_queue.put(_SIGNAL.STOP)
- self._thd.join()
+ self._feed_error = None
+ self._finished = False
+ def begin(self):
+ logging.info('TPU job name %s', self._master_job)
+ self._iterations_per_loop_var = _create_or_get_iterations_per_loop()
+ self._init_op = [tpu.initialize_system(job=self._master_job)]
+ self._finalize_op = [tpu.shutdown_system(job=self._master_job)]
-class _OutfeedThreadController(_InfeedOutfeedThreadBaseController):
- """This wraps the outfeed thread and stops when Estimator finishes."""
+ def _log_error(self, session, error):
+ """Log an infeed or outfeed error.
- def __init__(self, session, dequeue_ops):
- super(_OutfeedThreadController, self).__init__(
- threading.Thread(target=self._execute_dequeue_ops,
- args=(session, dequeue_ops)))
+ This logs a short error message immediately, and schedules a timer to
+ emit the full stack trace and error message after a short period of time.
+ If the main session has terminated by the time the timer triggers, we
+ assume the real source of the error was from the main session and avoid
+ emitting a stack trace for the infeed.
- def _execute_dequeue_ops(self, session, dequeue_ops):
- count = 0
- while True:
- signal = self.block_and_get_signal()
- if signal == _SIGNAL.STOP:
- logging.info('Stop outfeed thread.')
- return
+ Args:
+ session: `tf.Session`, session to be terminated
+ error: exception that triggered logging.
+ """
+ logging.warning(
+ '\n\n'
+ 'Error occurred during infeed/outfeed. This may be due to a compile '
+ 'error in the main session. Waiting for a short time for the main '
+ 'session to come back.\n\n%s', error)
- iterations = signal
- for i in range(iterations):
- logging.debug('Outfeed dequeue for iteration (%d, %d)', count, i)
- session.run(dequeue_ops)
- count += 1
+ self._feed_error = traceback.format_exc()
- def join(self):
- logging.info('Waiting for Outfeed Thread to exit.')
- super(_OutfeedThreadController, self).join()
-
-
-class _InfeedThreadController(_InfeedOutfeedThreadBaseController):
- """This wraps the infeed thread and stops when Estimator finishes."""
-
- def __init__(self, session, enqueue_ops, initial_infeed_sleep_secs):
- super(_InfeedThreadController, self).__init__(
- threading.Thread(
- target=self._input_thread_fn_for_loading,
- args=(session, enqueue_ops, initial_infeed_sleep_secs)))
-
- def _input_thread_fn_for_loading(self, session, enqueue_ops,
- initial_infeed_sleep_secs):
- count = 0
- if initial_infeed_sleep_secs:
- logging.info('Infeed thread sleeping for %d seconds.',
- initial_infeed_sleep_secs)
- time.sleep(initial_infeed_sleep_secs)
- logging.info('Infeed thread starting after sleep')
- try:
- while True:
- signal = self._signal_queue.get()
- if signal == _SIGNAL.STOP:
- logging.info('Stop Infeed input thread.')
- return
-
- if _WRAP_INPUT_FN_INTO_WHILE_LOOP:
- # Enqueue batches for next loop.
- session.run(enqueue_ops)
- else:
- iterations = signal
- for i in range(iterations):
- logging.debug('Infeed enqueue for iteration (%d, %d)', count, i)
- session.run(enqueue_ops)
- count += 1
+ # If we've already encountered a feed error, don't schedule another
+ # cancellation op.
+ if self._session_cancel_timer:
+ return
- except Exception: # pylint: disable=broad-except
+ def _cancel_session():
# Close the session to avoid the main thread from hanging. If input
# pipeline triggers any error, the infeed thread dies but the main thread
# for TPU computation waits for the infeed enqueue forever. Close the
# exception in the main thread, instead of the expected compile error.
# User code that depends on having the proper exception type will
# therefore be confused.
- logging.error(
- 'Failed running infeed, closing session.\n'
- 'You may see an exception from your main session after this. '
- 'Sleep for 2 minutes before close Session from infeed thread to '
- 'allow the main thread returning an error first, if any.',
- exc_info=1
- )
- time.sleep(120)
- logging.error('Closing the failed session.')
- session.close()
-
- def join(self):
- logging.info('Waiting for Infeed Thread to exit.')
- super(_InfeedThreadController, self).join()
-
-
-class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook):
- """A Session hook setting up the TPU initialization, infeed, and outfeed.
-
- This hook does two major things:
- 1. initialize and shutdown TPU system.
- 2. launch and join the threads for infeed enqueue and (optional) outfeed
- dequeue.
- """
+ time.sleep(5)
+
+ # If the main session is still running, the infeed/outfeed errors are
+ # legitimate, and should be logged.
+ if not self._finished:
+ logging.error('Feed error: %s', self._feed_error)
+ logging.error('Closing session. A RuntimeError should follow.')
+ session.close()
+
+ self._session_cancel_timer = threading.Thread(target=_cancel_session)
+ self._session_cancel_timer.daemon = True
+ self._session_cancel_timer.start()
+
+ def _run_infeed(self, queue_ctx, session):
+ logging.info('Starting infeed thread controller.')
+ if self._initial_infeed_sleep_secs:
+ logging.info('%s thread sleeping for %d seconds.', self._name,
+ self._initial_infeed_sleep_secs)
+ time.sleep(self._initial_infeed_sleep_secs)
+ logging.info('%s thread starting after sleep', self._name)
- def __init__(self, ctx, enqueue_ops, dequeue_ops=None):
- self._master_job = ctx.master_job
- self._enqueue_ops = enqueue_ops
- self._dequeue_ops = dequeue_ops
- self._initial_infeed_sleep_secs = (
- ctx.config.tpu_config.initial_infeed_sleep_secs)
+ try:
+ if _WRAP_INPUT_FN_INTO_WHILE_LOOP:
+ for _ in queue_ctx.read_iteration_counts():
+ session.run(self._enqueue_ops)
+ else:
+ for count, steps in enumerate(queue_ctx.read_iteration_counts()):
+ for i in xrange(steps):
+ logging.debug('Infeed enqueue for iteration (%d, %d)', count, i)
+ session.run(self._enqueue_ops)
+ logging.debug('Infeed thread finished, shutting down.')
+ except Exception as e: # pylint: disable=broad-except
+ self._log_error(session, e)
- def begin(self):
- logging.info('TPU job name %s', self._master_job)
- self._iterations_per_loop_var = _create_or_get_iterations_per_loop()
- self._init_op = [tpu.initialize_system(job=self._master_job)]
- self._finalize_op = [tpu.shutdown_system(job=self._master_job)]
+ def _run_outfeed(self, queue_ctx, session):
+ logging.info('Starting outfeed thread controller.')
+ try:
+ for count, steps in enumerate(queue_ctx.read_iteration_counts()):
+ for i in xrange(steps):
+ logging.debug('Outfeed dequeue for iteration (%d, %d)', count, i)
+ session.run(self._dequeue_ops)
+ except Exception as e: # pylint: disable=broad-except
+ self._log_error(session, e)
def after_create_session(self, session, coord):
logging.info('Init TPU system')
- session.run(self._init_op,
- options=config_pb2.RunOptions(timeout_in_ms=5*60*1000))
+ session.run(
+ self._init_op,
+ options=config_pb2.RunOptions(timeout_in_ms=5 * 60 * 1000))
logging.info('Start infeed thread controller')
- self._infeed_thd_controller = _InfeedThreadController(
- session, self._enqueue_ops, self._initial_infeed_sleep_secs)
+ self._infeed_controller = _OpQueueContext(
+ name='InfeedController', target=self._run_infeed, args=(session,))
if self._dequeue_ops is not None:
logging.info('Start outfeed thread controller')
- self._outfeed_thd_controller = _OutfeedThreadController(
- session, self._dequeue_ops)
+ self._outfeed_controller = _OpQueueContext(
+ name='OutfeedController', target=self._run_outfeed, args=(session,))
def before_run(self, run_context):
+ if self._feed_error:
+ logging.warning('Feed error occurred, terminating session.')
+ run_context.request_stop()
+ return
+
iterations = run_context.session.run(self._iterations_per_loop_var)
logging.info('Enqueue next (%d) batch(es) of data to infeed.', iterations)
+ self._infeed_controller.send_next_batch_signal(iterations)
- self._infeed_thd_controller.send_next_batch_signal(iterations)
if self._dequeue_ops is not None:
# TODO(xiejw): Refactor the outfeed dequeue into tf.while_loop.
- logging.info(
- 'Dequeue next (%d) batch(es) of data from outfeed.', iterations)
- self._outfeed_thd_controller.send_next_batch_signal(iterations)
+ logging.info('Dequeue next (%d) batch(es) of data from outfeed.',
+ iterations)
+ self._outfeed_controller.send_next_batch_signal(iterations)
def end(self, session):
+ if self._session_cancel_timer:
+ logging.warning('Feed error occurred; waiting for message.')
+ self._session_cancel_timer.join()
+
+ self._finished = True
logging.info('Stop infeed thread controller')
- self._infeed_thd_controller.join()
+ self._infeed_controller.join()
if self._dequeue_ops is not None:
logging.info('Stop output thread controller')
- self._outfeed_thd_controller.join()
+ self._outfeed_controller.join()
logging.info('Shutdown TPU system.')
session.run(self._finalize_op)
run_context.request_stop()
else:
iterations = self._next_iterations(global_step, self._last_step)
- self._iterations_per_loop_var.load(iterations,
- session=run_context.session)
+ self._iterations_per_loop_var.load(
+ iterations, session=run_context.session)
class _SetEvalIterationsHook(session_run_hook.SessionRunHook):
self._iterations_per_loop_var.load(self._num_steps, session=session)
-def generate_per_core_enqueue_ops_fn_for_host(
- ctx, input_fn, inputs_structure_recorder):
+def generate_per_core_enqueue_ops_fn_for_host(ctx, input_fn,
+ inputs_structure_recorder):
"""Generates infeed enqueue ops for per-core input_fn on a single host."""
captured_infeed_queue = _CapturedObject()
per_host_sharded_inputs)
per_host_enqueue_ops = infeed_queue.generate_enqueue_ops(
- per_host_sharded_inputs,
- tpu_ordinal_function=ctx.tpu_ordinal_function)
+ per_host_sharded_inputs, tpu_ordinal_function=ctx.tpu_ordinal_function)
return per_host_enqueue_ops
+
return enqueue_ops_fn, captured_infeed_queue
features, labels = inputs
else:
features, labels = inputs, None
- inputs_structure_recorder.validate_and_record_structure(
- features, labels)
+ inputs_structure_recorder.validate_and_record_structure(features, labels)
unsharded_tensor_list = (
inputs_structure_recorder.flatten_features_and_labels(
features, labels))
per_host_enqueue_ops = (
infeed_queue.split_inputs_and_generate_enqueue_ops(
- unsharded_tensor_list,
- placement_function=lambda x: device))
+ unsharded_tensor_list, placement_function=lambda x: device))
return per_host_enqueue_ops
+
return enqueue_ops_fn, captured_infeed_queue
def validate_and_record_structure(self, features, labels):
"""Validates and records the structure of features` and `labels`."""
+
def _extract_key_names(tensor_or_dict):
if tensor_or_dict is None:
return []
flattened_inputs = []
if self._feature_names:
# We need a fixed ordering for enqueueing and dequeueing.
- flattened_inputs.extend([features[name]
- for name in self._feature_names])
+ flattened_inputs.extend(
+ [features[name] for name in self._feature_names])
else:
flattened_inputs.append(features)
ValueError: If the number of expected tensors from `flattened_inputs`
mismatches the recorded structure.
"""
- expected_num_features = (len(self._feature_names) if self._feature_names
- else 1)
+ expected_num_features = (
+ len(self._feature_names) if self._feature_names else 1)
if self._has_labels:
- expected_num_labels = (len(self._label_names) if self._label_names
- else 1)
+ expected_num_labels = (
+ len(self._label_names) if self._label_names else 1)
else:
expected_num_labels = 0
if expected_num_labels == 0:
unflattened_label = None
elif self._label_names:
- unflattened_label = dict(zip(self._label_names,
- flattened_inputs[expected_num_features:]))
+ unflattened_label = dict(
+ zip(self._label_names, flattened_inputs[expected_num_features:]))
else:
# Single tensor case.
unflattened_label = flattened_inputs[expected_num_features]
self._ctx, self._input_fn, self._inputs_structure_recorder))
if _WRAP_INPUT_FN_INTO_WHILE_LOOP:
- enqueue_ops.append(_wrap_computation_in_while_loop(
- device=host_device, op_fn=enqueue_ops_fn))
+ enqueue_ops.append(
+ _wrap_computation_in_while_loop(
+ device=host_device, op_fn=enqueue_ops_fn))
else:
enqueue_ops.append(enqueue_ops_fn())
# Infeed_queue_getter must be called after enqueue_ops_fn is called.
self._batch_axis, host_device))
if _WRAP_INPUT_FN_INTO_WHILE_LOOP:
- enqueue_ops.append(_wrap_computation_in_while_loop(
- device=host_device, op_fn=enqueue_ops_fn))
+ enqueue_ops.append(
+ _wrap_computation_in_while_loop(
+ device=host_device, op_fn=enqueue_ops_fn))
else:
enqueue_ops.append(enqueue_ops_fn())
infeed_queues.append(captured_infeed_queue.get())
with ops.control_dependencies([train_op]):
return array_ops.identity(loss)
+
return train_step, captured_scaffold_fn
def convert_to_single_tpu_eval_step(self, dequeue_fn):
with ops.control_dependencies([outfeed_ops]):
return math_ops.add(total_loss, loss)
+
return eval_step, eval_metrics, captured_scaffold_fn
def _call_model_fn(self, features, labels):
kwargs['params'] = params
if 'params' not in model_fn_args:
- raise ValueError(
- 'model_fn ({}) does not include params argument, '
- 'required by TPUEstimator to pass batch size as '
- 'params[\'batch_size\']'.format(self._model_fn))
+ raise ValueError('model_fn ({}) does not include params argument, '
+ 'required by TPUEstimator to pass batch size as '
+ 'params[\'batch_size\']'.format(self._model_fn))
batch_size_for_model_fn = self._ctx.batch_size_for_model_fn
if batch_size_for_model_fn is not None:
def _log_and_record(self, elapsed_steps, elapsed_time, global_step):
examples_per_sec = self._batch_size * elapsed_steps / elapsed_time
if self._summary_writer is not None:
- example_summary = Summary(value=[Summary.Value(
- tag='examples_sec', simple_value=examples_per_sec)])
+ example_summary = Summary(value=[
+ Summary.Value(tag='examples_sec', simple_value=examples_per_sec)
+ ])
self._summary_writer.add_summary(example_summary, global_step)
logging.info('examples/sec: %g', examples_per_sec)
'`config` must be provided with type `tpu_config.RunConfig`')
if params is not None and any(k in params for k in _RESERVED_PARAMS_KEYS):
- raise ValueError(
- '{} are reserved keys but existed in params {}.'.format(
- _RESERVED_PARAMS_KEYS, params))
+ raise ValueError('{} are reserved keys but existed in params {}.'.format(
+ _RESERVED_PARAMS_KEYS, params))
if use_tpu:
if train_batch_size is None:
if max_steps is not None:
util_lib.check_positive_integer(max_steps, 'Train max_steps')
- return [_TPUStopAtStepHook(self._iterations_per_training_loop, steps,
- max_steps)]
+ return [
+ _TPUStopAtStepHook(self._iterations_per_training_loop, steps, max_steps)
+ ]
def _convert_eval_steps_to_hooks(self, steps):
with self._ctx.with_mode(model_fn_lib.ModeKeys.EVAL) as ctx:
# `features` in `model_fn` signature.
def _input_fn():
return input_fn(**kwargs)
+
return _input_fn
def _augment_model_fn(self, model_fn, batch_axis):
total_loss, eval_metric_ops, scaffold = _eval_on_tpu_system(
ctx, model_fn_wrapper, dequeue_fn)
iterations_per_loop_var = _create_or_get_iterations_per_loop()
- mean_loss = math_ops.div(
- total_loss,
- math_ops.cast(iterations_per_loop_var, dtype=total_loss.dtype))
+ mean_loss = math_ops.div(total_loss,
+ math_ops.cast(
+ iterations_per_loop_var,
+ dtype=total_loss.dtype))
# Creates a dummy metric update_op for all metrics. Estimator expects
# all metrics in eval_metric_ops have update_op and calls them one by
evaluation_hooks=hooks,
eval_metric_ops=eval_metric_ops,
scaffold=scaffold)
+
return _model_fn
model_fn_wrapper.convert_to_single_tpu_eval_step(dequeue_fn))
def multi_tpu_eval_steps_on_single_shard():
- return training_loop.repeat(iterations_per_loop_var,
- single_tpu_eval_step,
- [_ZERO_LOSS],
- name='loop')
+ return training_loop.repeat(
+ iterations_per_loop_var,
+ single_tpu_eval_step, [_ZERO_LOSS],
+ name='loop')
- (loss,) = tpu.shard(multi_tpu_eval_steps_on_single_shard,
- inputs=[],
- num_shards=num_cores,
- outputs_from_all_shards=False)
+ (loss,) = tpu.shard(
+ multi_tpu_eval_steps_on_single_shard,
+ inputs=[],
+ num_shards=num_cores,
+ outputs_from_all_shards=False)
scaffold = _get_scaffold(captured_scaffold_fn)
return loss, eval_metric_ops, scaffold
def multi_tpu_train_steps_on_single_shard():
return training_loop.repeat(
iterations_per_loop_var,
- single_tpu_train_step,
- [_INITIAL_LOSS],
+ single_tpu_train_step, [_INITIAL_LOSS],
name=b'loop')
- (loss,) = tpu.shard(multi_tpu_train_steps_on_single_shard,
- inputs=[],
- num_shards=num_cores,
- outputs_from_all_shards=False)
+ (loss,) = tpu.shard(
+ multi_tpu_train_steps_on_single_shard,
+ inputs=[],
+ num_shards=num_cores,
+ outputs_from_all_shards=False)
scaffold = _get_scaffold(captured_scaffold_fn)
return loss, scaffold
def _wrap_computation_in_while_loop(device, op_fn):
"""Wraps the ops generated by `op_fn` in tf.while_loop."""
+
def computation(i):
with ops.control_dependencies(op_fn()):
return i + 1
iterations = array_ops.identity(iterations_per_loop_var)
return control_flow_ops.while_loop(
lambda i: i < iterations,
- computation, [constant_op.constant(0)], parallel_iterations=1)
+ computation, [constant_op.constant(0)],
+ parallel_iterations=1)
def _validate_tpu_training_graph():
# Check if there is atleast one CrossReplicaSum operation in the graph
# This should be introduced by using the CrossShardOptimizer wrapper
- cross_replica_sum_ops = [o for o in operations
- if o.type == _CROSS_REPLICA_SUM_OP]
+ cross_replica_sum_ops = [
+ o for o in operations if o.type == _CROSS_REPLICA_SUM_OP
+ ]
if not cross_replica_sum_ops:
raise ValueError(
'CrossShardOptimizer must be used for model training on TPUs.')
if scaffold:
wrapped_finalize = scaffold.finalize
+
def _finalize():
with _CapturingContext('Inside Scaffold.finalize'):
wrapped_finalize()
+
scaffold.finalize = _finalize
return scaffold
def AddOp(self, op): # pylint: disable=invalid-name
for c in op.inputs:
if tpu._TPU_REPLICATE_ATTR in c.op.node_def.attr: # pylint: disable=protected-access
- raise ValueError(
- '{}: Op {} depends on TPU computation {}, '
- 'which is not allowed.'.format(self._message, op, c))
+ raise ValueError('{}: Op {} depends on TPU computation {}, '
+ 'which is not allowed.'.format(self._message, op, c))
def __enter__(self):
# pylint: disable=protected-access