From 5a213116df09c19c3ee0eecb5fc79444e5671e80 Mon Sep 17 00:00:00 2001 From: Allen Lavoie Date: Wed, 28 Mar 2018 10:03:06 -0700 Subject: [PATCH] Allow positional arguments in tf.keras.Model subclasses Makes the tf.keras.Layer.__call__ signature identical to tf.layers.Layer.__call__, but makes passing positional arguments other than "inputs" an error in most cases. The only case it's allowed is subclassed Models which do not have an "inputs" argument to their call() method. This means subclassed Models no longer need to pass all but the first argument as a keyword argument (or do list packing/unpacking) when call() takes multiple Tensor arguments. Includes errors for cases where whether an argument indicates an input is ambiguous, but otherwise doesn't do much to support non-"inputs" call() signatures for shape inference or deferred Tensors. The definition of an input/non-input is pretty clear, so that cleanup will mostly be tracking down all of the users of "self.call" and getting them to pass inputs as positional arguments if necessary. PiperOrigin-RevId: 190787899 --- .../eager/python/examples/spinn/spinn_test.py | 13 +-- .../python/keras/_impl/keras/engine/base_layer.py | 90 +++++++++++++- .../python/keras/_impl/keras/engine/network.py | 9 +- .../python/keras/_impl/keras/engine/training.py | 5 + .../keras/_impl/keras/model_subclassing_test.py | 130 ++++++++++++++++++++- tensorflow/python/layers/base.py | 2 + third_party/examples/eager/spinn/spinn.py | 29 ++--- 7 files changed, 246 insertions(+), 32 deletions(-) diff --git a/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py b/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py index 591d99e..9261823 100644 --- a/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py +++ b/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py @@ -173,7 +173,7 @@ class SpinnTest(test_util.TensorFlowTestCase): right_in.append(tf.random_normal((1, size * 2))) tracking.append(tf.random_normal((1, tracker_size * 2))) - out = reducer(left_in, right_in=right_in, tracking=tracking) + out = reducer(left_in, right_in, tracking=tracking) self.assertEqual(batch_size, len(out)) self.assertEqual(tf.float32, out[0].dtype) self.assertEqual((1, size * 2), out[0].shape) @@ -227,7 +227,7 @@ class SpinnTest(test_util.TensorFlowTestCase): self.assertEqual((batch_size, size * 2), stacks[0][0].shape) for _ in range(2): - out1, out2 = tracker(bufs, stacks=stacks) + out1, out2 = tracker(bufs, stacks) self.assertIsNone(out2) self.assertEqual(batch_size, len(out1)) self.assertEqual(tf.float32, out1[0].dtype) @@ -260,7 +260,7 @@ class SpinnTest(test_util.TensorFlowTestCase): self.assertEqual(tf.int64, transitions.dtype) self.assertEqual((num_transitions, 1), transitions.shape) - out = s(buffers, transitions=transitions, training=True) + out = s(buffers, transitions, training=True) self.assertEqual(tf.float32, out.dtype) self.assertEqual((1, embedding_dims), out.shape) @@ -286,15 +286,12 @@ class SpinnTest(test_util.TensorFlowTestCase): vocab_size) # Invoke model under non-training mode. - logits = model( - prem, premise_transition=prem_trans, hypothesis=hypo, - hypothesis_transition=hypo_trans, training=False) + logits = model(prem, prem_trans, hypo, hypo_trans, training=False) self.assertEqual(tf.float32, logits.dtype) self.assertEqual((batch_size, d_out), logits.shape) # Invoke model under training model. - logits = model(prem, premise_transition=prem_trans, hypothesis=hypo, - hypothesis_transition=hypo_trans, training=True) + logits = model(prem, prem_trans, hypo, hypo_trans, training=True) self.assertEqual(tf.float32, logits.dtype) self.assertEqual((batch_size, d_out), logits.shape) diff --git a/tensorflow/python/keras/_impl/keras/engine/base_layer.py b/tensorflow/python/keras/_impl/keras/engine/base_layer.py index 5615241..755607a 100644 --- a/tensorflow/python/keras/_impl/keras/engine/base_layer.py +++ b/tensorflow/python/keras/_impl/keras/engine/base_layer.py @@ -19,6 +19,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import inspect # Necessary supplement to tf_inspect to deal with variadic args. + from six.moves import zip # pylint: disable=redefined-builtin from tensorflow.python.eager import context @@ -30,6 +32,8 @@ from tensorflow.python.keras._impl.keras import regularizers from tensorflow.python.keras._impl.keras.utils import generic_utils from tensorflow.python.layers import base as tf_base_layers from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.util import tf_decorator +from tensorflow.python.util import tf_inspect from tensorflow.python.util.tf_export import tf_export @@ -143,6 +147,7 @@ class Layer(tf_base_layers.Layer): super(Layer, self).__init__( name=name, dtype=dtype, trainable=trainable, activity_regularizer=kwargs.get('activity_regularizer')) + self._uses_inputs_arg = True # Add properties that are Keras-only for now. self.supports_masking = False @@ -213,7 +218,71 @@ class Layer(tf_base_layers.Layer): """ return inputs - def __call__(self, inputs, **kwargs): + def _inputs_from_call_args(self, call_args, call_kwargs): + """Get Layer inputs from __call__ *args and **kwargs. + + Args: + call_args: The positional arguments passed to __call__. + call_kwargs: The keyword argument dict passed to __call__. + + Returns: + A tuple of (inputs, non_input_kwargs). These may be the same objects as + were passed in (call_args and call_kwargs). + """ + if getattr(self, '_uses_inputs_arg', True): + assert len(call_args) == 1 # TypeError raised earlier in __call__. + return call_args[0], call_kwargs + else: + call_arg_spec = tf_inspect.getargspec(self.call) + # There is no explicit "inputs" argument expected or provided to + # call(). Arguments which have default values are considered non-inputs, + # and arguments without are considered inputs. + if call_arg_spec.defaults: + if call_arg_spec.varargs is not None: + raise TypeError( + 'Layer.call() may not accept both *args and arguments with ' + 'default values (unable to determine which are inputs to the ' + 'Layer).') + keyword_arg_names = set( + call_arg_spec.args[-len(call_arg_spec.defaults):]) + else: + keyword_arg_names = set() + # Training is never an input argument name, to allow signatures like + # call(x, training). + keyword_arg_names.add('training') + _, unwrapped_call = tf_decorator.unwrap(self.call) + bound_args = inspect.getcallargs( + unwrapped_call, *call_args, **call_kwargs) + if call_arg_spec.keywords is not None: + var_kwargs = bound_args.pop(call_arg_spec.keywords) + bound_args.update(var_kwargs) + keyword_arg_names = keyword_arg_names.union(var_kwargs.keys()) + all_args = call_arg_spec.args + if all_args and bound_args[all_args[0]] is self: + # Ignore the 'self' argument of methods + bound_args.pop(call_arg_spec.args[0]) + all_args = all_args[1:] + non_input_arg_values = {} + input_arg_values = [] + remaining_args_are_keyword = False + for argument_name in all_args: + if argument_name in keyword_arg_names: + remaining_args_are_keyword = True + else: + if remaining_args_are_keyword: + raise TypeError( + 'Found a positional argument to call() after a non-input ' + 'argument. All arguments after "training" must be keyword ' + 'arguments, and are not tracked as inputs to the Layer.') + if remaining_args_are_keyword: + non_input_arg_values[argument_name] = bound_args[argument_name] + else: + input_arg_values.append(bound_args[argument_name]) + if call_arg_spec.varargs is not None: + input_arg_values.extend(bound_args[call_arg_spec.varargs]) + return input_arg_values, non_input_arg_values + + def __call__(self, inputs, *args, **kwargs): """Wrapper around self.call(), for handling internal references. If a Keras tensor is passed: @@ -226,6 +295,10 @@ class Layer(tf_base_layers.Layer): Arguments: inputs: Can be a tensor or list/tuple of tensors. + *args: Additional positional arguments to be passed to `call()`. Only + allowed in subclassed Models with custom call() signatures. In other + cases, `Layer` inputs must be passed using the `inputs` argument and + non-inputs must be keyword arguments. **kwargs: Additional keyword arguments to be passed to `call()`. Returns: @@ -234,12 +307,25 @@ class Layer(tf_base_layers.Layer): Raises: ValueError: in case the layer is missing shape information for its `build` call. + TypeError: If positional arguments are passed and this `Layer` is not a + subclassed `Model`. """ # Actually call the layer (optionally building it). - output = super(Layer, self).__call__(inputs, **kwargs) + output = super(Layer, self).__call__(inputs, *args, **kwargs) + + if args and getattr(self, '_uses_inputs_arg', True): + raise TypeError( + 'This Layer takes an `inputs` argument to call(), and only the ' + '`inputs` argument may be specified as a positional argument. Pass ' + 'everything else as a keyword argument (those arguments will not be ' + 'tracked as inputs to the Layer).') + if context.executing_eagerly(): return output + inputs, kwargs = self._inputs_from_call_args( + call_args=(inputs,) + args, call_kwargs=kwargs) + if hasattr(self, '_symbolic_set_inputs') and not self.inputs: # Subclassed network: explicitly set metadata normally set by a call to # self._set_inputs(). diff --git a/tensorflow/python/keras/_impl/keras/engine/network.py b/tensorflow/python/keras/_impl/keras/engine/network.py index ea4be0d..9f1c7de 100644 --- a/tensorflow/python/keras/_impl/keras/engine/network.py +++ b/tensorflow/python/keras/_impl/keras/engine/network.py @@ -117,6 +117,7 @@ class Network(base_layer.Layer): self._inbound_nodes = [] def _init_graph_network(self, inputs, outputs, name=None): + self._uses_inputs_arg = True # Normalize and set self.inputs, self.outputs. if isinstance(inputs, (list, tuple)): self.inputs = list(inputs) # Tensor or list of tensors. @@ -274,11 +275,15 @@ class Network(base_layer.Layer): def _init_subclassed_network(self, name=None): self._base_init(name=name) self._is_graph_network = False - if 'training' in tf_inspect.getargspec(self.call).args: + call_args = tf_inspect.getargspec(self.call).args + if 'training' in call_args: self._expects_training_arg = True else: self._expects_training_arg = False - + if 'inputs' in call_args: + self._uses_inputs_arg = True + else: + self._uses_inputs_arg = False self.outputs = None self.inputs = None self.built = False diff --git a/tensorflow/python/keras/_impl/keras/engine/training.py b/tensorflow/python/keras/_impl/keras/engine/training.py index 08288d3..971245c 100644 --- a/tensorflow/python/keras/_impl/keras/engine/training.py +++ b/tensorflow/python/keras/_impl/keras/engine/training.py @@ -874,6 +874,11 @@ class Model(Network): whether to build the model's graph in inference mode (False), training mode (True), or using the Keras learning phase (None). """ + if not getattr(self, '_uses_inputs_arg', True): + raise NotImplementedError( + 'Subclassed Models without "inputs" in their call() signatures do ' + 'not yet support shape inference. File a feature request if this ' + 'limitation bothers you.') if self.__class__.__name__ == 'Sequential': # Note: we can't test whether the model is `Sequential` via `isinstance` # since `Sequential` depends on `Model`. diff --git a/tensorflow/python/keras/_impl/keras/model_subclassing_test.py b/tensorflow/python/keras/_impl/keras/model_subclassing_test.py index 58b1443..4445900 100644 --- a/tensorflow/python/keras/_impl/keras/model_subclassing_test.py +++ b/tensorflow/python/keras/_impl/keras/model_subclassing_test.py @@ -22,7 +22,9 @@ import os import tempfile import numpy as np +import six +from tensorflow.python.eager import context from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import test_util from tensorflow.python.keras._impl import keras @@ -36,6 +38,7 @@ except ImportError: h5py = None +# pylint: disable=not-callable class SimpleTestModel(keras.Model): def __init__(self, use_bn=False, use_dp=False, num_classes=10): @@ -104,7 +107,7 @@ class NestedTestModel1(keras.Model): def call(self, inputs): x = self.dense1(inputs) x = self.bn(x) - x = self.test_net(x) # pylint: disable=not-callable + x = self.test_net(x) return self.dense2(x) @@ -161,7 +164,7 @@ def get_nested_model_3(input_dim, num_classes): return tensor_shape.TensorShape((input_shape[0], 5)) test_model = Inner() - x = test_model(x) # pylint: disable=not-callable + x = test_model(x) outputs = keras.layers.Dense(num_classes)(x) return keras.Model(inputs, outputs, name='nested_model_3') @@ -574,5 +577,128 @@ class ModelSubclassingTest(test.TestCase): self.assertGreater(loss, 0.1) +class CustomCallModel(keras.Model): + + def __init__(self): + super(CustomCallModel, self).__init__() + self.dense1 = keras.layers.Dense(1, activation='relu') + self.dense2 = keras.layers.Dense(1, activation='softmax') + + def call(self, first, second, fiddle_with_output='no', training=True): + combined = self.dense1(first) + self.dense2(second) + if fiddle_with_output == 'yes': + return 10. * combined + else: + return combined + + +class CustomCallSignatureTests(test.TestCase): + + @test_util.run_in_graph_and_eager_modes() + def test_no_inputs_in_signature(self): + model = CustomCallModel() + first = array_ops.ones([2, 3]) + second = array_ops.ones([2, 5]) + output = model(first, second) + self.evaluate([v.initializer for v in model.variables]) + expected_output = self.evaluate(model.dense1(first) + model.dense2(second)) + self.assertAllClose(expected_output, self.evaluate(output)) + output = model(first, second, fiddle_with_output='yes') + self.assertAllClose(10. * expected_output, self.evaluate(output)) + output = model(first, second=second, training=False) + self.assertAllClose(expected_output, self.evaluate(output)) + if not context.executing_eagerly(): + six.assertCountEqual(self, [first, second], model.inputs) + with self.assertRaises(TypeError): + # tf.layers.Layer expects an "inputs" argument, so all-keywords doesn't + # work at the moment. + model(first=first, second=second, fiddle_with_output='yes') + + @test_util.run_in_graph_and_eager_modes() + def test_inputs_in_signature(self): + + class HasInputsAndOtherPositional(keras.Model): + + def call(self, inputs, some_other_arg, training=False): + return inputs + + model = HasInputsAndOtherPositional() + with self.assertRaisesRegexp( + TypeError, 'everything else as a keyword argument'): + model(array_ops.ones([]), array_ops.ones([])) + + @test_util.run_in_graph_and_eager_modes() + def test_kwargs_in_signature(self): + + class HasKwargs(keras.Model): + + def call(self, x, y=3, **key_words): + return x + + model = HasKwargs() + arg = array_ops.ones([]) + model(arg, a=3) + if not context.executing_eagerly(): + six.assertCountEqual(self, [arg], model.inputs) + + @test_util.run_in_graph_and_eager_modes() + def test_args_in_signature(self): + + class HasArgs(keras.Model): + + def call(self, x, *args, **kwargs): + return [x] + list(args) + + model = HasArgs() + arg1 = array_ops.ones([]) + arg2 = array_ops.ones([]) + arg3 = array_ops.ones([]) + model(arg1, arg2, arg3, a=3) + if not context.executing_eagerly(): + six.assertCountEqual(self, [arg1, arg2, arg3], model.inputs) + + def test_args_and_keywords_in_signature(self): + + class HasArgs(keras.Model): + + def call(self, x, training=True, *args, **kwargs): + return x + + with context.graph_mode(): + model = HasArgs() + arg1 = array_ops.ones([]) + arg2 = array_ops.ones([]) + arg3 = array_ops.ones([]) + with self.assertRaisesRegexp(TypeError, 'args and arguments with'): + model(arg1, arg2, arg3, a=3) + + def test_training_no_default(self): + + class TrainingNoDefault(keras.Model): + + def call(self, x, training): + return x + + with context.graph_mode(): + model = TrainingNoDefault() + arg = array_ops.ones([]) + model(arg, True) + six.assertCountEqual(self, [arg], model.inputs) + + def test_training_no_default_with_positional(self): + + class TrainingNoDefaultWithPositional(keras.Model): + + def call(self, x, training, positional): + return x + + with context.graph_mode(): + model = TrainingNoDefaultWithPositional() + arg1 = array_ops.ones([]) + arg2 = array_ops.ones([]) + arg3 = array_ops.ones([]) + with self.assertRaisesRegexp(TypeError, 'after a non-input'): + model(arg1, arg2, arg3) + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/layers/base.py b/tensorflow/python/layers/base.py index 1e5f26a..242cdff 100644 --- a/tensorflow/python/layers/base.py +++ b/tensorflow/python/layers/base.py @@ -625,6 +625,8 @@ class Layer(checkpointable.CheckpointableBase): input_list = nest.flatten(inputs) build_graph = not context.executing_eagerly() + # TODO(fchollet, allenl): Make deferred mode work with subclassed Models + # which don't use an "inputs" argument. in_deferred_mode = isinstance(input_list[0], _DeferredTensor) # Ensure the Layer, if being reused, is working with inputs from # the same graph as where it was created. diff --git a/third_party/examples/eager/spinn/spinn.py b/third_party/examples/eager/spinn/spinn.py index f8fb6ec..8a2b24a 100644 --- a/third_party/examples/eager/spinn/spinn.py +++ b/third_party/examples/eager/spinn/spinn.py @@ -266,8 +266,7 @@ class SPINN(tf.keras.Model): trackings.append(tracking) if rights: - reducer_output = self.reducer( - lefts, right_in=rights, tracking=trackings) + reducer_output = self.reducer(lefts, rights, trackings) reduced = iter(reducer_output) for transition, stack in zip(trans, stacks): @@ -388,10 +387,10 @@ class SNLIClassifier(tf.keras.Model): # Run the batch-normalized and dropout-processed word vectors through the # SPINN encoder. - premise = self.encoder( - premise_embed, transitions=premise_transition, training=training) - hypothesis = self.encoder( - hypothesis_embed, transitions=hypothesis_transition, training=training) + premise = self.encoder(premise_embed, premise_transition, + training=training) + hypothesis = self.encoder(hypothesis_embed, hypothesis_transition, + training=training) # Combine encoder outputs for premises and hypotheses into logits. # Then apply batch normalization and dropuout on the logits. @@ -465,11 +464,10 @@ class SNLIClassifierTrainer(tfe.Checkpointable): """ with tfe.GradientTape() as tape: tape.watch(self._model.variables) - # TODO(allenl): Allow passing Layer inputs as position arguments. logits = self._model(premise, - premise_transition=premise_transition, - hypothesis=hypothesis, - hypothesis_transition=hypothesis_transition, + premise_transition, + hypothesis, + hypothesis_transition, training=True) loss = self.loss(labels, logits) gradients = tape.gradient(loss, self._model.variables) @@ -533,9 +531,7 @@ def _evaluate_on_dataset(snli_data, batch_size, trainer, use_gpu): snli_data, batch_size): if use_gpu: label, prem, hypo = label.gpu(), prem.gpu(), hypo.gpu() - logits = trainer.model( - prem, premise_transition=prem_trans, hypothesis=hypo, - hypothesis_transition=hypo_trans, training=False) + logits = trainer.model(prem, prem_trans, hypo, hypo_trans, training=False) loss_val = trainer.loss(label, logits) batch_size = tf.shape(label)[0] mean_loss(loss_val, weights=batch_size.gpu() if use_gpu else batch_size) @@ -639,11 +635,8 @@ def train_or_infer_spinn(embed, hypo, hypo_trans = inference_sentence_pair[1] hypo_trans = inference_sentence_pair[1][1] inference_logits = model( - tf.constant(prem), - premise_transition=tf.constant(prem_trans), - hypothesis=tf.constant(hypo), - hypothesis_transition=tf.constant(hypo_trans), - training=False) + tf.constant(prem), tf.constant(prem_trans), + tf.constant(hypo), tf.constant(hypo_trans), training=False) inference_logits = inference_logits[0][1:] max_index = tf.argmax(inference_logits) print("\nInference logits:") -- 2.7.4