Fixes #1320: transport adapters stored in ordered form
authorŁukasz Langa <lukasz@langa.pl>
Wed, 15 May 2013 11:34:09 +0000 (13:34 +0200)
committerŁukasz Langa <lukasz@langa.pl>
Wed, 15 May 2013 11:34:09 +0000 (13:34 +0200)
AUTHORS.rst
requests/sessions.py
test_requests.py

index d31983438a47570cdf60a2f63c04982db57c839b..fce2cf9c5b5e89b7e90e30ec42f78622637a91d8 100644 (file)
@@ -126,3 +126,4 @@ Patches and Suggestions
 - Bryce Boe <bbzbryce@gmail.com> @bboe
 - Colin Dunklau <colin.dunklau@gmail.com> @cdunklau
 - Hugo Osvaldo Barrera <hugo@osvaldobarrera.com.ar> @hobarrera
+- Łukasz Langa <lukasz@langa.pl> @llanga
index 185d5df73add1d02622880e33d7b2af75da27fd3..77df0e82dd119340161b4599fde9bf87711582e2 100644 (file)
@@ -11,14 +11,13 @@ requests (cookies, auth, proxies).
 import os
 from datetime import datetime
 
-from .compat import cookielib
+from .compat import cookielib, OrderedDict, urljoin, urlparse
 from .cookies import cookiejar_from_dict, extract_cookies_to_jar, RequestsCookieJar
 from .models import Request, PreparedRequest
 from .hooks import default_hooks, dispatch_hook
 from .utils import from_key_val_list, default_headers
 from .exceptions import TooManyRedirects, InvalidSchema
 
-from .compat import urlparse, urljoin
 from .adapters import HTTPAdapter
 
 from .utils import requote_uri, get_environ_proxies, get_netrc_auth
@@ -223,9 +222,9 @@ class Session(SessionRedirectMixin):
         self.cookies = cookiejar_from_dict({})
 
         # Default connection adapters.
-        self.adapters = {}
-        self.mount('http://', HTTPAdapter())
+        self.adapters = OrderedDict()
         self.mount('https://', HTTPAdapter())
+        self.mount('http://', HTTPAdapter())
 
     def __enter__(self):
         return self
@@ -490,8 +489,13 @@ class Session(SessionRedirectMixin):
             v.close()
 
     def mount(self, prefix, adapter):
-        """Registers a connection adapter to a prefix."""
+        """Registers a connection adapter to a prefix.
+
+        Adapters are sorted in descending order by key length."""
         self.adapters[prefix] = adapter
+        keys_to_move = [k for k in self.adapters if len(k) < len(prefix)]
+        for key in keys_to_move:
+            self.adapters[key] = self.adapters.pop(key)
 
     def __getstate__(self):
         return dict((attr, getattr(self, attr, None)) for attr in self.__attrs__)
index da0adb1f97906e05e9ae0b08244f1eb81ca5e33d..07a29554b511ab9e77b4077b00b7557cb012f67b 100644 (file)
@@ -11,6 +11,7 @@ import pickle
 
 import requests
 from requests.auth import HTTPDigestAuth
+from requests.adapters import HTTPAdapter
 from requests.compat import str, cookielib
 from requests.cookies import cookiejar_from_dict
 from requests.structures import CaseInsensitiveDict
@@ -482,6 +483,44 @@ class RequestsTestCase(unittest.TestCase):
             'application/json'
         )
 
+    def test_transport_adapter_ordering(self):
+        s = requests.Session()
+        order = ['https://', 'http://']
+        self.assertEqual(order, list(s.adapters))
+        s.mount('http://git', HTTPAdapter())
+        s.mount('http://github', HTTPAdapter())
+        s.mount('http://github.com', HTTPAdapter())
+        s.mount('http://github.com/about/', HTTPAdapter())
+        order = [
+            'http://github.com/about/',
+            'http://github.com',
+            'http://github',
+            'http://git',
+            'https://',
+            'http://',
+        ]
+        self.assertEqual(order, list(s.adapters))
+        s.mount('http://gittip', HTTPAdapter())
+        s.mount('http://gittip.com', HTTPAdapter())
+        s.mount('http://gittip.com/about/', HTTPAdapter())
+        order = [
+            'http://github.com/about/',
+            'http://gittip.com/about/',
+            'http://github.com',
+            'http://gittip.com',
+            'http://github',
+            'http://gittip',
+            'http://git',
+            'https://',
+            'http://',
+        ]
+        self.assertEqual(order, list(s.adapters))
+        s2 = requests.Session()
+        s2.adapters = {'http://': HTTPAdapter()}
+        s2.mount('https://', HTTPAdapter())
+        self.assertTrue('http://' in s2.adapters)
+        self.assertTrue('https://' in s2.adapters)
+
 
 class TestCaseInsensitiveDict(unittest.TestCase):
 
@@ -627,6 +666,5 @@ class TestCaseInsensitiveDict(unittest.TestCase):
         self.assertEqual(frozenset(cid), keyset)
 
 
-
 if __name__ == '__main__':
     unittest.main()