3 # Example 1-1 call signalling server
5 # Copyright (C) 2017 Centricular Ltd.
7 # Author: Nirbheek Chauhan <nirbheek@centricular.com>
21 class WebRTCSimpleServer(object):
23 def __init__(self, options):
24 ############### Global data ###############
26 # Format: {uid: (Peer WebSocketServerProtocol,
28 # <'session'|room_id|None>)}
30 # Format: {caller_uid: callee_uid,
31 # callee_uid: caller_uid}
32 # Bidirectional mapping between the two peers
33 self.sessions = dict()
34 # Format: {room_id: {peer1_id, peer2_id, peer3_id, ...}}
35 # Room dict with a set of peers in each room
39 self.addr = options.addr
40 self.port = options.port
41 self.keepalive_timeout = options.keepalive_timeout
42 self.cert_restart = options.cert_restart
43 self.cert_path = options.cert_path
44 self.disable_ssl = options.disable_ssl
45 self.health_path = options.health
47 # Certificate mtime, used to detect when to restart the server
50 ############### Helper functions ###############
52 async def health_check(self, path, request_headers):
53 if path == self.health_path:
54 return http.HTTPStatus.OK, [], b"OK\n"
57 async def recv_msg_ping(self, ws, raddr):
59 Wait for a message forever, and send a regular ping to prevent bad routers
60 from closing the connection.
65 msg = await asyncio.wait_for(ws.recv(), self.keepalive_timeout)
66 except (asyncio.TimeoutError, concurrent.futures._base.TimeoutError):
67 print('Sending keepalive ping to {!r} in recv'.format(raddr))
71 async def cleanup_session(self, uid):
72 if uid in self.sessions:
73 other_id = self.sessions[uid]
74 del self.sessions[uid]
75 print("Cleaned up {} session".format(uid))
76 if other_id in self.sessions:
77 del self.sessions[other_id]
78 print("Also cleaned up {} session".format(other_id))
79 # If there was a session with this peer, also
80 # close the connection to reset its state.
81 if other_id in self.peers:
82 print("Closing connection to {}".format(other_id))
83 wso, oaddr, _ = self.peers[other_id]
84 del self.peers[other_id]
87 async def cleanup_room(self, uid, room_id):
88 room_peers = self.rooms[room_id]
89 if uid not in room_peers:
91 room_peers.remove(uid)
92 for pid in room_peers:
93 wsp, paddr, _ = self.peers[pid]
94 msg = 'ROOM_PEER_LEFT {}'.format(uid)
95 print('room {}: {} -> {}: {}'.format(room_id, uid, pid, msg))
98 async def remove_peer(self, uid):
99 await self.cleanup_session(uid)
100 if uid in self.peers:
101 ws, raddr, status = self.peers[uid]
102 if status and status != 'session':
103 await self.cleanup_room(uid, status)
106 print("Disconnected from peer {!r} at {!r}".format(uid, raddr))
108 ############### Handler functions ###############
110 async def connection_handler(self, ws, uid):
111 raddr = ws.remote_address
113 self.peers[uid] = [ws, raddr, peer_status]
114 print("Registered peer {!r} at {!r}".format(uid, raddr))
116 # Receive command, wait forever if necessary
117 msg = await self.recv_msg_ping(ws, raddr)
118 # Update current status
119 peer_status = self.peers[uid][2]
120 # We are in a session or a room, messages must be relayed
121 if peer_status is not None:
122 # We're in a session, route message to connected peer
123 if peer_status == 'session':
124 other_id = self.sessions[uid]
125 wso, oaddr, status = self.peers[other_id]
126 assert(status == 'session')
127 print("{} -> {}: {}".format(uid, other_id, msg))
129 # We're in a room, accept room-specific commands
131 # ROOM_PEER_MSG peer_id MSG
132 if msg.startswith('ROOM_PEER_MSG'):
133 _, other_id, msg = msg.split(maxsplit=2)
134 if other_id not in self.peers:
135 await ws.send('ERROR peer {!r} not found'
138 wso, oaddr, status = self.peers[other_id]
139 if status != room_id:
140 await ws.send('ERROR peer {!r} is not in the room'
143 msg = 'ROOM_PEER_MSG {} {}'.format(uid, msg)
144 print('room {}: {} -> {}: {}'.format(room_id, uid, other_id, msg))
146 elif msg == 'ROOM_PEER_LIST':
147 room_id = self.peers[peer_id][2]
148 room_peers = ' '.join([pid for pid in self.rooms[room_id] if pid != peer_id])
149 msg = 'ROOM_PEER_LIST {}'.format(room_peers)
150 print('room {}: -> {}: {}'.format(room_id, uid, msg))
153 await ws.send('ERROR invalid msg, already in room')
156 raise AssertionError('Unknown peer status {!r}'.format(peer_status))
157 # Requested a session with a specific peer
158 elif msg.startswith('SESSION'):
159 print("{!r} command {!r}".format(uid, msg))
160 _, callee_id = msg.split(maxsplit=1)
161 if callee_id not in self.peers:
162 await ws.send('ERROR peer {!r} not found'.format(callee_id))
164 if peer_status is not None:
165 await ws.send('ERROR peer {!r} busy'.format(callee_id))
167 await ws.send('SESSION_OK')
168 wsc = self.peers[callee_id][0]
169 print('Session from {!r} ({!r}) to {!r} ({!r})'
170 ''.format(uid, raddr, callee_id, wsc.remote_address))
172 self.peers[uid][2] = peer_status = 'session'
173 self.sessions[uid] = callee_id
174 self.peers[callee_id][2] = 'session'
175 self.sessions[callee_id] = uid
176 # Requested joining or creation of a room
177 elif msg.startswith('ROOM'):
178 print('{!r} command {!r}'.format(uid, msg))
179 _, room_id = msg.split(maxsplit=1)
180 # Room name cannot be 'session', empty, or contain whitespace
181 if room_id == 'session' or room_id.split() != [room_id]:
182 await ws.send('ERROR invalid room id {!r}'.format(room_id))
184 if room_id in self.rooms:
185 if uid in self.rooms[room_id]:
186 raise AssertionError('How did we accept a ROOM command '
187 'despite already being in a room?')
189 # Create room if required
190 self.rooms[room_id] = set()
191 room_peers = ' '.join([pid for pid in self.rooms[room_id]])
192 await ws.send('ROOM_OK {}'.format(room_peers))
194 self.peers[uid][2] = peer_status = room_id
195 self.rooms[room_id].add(uid)
196 for pid in self.rooms[room_id]:
199 wsp, paddr, _ = self.peers[pid]
200 msg = 'ROOM_PEER_JOINED {}'.format(uid)
201 print('room {}: {} -> {}: {}'.format(room_id, uid, pid, msg))
204 print('Ignoring unknown message {!r} from {!r}'.format(msg, uid))
206 async def hello_peer(self, ws):
208 Exchange hello, register peer
210 raddr = ws.remote_address
211 hello = await ws.recv()
212 hello, uid = hello.split(maxsplit=1)
214 await ws.close(code=1002, reason='invalid protocol')
215 raise Exception("Invalid hello from {!r}".format(raddr))
216 if not uid or uid in self.peers or uid.split() != [uid]: # no whitespace
217 await ws.close(code=1002, reason='invalid peer uid')
218 raise Exception("Invalid uid {!r} from {!r}".format(uid, raddr))
220 await ws.send('HELLO')
223 def get_ssl_certs(self):
224 if 'letsencrypt' in self.cert_path:
225 chain_pem = os.path.join(self.cert_path, 'fullchain.pem')
226 key_pem = os.path.join(self.cert_path, 'privkey.pem')
228 chain_pem = os.path.join(self.cert_path, 'cert.pem')
229 key_pem = os.path.join(self.cert_path, 'key.pem')
230 return chain_pem, key_pem
232 def get_ssl_ctx(self):
235 # Create an SSL context to be used by the websocket server
236 print('Using TLS with keys in {!r}'.format(self.cert_path))
237 chain_pem, key_pem = self.get_ssl_certs()
238 sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
240 sslctx.load_cert_chain(chain_pem, keyfile=key_pem)
241 except FileNotFoundError:
242 print("Certificates not found, did you run generate_cert.sh?")
245 sslctx.check_hostname = False
246 sslctx.verify_mode = ssl.CERT_NONE
250 async def handler(ws, path):
252 All incoming messages are handled here. @path is unused.
254 raddr = ws.remote_address
255 print("Connected to {!r}".format(raddr))
256 peer_id = await self.hello_peer(ws)
258 await self.connection_handler(ws, peer_id)
259 except websockets.ConnectionClosed:
260 print("Connection to peer {!r} closed, exiting handler".format(raddr))
262 await self.remove_peer(peer_id)
264 sslctx = self.get_ssl_ctx()
266 print("Listening on https://{}:{}".format(self.addr, self.port))
268 wsd = websockets.serve(handler, self.addr, self.port, ssl=sslctx, process_request=self.health_check if self.health_path else None,
269 # Maximum number of messages that websockets will pop
270 # off the asyncio and OS buffers per connection. See:
271 # https://websockets.readthedocs.io/en/stable/api.html#websockets.protocol.WebSocketCommonProtocol
274 logger = logging.getLogger('websockets')
275 logger.setLevel(logging.INFO)
276 handler = logging.StreamHandler()
277 logger.addHandler(handler)
280 self.exit_future = asyncio.Future()
281 task = asyncio.create_task(self.check_server_needs_restart())
285 await self.exit_future
286 self.exit_future = None
289 logger.removeHandler(handler)
291 self.sessions = dict()
296 print('Stopping server... ', end='')
297 self.exit_future.set_result(None)
299 def check_cert_changed(self):
300 chain_pem, key_pem = self.get_ssl_certs()
301 mtime = max(os.stat(key_pem).st_mtime, os.stat(chain_pem).st_mtime)
302 if self.cert_mtime < 0:
303 self.cert_mtime = mtime
305 if mtime > self.cert_mtime:
306 self.cert_mtime = mtime
310 async def check_server_needs_restart(self):
311 "When the certificate changes, we need to restart the server"
312 if not self.cert_restart:
315 await asyncio.sleep(10)
316 if self.check_cert_changed():
317 print('Certificate changed, stopping server...')
323 parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
324 # See: host, port in https://docs.python.org/3/library/asyncio-eventloop.html#asyncio.loop.create_server
325 parser.add_argument('--addr', default='', help='Address to listen on (default: all interfaces, both ipv4 and ipv6)')
326 parser.add_argument('--port', default=8443, type=int, help='Port to listen on')
327 parser.add_argument('--keepalive-timeout', dest='keepalive_timeout', default=30, type=int, help='Timeout for keepalive (in seconds)')
328 parser.add_argument('--cert-path', default=os.path.dirname(__file__))
329 parser.add_argument('--disable-ssl', default=False, help='Disable ssl', action='store_true')
330 parser.add_argument('--health', default='/health', help='Health check route')
331 parser.add_argument('--restart-on-cert-change', default=False, dest='cert_restart', action='store_true', help='Automatically restart if the SSL certificate changes')
333 options = parser.parse_args(sys.argv[1:])
335 print('Starting server...')
337 r = WebRTCSimpleServer(options)
339 print('Restarting server...')
344 if __name__ == "__main__":