Add a single positional argument mode for shape inference in subclassed Models.
authorAllen Lavoie <allenl@google.com>
Fri, 1 Jun 2018 02:03:21 +0000 (19:03 -0700)
committerAlex Passos <apassos@google.com>
Thu, 21 Jun 2018 21:08:58 +0000 (14:08 -0700)
Allows fit() when call's signature looks something like call(x, training=True).

Calling conventions are "inputs", single positional, and multiple positional. Right now the distinction between "inputs" and single positional calling conventions is the text of one error message. Both support shape inference (which just hasn't been implemented for multiple positional input arguments yet).

PiperOrigin-RevId: 198815483

tensorflow/python/keras/engine/base_layer.py
tensorflow/python/keras/engine/network.py
tensorflow/python/keras/engine/training.py
tensorflow/python/keras/model_subclassing_test.py

index 24716cfbe4978c6fd5a7884a581524f804817d60..4814275fd5ba53e7845f383b3447a6ef9f47f6c2 100644 (file)
@@ -19,6 +19,7 @@ from __future__ import division
 from __future__ import print_function
 
 import collections
+import enum  # pylint: disable=g-bad-import-order
 import inspect  # Necessary supplement to tf_inspect to deal with variadic args.
 
 import numpy as np
@@ -50,6 +51,20 @@ from tensorflow.python.util import tf_inspect
 from tensorflow.python.util.tf_export import tf_export
 
 
+class CallConvention(enum.Enum):
+  """Calling conventions for passing `Layer` inputs to `Layer.call`."""
+  # The Layer takes inputs as its first argument, named "inputs" for
+  # compatibility with the signature of Layer.__call__. This is the mode assumed
+  # for Layers which are not subclassed Models.
+  EXPLICIT_INPUTS_ARGUMENT = 1
+  # The Layer takes a single positional argument, not named "inputs". It's
+  # treated like an "inputs" argument.
+  SINGLE_POSITIONAL_ARGUMENT = 2
+  # The Layer has multiple positional arguments to which its inputs should be
+  # bound.
+  POSITIONAL_ARGUMENTS_ARE_INPUTS = 3
+
+
 @tf_export('keras.layers.Layer')
 class Layer(checkpointable.CheckpointableBase):
   """Base layer class.
@@ -149,7 +164,7 @@ class Layer(checkpointable.CheckpointableBase):
     self._call_fn_args = function_utils.fn_args(self.call)
     self._compute_previous_mask = ('mask' in self._call_fn_args or
                                    hasattr(self, 'compute_mask'))
-    self._uses_inputs_arg = True
+    self._call_convention = CallConvention.EXPLICIT_INPUTS_ARGUMENT
 
     # These lists will be filled via successive calls
     # to self._add_inbound_node().
@@ -793,12 +808,22 @@ class Layer(checkpointable.CheckpointableBase):
           pass  # C type such as dict. Masking not supported in this case.
 
   def _set_connectivity_metadata_(self, inputs, outputs, args, kwargs):
-    if args and getattr(self, '_uses_inputs_arg', True):
-      raise TypeError(
-          'This Layer takes an `inputs` argument to call(), and only the '
-          '`inputs` argument may be specified as a positional argument. '
-          'Pass everything else as a keyword argument (those arguments will'
-          ' not be tracked as inputs to the Layer).')
+    call_convention = getattr(self, '_call_convention',
+                              CallConvention.EXPLICIT_INPUTS_ARGUMENT)
+    if args:
+      if call_convention == CallConvention.EXPLICIT_INPUTS_ARGUMENT:
+        raise TypeError(
+            'This Layer takes an `inputs` argument to call(), and only the '
+            '`inputs` argument may be specified as a positional argument. '
+            'Pass everything else as a keyword argument (those arguments will'
+            ' not be tracked as inputs to the Layer).')
+      elif call_convention == CallConvention.SINGLE_POSITIONAL_ARGUMENT:
+        raise TypeError(
+            'This Layer takes a single positional argument to call(), which is '
+            'by convention the inputs argument, and only this argument may be '
+            'specified as a positional argument. Pass everything else as a '
+            'keyword argument (those arguments will not be tracked as inputs '
+            'to the Layer).')
 
     # If the layer returns tensors from its inputs, unmodified,
     # we copy them to avoid loss of tensor metadata.
@@ -834,7 +859,11 @@ class Layer(checkpointable.CheckpointableBase):
       A tuple of (inputs, non_input_kwargs). These may be the same objects as
       were passed in (call_args and call_kwargs).
     """
-    if getattr(self, '_uses_inputs_arg', True):
+    call_convention = getattr(self, '_call_convention',
+                              CallConvention.EXPLICIT_INPUTS_ARGUMENT)
+    if (call_convention in (
+        CallConvention.EXPLICIT_INPUTS_ARGUMENT,
+        CallConvention.SINGLE_POSITIONAL_ARGUMENT)):
       assert len(call_args) == 1  # TypeError raised earlier in __call__.
       return call_args[0], call_kwargs
     else:
index 3d567b8378390eddf021781ef1175bfe488a6450..6f27eea1e7921e7381aea82b00bfd6e76699f468 100644 (file)
@@ -135,7 +135,7 @@ class Network(base_layer.Layer):
     self._in_progress_restore_finalizer = None
 
   def _init_graph_network(self, inputs, outputs, name=None):
-    self._uses_inputs_arg = True
+    self._call_convention = base_layer.CallConvention.EXPLICIT_INPUTS_ARGUMENT
     # Normalize and set self.inputs, self.outputs.
     if isinstance(inputs, (list, tuple)):
       self.inputs = list(inputs)  # Tensor or list of tensors.
@@ -295,19 +295,55 @@ class Network(base_layer.Layer):
   def _init_subclassed_network(self, name=None):
     self._base_init(name=name)
     self._is_graph_network = False
-    call_args = tf_inspect.getargspec(self.call).args
-    if 'training' in call_args:
+    call_argspec = tf_inspect.getargspec(self.call)
+    if 'training' in call_argspec.args:
       self._expects_training_arg = True
     else:
       self._expects_training_arg = False
-    if 'inputs' in call_args:
-      self._uses_inputs_arg = True
-    else:
-      self._uses_inputs_arg = False
+    self._call_convention = self._determine_call_convention(call_argspec)
     self.outputs = None
     self.inputs = None
     self.built = False
 
+  def _determine_call_convention(self, call_argspec):
+    """Decides how `self.call()` is invoked. See base_layer.CallConvention."""
+    if call_argspec.varargs:
+      may_take_single_argument = False
+    else:
+      try:
+        # Note: tf_inspect doesn't raise a TypeError when regular inspect would,
+        # so we need to keep in mind that "getcallargs" may have returned
+        # something even though we under-specified positional arguments.
+        all_args = tf_inspect.getcallargs(self.call, None)
+        self_args = set()
+        for arg_name, obj in all_args.items():
+          if obj is self:
+            self_args.add(arg_name)
+        may_take_single_argument = True
+      except TypeError:
+        may_take_single_argument = False
+    if may_take_single_argument:
+      # A single positional argument (plus "self") is considered equivalent to
+      # an "inputs" argument.
+      all_positional_args = len(call_argspec.args)
+      if call_argspec.defaults is not None:
+        all_positional_args -= len(call_argspec.defaults)
+      non_self_positional_args = all_positional_args
+      for positional_arg_name in call_argspec.args[:all_positional_args]:
+        if positional_arg_name in self_args:
+          non_self_positional_args -= 1
+      if non_self_positional_args == 1:
+        if 'inputs' in call_argspec.args[all_positional_args:]:
+          raise TypeError(
+              "Model.call() takes a single positional argument (to which "
+              "inputs are passed by convention) and a separate 'inputs' "
+              "argument. Unable to determine which arguments are inputs.")
+        return base_layer.CallConvention.SINGLE_POSITIONAL_ARGUMENT
+    if 'inputs' in call_argspec.args:
+      return base_layer.CallConvention.EXPLICIT_INPUTS_ARGUMENT
+    else:
+      return base_layer.CallConvention.POSITIONAL_ARGUMENTS_ARE_INPUTS
+
   def _track_layers(self, layers):
     """Add Checkpointable dependencies on a list of Layers."""
     weight_layer_index = 0
index 6d625f16c2b04544af135af94dc25641f11c41e0..04a2aa766444ca44016f37ed3d20d955de831fc4 100644 (file)
@@ -31,12 +31,11 @@ from tensorflow.python.keras import backend as K
 from tensorflow.python.keras import losses
 from tensorflow.python.keras import metrics as metrics_module
 from tensorflow.python.keras import optimizers
+from tensorflow.python.keras.engine import base_layer
 from tensorflow.python.keras.engine import training_arrays
 from tensorflow.python.keras.engine import training_eager
 from tensorflow.python.keras.engine import training_generator
 from tensorflow.python.keras.engine import training_utils
-from tensorflow.python.keras.engine.base_layer import DeferredTensor
-from tensorflow.python.keras.engine.base_layer import Layer
 from tensorflow.python.keras.engine.network import Network
 from tensorflow.python.keras.utils.generic_utils import slice_arrays
 from tensorflow.python.ops import array_ops
@@ -523,7 +522,7 @@ class Model(Network):
 
             # Keep track of state updates created by
             # stateful metrics (i.e. metrics layers).
-            if isinstance(metric_fn, Layer) and metric_fn.stateful:
+            if isinstance(metric_fn, base_layer.Layer) and metric_fn.stateful:
               self.stateful_metric_names.append(metric_name)
               self.stateful_metric_functions.append(metric_fn)
               self.metrics_updates += metric_fn.updates
@@ -959,11 +958,17 @@ class Model(Network):
         whether to build the model's graph in inference mode (False), training
         mode (True), or using the Keras learning phase (None).
     """
-    if not getattr(self, '_uses_inputs_arg', True):
+    call_convention = getattr(
+        self,
+        '_call_convention',
+        base_layer.CallConvention.EXPLICIT_INPUTS_ARGUMENT)
+    if call_convention not in (
+        base_layer.CallConvention.EXPLICIT_INPUTS_ARGUMENT,
+        base_layer.CallConvention.SINGLE_POSITIONAL_ARGUMENT):
       raise NotImplementedError(
-          'Subclassed Models without "inputs" in their call() signatures do '
-          'not yet support shape inference. File a feature request if this '
-          'limitation bothers you.')
+          'Subclassed Models without "inputs" (or single positional arguments) '
+          'in their call() signatures do not yet support shape inference. File '
+          'a feature request if this limitation bothers you.')
     if self.__class__.__name__ == 'Sequential':
       # Note: we can't test whether the model is `Sequential` via `isinstance`
       # since `Sequential` depends on `Model`.
@@ -1020,11 +1025,11 @@ class Model(Network):
     else:
       dummy_output_values = [dummy_output_values]
     self.outputs = [
-        DeferredTensor(shape=(None for _ in v.shape),
-                       dtype=v.dtype) for v in dummy_output_values]
+        base_layer.DeferredTensor(shape=(None for _ in v.shape),
+                                  dtype=v.dtype) for v in dummy_output_values]
     self.inputs = [
-        DeferredTensor(shape=(None for _ in v.shape),
-                       dtype=v.dtype) for v in dummy_input_values]
+        base_layer.DeferredTensor(shape=(None for _ in v.shape),
+                                  dtype=v.dtype) for v in dummy_input_values]
     self.input_names = [
         'input_%d' % (i + 1) for i in range(len(dummy_input_values))]
     self.output_names = [
index 86f7e20bec2800b2da92c390dc1b0885d0825097..8fb957da439dd490bc3378df96f611733335c809 100644 (file)
@@ -56,8 +56,8 @@ class SimpleTestModel(keras.Model):
     if self.use_bn:
       self.bn = keras.layers.BatchNormalization(axis=-1)
 
-  def call(self, inputs):
-    x = self.dense1(inputs)
+  def call(self, x):
+    x = self.dense1(x)
     if self.use_dp:
       x = self.dp(x)
     if self.use_bn: