simple_server: Abstract out ssl context generation
authorNirbheek Chauhan <nirbheek@centricular.com>
Mon, 25 May 2020 18:33:32 +0000 (18:33 +0000)
committerMatthew Waters <matthew@centricular.com>
Thu, 18 Jun 2020 13:34:48 +0000 (23:34 +1000)
webrtc/signalling/simple_server.py

index 2e43786..f7ee545 100755 (executable)
@@ -222,28 +222,33 @@ class WebRTCSimpleServer(object):
         await ws.send('HELLO')
         return uid
 
-    def run(self):
-        sslctx = None
-        if not self.disable_ssl:
-            # Create an SSL context to be used by the websocket server
-            print('Using TLS with keys in {!r}'.format(self.certpath))
-            if 'letsencrypt' in self.certpath:
-                chain_pem = os.path.join(self.certpath, 'fullchain.pem')
-                key_pem = os.path.join(self.certpath, 'privkey.pem')
-            else:
-                chain_pem = os.path.join(self.certpath, 'cert.pem')
-                key_pem = os.path.join(self.certpath, 'key.pem')
+    def get_ssl_certs(self):
+        if 'letsencrypt' in self.cert_path:
+            chain_pem = os.path.join(self.cert_path, 'fullchain.pem')
+            key_pem = os.path.join(self.cert_path, 'privkey.pem')
+        else:
+            chain_pem = os.path.join(self.cert_path, 'cert.pem')
+            key_pem = os.path.join(self.cert_path, 'key.pem')
+        return chain_pem, key_pem
 
-            sslctx = ssl.create_default_context()
-            try:
-                sslctx.load_cert_chain(chain_pem, keyfile=key_pem)
-            except FileNotFoundError:
-                print("Certificates not found, did you run generate_cert.sh?")
-                sys.exit(1)
-            # FIXME
-            sslctx.check_hostname = False
-            sslctx.verify_mode = ssl.CERT_NONE
+    def get_ssl_ctx(self):
+        if self.disable_ssl:
+            return None
+        # Create an SSL context to be used by the websocket server
+        print('Using TLS with keys in {!r}'.format(self.cert_path))
+        chain_pem, key_pem = self.get_ssl_certs()
+        sslctx = ssl.create_default_context()
+        try:
+            sslctx.load_cert_chain(chain_pem, keyfile=key_pem)
+        except FileNotFoundError:
+            print("Certificates not found, did you run generate_cert.sh?")
+            sys.exit(1)
+        # FIXME
+        sslctx.check_hostname = False
+        sslctx.verify_mode = ssl.CERT_NONE
+        return sslctx
 
+    def run(self):
         async def handler(ws, path):
             '''
             All incoming messages are handled here. @path is unused.
@@ -258,6 +263,8 @@ class WebRTCSimpleServer(object):
             finally:
                 await self.remove_peer(peer_id)
 
+        sslctx = self.get_ssl_ctx()
+
         print("Listening on https://{}:{}".format(self.addr, self.port))
         # Websocket server
         wsd = websockets.serve(handler, self.addr, self.port, ssl=sslctx, process_request=self.health_check if self.health_path else None,