signatures=signatures,
session=sess)
+ # Test cold starting
+ batch_numpy_times = numpy.tile(
+ numpy.arange(30, dtype=numpy.int64)[None, :], (10, 1))
+ batch_numpy_values = numpy.ones([10, 30, 1])
+ state = saved_model_utils.cold_start_filter(
+ signatures=signatures,
+ session=sess,
+ features={
+ feature_keys.FilteringFeatures.TIMES: batch_numpy_times,
+ feature_keys.FilteringFeatures.VALUES: batch_numpy_values
+ }
+ )
+ predict_times = numpy.tile(
+ numpy.arange(30, 45, dtype=numpy.int64)[None, :], (10, 1))
+ predictions = saved_model_utils.predict_continuation(
+ continue_from=state,
+ times=predict_times,
+ signatures=signatures,
+ session=sess)
+ self.assertAllEqual([10, 15, 1], predictions["mean"].shape)
+
def test_fit_restore_fit_ar_regressor(self):
def _estimator_fn(model_dir):
return estimators.ARRegressor(
no_state_features = {
k: v for k, v in features.items()
if not k.startswith(feature_keys.State.STATE_PREFIX)}
- cold_filtering_outputs = self.create_loss(
- no_state_features, estimator_lib.ModeKeys.EVAL)
+ # Ignore any state management when cold-starting. The model's default
+ # start state is replicated across the batch.
+ cold_filtering_outputs = self.model.define_loss(
+ features=no_state_features, mode=estimator_lib.ModeKeys.EVAL)
return estimator_lib.EstimatorSpec(
mode=estimator_lib.ModeKeys.PREDICT,
export_outputs={