Adds support for tf.HParams to TPUEstimator.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Mon, 9 Apr 2018 18:11:10 +0000 (11:11 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 9 Apr 2018 18:13:33 +0000 (11:13 -0700)
PiperOrigin-RevId: 192154504

tensorflow/contrib/tpu/BUILD
tensorflow/contrib/tpu/python/tpu/tpu_estimator.py

index 4de09dd..2f4a767 100644 (file)
@@ -47,6 +47,7 @@ py_library(
         ":tpu_lib",
         ":tpu_py",
         "//tensorflow/contrib/summary:summary_ops",
+        "//tensorflow/contrib/training:training_py",
         "//tensorflow/core:protos_all_py",
         "//tensorflow/python:array_ops",
         "//tensorflow/python:control_flow_ops",
index 47365b7..1332108 100644 (file)
@@ -38,6 +38,7 @@ from tensorflow.contrib.tpu.python.tpu import tpu_context
 from tensorflow.contrib.tpu.python.tpu import tpu_feed
 from tensorflow.contrib.tpu.python.tpu import training_loop
 from tensorflow.contrib.tpu.python.tpu import util as util_lib
+from tensorflow.contrib.training.python.training import hparam
 from tensorflow.core.framework import variable_pb2
 from tensorflow.core.framework.summary_pb2 import Summary
 from tensorflow.core.protobuf import config_pb2
@@ -1308,7 +1309,10 @@ class _ModelFnWrapper(object):
       batch_size_for_model_fn = self._ctx.batch_size_for_model_fn
 
     if batch_size_for_model_fn is not None:
-      params[_BATCH_SIZE_KEY] = batch_size_for_model_fn
+      if isinstance(params, hparam.HParams):
+        params.add_hparam(_BATCH_SIZE_KEY, batch_size_for_model_fn)
+      else:
+        params[_BATCH_SIZE_KEY] = batch_size_for_model_fn
 
     estimator_spec = self._model_fn(features=features, **kwargs)
     if (self._ctx.is_running_on_cpu(is_export_mode) and
@@ -1947,7 +1951,10 @@ class TPUEstimator(estimator_lib.Estimator):
       # input_fn for use_tpu=True/False.
       batch_size_for_input_fn = ctx.batch_size_for_input_fn
       if batch_size_for_input_fn is not None:
-        kwargs['params'][_BATCH_SIZE_KEY] = batch_size_for_input_fn
+        if isinstance(kwargs['params'], hparam.HParams):
+          kwargs['params'].add_hparam(_BATCH_SIZE_KEY, batch_size_for_input_fn)
+        else:
+          kwargs['params'][_BATCH_SIZE_KEY] = batch_size_for_input_fn
 
       # For export_savedmodel, input_fn is never passed to Estimator. So,
       # `is_export_mode` must be False.