Fixes a number of usability issues with model_to_estimator, in particular:
authorFrancois Chollet <fchollet@google.com>
Tue, 6 Mar 2018 02:49:53 +0000 (18:49 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 6 Mar 2018 02:53:40 +0000 (18:53 -0800)
- make it possible to use a model that was compiled with a TF optimizer (do not require a Keras optimizer)
- do not require input to be dict (input_fn supports plain arrays)
- do not require `config` to be a RunConfig instance, can now be a dict (better UX)
- make it possible to use a subclassed model (caveat: weights are not preserved, yet)
- clear error message when model isn't compiled; improve various error messages

PiperOrigin-RevId: 187959927

tensorflow/python/keras/_impl/keras/estimator.py
tensorflow/python/keras/_impl/keras/estimator_test.py
tensorflow/python/layers/base.py

index 5697771..081f25e 100644 (file)
@@ -25,11 +25,15 @@ from tensorflow.python.client import session
 from tensorflow.python.estimator import estimator as estimator_lib
 from tensorflow.python.estimator import export as export_lib
 from tensorflow.python.estimator import model_fn as model_fn_lib
+from tensorflow.python.estimator import run_config as run_config_lib
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import random_seed
 from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
 from tensorflow.python.keras._impl.keras import backend as K
 from tensorflow.python.keras._impl.keras import models
+from tensorflow.python.keras._impl.keras import optimizers
+from tensorflow.python.keras._impl.keras.engine.base_layer import Layer
+from tensorflow.python.keras._impl.keras.engine.network import Network
 from tensorflow.python.keras._impl.keras.utils.generic_utils import CustomObjectScope
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import metrics as metrics_module
@@ -50,36 +54,174 @@ def _cast_tensor_to_floatx(x):
     return math_ops.cast(x, K.floatx())
 
 
-def _create_ordered_io(keras_model, estimator_io_dict, is_input=True):
+def _create_ordered_io(keras_model, estimator_io, is_input=True):
   """Create a list of tensors from IO dictionary based on Keras IO order.
 
   Args:
-    keras_model: an instance of compiled keras model.
-    estimator_io_dict: features or labels dictionary from model_fn.
+    keras_model: An instance of compiled keras model.
+    estimator_io: The features or labels (dict or plain array) from model_fn.
     is_input: True if dictionary is for inputs.
 
   Returns:
-    a list of tensors based on Keras IO order.
+    A list of tensors based on Keras IO order.
 
   Raises:
     ValueError: if dictionary keys cannot be found in Keras model input_names
       or output_names.
   """
-  if is_input:
-    keras_io_names = keras_model.input_names
+  if isinstance(estimator_io, (list, tuple)):
+    # Case currently not supported by most built-in input_fn,
+    # but it's good to have for sanity
+    return [_cast_tensor_to_floatx(x) for x in estimator_io]
+  elif isinstance(estimator_io, dict):
+    if is_input:
+      if keras_model._is_graph_network:
+        keras_io_names = keras_model.input_names
+      else:
+        keras_io_names = [
+            'input_%d' % i for i in range(1, len(estimator_io) + 1)]
+    else:
+      if keras_model._is_graph_network:
+        keras_io_names = keras_model.output_names
+      else:
+        keras_io_names = [
+            'output_%d' % i for i in range(1, len(estimator_io) + 1)]
+
+    for key in estimator_io:
+      if key not in keras_io_names:
+        raise ValueError(
+            'Cannot find %s with name "%s" in Keras Model. '
+            'It needs to match one '
+            'of the following: %s' % ('input' if is_input else 'output', key,
+                                      ', '.join(keras_io_names)))
+      tensors = [_cast_tensor_to_floatx(estimator_io[io_name])
+                 for io_name in keras_io_names]
+    return tensors
   else:
-    keras_io_names = keras_model.output_names
+    # Plain array.
+    return _cast_tensor_to_floatx(estimator_io)
 
-  for key in estimator_io_dict:
-    if key not in keras_io_names:
-      raise ValueError(
-          'Cannot find %s with name "%s" in Keras Model. It needs to match '
-          'one of the following: %s' % ('input' if is_input else 'output', key,
-                                        ', '.join(keras_io_names)))
-  tensors = []
-  for io_name in keras_io_names:
-    tensors.append(_cast_tensor_to_floatx(estimator_io_dict[io_name]))
-  return tensors
+
+def _in_place_subclassed_model_reset(model):
+  """Substitute for model cloning that works for subclassed models.
+
+  Subclassed models cannot be cloned because their topology is not serializable.
+  To "instantiate" an identical model in a new TF graph, we reuse the original
+  model object, but we clear its state.
+
+  After calling this function on a model intance, you can use the model instance
+  as if it were a model clone (in particular you can use it in a new graph).
+
+  This method clears the state of the input model. It is thus destructive.
+  However the original state can be restored fully by calling
+  `_in_place_subclassed_model_state_restoration`.
+
+  Args:
+    model: Instance of a Keras model created via subclassing.
+
+  Raises:
+    ValueError: In case the model uses a subclassed model as inner layer.
+  """
+  assert not model._is_graph_network  # Only makes sense for subclassed networks
+  # Retrieve all layers tracked by the model as well as their attribute names
+  attributes_cache = {}
+  for name in dir(model):
+    try:
+      value = getattr(model, name)
+    except (AttributeError, ValueError, TypeError):
+      continue
+    if isinstance(value, Layer):
+      attributes_cache[name] = value
+      assert value in model._layers
+    elif isinstance(value, (list, tuple)) and name not in ('layers', '_layers'):
+      # Handle case: list/tuple of layers (also tracked by the Network API).
+      if value and all(isinstance(val, Layer) for val in value):
+        raise ValueError('We do not support the use of list-of-layers '
+                         'attributes in subclassed models used with '
+                         '`model_to_estimator` at this time. Found list '
+                         'model: %s' % name)
+
+  # Replace layers on the model with fresh layers
+  layers_to_names = {value: key for key, value in attributes_cache.items()}
+  original_layers = model._layers[:]
+  model._layers = []
+  for layer in original_layers:  # We preserve layer order.
+    config = layer.get_config()
+    # This will not work for nested subclassed models used as layers.
+    # This would be theoretically possible to support, but would add complexity.
+    # Only do it if users complain.
+    if isinstance(layer, Network) and not layer._is_graph_network:
+      raise ValueError('We do not support the use of nested subclassed models '
+                       'in `model_to_estimator` at this time. Found nested '
+                       'model: %s' % layer)
+    fresh_layer = layer.__class__.from_config(config)
+    name = layers_to_names[layer]
+    setattr(model, name, fresh_layer)
+
+  # Cache original model build attributes (in addition to layers)
+  if (not hasattr(model, '_original_attributes_cache') or
+      model._original_attributes_cache is None):
+    if model.built:
+      attributes_to_cache = [
+          'inputs',
+          'outputs',
+          '_feed_outputs',
+          '_feed_output_names',
+          '_feed_output_shapes',
+          '_feed_loss_fns',
+          'loss_weights_list',
+          'targets',
+          '_feed_targets',
+          'sample_weight_modes',
+          'weighted_metrics',
+          'metrics_names',
+          'metrics_tensors',
+          'metrics_updates',
+          'stateful_metric_names',
+          'total_loss',
+          'sample_weights',
+          '_feed_sample_weights',
+          'train_function',
+          'test_function',
+          'predict_function',
+          '_collected_trainable_weights',
+          '_feed_inputs',
+          '_feed_input_names',
+          '_feed_input_shapes',
+          'optimizer',
+      ]
+      for name in attributes_to_cache:
+        attributes_cache[name] = getattr(model, name)
+  model._original_attributes_cache = attributes_cache
+
+  # Reset built state
+  model.built = False
+  model.inputs = None
+  model.outputs = None
+
+
+def _in_place_subclassed_model_state_restoration(model):
+  """Restores the original state of a model after it was "reset".
+
+  This undoes this action of `_in_place_subclassed_model_reset`.
+
+  Args:
+    model: Instance of a Keras model created via subclassing, on which
+      `_in_place_subclassed_model_reset` was previously called.
+  """
+  assert not model._is_graph_network
+  # Restore layers and build attributes
+  if (hasattr(model, '_original_attributes_cache') and
+      model._original_attributes_cache is not None):
+    model._layers = []
+    for name, value in model._original_attributes_cache.items():
+      setattr(model, name, value)
+    model._original_attributes_cache = None
+  else:
+    # Restore to the state of a never-called model.
+    model.built = False
+    model.inputs = None
+    model.outputs = None
 
 
 def _clone_and_build_model(mode,
@@ -93,8 +235,8 @@ def _clone_and_build_model(mode,
     mode: training mode.
     keras_model: an instance of compiled keras model.
     custom_objects: Dictionary for custom objects.
-    features:
-    labels:
+    features: Dict of tensors.
+    labels: Dict of tensors, or single tensor instance.
 
   Returns:
     The newly built model.
@@ -102,33 +244,49 @@ def _clone_and_build_model(mode,
   # Set to True during training, False for inference.
   K.set_learning_phase(mode == model_fn_lib.ModeKeys.TRAIN)
 
-  # Clone keras model.
-  input_tensors = None if features is None else _create_ordered_io(
-      keras_model, features)
-  if custom_objects:
-    with CustomObjectScope(custom_objects):
+  # Get list of inputs.
+  if features is None:
+    input_tensors = None
+  else:
+    input_tensors = _create_ordered_io(keras_model,
+                                       estimator_io=features,
+                                       is_input=True)
+  # Get list of outputs.
+  if labels is None:
+    target_tensors = None
+  elif isinstance(labels, dict):
+    target_tensors = _create_ordered_io(keras_model,
+                                        estimator_io=labels,
+                                        is_input=False)
+  else:
+    target_tensors = [
+        _cast_tensor_to_floatx(
+            sparse_tensor_lib.convert_to_tensor_or_sparse_tensor(labels))
+    ]
+
+  if keras_model._is_graph_network:
+    if custom_objects:
+      with CustomObjectScope(custom_objects):
+        model = models.clone_model(keras_model, input_tensors=input_tensors)
+    else:
       model = models.clone_model(keras_model, input_tensors=input_tensors)
   else:
-    model = models.clone_model(keras_model, input_tensors=input_tensors)
+    model = keras_model
+    _in_place_subclassed_model_reset(model)
+    if input_tensors is not None:
+      model._set_inputs(input_tensors)
 
   # Compile/Build model
-  if mode is model_fn_lib.ModeKeys.PREDICT and not model.built:
-    model.build()
+  if mode is model_fn_lib.ModeKeys.PREDICT:
+    if isinstance(model, models.Sequential):
+      model.build()
   else:
-    optimizer_config = keras_model.optimizer.get_config()
-    optimizer = keras_model.optimizer.__class__.from_config(optimizer_config)
-    optimizer.iterations = training_util.get_or_create_global_step()
-
-    # Get list of outputs.
-    if labels is None:
-      target_tensors = None
-    elif isinstance(labels, dict):
-      target_tensors = _create_ordered_io(keras_model, labels, is_input=False)
+    if isinstance(keras_model.optimizer, optimizers.TFOptimizer):
+      optimizer = keras_model.optimizer
     else:
-      target_tensors = [
-          _cast_tensor_to_floatx(
-              sparse_tensor_lib.convert_to_tensor_or_sparse_tensor(labels))
-      ]
+      optimizer_config = keras_model.optimizer.get_config()
+      optimizer = keras_model.optimizer.__class__.from_config(optimizer_config)
+    optimizer.iterations = training_util.get_or_create_global_step()
 
     model.compile(
         optimizer,
@@ -168,10 +326,14 @@ def _create_keras_model_fn(keras_model, custom_objects=None):
 
     # Set loss and metric only during train and evaluate.
     if mode is not model_fn_lib.ModeKeys.PREDICT:
-      model._make_train_function()  # pylint: disable=protected-access
+      if mode is model_fn_lib.ModeKeys.TRAIN:
+        model._make_train_function()  # pylint: disable=protected-access
+      else:
+        model._make_test_function()  # pylint: disable=protected-access
       loss = model.total_loss
 
       if model.metrics:
+        # TODO(fchollet): support stateful metrics
         eval_metric_ops = {}
         # When each metric maps to an output
         if isinstance(model.metrics, dict):
@@ -195,6 +357,10 @@ def _create_keras_model_fn(keras_model, custom_objects=None):
     if mode is model_fn_lib.ModeKeys.TRAIN:
       train_op = model.train_function.updates_op
 
+    if not model._is_graph_network:
+      # Reset model state to original state,
+      # to avoid `model_fn` being destructive for the initial model argument.
+      _in_place_subclassed_model_state_restoration(keras_model)
     return model_fn_lib.EstimatorSpec(
         mode=mode,
         predictions=predictions,
@@ -274,10 +440,11 @@ def model_to_estimator(keras_model=None,
   """
   if (not keras_model) and (not keras_model_path):
     raise ValueError(
-        'Either keras_model or keras_model_path needs to be provided.')
+        'Either `keras_model` or `keras_model_path` needs to be provided.')
   if keras_model and keras_model_path:
     raise ValueError(
-        'Please specity either keras_model or keras_model_path but not both.')
+        'Please specity either `keras_model` or `keras_model_path`, '
+        'but not both.')
 
   if not keras_model:
     if keras_model_path.startswith(
@@ -288,22 +455,42 @@ def model_to_estimator(keras_model=None,
     logging.info('Loading models from %s', keras_model_path)
     keras_model = models.load_model(keras_model_path)
   else:
-    logging.info('Using the Keras model from memory.')
+    logging.info('Using the Keras model provided.')
     keras_model = keras_model
 
-  if not hasattr(keras_model, 'optimizer'):
+  if not hasattr(keras_model, 'optimizer') or not keras_model.optimizer:
     raise ValueError(
-        'Given keras model has not been compiled yet. Please compile first '
-        'before creating the estimator.')
+        'The given keras model has not been compiled yet. Please compile first '
+        'before calling `model_to_estimator`.')
+
+  if isinstance(config, dict):
+    config = run_config_lib.RunConfig(**config)
 
   keras_model_fn = _create_keras_model_fn(keras_model, custom_objects)
-  est = estimator_lib.Estimator(
+  estimator = estimator_lib.Estimator(
       keras_model_fn, model_dir=model_dir, config=config)
+
   # Pass the config into keras backend's default session.
-  with session.Session(config=est._session_config) as sess:
+  with session.Session(config=estimator._session_config) as sess:
     K.set_session(sess)
 
   keras_weights = keras_model.get_weights()
-  # TODO(yifeif): move checkpoint initialization to scaffold.init_fn
-  _save_first_checkpoint(keras_model, est, custom_objects, keras_weights)
-  return est
+  if keras_model._is_graph_network:
+    # TODO(yifeif): move checkpoint initialization to scaffold.init_fn
+    _save_first_checkpoint(keras_model,
+                           estimator,
+                           custom_objects,
+                           keras_weights)
+  elif keras_model.built:
+    logging.warning('You are creating an Estimator from a Keras model '
+                    'manually subclassed from `Model`, that was '
+                    'already called on some inputs (and thus already had '
+                    'weights). We are currently unable to preserve '
+                    'the model\'s state (its weights) '
+                    'as part of the estimator '
+                    'in this case. Be warned that the estimator '
+                    'has been created using '
+                    'a freshly initialized version of your model.\n'
+                    'Note that this doesn\'t affect the state of the '
+                    'model instance you passed as `keras_model` argument.')
+  return estimator
index a9de5dd..e076dc2 100644 (file)
@@ -34,6 +34,7 @@ from tensorflow.python.keras._impl.keras.applications import mobilenet
 from tensorflow.python.platform import gfile
 from tensorflow.python.platform import test
 from tensorflow.python.summary.writer import writer_cache
+from tensorflow.python.training import rmsprop
 
 
 try:
@@ -64,12 +65,42 @@ def simple_functional_model():
   return model
 
 
-def get_resource_for_simple_model(is_sequential=True, is_evaluate=False):
-  model = simple_sequential_model(
-  ) if is_sequential else simple_functional_model()
-  if is_sequential:
+def simple_subclassed_model():
+
+  class SimpleModel(keras.Model):
+
+    def __init__(self):
+      super(SimpleModel, self).__init__()
+      self.dense1 = keras.layers.Dense(16, activation='relu')
+      self.dp = keras.layers.Dropout(0.1)
+      self.dense2 = keras.layers.Dense(_NUM_CLASS, activation='softmax')
+
+    def call(self, inputs):
+      x = self.dense1(inputs)
+      x = self.dp(x)
+      return self.dense2(x)
+
+  return SimpleModel()
+
+
+def get_resource_for_simple_model(model_type='sequential',
+                                  is_evaluate=False,):
+  if model_type == 'sequential':
+    model = simple_sequential_model()
     model.build()
-  input_name = model.input_names[0]
+  elif model_type == 'subclass':
+    model = simple_subclassed_model()
+  else:
+    assert model_type == 'functional'
+    model = simple_functional_model()
+
+  if model_type == 'subclass':
+    input_name = 'input_1'
+    output_name = 'output_1'
+  else:
+    input_name = model.input_names[0]
+    output_name = model.output_names[0]
+
   np.random.seed(_RANDOM_SEED)
   (x_train, y_train), (x_test, y_test) = testing_utils.get_test_data(
       train_samples=_TRAIN_SIZE,
@@ -80,17 +111,19 @@ def get_resource_for_simple_model(is_sequential=True, is_evaluate=False):
   y_test = keras.utils.to_categorical(y_test)
 
   train_input_fn = numpy_io.numpy_input_fn(
-      x={input_name: x_train},
-      y=y_train,
+      x=randomize_io_type(x_train, input_name),
+      y=randomize_io_type(y_train, output_name),
       shuffle=False,
       num_epochs=None,
       batch_size=16)
 
   evaluate_input_fn = numpy_io.numpy_input_fn(
-      x={input_name: x_test}, y=y_test, num_epochs=1, shuffle=False)
+      x=randomize_io_type(x_test, input_name),
+      y=randomize_io_type(y_test, output_name),
+      num_epochs=1, shuffle=False)
 
   predict_input_fn = numpy_io.numpy_input_fn(
-      x={input_name: x_test}, num_epochs=1, shuffle=False)
+      x=randomize_io_type(x_test, input_name), num_epochs=1, shuffle=False)
 
   inference_input_fn = evaluate_input_fn if is_evaluate else predict_input_fn
 
@@ -98,6 +131,14 @@ def get_resource_for_simple_model(is_sequential=True, is_evaluate=False):
                                      y_test), train_input_fn, inference_input_fn
 
 
+def randomize_io_type(array, name):
+  switch = np.random.random()
+  if switch > 0.5:
+    return array
+  else:
+    return {name: array}
+
+
 def multi_inputs_multi_outputs_model():
   # test multi-input layer
   a = keras.layers.Input(shape=(16,), name='input_a')
@@ -134,10 +175,10 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
       gfile.DeleteRecursively(self._base_dir)
 
   def test_train(self):
-    for is_sequential in [True, False]:
+    for model_type in ['sequential', 'functional']:
       keras_model, (_, _), (
           _, _), train_input_fn, eval_input_fn = get_resource_for_simple_model(
-              is_sequential=is_sequential, is_evaluate=True)
+              model_type=model_type, is_evaluate=True)
       keras_model.compile(
           loss='categorical_crossentropy',
           optimizer='rmsprop',
@@ -155,10 +196,87 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
       writer_cache.FileWriterCache.clear()
       gfile.DeleteRecursively(self._config.model_dir)
 
+  def test_train_with_tf_optimizer(self):
+    for model_type in ['sequential', 'functional']:
+      keras_model, (_, _), (
+          _, _), train_input_fn, eval_input_fn = get_resource_for_simple_model(
+              model_type=model_type, is_evaluate=True)
+      keras_model.compile(
+          loss='categorical_crossentropy',
+          optimizer=rmsprop.RMSPropOptimizer(1e-3),
+          metrics=['mse', keras.metrics.categorical_accuracy])
+
+      with self.test_session():
+        est_keras = keras.estimator.model_to_estimator(
+            keras_model=keras_model,
+            # Also use dict config argument to get test coverage for that line.
+            config={
+                'tf_random_seed': _RANDOM_SEED,
+                'model_dir': self._base_dir,
+            })
+        before_eval_results = est_keras.evaluate(
+            input_fn=eval_input_fn, steps=1)
+        est_keras.train(input_fn=train_input_fn, steps=_TRAIN_SIZE / 16)
+        after_eval_results = est_keras.evaluate(input_fn=eval_input_fn, steps=1)
+        self.assertLess(after_eval_results['loss'], before_eval_results['loss'])
+
+      writer_cache.FileWriterCache.clear()
+      gfile.DeleteRecursively(self._config.model_dir)
+
+  def test_train_with_subclassed_model(self):
+    keras_model, (_, _), (
+        _, _), train_input_fn, eval_input_fn = get_resource_for_simple_model(
+            model_type='subclass', is_evaluate=True)
+    keras_model.compile(
+        loss='categorical_crossentropy',
+        optimizer=rmsprop.RMSPropOptimizer(1e-3),
+        metrics=['mse', keras.metrics.categorical_accuracy])
+
+    with self.test_session():
+      est_keras = keras.estimator.model_to_estimator(
+          keras_model=keras_model, config=self._config)
+      est_keras.train(input_fn=train_input_fn, steps=_TRAIN_SIZE / 16)
+      before_eval_results = est_keras.evaluate(
+          input_fn=eval_input_fn, steps=1)
+      est_keras.train(input_fn=train_input_fn, steps=_TRAIN_SIZE / 16)
+      after_eval_results = est_keras.evaluate(input_fn=eval_input_fn, steps=1)
+      self.assertLess(after_eval_results['loss'], before_eval_results['loss'])
+
+  def test_train_with_subclassed_model_with_existing_state(self):
+    keras_model, (_, _), (
+        _, _), train_input_fn, eval_input_fn = get_resource_for_simple_model(
+            model_type='subclass', is_evaluate=True)
+    keras_model.compile(
+        loss='categorical_crossentropy',
+        optimizer=rmsprop.RMSPropOptimizer(1e-3),
+        metrics=['mse', keras.metrics.categorical_accuracy])
+
+    with self.test_session():
+      # Create state
+      keras_model.train_on_batch(np.random.random((10,) + _INPUT_SIZE),
+                                 np.random.random((10, _NUM_CLASS)))
+      original_preds = keras_model.predict(np.ones((10,) + _INPUT_SIZE))
+
+      est_keras = keras.estimator.model_to_estimator(
+          keras_model=keras_model, config=self._config)
+      est_keras.train(input_fn=train_input_fn, steps=_TRAIN_SIZE / 16)
+      before_eval_results = est_keras.evaluate(
+          input_fn=eval_input_fn, steps=1)
+      est_keras.train(input_fn=train_input_fn, steps=_TRAIN_SIZE / 16)
+      after_eval_results = est_keras.evaluate(input_fn=eval_input_fn, steps=1)
+      self.assertLess(after_eval_results['loss'], before_eval_results['loss'])
+
+      # Check that original model state was not altered
+      preds = keras_model.predict(np.ones((10,) + _INPUT_SIZE))
+      self.assertAllClose(original_preds, preds, atol=1e-5)
+      # Check that the original model compilation did not break
+      keras_model.train_on_batch(np.random.random((10,) + _INPUT_SIZE),
+                                 np.random.random((10, _NUM_CLASS)))
+
   def test_evaluate(self):
     keras_model, (x_train, y_train), (
         x_test, y_test), _, eval_input_fn = get_resource_for_simple_model(
-            is_sequential=False, is_evaluate=True)
+            model_type='functional', is_evaluate=True)
 
     with self.test_session():
       metrics = [
@@ -200,7 +318,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
     # Check that predict on a pretrained model yield the same result.
     keras_model, (x_train, y_train), (
         x_test, _), _, pred_input_fn = get_resource_for_simple_model(
-            is_sequential=True, is_evaluate=False)
+            model_type='sequential', is_evaluate=False)
 
     with self.test_session():
       keras_model.compile(
@@ -262,7 +380,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
 
     keras_model, (x_train, y_train), (
         x_test, _), _, pred_input_fn = get_resource_for_simple_model(
-            is_sequential=False, is_evaluate=False)
+            model_type='functional', is_evaluate=False)
 
     with self.test_session():
       keras_model.compile(
index 2ec9971..c6d16a3 100644 (file)
@@ -127,7 +127,7 @@ class Layer(checkpointable.CheckpointableBase):
     # return tensors. When using graph execution, _losses is a list of ops.
     self._losses = []
     self._reuse = kwargs.get('_reuse')
-    self._graph = ops.get_default_graph()
+    self._graph = None  # Will be set at build time.
     self._dtype = None if dtype is None else dtypes.as_dtype(dtype).name
     call_fn_args = estimator_util.fn_args(self.call)
     self._compute_previous_mask = ('mask' in call_fn_args or
@@ -630,7 +630,8 @@ class Layer(checkpointable.CheckpointableBase):
     # the same graph as where it was created.
     if in_graph_mode:
       try:
-        ops._get_graph_from_inputs(input_list, graph=self.graph)  # pylint: disable=protected-access
+        # Set layer's "graph" at build time
+        self._graph = ops._get_graph_from_inputs(input_list, graph=self._graph)  # pylint: disable=protected-access
       except ValueError as e:
         raise ValueError('Input graph and Layer graph are not the same: %s' % e)
     if in_graph_mode or in_deferred_mode: