address connection leak issue from #520
authorShivaram Lingamneni <slingamn@cs.stanford.edu>
Wed, 18 Jul 2012 02:47:13 +0000 (19:47 -0700)
committerShivaram Lingamneni <slingamn@cs.stanford.edu>
Mon, 6 Aug 2012 04:09:13 +0000 (21:09 -0700)
* prefetch now defaults to True, ensuring that by default, sockets
  are returned to the urllib3 connection pool on request end
* sessions now have a close() method, notifying urllib3 to close pooled
  connections
* the module-level API, e.g., `requests.get('http://www.google.com')`,
  explicitly closes its session when finished

When prefetch is False, the open socket becomes part of the state of the
Response object, and it's the client's responsibility to read the whole
body, at which point the socket will be returned to the pool.

requests/api.py
requests/models.py
requests/sessions.py
tests/informal/test_leaked_connections.py [new file with mode: 0644]
tests/test_requests.py

index 9cea79afc09b3f3edf316d14c9aeee1d53fc0a3d..f192b8f522dc2e20c3dcbc10e9f5fc95ac35ac5d 100644 (file)
@@ -38,10 +38,19 @@ def request(method, url, **kwargs):
     :param cert: (optional) if String, path to ssl client cert file (.pem). If Tuple, ('cert', 'key') pair.
     """
 
-    s = kwargs.pop('session') if 'session' in kwargs else sessions.session()
-    return s.request(method=method, url=url, **kwargs)
-
-
+    # if this session was passed in, leave it open (and retain pooled connections);
+    # if we're making it just for this call, then close it when we're done.
+    adhoc_session = False
+    session = kwargs.pop('session', None)
+    if session is None:
+        session = sessions.session()
+        adhoc_session = True
+
+    try:
+        return session.request(method=method, url=url, **kwargs)
+    finally:
+        if adhoc_session:
+            session.close()
 
 def get(url, **kwargs):
     """Sends a GET request. Returns :class:`Response` object.
index 2c0c7bdd7d2b53bcfc7b640064fdc01e9cb688c2..e42ab366cb3d4874cde73394ba9c62b427ba1dc3 100644 (file)
@@ -57,7 +57,7 @@ class Request(object):
         proxies=None,
         hooks=None,
         config=None,
-        prefetch=False,
+        prefetch=True,
         _poolmanager=None,
         verify=None,
         session=None,
@@ -458,7 +458,7 @@ class Request(object):
         except ValueError:
             return False
 
-    def send(self, anyway=False, prefetch=False):
+    def send(self, anyway=False, prefetch=True):
         """Sends the request. Returns True if successful, False if not.
         If there was an HTTPError during transmission,
         self.response.status_code will contain the HTTPError code.
@@ -774,6 +774,8 @@ class Response(object):
                 self._content = None
 
         self._content_consumed = True
+        # don't need to release the connection; that's been handled by urllib3
+        # since we exhausted the data.
         return self._content
 
     @property
index 3113c787d6f375eec31d47e8576f247512b57688..73c7b17b8cc98919df631889ed31b47331ae6849 100644 (file)
@@ -66,7 +66,7 @@ class Session(object):
         hooks=None,
         params=None,
         config=None,
-        prefetch=False,
+        prefetch=True,
         verify=True,
         cert=None):
 
@@ -105,7 +105,15 @@ class Session(object):
         return self
 
     def __exit__(self, *args):
-        pass
+        self.close()
+
+    def close(self):
+        """Dispose of any internal state.
+
+        Currently, this just closes the PoolManager, which closes pooled
+        connections.
+        """
+        self.poolmanager.clear()
 
     def request(self, method, url,
         params=None,
@@ -120,7 +128,7 @@ class Session(object):
         hooks=None,
         return_response=True,
         config=None,
-        prefetch=False,
+        prefetch=None,
         verify=None,
         cert=None):
 
@@ -140,7 +148,7 @@ class Session(object):
         :param proxies: (optional) Dictionary mapping protocol to the URL of the proxy.
         :param return_response: (optional) If False, an un-sent Request object will returned.
         :param config: (optional) A configuration dictionary. See ``request.defaults`` for allowed keys and their default values.
-        :param prefetch: (optional) if ``True``, the response content will be immediately downloaded.
+        :param prefetch: (optional) whether to immediately download the response content. Defaults to ``True``.
         :param verify: (optional) if ``True``, the SSL cert will be verified. A CA_BUNDLE path can also be provided.
         :param cert: (optional) if String, path to ssl client cert file (.pem). If Tuple, ('cert', 'key') pair.
         """
@@ -153,7 +161,7 @@ class Session(object):
         headers = {} if headers is None else headers
         params = {} if params is None else params
         hooks = {} if hooks is None else hooks
-        prefetch = self.prefetch or prefetch
+        prefetch = prefetch if prefetch is not None else self.prefetch
 
         # use session's hooks as defaults
         for key, cb in list(self.hooks.items()):
diff --git a/tests/informal/test_leaked_connections.py b/tests/informal/test_leaked_connections.py
new file mode 100644 (file)
index 0000000..5357bf2
--- /dev/null
@@ -0,0 +1,26 @@
+"""
+This is an informal test originally written by Bluehorn;
+it verifies that Requests does not leak connections when
+the body of the request is not read.
+"""
+
+import gc, os, subprocess, requests, sys
+
+def main():
+    gc.disable()
+
+    for x in range(20):
+        requests.head("http://www.google.com/")
+
+    print("Open sockets after 20 head requests:")
+    pid = os.getpid()
+    subprocess.call("lsof -p%d -a -iTCP" % (pid,), shell=True)
+
+    gcresult = gc.collect()
+    print("Garbage collection result: %s" % (gcresult,))
+
+    print("Open sockets after garbage collection:")
+    subprocess.call("lsof -p%d -a -iTCP" % (pid,), shell=True)
+
+if __name__ == '__main__':
+    sys.exit(main())
index 9ddc58b4e2964d1105504cd02719a97f92fa568d..6b7b57a362afdcaea670c69b84c0d5677f6957d7 100755 (executable)
@@ -807,9 +807,9 @@ class RequestsTestSuite(TestSetup, TestBaseMixin, unittest.TestCase):
         assert 'k' in c
 
         ds1 = pickle.loads(pickle.dumps(requests.session()))
-        ds2 = pickle.loads(pickle.dumps(requests.session(prefetch=True)))
-        assert not ds1.prefetch
-        assert ds2.prefetch
+        ds2 = pickle.loads(pickle.dumps(requests.session(prefetch=False)))
+        assert ds1.prefetch
+        assert not ds2.prefetch
 
     # def test_invalid_content(self):
     #     # WARNING: if you're using a terrible DNS provider (comcast),
@@ -858,7 +858,7 @@ class RequestsTestSuite(TestSetup, TestBaseMixin, unittest.TestCase):
         )
 
         # Make a request and monkey-patch its contents
-        r = get(httpbin('get'))
+        r = get(httpbin('get'), prefetch=False)
         r.raw = StringIO(quote)
 
         lines = list(r.iter_lines())