TFTS: Fix a bug in the SavedModel cold-start export
authorAllen Lavoie <allenl@google.com>
Tue, 27 Mar 2018 22:55:04 +0000 (15:55 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 27 Mar 2018 22:57:43 +0000 (15:57 -0700)
It now correctly broadcasts start state across whatever batch dimension it is
passed rather than sqishing it down to a batch dimension of 1.

PiperOrigin-RevId: 190688855

tensorflow/contrib/timeseries/python/timeseries/estimators_test.py
tensorflow/contrib/timeseries/python/timeseries/head.py

index f4304f2..51d0c0c 100644 (file)
@@ -126,6 +126,27 @@ class TimeSeriesRegressorTest(test.TestCase):
             signatures=signatures,
             session=sess)
 
+        # Test cold starting
+        batch_numpy_times = numpy.tile(
+            numpy.arange(30, dtype=numpy.int64)[None, :], (10, 1))
+        batch_numpy_values = numpy.ones([10, 30, 1])
+        state = saved_model_utils.cold_start_filter(
+            signatures=signatures,
+            session=sess,
+            features={
+                feature_keys.FilteringFeatures.TIMES: batch_numpy_times,
+                feature_keys.FilteringFeatures.VALUES: batch_numpy_values
+            }
+        )
+        predict_times = numpy.tile(
+            numpy.arange(30, 45, dtype=numpy.int64)[None, :], (10, 1))
+        predictions = saved_model_utils.predict_continuation(
+            continue_from=state,
+            times=predict_times,
+            signatures=signatures,
+            session=sess)
+        self.assertAllEqual([10, 15, 1], predictions["mean"].shape)
+
   def test_fit_restore_fit_ar_regressor(self):
     def _estimator_fn(model_dir):
       return estimators.ARRegressor(
index 3d7e615..4cf6bbc 100644 (file)
@@ -154,8 +154,10 @@ class _TimeSeriesRegressionHead(head_lib._Head):  # pylint:disable=protected-acc
       no_state_features = {
           k: v for k, v in features.items()
           if not k.startswith(feature_keys.State.STATE_PREFIX)}
-      cold_filtering_outputs = self.create_loss(
-          no_state_features, estimator_lib.ModeKeys.EVAL)
+      # Ignore any state management when cold-starting. The model's default
+      # start state is replicated across the batch.
+      cold_filtering_outputs = self.model.define_loss(
+          features=no_state_features, mode=estimator_lib.ModeKeys.EVAL)
     return estimator_lib.EstimatorSpec(
         mode=estimator_lib.ModeKeys.PREDICT,
         export_outputs={