Fixed an issue when add context into params.
authorJianwei Xie <xiejw@google.com>
Fri, 18 May 2018 21:59:01 +0000 (14:59 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 18 May 2018 22:01:36 +0000 (15:01 -0700)
PiperOrigin-RevId: 197205327

tensorflow/contrib/tpu/python/tpu/tpu_estimator.py

index 808545b..77d117b 100644 (file)
@@ -1314,10 +1314,7 @@ class _ModelFnWrapper(object):
       batch_size_for_model_fn = self._ctx.batch_size_for_model_fn
 
     if batch_size_for_model_fn is not None:
-      if isinstance(params, hparam.HParams):
-        params.add_hparam(_BATCH_SIZE_KEY, batch_size_for_model_fn)
-      else:
-        params[_BATCH_SIZE_KEY] = batch_size_for_model_fn
+      _add_item_to_params(params, _BATCH_SIZE_KEY, batch_size_for_model_fn)
 
     estimator_spec = self._model_fn(features=features, **kwargs)
     if (self._ctx.is_running_on_cpu(is_export_mode) and
@@ -1983,10 +1980,8 @@ class TPUEstimator(estimator_lib.Estimator):
       # input_fn for use_tpu=True/False.
       batch_size_for_input_fn = ctx.batch_size_for_input_fn
       if batch_size_for_input_fn is not None:
-        if isinstance(kwargs['params'], hparam.HParams):
-          kwargs['params'].add_hparam(_BATCH_SIZE_KEY, batch_size_for_input_fn)
-        else:
-          kwargs['params'][_BATCH_SIZE_KEY] = batch_size_for_input_fn
+        _add_item_to_params(kwargs['params'],
+                            _BATCH_SIZE_KEY, batch_size_for_input_fn)
 
       # For export_savedmodel, input_fn is never passed to Estimator. So,
       # `is_export_mode` must be False.
@@ -2005,7 +2000,7 @@ class TPUEstimator(estimator_lib.Estimator):
       # dequeue_fn to model_fn. Here, `input_fn` is passed directly as
       # `features` in `model_fn` signature.
       def _input_fn(ctx):
-        kwargs['params'][_CTX_KEY] = ctx
+        _add_item_to_params(kwargs['params'], _CTX_KEY, ctx)
         return input_fn(**kwargs)
 
       return _input_fn
@@ -2823,3 +2818,17 @@ def _verify_cross_hosts_transfer_size(tensor_dict, message):
         '{}'.format(message, '\n'.join([
             ' -- Key: {}, Shape: {}'.format(k, v)
             for k, v in tensor_structure.items()])))
+
+
+def _add_item_to_params(params, key, value):
+  """Adds a new item into `params`."""
+  if isinstance(params, hparam.HParams):
+    # For HParams, we need to use special API.
+    if key in params:
+      params.key = value
+    else:
+      params.add_hparam(key, value)
+  else:
+    # Now params is Python dict.
+    params[key] = value
+