outputs: Output tensors to fetch.
updates: Additional update ops to be run at function call.
name: A name to help users identify what this function does.
- session_kwargs: Arguments to `tf.Session.run()`: `fetches`, `feed_dict`,
- `options`, `run_metadata`
+ session_kwargs: Arguments to `tf.Session.run()`: `fetches`, `feed_dict`.
"""
def __init__(self, inputs, outputs, updates=None, name=None,
self.fetches = session_kwargs.pop('fetches', [])
if not isinstance(self.fetches, list):
self.fetches = [self.fetches]
+ # The main use case of `fetches` being passed to a model is the ability
+ # to run custom updates (since the outputs of fetches are never returned).
+ # This requires us to wrap fetches in `identity` ops.
+ self.fetches = [array_ops.identity(x) for x in self.fetches]
self.session_kwargs = session_kwargs
+ if session_kwargs:
+ raise ValueError('Some keys in session_kwargs are not supported at this '
+ 'time: %s', session_kwargs.keys())
+
+ self._callable_fn = None
+ self._feed_arrays = None
+ self._feed_symbols = None
+ self._symbol_vals = None
+ self._session = None
+
+ def _make_callable(self, feed_arrays, feed_symbols, symbol_vals, session):
+ """Generates a callable that runs the graph.
+
+ Arguments:
+ feed_arrays: List of input tensors to be fed Numpy arrays at runtime.
+ feed_symbols: List of input tensors to be fed symbolic tensors at runtime.
+ symbol_vals: List of symbolic tensors to be fed to `feed_symbols`.
+ session: Session to use to generate the callable.
+
+ Returns:
+ Function that runs the graph according to the above options.
+ """
+ # Prepare callable options.
+ callable_opts = config_pb2.CallableOptions()
+ # Handle external-data feed.
+ for x in feed_arrays:
+ callable_opts.feed.append(x.name)
+ if self.feed_dict:
+ for key in sorted(self.feed_dict.keys()):
+ callable_opts.feed.append(key.name)
+ # Handle symbolic feed.
+ for x, y in zip(feed_symbols, symbol_vals):
+ connection = callable_opts.tensor_connection.add()
+ from_tensor = ops._as_graph_element(y)
+ if from_tensor is None:
+ from_tensor = y
+ connection.from_tensor = from_tensor.name # Data tensor
+ connection.to_tensor = x.name # Placeholder
+ # Handle fetches.
+ for x in self.outputs + self.fetches:
+ callable_opts.fetch.append(x.name)
+ # Handle updates.
+ callable_opts.target.append(self.updates_op.name)
+ # Create callable.
+ callable_fn = session._make_callable_from_options(callable_opts)
+ # Cache parameters corresponding to the generated callable, so that
+ # we can detect future mismatches and refresh the callable.
+ self._callable_fn = callable_fn
+ self._feed_arrays = feed_arrays
+ self._feed_symbols = feed_symbols
+ self._symbol_vals = symbol_vals
+ self._session = session
+
def __call__(self, inputs):
if not isinstance(inputs, (list, tuple)):
raise TypeError('`inputs` should be a list or tuple.')
- if self.feed_dict:
- feed_dict = self.feed_dict.copy()
- else:
- feed_dict = {}
-
session = get_session()
- data_tensors_to_feed = []
+ feed_arrays = []
+ array_vals = []
+ feed_symbols = []
+ symbol_vals = []
for tensor, value in zip(self.inputs, inputs):
if value is None:
continue
indices = np.concatenate((np.expand_dims(sparse_coo.row, 1),
np.expand_dims(sparse_coo.col, 1)), 1)
value = (indices, sparse_coo.data, sparse_coo.shape)
- elif tensor_util.is_tensor(value):
- data_tensors_to_feed.append((tensor, value))
+ if tensor_util.is_tensor(value):
+ # Case: feeding symbolic tensor.
+ feed_symbols.append(tensor)
+ symbol_vals.append(value)
else:
- feed_dict[tensor] = value
-
- if data_tensors_to_feed:
- # This is a *temporary* workaround (i.e. hack) to feed a symbolic tensor
- # to `feed_dict`. It is very inefficient. It will be removed as soon
- # as it becomes possible to pass symbolic tensors to `feed_dict`.
- data_tensor_values = session.run([x[1] for x in data_tensors_to_feed])
- for i, v in enumerate(data_tensor_values):
- feed_dict[data_tensors_to_feed[i][0]] = v
-
- fetches = self.outputs + [self.updates_op] + self.fetches
- updated = session.run(
- fetches=fetches, feed_dict=feed_dict, **self.session_kwargs)
- return updated[:len(self.outputs)]
+ # Case: feeding Numpy array.
+ feed_arrays.append(tensor)
+ # We need to do array conversion and type casting at this level, since
+ # `callable_fn` only supports exact matches.
+ array_vals.append(np.asarray(value, dtype=tensor.dtype.base_dtype.name))
+ if self.feed_dict:
+ for key in sorted(self.feed_dict.keys()):
+ array_vals.append(
+ np.asarray(self.feed_dict[key], dtype=key.dtype.base_dtype.name))
+
+ # Refresh callable if anything has changed.
+ if (self._callable_fn is None or
+ feed_arrays != self._feed_arrays or
+ symbol_vals != self._symbol_vals or
+ feed_symbols != self._feed_symbols or
+ session != self._session):
+ self._make_callable(feed_arrays, feed_symbols, symbol_vals, session)
+
+ fetched = self._callable_fn(*array_vals)
+ return fetched[:len(self.outputs)]
@tf_export('keras.backend.function')
for y in ys:
self.assertEqual(y.op.name[:12], 'StopGradient')
+ def test_function_tf_feed_symbols(self):
+ with self.test_session():
+ # Test feeding a resource variable to `function`.
+ x1 = keras.backend.placeholder(shape=())
+ x2 = keras.backend.placeholder(shape=())
+ lr = keras.backend.learning_phase() # Include a placeholder_with_default.
+
+ y1 = keras.backend.variable(10.)
+ y2 = 3
+
+ f = keras.backend.function(
+ inputs=[x1, x2, lr],
+ outputs=[x1 + 1,
+ keras.backend.in_train_phase(x2 + 2, x2 - 1)])
+ outs = f([y1, y2, None]) # Use default learning_phase value.
+ self.assertEqual(outs, [11., 2.])
+ outs = f([y1, y2, 1]) # Set learning phase value.
+ self.assertEqual(outs, [11., 5.])
+
+ # Test triggering a callable refresh by changing the input.
+ y3 = keras.backend.constant(20.) # Test with tensor
+ outs = f([y3, y2, None])
+ self.assertEqual(outs, [21., 2.])
+
+ y4 = 4 # Test with non-symbol
+ outs = f([y4, y2, None])
+ self.assertEqual(outs, [5., 2.])
+
def test_function_tf_fetches(self):
# Additional operations can be passed to tf.Session().run() via its
# `fetches` arguments. In contrast to `updates` argument of
updates=[(x, x_placeholder + 1.)],
fetches=[keras.backend.update(y, 5.)])
output = f([10., 20.])
- assert output == [30.]
- assert keras.backend.get_session().run(fetches=[x, y]) == [11., 5.]
+ self.assertEqual(output, [30.])
+ self.assertEqual(
+ keras.backend.get_session().run(fetches=[x, y]), [11., 5.])
def test_function_tf_feed_dict(self):
# Additional substitutions can be passed to `tf.Session().run()` via its
feed_dict=feed_dict,
fetches=fetches)
output = f([10.])
- assert output == [11.]
- assert keras.backend.get_session().run(fetches=[x, y]) == [20., 30.]
+ self.assertEqual(output, [11.])
+ self.assertEqual(
+ keras.backend.get_session().run(fetches=[x, y]), [20., 30.])
# updated value in feed_dict will be modified within the K.function()
feed_dict[y_placeholder] = 4.
output = f([20.])
- assert output == [21.]
- assert keras.backend.get_session().run(fetches=[x, y]) == [30., 40.]
+ self.assertEqual(output, [21.])
+ self.assertEqual(
+ keras.backend.get_session().run(fetches=[x, y]), [30., 40.])
class BackendVariableTest(test.TestCase):