From a16761483ec55095158b1b11118d93ea00a538f4 Mon Sep 17 00:00:00 2001 From: Allen Lavoie Date: Tue, 27 Mar 2018 15:55:04 -0700 Subject: [PATCH] TFTS: Fix a bug in the SavedModel cold-start export It now correctly broadcasts start state across whatever batch dimension it is passed rather than sqishing it down to a batch dimension of 1. PiperOrigin-RevId: 190688855 --- .../timeseries/python/timeseries/estimators_test.py | 21 +++++++++++++++++++++ .../contrib/timeseries/python/timeseries/head.py | 6 ++++-- 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/tensorflow/contrib/timeseries/python/timeseries/estimators_test.py b/tensorflow/contrib/timeseries/python/timeseries/estimators_test.py index f4304f2..51d0c0c 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/estimators_test.py +++ b/tensorflow/contrib/timeseries/python/timeseries/estimators_test.py @@ -126,6 +126,27 @@ class TimeSeriesRegressorTest(test.TestCase): signatures=signatures, session=sess) + # Test cold starting + batch_numpy_times = numpy.tile( + numpy.arange(30, dtype=numpy.int64)[None, :], (10, 1)) + batch_numpy_values = numpy.ones([10, 30, 1]) + state = saved_model_utils.cold_start_filter( + signatures=signatures, + session=sess, + features={ + feature_keys.FilteringFeatures.TIMES: batch_numpy_times, + feature_keys.FilteringFeatures.VALUES: batch_numpy_values + } + ) + predict_times = numpy.tile( + numpy.arange(30, 45, dtype=numpy.int64)[None, :], (10, 1)) + predictions = saved_model_utils.predict_continuation( + continue_from=state, + times=predict_times, + signatures=signatures, + session=sess) + self.assertAllEqual([10, 15, 1], predictions["mean"].shape) + def test_fit_restore_fit_ar_regressor(self): def _estimator_fn(model_dir): return estimators.ARRegressor( diff --git a/tensorflow/contrib/timeseries/python/timeseries/head.py b/tensorflow/contrib/timeseries/python/timeseries/head.py index 3d7e615..4cf6bbc 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/head.py +++ b/tensorflow/contrib/timeseries/python/timeseries/head.py @@ -154,8 +154,10 @@ class _TimeSeriesRegressionHead(head_lib._Head): # pylint:disable=protected-acc no_state_features = { k: v for k, v in features.items() if not k.startswith(feature_keys.State.STATE_PREFIX)} - cold_filtering_outputs = self.create_loss( - no_state_features, estimator_lib.ModeKeys.EVAL) + # Ignore any state management when cold-starting. The model's default + # start state is replicated across the batch. + cold_filtering_outputs = self.model.define_loss( + features=no_state_features, mode=estimator_lib.ModeKeys.EVAL) return estimator_lib.EstimatorSpec( mode=estimator_lib.ModeKeys.PREDICT, export_outputs={ -- 2.7.4