simple_server: Restart when the certificate changes
authorNirbheek Chauhan <nirbheek@centricular.com>
Mon, 25 May 2020 18:34:11 +0000 (18:34 +0000)
committerMatthew Waters <matthew@centricular.com>
Thu, 18 Jun 2020 13:34:48 +0000 (23:34 +1000)
Reload the SSL context and restart the server if the certificate
changes. Without this, new connections will continue to use the old
expired certificate.

webrtc/signalling/simple_server.py

index f7ee545..12153a0 100755 (executable)
@@ -44,10 +44,14 @@ class WebRTCSimpleServer(object):
         self.addr = options.addr
         self.port = options.port
         self.keepalive_timeout = options.keepalive_timeout
+        self.cert_restart = options.cert_restart
         self.cert_path = options.cert_path
         self.disable_ssl = options.disable_ssl
         self.health_path = options.health
 
+        # Certificate mtime, used to detect when to restart the server
+        self.cert_mtime = -1
+
     ############### Helper functions ###############
 
     async def health_check(self, path, request_headers):
@@ -280,6 +284,38 @@ class WebRTCSimpleServer(object):
 
         # Run the server
         self.server = self.loop.run_until_complete(wsd)
+        # Stop the server if certificate changes
+        self.loop.run_until_complete(self.check_server_needs_restart())
+
+    async def stop(self):
+        print('Stopping server... ', end='')
+        self.server.close()
+        await self.server.wait_closed()
+        self.loop.stop()
+        print('Stopped.')
+
+    def check_cert_changed(self):
+        chain_pem, key_pem = self.get_ssl_certs()
+        mtime = max(os.stat(key_pem).st_mtime, os.stat(chain_pem).st_mtime)
+        if self.cert_mtime < 0:
+            self.cert_mtime = mtime
+            return False
+        if mtime > self.cert_mtime:
+            self.cert_mtime = mtime
+            return True
+        return False
+
+    async def check_server_needs_restart(self):
+        "When the certificate changes, we need to restart the server"
+        if not self.cert_restart:
+            return
+        while True:
+            await asyncio.sleep(10)
+            if self.check_cert_changed():
+                print('Certificate changed, stopping server...')
+                await self.stop()
+                return
+
 
 def main():
     parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
@@ -290,6 +326,7 @@ def main():
     parser.add_argument('--cert-path', default=os.path.dirname(__file__))
     parser.add_argument('--disable-ssl', default=False, help='Disable ssl', action='store_true')
     parser.add_argument('--health', default='/health', help='Health check route')
+    parser.add_argument('--restart-on-cert-change', default=False, dest='cert_restart', action='store_true', help='Automatically restart if the SSL certificate changes')
 
     options = parser.parse_args(sys.argv[1:])
 
@@ -298,8 +335,10 @@ def main():
     r = WebRTCSimpleServer(loop, options)
 
     print('Starting server...')
-    r.run()
-    loop.run_forever()
+    while True:
+        r.run()
+        loop.run_forever()
+        print('Restarting server...')
     print("Goodbye!")
 
 if __name__ == "__main__":