From 77ae10ab663bda858c03a5becd452a269c70ca19 Mon Sep 17 00:00:00 2001 From: Nirbheek Chauhan Date: Mon, 25 May 2020 18:34:11 +0000 Subject: [PATCH] simple_server: Restart when the certificate changes Reload the SSL context and restart the server if the certificate changes. Without this, new connections will continue to use the old expired certificate. --- webrtc/signalling/simple_server.py | 43 ++++++++++++++++++++++++++++++++++++-- 1 file changed, 41 insertions(+), 2 deletions(-) diff --git a/webrtc/signalling/simple_server.py b/webrtc/signalling/simple_server.py index f7ee545..12153a0 100755 --- a/webrtc/signalling/simple_server.py +++ b/webrtc/signalling/simple_server.py @@ -44,10 +44,14 @@ class WebRTCSimpleServer(object): self.addr = options.addr self.port = options.port self.keepalive_timeout = options.keepalive_timeout + self.cert_restart = options.cert_restart self.cert_path = options.cert_path self.disable_ssl = options.disable_ssl self.health_path = options.health + # Certificate mtime, used to detect when to restart the server + self.cert_mtime = -1 + ############### Helper functions ############### async def health_check(self, path, request_headers): @@ -280,6 +284,38 @@ class WebRTCSimpleServer(object): # Run the server self.server = self.loop.run_until_complete(wsd) + # Stop the server if certificate changes + self.loop.run_until_complete(self.check_server_needs_restart()) + + async def stop(self): + print('Stopping server... ', end='') + self.server.close() + await self.server.wait_closed() + self.loop.stop() + print('Stopped.') + + def check_cert_changed(self): + chain_pem, key_pem = self.get_ssl_certs() + mtime = max(os.stat(key_pem).st_mtime, os.stat(chain_pem).st_mtime) + if self.cert_mtime < 0: + self.cert_mtime = mtime + return False + if mtime > self.cert_mtime: + self.cert_mtime = mtime + return True + return False + + async def check_server_needs_restart(self): + "When the certificate changes, we need to restart the server" + if not self.cert_restart: + return + while True: + await asyncio.sleep(10) + if self.check_cert_changed(): + print('Certificate changed, stopping server...') + await self.stop() + return + def main(): parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) @@ -290,6 +326,7 @@ def main(): parser.add_argument('--cert-path', default=os.path.dirname(__file__)) parser.add_argument('--disable-ssl', default=False, help='Disable ssl', action='store_true') parser.add_argument('--health', default='/health', help='Health check route') + parser.add_argument('--restart-on-cert-change', default=False, dest='cert_restart', action='store_true', help='Automatically restart if the SSL certificate changes') options = parser.parse_args(sys.argv[1:]) @@ -298,8 +335,10 @@ def main(): r = WebRTCSimpleServer(loop, options) print('Starting server...') - r.run() - loop.run_forever() + while True: + r.run() + loop.run_forever() + print('Restarting server...') print("Goodbye!") if __name__ == "__main__": -- 2.7.4