Pickling of Session and HTTPAdapter + a test
authorDarjus Loktevic <darjus@amazon.com>
Mon, 11 Mar 2013 18:12:34 +0000 (18:12 +0000)
committerDarjus Loktevic <darjus@amazon.com>
Mon, 11 Mar 2013 18:12:34 +0000 (18:12 +0000)
This is for issue #1088

requests/adapters.py
requests/sessions.py
test_requests.py

index 6f507da..1472749 100644 (file)
@@ -44,6 +44,8 @@ class BaseAdapter(object):
 
 class HTTPAdapter(BaseAdapter):
     """Built-In HTTP Adapter for Urllib3."""
+    __attrs__ = ['max_retries', 'config', '_pool_connections', '_pool_maxsize']
+
     def __init__(self, pool_connections=DEFAULT_POOLSIZE, pool_maxsize=DEFAULT_POOLSIZE):
         self.max_retries = DEFAULT_RETRIES
         self.config = {}
@@ -52,7 +54,23 @@ class HTTPAdapter(BaseAdapter):
 
         self.init_poolmanager(pool_connections, pool_maxsize)
 
+    def __getstate__(self):
+        return dict((attr, getattr(self, attr, None)) for attr in
+                    self.__attrs__)
+
+    def __setstate__(self, state):
+        for attr, value in state.items():
+            setattr(self, attr, value)
+
+        # setup a new poolmanager after unpickling
+        if self._pool_connections is not None:
+            self.init_poolmanager(self._pool_connections, self._pool_maxsize)
+
     def init_poolmanager(self, connections, maxsize):
+        # save these values for pickling
+        self._pool_connections = connections
+        self._pool_maxsize = maxsize
+
         self.poolmanager = PoolManager(num_pools=connections, maxsize=maxsize)
 
     def cert_verify(self, conn, url, verify, cert):
index e5fdf67..36d90a1 100644 (file)
@@ -177,8 +177,8 @@ class Session(SessionRedirectMixin):
     """
 
     __attrs__ = [
-        'headers', 'cookies', 'auth', 'timeout', 'proxies', 'hooks',   
-        'params', 'verify', 'cert', 'prefetch']
+        'headers', 'cookies', 'auth', 'timeout', 'proxies', 'hooks',
+        'params', 'verify', 'cert', 'prefetch', 'adapters']
 
     def __init__(self):
 
index 52de11c..31d9fbf 100644 (file)
@@ -7,6 +7,7 @@ from __future__ import division
 import json
 import os
 import unittest
+import pickle
 
 import requests
 from requests.auth import HTTPDigestAuth
@@ -376,6 +377,15 @@ class RequestsTestCase(unittest.TestCase):
         self.assertEqual(str(error), 'message')
         self.assertEqual(error.response, response)
 
+    def test_session_pickling(self):
+        r = requests.Request('GET', httpbin('get'))
+        s = requests.Session()
+
+        s = pickle.loads(pickle.dumps(s))
+
+        r = s.send(r.prepare())
+        self.assertEqual(r.status_code, 200)
+
 
 if __name__ == '__main__':
     unittest.main()