3 # Example 1-1 call signalling server
5 # Copyright (C) 2017 Centricular Ltd.
7 # Author: Nirbheek Chauhan <nirbheek@centricular.com>
21 class WebRTCSimpleServer(object):
23 def __init__(self, loop, options):
24 ############### Global data ###############
26 # Format: {uid: (Peer WebSocketServerProtocol,
28 # <'session'|room_id|None>)}
30 # Format: {caller_uid: callee_uid,
31 # callee_uid: caller_uid}
32 # Bidirectional mapping between the two peers
33 self.sessions = dict()
34 # Format: {room_id: {peer1_id, peer2_id, peer3_id, ...}}
35 # Room dict with a set of peers in each room
40 # Websocket Server Instance
44 self.addr = options.addr
45 self.port = options.port
46 self.keepalive_timeout = options.keepalive_timeout
47 self.cert_restart = options.cert_restart
48 self.cert_path = options.cert_path
49 self.disable_ssl = options.disable_ssl
50 self.health_path = options.health
52 # Certificate mtime, used to detect when to restart the server
55 ############### Helper functions ###############
57 async def health_check(self, path, request_headers):
58 if path == self.health_path:
59 return http.HTTPStatus.OK, [], b"OK\n"
62 async def recv_msg_ping(self, ws, raddr):
64 Wait for a message forever, and send a regular ping to prevent bad routers
65 from closing the connection.
70 msg = await asyncio.wait_for(ws.recv(), self.keepalive_timeout)
71 except (asyncio.TimeoutError, concurrent.futures._base.TimeoutError):
72 print('Sending keepalive ping to {!r} in recv'.format(raddr))
76 async def cleanup_session(self, uid):
77 if uid in self.sessions:
78 other_id = self.sessions[uid]
79 del self.sessions[uid]
80 print("Cleaned up {} session".format(uid))
81 if other_id in self.sessions:
82 del self.sessions[other_id]
83 print("Also cleaned up {} session".format(other_id))
84 # If there was a session with this peer, also
85 # close the connection to reset its state.
86 if other_id in self.peers:
87 print("Closing connection to {}".format(other_id))
88 wso, oaddr, _ = self.peers[other_id]
89 del self.peers[other_id]
92 async def cleanup_room(self, uid, room_id):
93 room_peers = self.rooms[room_id]
94 if uid not in room_peers:
96 room_peers.remove(uid)
97 for pid in room_peers:
98 wsp, paddr, _ = self.peers[pid]
99 msg = 'ROOM_PEER_LEFT {}'.format(uid)
100 print('room {}: {} -> {}: {}'.format(room_id, uid, pid, msg))
103 async def remove_peer(self, uid):
104 await self.cleanup_session(uid)
105 if uid in self.peers:
106 ws, raddr, status = self.peers[uid]
107 if status and status != 'session':
108 await self.cleanup_room(uid, status)
111 print("Disconnected from peer {!r} at {!r}".format(uid, raddr))
113 ############### Handler functions ###############
116 async def connection_handler(self, ws, uid):
117 raddr = ws.remote_address
119 self.peers[uid] = [ws, raddr, peer_status]
120 print("Registered peer {!r} at {!r}".format(uid, raddr))
122 # Receive command, wait forever if necessary
123 msg = await self.recv_msg_ping(ws, raddr)
124 # Update current status
125 peer_status = self.peers[uid][2]
126 # We are in a session or a room, messages must be relayed
127 if peer_status is not None:
128 # We're in a session, route message to connected peer
129 if peer_status == 'session':
130 other_id = self.sessions[uid]
131 wso, oaddr, status = self.peers[other_id]
132 assert(status == 'session')
133 print("{} -> {}: {}".format(uid, other_id, msg))
135 # We're in a room, accept room-specific commands
137 # ROOM_PEER_MSG peer_id MSG
138 if msg.startswith('ROOM_PEER_MSG'):
139 _, other_id, msg = msg.split(maxsplit=2)
140 if other_id not in self.peers:
141 await ws.send('ERROR peer {!r} not found'
144 wso, oaddr, status = self.peers[other_id]
145 if status != room_id:
146 await ws.send('ERROR peer {!r} is not in the room'
149 msg = 'ROOM_PEER_MSG {} {}'.format(uid, msg)
150 print('room {}: {} -> {}: {}'.format(room_id, uid, other_id, msg))
152 elif msg == 'ROOM_PEER_LIST':
153 room_id = self.peers[peer_id][2]
154 room_peers = ' '.join([pid for pid in self.rooms[room_id] if pid != peer_id])
155 msg = 'ROOM_PEER_LIST {}'.format(room_peers)
156 print('room {}: -> {}: {}'.format(room_id, uid, msg))
159 await ws.send('ERROR invalid msg, already in room')
162 raise AssertionError('Unknown peer status {!r}'.format(peer_status))
163 # Requested a session with a specific peer
164 elif msg.startswith('SESSION'):
165 print("{!r} command {!r}".format(uid, msg))
166 _, callee_id = msg.split(maxsplit=1)
167 if callee_id not in self.peers:
168 await ws.send('ERROR peer {!r} not found'.format(callee_id))
170 if peer_status is not None:
171 await ws.send('ERROR peer {!r} busy'.format(callee_id))
173 await ws.send('SESSION_OK')
174 wsc = self.peers[callee_id][0]
175 print('Session from {!r} ({!r}) to {!r} ({!r})'
176 ''.format(uid, raddr, callee_id, wsc.remote_address))
178 self.peers[uid][2] = peer_status = 'session'
179 self.sessions[uid] = callee_id
180 self.peers[callee_id][2] = 'session'
181 self.sessions[callee_id] = uid
182 # Requested joining or creation of a room
183 elif msg.startswith('ROOM'):
184 print('{!r} command {!r}'.format(uid, msg))
185 _, room_id = msg.split(maxsplit=1)
186 # Room name cannot be 'session', empty, or contain whitespace
187 if room_id == 'session' or room_id.split() != [room_id]:
188 await ws.send('ERROR invalid room id {!r}'.format(room_id))
190 if room_id in self.rooms:
191 if uid in self.rooms[room_id]:
192 raise AssertionError('How did we accept a ROOM command '
193 'despite already being in a room?')
195 # Create room if required
196 self.rooms[room_id] = set()
197 room_peers = ' '.join([pid for pid in self.rooms[room_id]])
198 await ws.send('ROOM_OK {}'.format(room_peers))
200 self.peers[uid][2] = peer_status = room_id
201 self.rooms[room_id].add(uid)
202 for pid in self.rooms[room_id]:
205 wsp, paddr, _ = self.peers[pid]
206 msg = 'ROOM_PEER_JOINED {}'.format(uid)
207 print('room {}: {} -> {}: {}'.format(room_id, uid, pid, msg))
210 print('Ignoring unknown message {!r} from {!r}'.format(msg, uid))
212 async def hello_peer(self, ws):
214 Exchange hello, register peer
216 raddr = ws.remote_address
217 hello = await ws.recv()
218 hello, uid = hello.split(maxsplit=1)
220 await ws.close(code=1002, reason='invalid protocol')
221 raise Exception("Invalid hello from {!r}".format(raddr))
222 if not uid or uid in self.peers or uid.split() != [uid]: # no whitespace
223 await ws.close(code=1002, reason='invalid peer uid')
224 raise Exception("Invalid uid {!r} from {!r}".format(uid, raddr))
226 await ws.send('HELLO')
229 def get_ssl_certs(self):
230 if 'letsencrypt' in self.cert_path:
231 chain_pem = os.path.join(self.cert_path, 'fullchain.pem')
232 key_pem = os.path.join(self.cert_path, 'privkey.pem')
234 chain_pem = os.path.join(self.cert_path, 'cert.pem')
235 key_pem = os.path.join(self.cert_path, 'key.pem')
236 return chain_pem, key_pem
238 def get_ssl_ctx(self):
241 # Create an SSL context to be used by the websocket server
242 print('Using TLS with keys in {!r}'.format(self.cert_path))
243 chain_pem, key_pem = self.get_ssl_certs()
244 sslctx = ssl.create_default_context()
246 sslctx.load_cert_chain(chain_pem, keyfile=key_pem)
247 except FileNotFoundError:
248 print("Certificates not found, did you run generate_cert.sh?")
251 sslctx.check_hostname = False
252 sslctx.verify_mode = ssl.CERT_NONE
256 async def handler(ws, path):
258 All incoming messages are handled here. @path is unused.
260 raddr = ws.remote_address
261 print("Connected to {!r}".format(raddr))
262 peer_id = await self.hello_peer(ws)
264 await self.connection_handler(ws, peer_id)
265 except websockets.ConnectionClosed:
266 print("Connection to peer {!r} closed, exiting handler".format(raddr))
268 await self.remove_peer(peer_id)
270 sslctx = self.get_ssl_ctx()
272 print("Listening on https://{}:{}".format(self.addr, self.port))
274 wsd = websockets.serve(handler, self.addr, self.port, ssl=sslctx, process_request=self.health_check if self.health_path else None,
275 # Maximum number of messages that websockets will pop
276 # off the asyncio and OS buffers per connection. See:
277 # https://websockets.readthedocs.io/en/stable/api.html#websockets.protocol.WebSocketCommonProtocol
281 logger = logging.getLogger('websockets')
282 logger.setLevel(logging.INFO)
283 logger.addHandler(logging.StreamHandler())
286 self.server = self.loop.run_until_complete(wsd)
287 # Stop the server if certificate changes
288 self.loop.run_until_complete(self.check_server_needs_restart())
290 async def stop(self):
291 print('Stopping server... ', end='')
293 await self.server.wait_closed()
297 def check_cert_changed(self):
298 chain_pem, key_pem = self.get_ssl_certs()
299 mtime = max(os.stat(key_pem).st_mtime, os.stat(chain_pem).st_mtime)
300 if self.cert_mtime < 0:
301 self.cert_mtime = mtime
303 if mtime > self.cert_mtime:
304 self.cert_mtime = mtime
308 async def check_server_needs_restart(self):
309 "When the certificate changes, we need to restart the server"
310 if not self.cert_restart:
313 await asyncio.sleep(10)
314 if self.check_cert_changed():
315 print('Certificate changed, stopping server...')
321 parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
322 # See: host, port in https://docs.python.org/3/library/asyncio-eventloop.html#asyncio.loop.create_server
323 parser.add_argument('--addr', default='', help='Address to listen on (default: all interfaces, both ipv4 and ipv6)')
324 parser.add_argument('--port', default=8443, type=int, help='Port to listen on')
325 parser.add_argument('--keepalive-timeout', dest='keepalive_timeout', default=30, type=int, help='Timeout for keepalive (in seconds)')
326 parser.add_argument('--cert-path', default=os.path.dirname(__file__))
327 parser.add_argument('--disable-ssl', default=False, help='Disable ssl', action='store_true')
328 parser.add_argument('--health', default='/health', help='Health check route')
329 parser.add_argument('--restart-on-cert-change', default=False, dest='cert_restart', action='store_true', help='Automatically restart if the SSL certificate changes')
331 options = parser.parse_args(sys.argv[1:])
333 loop = asyncio.get_event_loop()
335 r = WebRTCSimpleServer(loop, options)
337 print('Starting server...')
341 print('Restarting server...')
344 if __name__ == "__main__":