from tensorflow.contrib.tpu.python.tpu import tpu_feed
from tensorflow.contrib.tpu.python.tpu import training_loop
from tensorflow.contrib.tpu.python.tpu import util as util_lib
+from tensorflow.core.framework import variable_pb2
from tensorflow.core.framework.summary_pb2 import Summary
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
_BATCH_SIZE_KEY = 'batch_size'
_CROSS_REPLICA_SUM_OP = 'CrossReplicaSum'
_ONE_GIGABYTE = 1024 * 1024 * 1024
+_TPU_ENQUEUE_OPS = '_tpu_enqueue_ops'
+_TPU_TRAIN_OP = '_tpu_train_op'
_RESERVED_PARAMS_KEYS = [_BATCH_SIZE_KEY]
_WRAP_INPUT_FN_INTO_WHILE_LOOP = False
+ops.register_proto_function(
+ '{}_{}'.format(_TPU_ESTIMATOR, _ITERATIONS_PER_LOOP_VAR),
+ proto_type=variable_pb2.VariableDef,
+ to_proto=resource_variable_ops._to_proto_fn, # pylint: disable=protected-access
+ from_proto=resource_variable_ops._from_proto_fn) # pylint: disable=protected-access
+
+
def _create_global_step(graph):
graph = graph or ops.get_default_graph()
if training.get_global_step(graph) is not None:
enqueue_ops, dequeue_fn, input_hooks, run_infeed_loop_on_coordinator = (
input_holders.generate_infeed_enqueue_ops_and_dequeue_fn())
+ graph = ops.get_default_graph()
+ for enqueue_op in enqueue_ops:
+ if isinstance(enqueue_op, list):
+ graph.get_collection_ref(_TPU_ENQUEUE_OPS).extend(enqueue_op)
+ else:
+ graph.add_to_collection(_TPU_ENQUEUE_OPS, enqueue_op)
+
if mode == model_fn_lib.ModeKeys.TRAIN:
loss, host_call, scaffold = (
_train_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn))
# Validate the TPU training graph to catch basic errors
_validate_tpu_training_graph()
+ train_op = control_flow_ops.group(*update_ops)
+ graph.add_to_collection(_TPU_TRAIN_OP, train_op)
+
return model_fn_lib.EstimatorSpec(
mode,
loss=loss,
training_hooks=hooks,
- train_op=control_flow_ops.group(*update_ops),
+ train_op=train_op,
scaffold=scaffold)
if mode == model_fn_lib.ModeKeys.EVAL: