Use the Keras session for saving/loading in TensorFlow format
authorAllen Lavoie <allenl@google.com>
Mon, 11 Jun 2018 18:55:34 +0000 (11:55 -0700)
committerMark Daoust <markdaoust@google.com>
Wed, 13 Jun 2018 18:48:49 +0000 (11:48 -0700)
Fixes issues when there's no default session

PiperOrigin-RevId: 200088574

tensorflow/python/keras/engine/network.py
tensorflow/python/keras/engine/saving_test.py

index 9dbf94a..3d567b8 100644 (file)
@@ -20,6 +20,7 @@ from __future__ import division
 from __future__ import print_function
 
 import copy
+import functools
 import json
 import os
 import weakref
@@ -1264,7 +1265,11 @@ class Network(base_layer.Layer):
       with h5py.File(filepath, 'w') as f:
         saving.save_weights_to_hdf5_group(f, self.layers)
     else:
-      self._checkpointable_saver.save(filepath)
+      if context.executing_eagerly():
+        session = None
+      else:
+        session = backend.get_session()
+      self._checkpointable_saver.save(filepath, session=session)
 
   def load_weights(self, filepath, by_name=False):
     """Loads all layer weights, either from a TensorFlow or an HDF5 weight file.
@@ -1324,7 +1329,8 @@ class Network(base_layer.Layer):
             'loading TensorFlow-formatted weights (got by_name=True to '
             'load_weights).')
       if not context.executing_eagerly():
-        finalizer = status.run_restore_ops
+        session = backend.get_session()
+        finalizer = functools.partial(status.run_restore_ops, session=session)
         if self.built:
           finalizer()
         else:
index 30bcd3d..b5448a9 100644 (file)
@@ -404,26 +404,27 @@ class TestWholeModelSaving(test.TestCase):
       os.remove(fname)
 
   def test_saving_lambda_numpy_array_arguments(self):
-    if h5py is None:
-      self.skipTest('h5py required to run this test')
+    with self.test_session():
+      if h5py is None:
+        self.skipTest('h5py required to run this test')
 
-    mean = np.random.random((4, 2, 3))
-    std = np.abs(np.random.random((4, 2, 3))) + 1e-5
-    inputs = keras.layers.Input(shape=(4, 2, 3))
-    output = keras.layers.Lambda(lambda image, mu, std: (image - mu) / std,
-                                 arguments={'mu': mean, 'std': std})(inputs)
-    model = keras.models.Model(inputs, output)
-    model.compile(loss='mse', optimizer='sgd', metrics=['acc'])
+      mean = np.random.random((4, 2, 3))
+      std = np.abs(np.random.random((4, 2, 3))) + 1e-5
+      inputs = keras.layers.Input(shape=(4, 2, 3))
+      output = keras.layers.Lambda(lambda image, mu, std: (image - mu) / std,
+                                   arguments={'mu': mean, 'std': std})(inputs)
+      model = keras.models.Model(inputs, output)
+      model.compile(loss='mse', optimizer='sgd', metrics=['acc'])
 
-    fd, fname = tempfile.mkstemp('.h5')
-    keras.models.save_model(model, fname)
+      fd, fname = tempfile.mkstemp('.h5')
+      keras.models.save_model(model, fname)
 
-    model = keras.models.load_model(fname)
-    os.close(fd)
-    os.remove(fname)
+      model = keras.models.load_model(fname)
+      os.close(fd)
+      os.remove(fname)
 
-    self.assertAllClose(mean, model.layers[1].arguments['mu'])
-    self.assertAllClose(std, model.layers[1].arguments['std'])
+      self.assertAllClose(mean, model.layers[1].arguments['mu'])
+      self.assertAllClose(std, model.layers[1].arguments['std'])
 
   def test_saving_model_with_long_layer_names(self):
     if h5py is None:
@@ -580,6 +581,25 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase):
         # Indirectly tests that the user is prompted
         model.save_weights(prefix, save_format='tensorflow', overwrite=False)
 
+  def test_no_default_session(self):
+    with ops.Graph().as_default():
+      self.assertFalse(ops.get_default_session())
+      data = np.random.random((1000, 32)).astype(np.float32)
+      labels = np.random.random((1000, 10)).astype(np.float32)
+
+      model = keras.models.Sequential([
+          keras.layers.Dense(10, activation='softmax'),
+          keras.layers.Dense(10, activation='softmax')])
+
+      model.compile(optimizer=training_module.RMSPropOptimizer(0.001),
+                    loss='categorical_crossentropy',
+                    metrics=['accuracy'])
+
+      model.fit(data, labels)
+      fname = os.path.join(self.get_temp_dir(), 'weights', 'ckpt')
+      model.save_weights(fname)
+      model.load_weights(fname)
+
   def test_no_graph_pollution(self):
     with context.graph_mode():
       graph = ops.Graph()