Save some useful TPU estimator's ops into collections for performance measurement.
authorRui Zhao <rzhao@google.com>
Sat, 7 Apr 2018 21:43:08 +0000 (14:43 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Sat, 7 Apr 2018 21:45:49 +0000 (14:45 -0700)
PiperOrigin-RevId: 192016099

tensorflow/contrib/tpu/python/tpu/tpu_estimator.py

index 6834600..47365b7 100644 (file)
@@ -38,6 +38,7 @@ from tensorflow.contrib.tpu.python.tpu import tpu_context
 from tensorflow.contrib.tpu.python.tpu import tpu_feed
 from tensorflow.contrib.tpu.python.tpu import training_loop
 from tensorflow.contrib.tpu.python.tpu import util as util_lib
+from tensorflow.core.framework import variable_pb2
 from tensorflow.core.framework.summary_pb2 import Summary
 from tensorflow.core.protobuf import config_pb2
 from tensorflow.python.data.ops import dataset_ops
@@ -53,6 +54,7 @@ from tensorflow.python.ops import check_ops
 from tensorflow.python.ops import control_flow_ops
 from tensorflow.python.ops import init_ops
 from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import resource_variable_ops
 from tensorflow.python.ops import state_ops
 from tensorflow.python.ops import variable_scope
 from tensorflow.python.ops import variables
@@ -73,6 +75,8 @@ _ITERATIONS_PER_LOOP_VAR = 'iterations_per_loop'
 _BATCH_SIZE_KEY = 'batch_size'
 _CROSS_REPLICA_SUM_OP = 'CrossReplicaSum'
 _ONE_GIGABYTE = 1024 * 1024 * 1024
+_TPU_ENQUEUE_OPS = '_tpu_enqueue_ops'
+_TPU_TRAIN_OP = '_tpu_train_op'
 
 _RESERVED_PARAMS_KEYS = [_BATCH_SIZE_KEY]
 
@@ -85,6 +89,13 @@ _RESERVED_PARAMS_KEYS = [_BATCH_SIZE_KEY]
 _WRAP_INPUT_FN_INTO_WHILE_LOOP = False
 
 
+ops.register_proto_function(
+    '{}_{}'.format(_TPU_ESTIMATOR, _ITERATIONS_PER_LOOP_VAR),
+    proto_type=variable_pb2.VariableDef,
+    to_proto=resource_variable_ops._to_proto_fn,  # pylint: disable=protected-access
+    from_proto=resource_variable_ops._from_proto_fn)  # pylint: disable=protected-access
+
+
 def _create_global_step(graph):
   graph = graph or ops.get_default_graph()
   if training.get_global_step(graph) is not None:
@@ -2006,6 +2017,13 @@ class TPUEstimator(estimator_lib.Estimator):
         enqueue_ops, dequeue_fn, input_hooks, run_infeed_loop_on_coordinator = (
             input_holders.generate_infeed_enqueue_ops_and_dequeue_fn())
 
+        graph = ops.get_default_graph()
+        for enqueue_op in enqueue_ops:
+          if isinstance(enqueue_op, list):
+            graph.get_collection_ref(_TPU_ENQUEUE_OPS).extend(enqueue_op)
+          else:
+            graph.add_to_collection(_TPU_ENQUEUE_OPS, enqueue_op)
+
         if mode == model_fn_lib.ModeKeys.TRAIN:
           loss, host_call, scaffold = (
               _train_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn))
@@ -2036,11 +2054,14 @@ class TPUEstimator(estimator_lib.Estimator):
           # Validate the TPU training graph to catch basic errors
           _validate_tpu_training_graph()
 
+          train_op = control_flow_ops.group(*update_ops)
+          graph.add_to_collection(_TPU_TRAIN_OP, train_op)
+
           return model_fn_lib.EstimatorSpec(
               mode,
               loss=loss,
               training_hooks=hooks,
-              train_op=control_flow_ops.group(*update_ops),
+              train_op=train_op,
               scaffold=scaffold)
 
         if mode == model_fn_lib.ModeKeys.EVAL: