TFTS: Remove a race condition in lstm_test (switch to resource variables)
authorAllen Lavoie <allenl@google.com>
Mon, 29 Jan 2018 17:44:35 +0000 (09:44 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 29 Jan 2018 17:48:39 +0000 (09:48 -0800)
PiperOrigin-RevId: 183679060

tensorflow/contrib/timeseries/examples/lstm.py

index c7193ce..c834430 100644 (file)
@@ -18,6 +18,7 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+import functools
 from os import path
 
 import numpy
@@ -80,18 +81,19 @@ class _LSTMModel(ts_model.SequentialTimeSeriesModel):
       input_statistics: A math_utils.InputStatistics object.
     """
     super(_LSTMModel, self).initialize_graph(input_statistics=input_statistics)
-    self._lstm_cell = tf.nn.rnn_cell.LSTMCell(num_units=self._num_units)
-    # Create templates so we don't have to worry about variable reuse.
-    self._lstm_cell_run = tf.make_template(
-        name_="lstm_cell",
-        func_=self._lstm_cell,
-        create_scope_now_=True)
-    # Transforms LSTM output into mean predictions.
-    self._predict_from_lstm_output = tf.make_template(
-        name_="predict_from_lstm_output",
-        func_=
-        lambda inputs: tf.layers.dense(inputs=inputs, units=self.num_features),
-        create_scope_now_=True)
+    with tf.variable_scope("", use_resource=True):
+      # Use ResourceVariables to avoid race conditions.
+      self._lstm_cell = tf.nn.rnn_cell.LSTMCell(num_units=self._num_units)
+      # Create templates so we don't have to worry about variable reuse.
+      self._lstm_cell_run = tf.make_template(
+          name_="lstm_cell",
+          func_=self._lstm_cell,
+          create_scope_now_=True)
+      # Transforms LSTM output into mean predictions.
+      self._predict_from_lstm_output = tf.make_template(
+          name_="predict_from_lstm_output",
+          func_=functools.partial(tf.layers.dense, units=self.num_features),
+          create_scope_now_=True)
 
   def get_start_state(self):
     """Return initial state for the time series model."""