keras: Avoid unneccesary call to .call() when building models with subclassing.
authorAsim Shankar <ashankar@google.com>
Wed, 28 Feb 2018 23:18:29 +0000 (15:18 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 28 Feb 2018 23:23:36 +0000 (15:23 -0800)
This fixes a regression in the defun microbenchmarks
(ResNet50Benchmarks.eager_train_with_defun_gpu_batch_32_channels_first etc.)
in tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py
seen after https://github.com/tensorflow/tensorflow/commit/9a84277be2cb8233c5c14270db6fcdff31ab4d93
(which embeds a model in model)

Without this change, converting a model call to a graph function using
something like:
model.call = tfe.defun(model.call)
could result in redundant nodes being added to the graph function
as the model._set_inputs() call would invoke model.call() again.

PiperOrigin-RevId: 187391494

tensorflow/python/keras/_impl/keras/engine/base_layer.py
tensorflow/python/keras/_impl/keras/engine/training.py

index 1423250..7f215f5 100644 (file)
@@ -240,9 +240,10 @@ class Layer(tf_base_layers.Layer):
     if context.in_eager_mode():
       return output
 
-    # Un-built subclassed network: build it
-    if hasattr(self, '_set_inputs') and not self.inputs:
-      self._set_inputs(inputs, training=kwargs.get('training'))
+    if hasattr(self, '_symbolic_set_inputs') and not self.inputs:
+      # Subclassed network: explicitly set metadata normally set by a call to
+      # self._set_inputs().
+      self._symbolic_set_inputs(inputs, output)
 
     # Update learning phase info.
     output_tensors = generic_utils.to_list(output)
index 63bea08..c121d81 100644 (file)
@@ -1835,14 +1835,17 @@ class Model(Network):
         'output_%d' % (i + 1) for i in range(len(dummy_output_values))]
     self.built = True
 
-  def _symbolic_set_inputs(self, inputs, training=None):
-    """Set model's inputs based on the input data received from the user.
+  def _symbolic_set_inputs(self, inputs, outputs=None, training=None):
+    """Set model's inputs and output specs based.
 
     This is to be used for Model subclasses, which do not know at instantiation
     time what their inputs look like.
 
     Args:
       inputs: Argument `x` (input data) passed by the user upon first model use.
+      outputs: None, a data tensor, or a list of data tensors. If None, the
+        outputs will be determined by invoking self.call(), otherwise the
+        provided value will be used.
       training: Boolean or None. Only relevant in symbolic mode. Specifies
         whether to build the model's graph in inference mode (False), training
         mode (True), or using the Keras learning phase (None).
@@ -1892,17 +1895,18 @@ class Model(Network):
           self._feed_input_names.append(name)
           self._feed_input_shapes.append(K.int_shape(v))
 
-    # Obtain symbolic outputs by calling the model.
-    if len(self.inputs) == 1:
-      if self._expects_training_arg:
-        outputs = self.call(self.inputs[0], training=training)
-      else:
-        outputs = self.call(self.inputs[0])
-    else:
-      if self._expects_training_arg:
-        outputs = self.call(self.inputs, training=training)
+    if outputs is None:
+      # Obtain symbolic outputs by calling the model.
+      if len(self.inputs) == 1:
+        if self._expects_training_arg:
+          outputs = self.call(self.inputs[0], training=training)
+        else:
+          outputs = self.call(self.inputs[0])
       else:
-        outputs = self.call(self.inputs)
+        if self._expects_training_arg:
+          outputs = self.call(self.inputs, training=training)
+        else:
+          outputs = self.call(self.inputs)
     if isinstance(outputs, (list, tuple)):
       outputs = list(outputs)
     else: