Pipe through warm_start_from parameter
authorA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 11 May 2018 00:07:21 +0000 (17:07 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 11 May 2018 00:11:42 +0000 (17:11 -0700)
PiperOrigin-RevId: 196194069

tensorflow/contrib/tpu/python/tpu/tpu_estimator.py

index a624ece..afc8c7d 100644 (file)
@@ -1759,7 +1759,8 @@ class TPUEstimator(estimator_lib.Estimator):
                train_batch_size=None,
                eval_batch_size=None,
                predict_batch_size=None,
-               batch_axis=None):
+               batch_axis=None,
+               warm_start_from=None):
     """Constructs an `TPUEstimator` instance.
 
     Args:
@@ -1798,6 +1799,12 @@ class TPUEstimator(estimator_lib.Estimator):
         and per_host_input_for_training is True, batches will be sharded based
         on the major dimension. If tpu_config.per_host_input_for_training is
         False or `PER_HOST_V2`, batch_axis is ignored.
+      warm_start_from: Optional string filepath to a checkpoint or SavedModel to
+                       warm-start from, or a `tf.estimator.WarmStartSettings`
+                       object to fully configure warm-starting.  If the string
+                       filepath is provided instead of a `WarmStartSettings`,
+                       then all variables are warm-started, and it is assumed
+                       that vocabularies and Tensor names are unchanged.
 
     Raises:
       ValueError: `params` has reserved keys already.
@@ -1850,7 +1857,8 @@ class TPUEstimator(estimator_lib.Estimator):
         model_fn=model_function,
         model_dir=model_dir,
         config=config,
-        params=params)
+        params=params,
+        warm_start_from=warm_start_from)
     self._iterations_per_training_loop = (
         self._config.tpu_config.iterations_per_loop)