Warn when creating a `tf.InteractiveSession` if another is active.
authorDerek Murray <mrry@google.com>
Thu, 15 Mar 2018 19:33:10 +0000 (12:33 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 15 Mar 2018 19:37:47 +0000 (12:37 -0700)
Fixes #13202 (as far as possible without breaking backwards compatibility).

PiperOrigin-RevId: 189228094

tensorflow/python/client/session.py
tensorflow/python/client/session_test.py

index 924d629..29f06c8 100644 (file)
@@ -21,6 +21,7 @@ from __future__ import print_function
 import functools
 import re
 import threading
+import warnings
 
 import numpy as np
 
@@ -1624,6 +1625,9 @@ class InteractiveSession(BaseSession):
   ```
   """
 
+  _count_lock = threading.Lock()
+  _active_session_count = 0  # GUARDED_BY(_count_lock)
+
   def __init__(self, target='', graph=None, config=None):
     """Creates a new interactive TensorFlow session.
 
@@ -1652,6 +1656,15 @@ class InteractiveSession(BaseSession):
     config.graph_options.place_pruned_graph = True
 
     super(InteractiveSession, self).__init__(target, graph, config)
+    with InteractiveSession._count_lock:
+      if InteractiveSession._active_session_count > 0:
+        warnings.warn('An interactive session is already active. This can '
+                      'cause out-of-memory errors in some cases. You must '
+                      'explicitly call `InteractiveSession.close()` to release '
+                      'resources held by the other session(s).')
+      InteractiveSession._active_session_count += 1
+    self._closed = False
+
     self._default_session = self.as_default()
     self._default_session.enforce_nesting = False
     self._default_session.__enter__()
@@ -1664,6 +1677,12 @@ class InteractiveSession(BaseSession):
   def close(self):
     """Closes an `InteractiveSession`."""
     super(InteractiveSession, self).close()
+    with InteractiveSession._count_lock:
+      if not self._closed:
+        InteractiveSession._active_session_count -= 1
+        self._closed = True
+      else:
+        return
     if self._explicit_graph is not None:
       self._default_graph.__exit__(None, None, None)
     self._default_session.__exit__(None, None, None)
index 781725d..6c7339f 100644 (file)
@@ -22,6 +22,7 @@ import os
 import sys
 import threading
 import time
+import warnings
 
 import numpy as np
 import six
@@ -66,6 +67,10 @@ ops.RegisterShape('ConstructionFails')(common_shapes.unknown_shape)
 # @test_util.with_c_api
 class SessionTest(test_util.TensorFlowTestCase):
 
+  def setUp(self):
+    super(SessionTest, self).setUp()
+    warnings.simplefilter('always')
+
   def testUseExistingGraph(self):
     with ops.Graph().as_default() as g, ops.device('/cpu:0'):
       a = constant_op.constant(6.0, shape=[1, 1])
@@ -1191,6 +1196,32 @@ class SessionTest(test_util.TensorFlowTestCase):
       self.assertAllEqual([[24.0]], e.eval())
       sess.close()
 
+  def testMultipleInteractiveSessionsWarning(self):
+    # Reinitialize the global state to ensure that the expected warnings will
+    # be emitted.
+    session.InteractiveSession._active_session_count = 0  # pylint: disable=protected-access
+
+    sess = session.InteractiveSession()
+    sess.close()
+    # Opening and closing interactive sessions serially should not warn.
+    with warnings.catch_warnings(record=True) as w:
+      sess = session.InteractiveSession()
+      sess.close()
+    self.assertEqual(0, len(w))
+
+    with warnings.catch_warnings(record=True) as w:
+      sess = session.InteractiveSession()
+    self.assertEqual(0, len(w))
+    with warnings.catch_warnings(record=True) as w:
+      sess2 = session.InteractiveSession()
+    self.assertEqual(1, len(w))
+    self.assertTrue('An interactive session is already active. This can cause '
+                    'out-of-memory errors in some cases. You must explicitly '
+                    'call `InteractiveSession.close()` to release resources '
+                    'held by the other session(s).' in str(w[0].message))
+    sess2.close()
+    sess.close()
+
   def testInteractivePlacePrunedGraph(self):
     sess = session.InteractiveSession()