TFTS: Allow cold-starting from SavedModels
authorAllen Lavoie <allenl@google.com>
Sat, 17 Mar 2018 00:00:17 +0000 (17:00 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Sat, 17 Mar 2018 00:04:34 +0000 (17:04 -0700)
This means the model starts from its default start state and is fed a series
(filtering) to warm up its state. This warmed up state can then be used to
make predictions.

Some shape fiddling with the receiver_fn to make feeding state optional, and a
new signature for cold-starting which uses the model's default start state.

Some other shape fiddling to make feeding strings to SavedModels work more
smoothly in the cold-start part of the LSTM example. I was squeezing out the
last dimension of "scalar" exogenous features, now I'm leaving them, which
matches the placeholder generation logic.

PiperOrigin-RevId: 189414869

tensorflow/contrib/timeseries/examples/known_anomaly.py
tensorflow/contrib/timeseries/examples/lstm.py
tensorflow/contrib/timeseries/python/timeseries/estimators.py
tensorflow/contrib/timeseries/python/timeseries/feature_keys.py
tensorflow/contrib/timeseries/python/timeseries/head.py
tensorflow/contrib/timeseries/python/timeseries/input_pipeline.py
tensorflow/contrib/timeseries/python/timeseries/saved_model_utils.py

index c08c0b0..e77628d 100644 (file)
@@ -53,6 +53,15 @@ def train_and_evaluate_exogenous(csv_file_name=_DATA_FILE, train_steps=300):
   one_hot_feature = tf.feature_column.indicator_column(
       categorical_column=string_feature)
 
+  def _exogenous_update_condition(times, features):
+    del times  # unused
+    # Make exogenous updates sparse by setting an update condition. This in
+    # effect allows missing exogenous features: if the condition evaluates to
+    # False, no update is performed. Otherwise we sometimes end up with "leaky"
+    # updates which add unnecessary uncertainty to the model even when there is
+    # no changepoint.
+    return tf.equal(tf.squeeze(features["is_changepoint"], axis=-1), "yes")
+
   estimator = tf.contrib.timeseries.StructuralEnsembleRegressor(
       periodicities=12,
       # Extract a smooth period by constraining the number of latent values
@@ -60,13 +69,7 @@ def train_and_evaluate_exogenous(csv_file_name=_DATA_FILE, train_steps=300):
       cycle_num_latent_values=3,
       num_features=1,
       exogenous_feature_columns=[one_hot_feature],
-      # Make exogenous updates sparse by setting an update condition. This in
-      # effect allows missing exogenous features: if the condition evaluates to
-      # False, no update is performed. Otherwise we sometimes end up with
-      # "leaky" updates which add unnecessary uncertainty to the model even when
-      # there is no changepoint.
-      exogenous_update_condition=
-      lambda times, features: tf.equal(features["is_changepoint"], "yes"))
+      exogenous_update_condition=_exogenous_update_condition)
   reader = tf.contrib.timeseries.CSVReader(
       csv_file_name,
       # Indicate the format of our CSV file. First we have two standard columns,
index 2eee878..b1c7475 100644 (file)
@@ -236,20 +236,36 @@ def train_and_predict(
       [evaluation["mean"][0], predictions["mean"]], axis=0))
   all_times = numpy.concatenate([times, predictions["times"]], axis=0)
 
-  # Export the model in SavedModel format.
+  # Export the model in SavedModel format. We include a bit of extra boilerplate
+  # for "cold starting" as if we didn't have any state from the Estimator, which
+  # is the case when serving from a SavedModel. If Estimator output is
+  # available, the result of "Estimator.evaluate" can be passed directly to
+  # `tf.contrib.timeseries.saved_model_utils.predict_continuation` as the
+  # `continue_from` argument.
+  with tf.Graph().as_default():
+    filter_feature_tensors, _ = evaluation_input_fn()
+    with tf.train.MonitoredSession() as session:
+      # Fetch the series to "warm up" our state, which will allow us to make
+      # predictions for its future values. This is just a dictionary of times,
+      # values, and exogenous features mapping to numpy arrays. The use of an
+      # input_fn is just a convenience for the example; they can also be
+      # specified manually.
+      filter_features = session.run(filter_feature_tensors)
   if export_directory is None:
     export_directory = tempfile.mkdtemp()
   input_receiver_fn = estimator.build_raw_serving_input_receiver_fn()
   export_location = estimator.export_savedmodel(
       export_directory, input_receiver_fn)
-  # Predict using the SavedModel
+  # Warm up and predict using the SavedModel
   with tf.Graph().as_default():
     with tf.Session() as session:
       signatures = tf.saved_model.loader.load(
           session, [tf.saved_model.tag_constants.SERVING], export_location)
+      state = tf.contrib.timeseries.saved_model_utils.cold_start_filter(
+          signatures=signatures, session=session, features=filter_features)
       saved_model_output = (
           tf.contrib.timeseries.saved_model_utils.predict_continuation(
-              continue_from=evaluation, signatures=signatures,
+              continue_from=state, signatures=signatures,
               session=session, steps=100,
               exogenous_features=predict_exogenous_features))
       # The exported model gives the same results as the Estimator.predict()
index 8d13343..469cea4 100644 (file)
@@ -33,9 +33,11 @@ from tensorflow.python.feature_column import feature_column
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import tensor_util
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import parsing_ops
 from tensorflow.python.training import training as train
+from tensorflow.python.util import nest
 
 
 class TimeSeriesRegressor(estimator_lib.Estimator):
@@ -98,11 +100,11 @@ class TimeSeriesRegressor(estimator_lib.Estimator):
     def _serving_input_receiver_fn():
       """A receiver function to be passed to export_savedmodel."""
       placeholders = {}
-      placeholders[feature_keys.TrainEvalFeatures.TIMES] = (
-          array_ops.placeholder(
-              name=feature_keys.TrainEvalFeatures.TIMES,
-              dtype=dtypes.int64,
-              shape=[default_batch_size, default_series_length]))
+      time_placeholder = array_ops.placeholder(
+          name=feature_keys.TrainEvalFeatures.TIMES,
+          dtype=dtypes.int64,
+          shape=[default_batch_size, default_series_length])
+      placeholders[feature_keys.TrainEvalFeatures.TIMES] = time_placeholder
       # Values are only necessary when filtering. For prediction the default
       # value will be ignored.
       placeholders[feature_keys.TrainEvalFeatures.VALUES] = (
@@ -145,15 +147,29 @@ class TimeSeriesRegressor(estimator_lib.Estimator):
       # use only static metadata from the returned Tensors.
       with ops.Graph().as_default():
         self._model.initialize_graph()
-        model_start_state = self._model.get_start_state()
-      for prefixed_state_name, state_tensor in ts_head_lib.state_to_dictionary(
-          model_start_state).items():
+        # Evaluate the initial state as same-dtype "zero" values. These zero
+        # constants aren't used, but are necessary for feeding to
+        # placeholder_with_default for the "cold start" case where state is not
+        # fed to the model.
+        def _zeros_like_constant(tensor):
+          return tensor_util.constant_value(array_ops.zeros_like(tensor))
+        start_state = nest.map_structure(
+            _zeros_like_constant, self._model.get_start_state())
+      batch_size_tensor = array_ops.shape(time_placeholder)[0]
+      for prefixed_state_name, state in ts_head_lib.state_to_dictionary(
+          start_state).items():
         state_shape_with_batch = tensor_shape.TensorShape(
-            (default_batch_size,)).concatenate(state_tensor.get_shape())
-        placeholders[prefixed_state_name] = array_ops.placeholder(
+            (default_batch_size,)).concatenate(state.shape)
+        default_state_broadcast = array_ops.tile(
+            state[None, ...],
+            multiples=array_ops.concat(
+                [batch_size_tensor[None],
+                 array_ops.ones(len(state.shape), dtype=dtypes.int32)],
+                axis=0))
+        placeholders[prefixed_state_name] = array_ops.placeholder_with_default(
+            input=default_state_broadcast,
             name=prefixed_state_name,
-            shape=state_shape_with_batch,
-            dtype=state_tensor.dtype)
+            shape=state_shape_with_batch)
       return export_lib.ServingInputReceiver(placeholders, placeholders)
 
     return _serving_input_receiver_fn
index 970b9aa..56566ee 100644 (file)
@@ -72,3 +72,4 @@ class SavedModelLabels(object):
   """Names of signatures exported with export_savedmodel."""
   PREDICT = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
   FILTER = "filter"
+  COLD_START_FILTER = "cold_start_filter"
index f4d9351..3d7e615 100644 (file)
@@ -150,6 +150,12 @@ class _TimeSeriesRegressionHead(head_lib._Head):  # pylint:disable=protected-acc
     with variable_scope.variable_scope("model", reuse=True):
       filtering_outputs = self.create_loss(
           features, estimator_lib.ModeKeys.EVAL)
+    with variable_scope.variable_scope("model", reuse=True):
+      no_state_features = {
+          k: v for k, v in features.items()
+          if not k.startswith(feature_keys.State.STATE_PREFIX)}
+      cold_filtering_outputs = self.create_loss(
+          no_state_features, estimator_lib.ModeKeys.EVAL)
     return estimator_lib.EstimatorSpec(
         mode=estimator_lib.ModeKeys.PREDICT,
         export_outputs={
@@ -157,7 +163,10 @@ class _TimeSeriesRegressionHead(head_lib._Head):  # pylint:disable=protected-acc
                 export_lib.PredictOutput(prediction_outputs),
             feature_keys.SavedModelLabels.FILTER:
                 export_lib.PredictOutput(
-                    state_to_dictionary(filtering_outputs.end_state))
+                    state_to_dictionary(filtering_outputs.end_state)),
+            feature_keys.SavedModelLabels.COLD_START_FILTER:
+                export_lib.PredictOutput(
+                    state_to_dictionary(cold_filtering_outputs.end_state))
         },
         # Likely unused, but it is necessary to return `predictions` to satisfy
         # the Estimator's error checking.
index 0422533..403c6e2 100644 (file)
@@ -492,8 +492,7 @@ class CSVReader(ReaderBaseTimeSeriesParser):
       features_lists.setdefault(column_name, []).append(value)
     features = {}
     for column_name, values in features_lists.items():
-      if (len(values) == 1 and
-          column_name != feature_keys.TrainEvalFeatures.VALUES):
+      if column_name == feature_keys.TrainEvalFeatures.TIMES:
         features[column_name] = values[0]
       else:
         features[column_name] = array_ops.stack(values, axis=1)
index 97f6d36..0461abd 100644 (file)
@@ -15,6 +15,7 @@
 """Convenience functions for working with time series saved_models.
 
 @@predict_continuation
+@@cold_start_filter
 @@filter_continuation
 """
 
@@ -30,10 +31,12 @@ from tensorflow.contrib.timeseries.python.timeseries import model_utils as _mode
 from tensorflow.python.util.all_util import remove_undocumented
 
 
-def _colate_features_to_feeds_and_fetches(continue_from, signature, features,
-                                          graph):
+def _colate_features_to_feeds_and_fetches(signature, features, graph,
+                                          continue_from=None):
   """Uses a saved model signature to construct feed and fetch dictionaries."""
-  if _feature_keys.FilteringResults.STATE_TUPLE in continue_from:
+  if continue_from is None:
+    state_values = {}
+  elif _feature_keys.FilteringResults.STATE_TUPLE in continue_from:
     # We're continuing from an evaluation, so we need to unpack/flatten state.
     state_values = _head.state_to_dictionary(
         continue_from[_feature_keys.FilteringResults.STATE_TUPLE])
@@ -115,6 +118,55 @@ def predict_continuation(continue_from,
   return output
 
 
+def cold_start_filter(signatures, session, features):
+  """Perform filtering using an exported saved model.
+
+  Filtering refers to updating model state based on new observations.
+  Predictions based on the returned model state will be conditioned on these
+  observations.
+
+  Starts from the model's default/uninformed state.
+
+  Args:
+    signatures: The `MetaGraphDef` protocol buffer returned from
+      `tf.saved_model.loader.load`. Used to determine the names of Tensors to
+      feed and fetch. Must be from the same model as `continue_from`.
+    session: The session to use. The session's graph must be the one into which
+      `tf.saved_model.loader.load` loaded the model.
+    features: A dictionary mapping keys to Numpy arrays, with several possible
+      shapes (requires keys `FilteringFeatures.TIMES` and
+      `FilteringFeatures.VALUES`):
+        Single example; `TIMES` is a scalar and `VALUES` is either a scalar or a
+          vector of length [number of features].
+        Sequence; `TIMES` is a vector of shape [series length], `VALUES` either
+          has shape [series length] (univariate) or [series length x number of
+          features] (multivariate).
+        Batch of sequences; `TIMES` is a vector of shape [batch size x series
+          length], `VALUES` has shape [batch size x series length] or [batch
+          size x series length x number of features].
+      In any case, `VALUES` and any exogenous features must have their shapes
+      prefixed by the shape of the value corresponding to the `TIMES` key.
+  Returns:
+    A dictionary containing model state updated to account for the observations
+    in `features`.
+  """
+  filter_signature = signatures.signature_def[
+      _feature_keys.SavedModelLabels.COLD_START_FILTER]
+  features = _input_pipeline._canonicalize_numpy_data(  # pylint: disable=protected-access
+      data=features,
+      require_single_batch=False)
+  output_tensors_by_name, feed_dict = _colate_features_to_feeds_and_fetches(
+      signature=filter_signature,
+      features=features,
+      graph=session.graph)
+  output = session.run(output_tensors_by_name, feed_dict=feed_dict)
+  # Make it easier to chain filter -> predict by keeping track of the current
+  # time.
+  output[_feature_keys.FilteringResults.TIMES] = features[
+      _feature_keys.FilteringFeatures.TIMES]
+  return output
+
+
 def filter_continuation(continue_from, signatures, session, features):
   """Perform filtering using an exported saved model.
 
@@ -124,8 +176,8 @@ def filter_continuation(continue_from, signatures, session, features):
 
   Args:
     continue_from: A dictionary containing the results of either an Estimator's
-      evaluate method or a previous filter_continuation. Used to determine the
-      model state to start filtering from.
+      evaluate method or a previous filter step (cold start or
+      continuation). Used to determine the model state to start filtering from.
     signatures: The `MetaGraphDef` protocol buffer returned from
       `tf.saved_model.loader.load`. Used to determine the names of Tensors to
       feed and fetch. Must be from the same model as `continue_from`.