TFTS: Add exporting to SavedModel to the LSTM example
authorAllen Lavoie <allenl@google.com>
Fri, 9 Feb 2018 17:25:10 +0000 (09:25 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 9 Feb 2018 17:28:52 +0000 (09:28 -0800)
Was previously not using a state manager, which among other issues prevented
exporting.

Fixes #16590.

PiperOrigin-RevId: 185150900

tensorflow/contrib/timeseries/examples/lstm.py
tensorflow/contrib/timeseries/examples/lstm_test.py

index c834430..630f4fc 100644 (file)
@@ -20,12 +20,14 @@ from __future__ import print_function
 
 import functools
 from os import path
+import tempfile
 
 import numpy
 import tensorflow as tf
 
 from tensorflow.contrib.timeseries.python.timeseries import estimators as ts_estimators
 from tensorflow.contrib.timeseries.python.timeseries import model as ts_model
+from tensorflow.contrib.timeseries.python.timeseries import state_management
 
 try:
   import matplotlib  # pylint: disable=g-import-not-at-top
@@ -70,7 +72,7 @@ class _LSTMModel(ts_model.SequentialTimeSeriesModel):
     self._lstm_cell_run = None
     self._predict_from_lstm_output = None
 
-  def initialize_graph(self, input_statistics):
+  def initialize_graph(self, input_statistics=None):
     """Save templates for components, which can then be used repeatedly.
 
     This method is called every time a new graph is created. It's safe to start
@@ -168,12 +170,15 @@ class _LSTMModel(ts_model.SequentialTimeSeriesModel):
 
 
 def train_and_predict(
-    csv_file_name=_DATA_FILE, training_steps=200, estimator_config=None):
+    csv_file_name=_DATA_FILE, training_steps=200, estimator_config=None,
+    export_directory=None):
   """Train and predict using a custom time series model."""
   # Construct an Estimator from our LSTM model.
   estimator = ts_estimators.TimeSeriesRegressor(
       model=_LSTMModel(num_features=5, num_units=128),
-      optimizer=tf.train.AdamOptimizer(0.001), config=estimator_config)
+      optimizer=tf.train.AdamOptimizer(0.001), config=estimator_config,
+      # Set state to be saved across windows.
+      state_manager=state_management.ChainingStateManager())
   reader = tf.contrib.timeseries.CSVReader(
       csv_file_name,
       column_names=((tf.contrib.timeseries.TrainEvalFeatures.TIMES,)
@@ -192,6 +197,28 @@ def train_and_predict(
   predicted_mean = numpy.squeeze(numpy.concatenate(
       [evaluation["mean"][0], predictions["mean"]], axis=0))
   all_times = numpy.concatenate([times, predictions["times"]], axis=0)
+
+  # Export the model in SavedModel format.
+  if export_directory is None:
+    export_directory = tempfile.mkdtemp()
+  input_receiver_fn = estimator.build_raw_serving_input_receiver_fn()
+  export_location = estimator.export_savedmodel(
+      export_directory, input_receiver_fn)
+
+  # Predict using the SavedModel
+  with tf.Graph().as_default():
+    with tf.Session() as session:
+      signatures = tf.saved_model.loader.load(
+          session, [tf.saved_model.tag_constants.SERVING], export_location)
+      saved_model_output = (
+          tf.contrib.timeseries.saved_model_utils.predict_continuation(
+              continue_from=evaluation, signatures=signatures,
+              session=session, steps=100))
+      # The exported model gives the same results as the Estimator.predict()
+      # call above.
+      numpy.testing.assert_allclose(
+          predictions["mean"],
+          numpy.squeeze(saved_model_output["mean"], axis=0))
   return times, observed, all_times, predicted_mean
 
 
index 3cace56..ca56e38 100644 (file)
@@ -36,7 +36,8 @@ class LSTMExampleTest(test.TestCase):
   def test_periodicity_learned(self):
     (observed_times, observed_values,
      all_times, predicted_values) = lstm.train_and_predict(
-         training_steps=100, estimator_config=_SeedRunConfig())
+         training_steps=100, estimator_config=_SeedRunConfig(),
+         export_directory=self.get_temp_dir())
     self.assertAllEqual([100], observed_times.shape)
     self.assertAllEqual([100, 5], observed_values.shape)
     self.assertAllEqual([200], all_times.shape)