if context.in_eager_mode():
return output
- # Un-built subclassed network: build it
- if hasattr(self, '_set_inputs') and not self.inputs:
- self._set_inputs(inputs, training=kwargs.get('training'))
+ if hasattr(self, '_symbolic_set_inputs') and not self.inputs:
+ # Subclassed network: explicitly set metadata normally set by a call to
+ # self._set_inputs().
+ self._symbolic_set_inputs(inputs, output)
# Update learning phase info.
output_tensors = generic_utils.to_list(output)
'output_%d' % (i + 1) for i in range(len(dummy_output_values))]
self.built = True
- def _symbolic_set_inputs(self, inputs, training=None):
- """Set model's inputs based on the input data received from the user.
+ def _symbolic_set_inputs(self, inputs, outputs=None, training=None):
+ """Set model's inputs and output specs based.
This is to be used for Model subclasses, which do not know at instantiation
time what their inputs look like.
Args:
inputs: Argument `x` (input data) passed by the user upon first model use.
+ outputs: None, a data tensor, or a list of data tensors. If None, the
+ outputs will be determined by invoking self.call(), otherwise the
+ provided value will be used.
training: Boolean or None. Only relevant in symbolic mode. Specifies
whether to build the model's graph in inference mode (False), training
mode (True), or using the Keras learning phase (None).
self._feed_input_names.append(name)
self._feed_input_shapes.append(K.int_shape(v))
- # Obtain symbolic outputs by calling the model.
- if len(self.inputs) == 1:
- if self._expects_training_arg:
- outputs = self.call(self.inputs[0], training=training)
- else:
- outputs = self.call(self.inputs[0])
- else:
- if self._expects_training_arg:
- outputs = self.call(self.inputs, training=training)
+ if outputs is None:
+ # Obtain symbolic outputs by calling the model.
+ if len(self.inputs) == 1:
+ if self._expects_training_arg:
+ outputs = self.call(self.inputs[0], training=training)
+ else:
+ outputs = self.call(self.inputs[0])
else:
- outputs = self.call(self.inputs)
+ if self._expects_training_arg:
+ outputs = self.call(self.inputs, training=training)
+ else:
+ outputs = self.call(self.inputs)
if isinstance(outputs, (list, tuple)):
outputs = list(outputs)
else: