Enable efficient feeding of symbolic tensors to placeholders in the Keras backend.
authorFrancois Chollet <fchollet@google.com>
Fri, 13 Apr 2018 02:01:10 +0000 (19:01 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 13 Apr 2018 02:04:02 +0000 (19:04 -0700)
PiperOrigin-RevId: 192707345

tensorflow/python/keras/_impl/keras/backend.py
tensorflow/python/keras/_impl/keras/backend_test.py
tensorflow/python/keras/_impl/keras/integration_test.py

index 096db8d..6647cc5 100644 (file)
@@ -2760,8 +2760,7 @@ class Function(object):
       outputs: Output tensors to fetch.
       updates: Additional update ops to be run at function call.
       name: A name to help users identify what this function does.
-      session_kwargs: Arguments to `tf.Session.run()`: `fetches`, `feed_dict`,
-        `options`, `run_metadata`
+      session_kwargs: Arguments to `tf.Session.run()`: `fetches`, `feed_dict`.
   """
 
   def __init__(self, inputs, outputs, updates=None, name=None,
@@ -2795,19 +2794,74 @@ class Function(object):
     self.fetches = session_kwargs.pop('fetches', [])
     if not isinstance(self.fetches, list):
       self.fetches = [self.fetches]
+    # The main use case of `fetches` being passed to a model is the ability
+    # to run custom updates (since the outputs of fetches are never returned).
+    # This requires us to wrap fetches in `identity` ops.
+    self.fetches = [array_ops.identity(x) for x in self.fetches]
     self.session_kwargs = session_kwargs
 
+    if session_kwargs:
+      raise ValueError('Some keys in session_kwargs are not supported at this '
+                       'time: %s', session_kwargs.keys())
+
+    self._callable_fn = None
+    self._feed_arrays = None
+    self._feed_symbols = None
+    self._symbol_vals = None
+    self._session = None
+
+  def _make_callable(self, feed_arrays, feed_symbols, symbol_vals, session):
+    """Generates a callable that runs the graph.
+
+    Arguments:
+      feed_arrays: List of input tensors to be fed Numpy arrays at runtime.
+      feed_symbols: List of input tensors to be fed symbolic tensors at runtime.
+      symbol_vals: List of symbolic tensors to be fed to `feed_symbols`.
+      session: Session to use to generate the callable.
+
+    Returns:
+      Function that runs the graph according to the above options.
+    """
+    # Prepare callable options.
+    callable_opts = config_pb2.CallableOptions()
+    # Handle external-data feed.
+    for x in feed_arrays:
+      callable_opts.feed.append(x.name)
+    if self.feed_dict:
+      for key in sorted(self.feed_dict.keys()):
+        callable_opts.feed.append(key.name)
+    # Handle symbolic feed.
+    for x, y in zip(feed_symbols, symbol_vals):
+      connection = callable_opts.tensor_connection.add()
+      from_tensor = ops._as_graph_element(y)
+      if from_tensor is None:
+        from_tensor = y
+      connection.from_tensor = from_tensor.name  # Data tensor
+      connection.to_tensor = x.name  # Placeholder
+    # Handle fetches.
+    for x in self.outputs + self.fetches:
+      callable_opts.fetch.append(x.name)
+    # Handle updates.
+    callable_opts.target.append(self.updates_op.name)
+    # Create callable.
+    callable_fn = session._make_callable_from_options(callable_opts)
+    # Cache parameters corresponding to the generated callable, so that
+    # we can detect future mismatches and refresh the callable.
+    self._callable_fn = callable_fn
+    self._feed_arrays = feed_arrays
+    self._feed_symbols = feed_symbols
+    self._symbol_vals = symbol_vals
+    self._session = session
+
   def __call__(self, inputs):
     if not isinstance(inputs, (list, tuple)):
       raise TypeError('`inputs` should be a list or tuple.')
 
-    if self.feed_dict:
-      feed_dict = self.feed_dict.copy()
-    else:
-      feed_dict = {}
-
     session = get_session()
-    data_tensors_to_feed = []
+    feed_arrays = []
+    array_vals = []
+    feed_symbols = []
+    symbol_vals = []
     for tensor, value in zip(self.inputs, inputs):
       if value is None:
         continue
@@ -2816,23 +2870,31 @@ class Function(object):
         indices = np.concatenate((np.expand_dims(sparse_coo.row, 1),
                                   np.expand_dims(sparse_coo.col, 1)), 1)
         value = (indices, sparse_coo.data, sparse_coo.shape)
-      elif tensor_util.is_tensor(value):
-        data_tensors_to_feed.append((tensor, value))
+      if tensor_util.is_tensor(value):
+        # Case: feeding symbolic tensor.
+        feed_symbols.append(tensor)
+        symbol_vals.append(value)
       else:
-        feed_dict[tensor] = value
-
-    if data_tensors_to_feed:
-      # This is a *temporary* workaround (i.e. hack) to feed a symbolic tensor
-      # to `feed_dict`. It is very inefficient. It will be removed as soon
-      # as it becomes possible to pass symbolic tensors to `feed_dict`.
-      data_tensor_values = session.run([x[1] for x in data_tensors_to_feed])
-      for i, v in enumerate(data_tensor_values):
-        feed_dict[data_tensors_to_feed[i][0]] = v
-
-    fetches = self.outputs + [self.updates_op] + self.fetches
-    updated = session.run(
-        fetches=fetches, feed_dict=feed_dict, **self.session_kwargs)
-    return updated[:len(self.outputs)]
+        # Case: feeding Numpy array.
+        feed_arrays.append(tensor)
+        # We need to do array conversion and type casting at this level, since
+        # `callable_fn` only supports exact matches.
+        array_vals.append(np.asarray(value, dtype=tensor.dtype.base_dtype.name))
+    if self.feed_dict:
+      for key in sorted(self.feed_dict.keys()):
+        array_vals.append(
+            np.asarray(self.feed_dict[key], dtype=key.dtype.base_dtype.name))
+
+    # Refresh callable if anything has changed.
+    if (self._callable_fn is None or
+        feed_arrays != self._feed_arrays or
+        symbol_vals != self._symbol_vals or
+        feed_symbols != self._feed_symbols or
+        session != self._session):
+      self._make_callable(feed_arrays, feed_symbols, symbol_vals, session)
+
+    fetched = self._callable_fn(*array_vals)
+    return fetched[:len(self.outputs)]
 
 
 @tf_export('keras.backend.function')
index fb4b2a0..0193fc6 100644 (file)
@@ -189,6 +189,34 @@ class BackendUtilsTest(test.TestCase):
     for y in ys:
       self.assertEqual(y.op.name[:12], 'StopGradient')
 
+  def test_function_tf_feed_symbols(self):
+    with self.test_session():
+      # Test feeding a resource variable to `function`.
+      x1 = keras.backend.placeholder(shape=())
+      x2 = keras.backend.placeholder(shape=())
+      lr = keras.backend.learning_phase()  # Include a placeholder_with_default.
+
+      y1 = keras.backend.variable(10.)
+      y2 = 3
+
+      f = keras.backend.function(
+          inputs=[x1, x2, lr],
+          outputs=[x1 + 1,
+                   keras.backend.in_train_phase(x2 + 2, x2 - 1)])
+      outs = f([y1, y2, None])  # Use default learning_phase value.
+      self.assertEqual(outs, [11., 2.])
+      outs = f([y1, y2, 1])  # Set learning phase value.
+      self.assertEqual(outs, [11., 5.])
+
+      # Test triggering a callable refresh by changing the input.
+      y3 = keras.backend.constant(20.)  # Test with tensor
+      outs = f([y3, y2, None])
+      self.assertEqual(outs, [21., 2.])
+
+      y4 = 4  # Test with non-symbol
+      outs = f([y4, y2, None])
+      self.assertEqual(outs, [5., 2.])
+
   def test_function_tf_fetches(self):
     # Additional operations can be passed to tf.Session().run() via its
     # `fetches` arguments. In contrast to `updates` argument of
@@ -206,8 +234,9 @@ class BackendUtilsTest(test.TestCase):
                                  updates=[(x, x_placeholder + 1.)],
                                  fetches=[keras.backend.update(y, 5.)])
       output = f([10., 20.])
-      assert output == [30.]
-      assert keras.backend.get_session().run(fetches=[x, y]) == [11., 5.]
+      self.assertEqual(output, [30.])
+      self.assertEqual(
+          keras.backend.get_session().run(fetches=[x, y]), [11., 5.])
 
   def test_function_tf_feed_dict(self):
     # Additional substitutions can be passed to `tf.Session().run()` via its
@@ -229,14 +258,16 @@ class BackendUtilsTest(test.TestCase):
                                  feed_dict=feed_dict,
                                  fetches=fetches)
       output = f([10.])
-      assert output == [11.]
-      assert keras.backend.get_session().run(fetches=[x, y]) == [20., 30.]
+      self.assertEqual(output, [11.])
+      self.assertEqual(
+          keras.backend.get_session().run(fetches=[x, y]), [20., 30.])
 
       # updated value in feed_dict will be modified within the K.function()
       feed_dict[y_placeholder] = 4.
       output = f([20.])
-      assert output == [21.]
-      assert keras.backend.get_session().run(fetches=[x, y]) == [30., 40.]
+      self.assertEqual(output, [21.])
+      self.assertEqual(
+          keras.backend.get_session().run(fetches=[x, y]), [30., 40.])
 
 
 class BackendVariableTest(test.TestCase):
index c448084..43aff67 100644 (file)
@@ -95,7 +95,7 @@ class KerasIntegrationTest(test.TestCase):
       model.compile(loss='categorical_crossentropy',
                     optimizer=keras.optimizers.Adam(lr=0.1),
                     metrics=['accuracy'])
-      history = model.fit(x_train, y_train, epochs=10, batch_size=16,
+      history = model.fit(x_train, y_train, epochs=15, batch_size=16,
                           validation_data=(x_train, y_train),
                           verbose=2)
       self.assertGreater(history.history['val_acc'][-1], 0.7)