train_batch_size=None,
eval_batch_size=None,
predict_batch_size=None,
- batch_axis=None):
+ batch_axis=None,
+ warm_start_from=None):
"""Constructs an `TPUEstimator` instance.
Args:
and per_host_input_for_training is True, batches will be sharded based
on the major dimension. If tpu_config.per_host_input_for_training is
False or `PER_HOST_V2`, batch_axis is ignored.
+ warm_start_from: Optional string filepath to a checkpoint or SavedModel to
+ warm-start from, or a `tf.estimator.WarmStartSettings`
+ object to fully configure warm-starting. If the string
+ filepath is provided instead of a `WarmStartSettings`,
+ then all variables are warm-started, and it is assumed
+ that vocabularies and Tensor names are unchanged.
Raises:
ValueError: `params` has reserved keys already.
model_fn=model_function,
model_dir=model_dir,
config=config,
- params=params)
+ params=params,
+ warm_start_from=warm_start_from)
self._iterations_per_training_loop = (
self._config.tpu_config.iterations_per_loop)