":tpu_lib",
":tpu_py",
"//tensorflow/contrib/summary:summary_ops",
+ "//tensorflow/contrib/training:training_py",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:control_flow_ops",
from tensorflow.contrib.tpu.python.tpu import tpu_feed
from tensorflow.contrib.tpu.python.tpu import training_loop
from tensorflow.contrib.tpu.python.tpu import util as util_lib
+from tensorflow.contrib.training.python.training import hparam
from tensorflow.core.framework import variable_pb2
from tensorflow.core.framework.summary_pb2 import Summary
from tensorflow.core.protobuf import config_pb2
batch_size_for_model_fn = self._ctx.batch_size_for_model_fn
if batch_size_for_model_fn is not None:
- params[_BATCH_SIZE_KEY] = batch_size_for_model_fn
+ if isinstance(params, hparam.HParams):
+ params.add_hparam(_BATCH_SIZE_KEY, batch_size_for_model_fn)
+ else:
+ params[_BATCH_SIZE_KEY] = batch_size_for_model_fn
estimator_spec = self._model_fn(features=features, **kwargs)
if (self._ctx.is_running_on_cpu(is_export_mode) and
# input_fn for use_tpu=True/False.
batch_size_for_input_fn = ctx.batch_size_for_input_fn
if batch_size_for_input_fn is not None:
- kwargs['params'][_BATCH_SIZE_KEY] = batch_size_for_input_fn
+ if isinstance(kwargs['params'], hparam.HParams):
+ kwargs['params'].add_hparam(_BATCH_SIZE_KEY, batch_size_for_input_fn)
+ else:
+ kwargs['params'][_BATCH_SIZE_KEY] = batch_size_for_input_fn
# For export_savedmodel, input_fn is never passed to Estimator. So,
# `is_export_mode` must be False.