From: Jianwei Xie Date: Fri, 18 May 2018 21:59:01 +0000 (-0700) Subject: Fixed an issue when add context into params. X-Git-Tag: upstream/v1.9.0_rc1~94^2^2~5 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=5c107c91323fa5ce5b2df4de923a8c689d19cdcd;p=platform%2Fupstream%2Ftensorflow.git Fixed an issue when add context into params. PiperOrigin-RevId: 197205327 --- diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py index 808545b..77d117b 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py @@ -1314,10 +1314,7 @@ class _ModelFnWrapper(object): batch_size_for_model_fn = self._ctx.batch_size_for_model_fn if batch_size_for_model_fn is not None: - if isinstance(params, hparam.HParams): - params.add_hparam(_BATCH_SIZE_KEY, batch_size_for_model_fn) - else: - params[_BATCH_SIZE_KEY] = batch_size_for_model_fn + _add_item_to_params(params, _BATCH_SIZE_KEY, batch_size_for_model_fn) estimator_spec = self._model_fn(features=features, **kwargs) if (self._ctx.is_running_on_cpu(is_export_mode) and @@ -1983,10 +1980,8 @@ class TPUEstimator(estimator_lib.Estimator): # input_fn for use_tpu=True/False. batch_size_for_input_fn = ctx.batch_size_for_input_fn if batch_size_for_input_fn is not None: - if isinstance(kwargs['params'], hparam.HParams): - kwargs['params'].add_hparam(_BATCH_SIZE_KEY, batch_size_for_input_fn) - else: - kwargs['params'][_BATCH_SIZE_KEY] = batch_size_for_input_fn + _add_item_to_params(kwargs['params'], + _BATCH_SIZE_KEY, batch_size_for_input_fn) # For export_savedmodel, input_fn is never passed to Estimator. So, # `is_export_mode` must be False. @@ -2005,7 +2000,7 @@ class TPUEstimator(estimator_lib.Estimator): # dequeue_fn to model_fn. Here, `input_fn` is passed directly as # `features` in `model_fn` signature. def _input_fn(ctx): - kwargs['params'][_CTX_KEY] = ctx + _add_item_to_params(kwargs['params'], _CTX_KEY, ctx) return input_fn(**kwargs) return _input_fn @@ -2823,3 +2818,17 @@ def _verify_cross_hosts_transfer_size(tensor_dict, message): '{}'.format(message, '\n'.join([ ' -- Key: {}, Shape: {}'.format(k, v) for k, v in tensor_structure.items()]))) + + +def _add_item_to_params(params, key, value): + """Adds a new item into `params`.""" + if isinstance(params, hparam.HParams): + # For HParams, we need to use special API. + if key in params: + params.key = value + else: + params.add_hparam(key, value) + else: + # Now params is Python dict. + params[key] = value +