From: Allen Lavoie Date: Sat, 17 Mar 2018 00:00:17 +0000 (-0700) Subject: TFTS: Allow cold-starting from SavedModels X-Git-Tag: tflite-v0.1.7~149^2~2^2~29 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=bd33984641fda2f892b77bb2a1ac8c33c7a2211a;p=platform%2Fupstream%2Ftensorflow.git TFTS: Allow cold-starting from SavedModels This means the model starts from its default start state and is fed a series (filtering) to warm up its state. This warmed up state can then be used to make predictions. Some shape fiddling with the receiver_fn to make feeding state optional, and a new signature for cold-starting which uses the model's default start state. Some other shape fiddling to make feeding strings to SavedModels work more smoothly in the cold-start part of the LSTM example. I was squeezing out the last dimension of "scalar" exogenous features, now I'm leaving them, which matches the placeholder generation logic. PiperOrigin-RevId: 189414869 --- diff --git a/tensorflow/contrib/timeseries/examples/known_anomaly.py b/tensorflow/contrib/timeseries/examples/known_anomaly.py index c08c0b0..e77628d 100644 --- a/tensorflow/contrib/timeseries/examples/known_anomaly.py +++ b/tensorflow/contrib/timeseries/examples/known_anomaly.py @@ -53,6 +53,15 @@ def train_and_evaluate_exogenous(csv_file_name=_DATA_FILE, train_steps=300): one_hot_feature = tf.feature_column.indicator_column( categorical_column=string_feature) + def _exogenous_update_condition(times, features): + del times # unused + # Make exogenous updates sparse by setting an update condition. This in + # effect allows missing exogenous features: if the condition evaluates to + # False, no update is performed. Otherwise we sometimes end up with "leaky" + # updates which add unnecessary uncertainty to the model even when there is + # no changepoint. + return tf.equal(tf.squeeze(features["is_changepoint"], axis=-1), "yes") + estimator = tf.contrib.timeseries.StructuralEnsembleRegressor( periodicities=12, # Extract a smooth period by constraining the number of latent values @@ -60,13 +69,7 @@ def train_and_evaluate_exogenous(csv_file_name=_DATA_FILE, train_steps=300): cycle_num_latent_values=3, num_features=1, exogenous_feature_columns=[one_hot_feature], - # Make exogenous updates sparse by setting an update condition. This in - # effect allows missing exogenous features: if the condition evaluates to - # False, no update is performed. Otherwise we sometimes end up with - # "leaky" updates which add unnecessary uncertainty to the model even when - # there is no changepoint. - exogenous_update_condition= - lambda times, features: tf.equal(features["is_changepoint"], "yes")) + exogenous_update_condition=_exogenous_update_condition) reader = tf.contrib.timeseries.CSVReader( csv_file_name, # Indicate the format of our CSV file. First we have two standard columns, diff --git a/tensorflow/contrib/timeseries/examples/lstm.py b/tensorflow/contrib/timeseries/examples/lstm.py index 2eee878..b1c7475 100644 --- a/tensorflow/contrib/timeseries/examples/lstm.py +++ b/tensorflow/contrib/timeseries/examples/lstm.py @@ -236,20 +236,36 @@ def train_and_predict( [evaluation["mean"][0], predictions["mean"]], axis=0)) all_times = numpy.concatenate([times, predictions["times"]], axis=0) - # Export the model in SavedModel format. + # Export the model in SavedModel format. We include a bit of extra boilerplate + # for "cold starting" as if we didn't have any state from the Estimator, which + # is the case when serving from a SavedModel. If Estimator output is + # available, the result of "Estimator.evaluate" can be passed directly to + # `tf.contrib.timeseries.saved_model_utils.predict_continuation` as the + # `continue_from` argument. + with tf.Graph().as_default(): + filter_feature_tensors, _ = evaluation_input_fn() + with tf.train.MonitoredSession() as session: + # Fetch the series to "warm up" our state, which will allow us to make + # predictions for its future values. This is just a dictionary of times, + # values, and exogenous features mapping to numpy arrays. The use of an + # input_fn is just a convenience for the example; they can also be + # specified manually. + filter_features = session.run(filter_feature_tensors) if export_directory is None: export_directory = tempfile.mkdtemp() input_receiver_fn = estimator.build_raw_serving_input_receiver_fn() export_location = estimator.export_savedmodel( export_directory, input_receiver_fn) - # Predict using the SavedModel + # Warm up and predict using the SavedModel with tf.Graph().as_default(): with tf.Session() as session: signatures = tf.saved_model.loader.load( session, [tf.saved_model.tag_constants.SERVING], export_location) + state = tf.contrib.timeseries.saved_model_utils.cold_start_filter( + signatures=signatures, session=session, features=filter_features) saved_model_output = ( tf.contrib.timeseries.saved_model_utils.predict_continuation( - continue_from=evaluation, signatures=signatures, + continue_from=state, signatures=signatures, session=session, steps=100, exogenous_features=predict_exogenous_features)) # The exported model gives the same results as the Estimator.predict() diff --git a/tensorflow/contrib/timeseries/python/timeseries/estimators.py b/tensorflow/contrib/timeseries/python/timeseries/estimators.py index 8d13343..469cea4 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/estimators.py +++ b/tensorflow/contrib/timeseries/python/timeseries/estimators.py @@ -33,9 +33,11 @@ from tensorflow.python.feature_column import feature_column from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import parsing_ops from tensorflow.python.training import training as train +from tensorflow.python.util import nest class TimeSeriesRegressor(estimator_lib.Estimator): @@ -98,11 +100,11 @@ class TimeSeriesRegressor(estimator_lib.Estimator): def _serving_input_receiver_fn(): """A receiver function to be passed to export_savedmodel.""" placeholders = {} - placeholders[feature_keys.TrainEvalFeatures.TIMES] = ( - array_ops.placeholder( - name=feature_keys.TrainEvalFeatures.TIMES, - dtype=dtypes.int64, - shape=[default_batch_size, default_series_length])) + time_placeholder = array_ops.placeholder( + name=feature_keys.TrainEvalFeatures.TIMES, + dtype=dtypes.int64, + shape=[default_batch_size, default_series_length]) + placeholders[feature_keys.TrainEvalFeatures.TIMES] = time_placeholder # Values are only necessary when filtering. For prediction the default # value will be ignored. placeholders[feature_keys.TrainEvalFeatures.VALUES] = ( @@ -145,15 +147,29 @@ class TimeSeriesRegressor(estimator_lib.Estimator): # use only static metadata from the returned Tensors. with ops.Graph().as_default(): self._model.initialize_graph() - model_start_state = self._model.get_start_state() - for prefixed_state_name, state_tensor in ts_head_lib.state_to_dictionary( - model_start_state).items(): + # Evaluate the initial state as same-dtype "zero" values. These zero + # constants aren't used, but are necessary for feeding to + # placeholder_with_default for the "cold start" case where state is not + # fed to the model. + def _zeros_like_constant(tensor): + return tensor_util.constant_value(array_ops.zeros_like(tensor)) + start_state = nest.map_structure( + _zeros_like_constant, self._model.get_start_state()) + batch_size_tensor = array_ops.shape(time_placeholder)[0] + for prefixed_state_name, state in ts_head_lib.state_to_dictionary( + start_state).items(): state_shape_with_batch = tensor_shape.TensorShape( - (default_batch_size,)).concatenate(state_tensor.get_shape()) - placeholders[prefixed_state_name] = array_ops.placeholder( + (default_batch_size,)).concatenate(state.shape) + default_state_broadcast = array_ops.tile( + state[None, ...], + multiples=array_ops.concat( + [batch_size_tensor[None], + array_ops.ones(len(state.shape), dtype=dtypes.int32)], + axis=0)) + placeholders[prefixed_state_name] = array_ops.placeholder_with_default( + input=default_state_broadcast, name=prefixed_state_name, - shape=state_shape_with_batch, - dtype=state_tensor.dtype) + shape=state_shape_with_batch) return export_lib.ServingInputReceiver(placeholders, placeholders) return _serving_input_receiver_fn diff --git a/tensorflow/contrib/timeseries/python/timeseries/feature_keys.py b/tensorflow/contrib/timeseries/python/timeseries/feature_keys.py index 970b9aa..56566ee 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/feature_keys.py +++ b/tensorflow/contrib/timeseries/python/timeseries/feature_keys.py @@ -72,3 +72,4 @@ class SavedModelLabels(object): """Names of signatures exported with export_savedmodel.""" PREDICT = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY FILTER = "filter" + COLD_START_FILTER = "cold_start_filter" diff --git a/tensorflow/contrib/timeseries/python/timeseries/head.py b/tensorflow/contrib/timeseries/python/timeseries/head.py index f4d9351..3d7e615 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/head.py +++ b/tensorflow/contrib/timeseries/python/timeseries/head.py @@ -150,6 +150,12 @@ class _TimeSeriesRegressionHead(head_lib._Head): # pylint:disable=protected-acc with variable_scope.variable_scope("model", reuse=True): filtering_outputs = self.create_loss( features, estimator_lib.ModeKeys.EVAL) + with variable_scope.variable_scope("model", reuse=True): + no_state_features = { + k: v for k, v in features.items() + if not k.startswith(feature_keys.State.STATE_PREFIX)} + cold_filtering_outputs = self.create_loss( + no_state_features, estimator_lib.ModeKeys.EVAL) return estimator_lib.EstimatorSpec( mode=estimator_lib.ModeKeys.PREDICT, export_outputs={ @@ -157,7 +163,10 @@ class _TimeSeriesRegressionHead(head_lib._Head): # pylint:disable=protected-acc export_lib.PredictOutput(prediction_outputs), feature_keys.SavedModelLabels.FILTER: export_lib.PredictOutput( - state_to_dictionary(filtering_outputs.end_state)) + state_to_dictionary(filtering_outputs.end_state)), + feature_keys.SavedModelLabels.COLD_START_FILTER: + export_lib.PredictOutput( + state_to_dictionary(cold_filtering_outputs.end_state)) }, # Likely unused, but it is necessary to return `predictions` to satisfy # the Estimator's error checking. diff --git a/tensorflow/contrib/timeseries/python/timeseries/input_pipeline.py b/tensorflow/contrib/timeseries/python/timeseries/input_pipeline.py index 0422533..403c6e2 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/input_pipeline.py +++ b/tensorflow/contrib/timeseries/python/timeseries/input_pipeline.py @@ -492,8 +492,7 @@ class CSVReader(ReaderBaseTimeSeriesParser): features_lists.setdefault(column_name, []).append(value) features = {} for column_name, values in features_lists.items(): - if (len(values) == 1 and - column_name != feature_keys.TrainEvalFeatures.VALUES): + if column_name == feature_keys.TrainEvalFeatures.TIMES: features[column_name] = values[0] else: features[column_name] = array_ops.stack(values, axis=1) diff --git a/tensorflow/contrib/timeseries/python/timeseries/saved_model_utils.py b/tensorflow/contrib/timeseries/python/timeseries/saved_model_utils.py index 97f6d36..0461abd 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/saved_model_utils.py +++ b/tensorflow/contrib/timeseries/python/timeseries/saved_model_utils.py @@ -15,6 +15,7 @@ """Convenience functions for working with time series saved_models. @@predict_continuation +@@cold_start_filter @@filter_continuation """ @@ -30,10 +31,12 @@ from tensorflow.contrib.timeseries.python.timeseries import model_utils as _mode from tensorflow.python.util.all_util import remove_undocumented -def _colate_features_to_feeds_and_fetches(continue_from, signature, features, - graph): +def _colate_features_to_feeds_and_fetches(signature, features, graph, + continue_from=None): """Uses a saved model signature to construct feed and fetch dictionaries.""" - if _feature_keys.FilteringResults.STATE_TUPLE in continue_from: + if continue_from is None: + state_values = {} + elif _feature_keys.FilteringResults.STATE_TUPLE in continue_from: # We're continuing from an evaluation, so we need to unpack/flatten state. state_values = _head.state_to_dictionary( continue_from[_feature_keys.FilteringResults.STATE_TUPLE]) @@ -115,6 +118,55 @@ def predict_continuation(continue_from, return output +def cold_start_filter(signatures, session, features): + """Perform filtering using an exported saved model. + + Filtering refers to updating model state based on new observations. + Predictions based on the returned model state will be conditioned on these + observations. + + Starts from the model's default/uninformed state. + + Args: + signatures: The `MetaGraphDef` protocol buffer returned from + `tf.saved_model.loader.load`. Used to determine the names of Tensors to + feed and fetch. Must be from the same model as `continue_from`. + session: The session to use. The session's graph must be the one into which + `tf.saved_model.loader.load` loaded the model. + features: A dictionary mapping keys to Numpy arrays, with several possible + shapes (requires keys `FilteringFeatures.TIMES` and + `FilteringFeatures.VALUES`): + Single example; `TIMES` is a scalar and `VALUES` is either a scalar or a + vector of length [number of features]. + Sequence; `TIMES` is a vector of shape [series length], `VALUES` either + has shape [series length] (univariate) or [series length x number of + features] (multivariate). + Batch of sequences; `TIMES` is a vector of shape [batch size x series + length], `VALUES` has shape [batch size x series length] or [batch + size x series length x number of features]. + In any case, `VALUES` and any exogenous features must have their shapes + prefixed by the shape of the value corresponding to the `TIMES` key. + Returns: + A dictionary containing model state updated to account for the observations + in `features`. + """ + filter_signature = signatures.signature_def[ + _feature_keys.SavedModelLabels.COLD_START_FILTER] + features = _input_pipeline._canonicalize_numpy_data( # pylint: disable=protected-access + data=features, + require_single_batch=False) + output_tensors_by_name, feed_dict = _colate_features_to_feeds_and_fetches( + signature=filter_signature, + features=features, + graph=session.graph) + output = session.run(output_tensors_by_name, feed_dict=feed_dict) + # Make it easier to chain filter -> predict by keeping track of the current + # time. + output[_feature_keys.FilteringResults.TIMES] = features[ + _feature_keys.FilteringFeatures.TIMES] + return output + + def filter_continuation(continue_from, signatures, session, features): """Perform filtering using an exported saved model. @@ -124,8 +176,8 @@ def filter_continuation(continue_from, signatures, session, features): Args: continue_from: A dictionary containing the results of either an Estimator's - evaluate method or a previous filter_continuation. Used to determine the - model state to start filtering from. + evaluate method or a previous filter step (cold start or + continuation). Used to determine the model state to start filtering from. signatures: The `MetaGraphDef` protocol buffer returned from `tf.saved_model.loader.load`. Used to determine the names of Tensors to feed and fetch. Must be from the same model as `continue_from`.