0cae63321102b48b84e4933a8752d8b88969291d
[platform/upstream/gstreamer.git] / subprojects / gst-examples / webrtc / signalling / simple_server.py
1 #!/usr/bin/env python3
2 #
3 # Example 1-1 call signalling server
4 #
5 # Copyright (C) 2017 Centricular Ltd.
6 #
7 #  Author: Nirbheek Chauhan <nirbheek@centricular.com>
8 #
9
10 import os
11 import sys
12 import ssl
13 import logging
14 import asyncio
15 import websockets
16 import argparse
17 import http
18 import concurrent
19
20
21 class WebRTCSimpleServer(object):
22
23     def __init__(self, loop, options):
24         ############### Global data ###############
25
26         # Format: {uid: (Peer WebSocketServerProtocol,
27         #                remote_address,
28         #                <'session'|room_id|None>)}
29         self.peers = dict()
30         # Format: {caller_uid: callee_uid,
31         #          callee_uid: caller_uid}
32         # Bidirectional mapping between the two peers
33         self.sessions = dict()
34         # Format: {room_id: {peer1_id, peer2_id, peer3_id, ...}}
35         # Room dict with a set of peers in each room
36         self.rooms = dict()
37
38         # Event loop
39         self.loop = loop
40         # Websocket Server Instance
41         self.server = None
42
43         # Options
44         self.addr = options.addr
45         self.port = options.port
46         self.keepalive_timeout = options.keepalive_timeout
47         self.cert_restart = options.cert_restart
48         self.cert_path = options.cert_path
49         self.disable_ssl = options.disable_ssl
50         self.health_path = options.health
51
52         # Certificate mtime, used to detect when to restart the server
53         self.cert_mtime = -1
54
55     ############### Helper functions ###############
56
57     async def health_check(self, path, request_headers):
58         if path == self.health_path:
59             return http.HTTPStatus.OK, [], b"OK\n"
60         return None
61
62     async def recv_msg_ping(self, ws, raddr):
63         '''
64         Wait for a message forever, and send a regular ping to prevent bad routers
65         from closing the connection.
66         '''
67         msg = None
68         while msg is None:
69             try:
70                 msg = await asyncio.wait_for(ws.recv(), self.keepalive_timeout)
71             except (asyncio.TimeoutError, concurrent.futures._base.TimeoutError):
72                 print('Sending keepalive ping to {!r} in recv'.format(raddr))
73                 await ws.ping()
74         return msg
75
76     async def cleanup_session(self, uid):
77         if uid in self.sessions:
78             other_id = self.sessions[uid]
79             del self.sessions[uid]
80             print("Cleaned up {} session".format(uid))
81             if other_id in self.sessions:
82                 del self.sessions[other_id]
83                 print("Also cleaned up {} session".format(other_id))
84                 # If there was a session with this peer, also
85                 # close the connection to reset its state.
86                 if other_id in self.peers:
87                     print("Closing connection to {}".format(other_id))
88                     wso, oaddr, _ = self.peers[other_id]
89                     del self.peers[other_id]
90                     await wso.close()
91
92     async def cleanup_room(self, uid, room_id):
93         room_peers = self.rooms[room_id]
94         if uid not in room_peers:
95             return
96         room_peers.remove(uid)
97         for pid in room_peers:
98             wsp, paddr, _ = self.peers[pid]
99             msg = 'ROOM_PEER_LEFT {}'.format(uid)
100             print('room {}: {} -> {}: {}'.format(room_id, uid, pid, msg))
101             await wsp.send(msg)
102
103     async def remove_peer(self, uid):
104         await self.cleanup_session(uid)
105         if uid in self.peers:
106             ws, raddr, status = self.peers[uid]
107             if status and status != 'session':
108                 await self.cleanup_room(uid, status)
109             del self.peers[uid]
110             await ws.close()
111             print("Disconnected from peer {!r} at {!r}".format(uid, raddr))
112
113     ############### Handler functions ###############
114
115     
116     async def connection_handler(self, ws, uid):
117         raddr = ws.remote_address
118         peer_status = None
119         self.peers[uid] = [ws, raddr, peer_status]
120         print("Registered peer {!r} at {!r}".format(uid, raddr))
121         while True:
122             # Receive command, wait forever if necessary
123             msg = await self.recv_msg_ping(ws, raddr)
124             # Update current status
125             peer_status = self.peers[uid][2]
126             # We are in a session or a room, messages must be relayed
127             if peer_status is not None:
128                 # We're in a session, route message to connected peer
129                 if peer_status == 'session':
130                     other_id = self.sessions[uid]
131                     wso, oaddr, status = self.peers[other_id]
132                     assert(status == 'session')
133                     print("{} -> {}: {}".format(uid, other_id, msg))
134                     await wso.send(msg)
135                 # We're in a room, accept room-specific commands
136                 elif peer_status:
137                     # ROOM_PEER_MSG peer_id MSG
138                     if msg.startswith('ROOM_PEER_MSG'):
139                         _, other_id, msg = msg.split(maxsplit=2)
140                         if other_id not in self.peers:
141                             await ws.send('ERROR peer {!r} not found'
142                                           ''.format(other_id))
143                             continue
144                         wso, oaddr, status = self.peers[other_id]
145                         if status != room_id:
146                             await ws.send('ERROR peer {!r} is not in the room'
147                                           ''.format(other_id))
148                             continue
149                         msg = 'ROOM_PEER_MSG {} {}'.format(uid, msg)
150                         print('room {}: {} -> {}: {}'.format(room_id, uid, other_id, msg))
151                         await wso.send(msg)
152                     elif msg == 'ROOM_PEER_LIST':
153                         room_id = self.peers[peer_id][2]
154                         room_peers = ' '.join([pid for pid in self.rooms[room_id] if pid != peer_id])
155                         msg = 'ROOM_PEER_LIST {}'.format(room_peers)
156                         print('room {}: -> {}: {}'.format(room_id, uid, msg))
157                         await ws.send(msg)
158                     else:
159                         await ws.send('ERROR invalid msg, already in room')
160                         continue
161                 else:
162                     raise AssertionError('Unknown peer status {!r}'.format(peer_status))
163             # Requested a session with a specific peer
164             elif msg.startswith('SESSION'):
165                 print("{!r} command {!r}".format(uid, msg))
166                 _, callee_id = msg.split(maxsplit=1)
167                 if callee_id not in self.peers:
168                     await ws.send('ERROR peer {!r} not found'.format(callee_id))
169                     continue
170                 if peer_status is not None:
171                     await ws.send('ERROR peer {!r} busy'.format(callee_id))
172                     continue
173                 await ws.send('SESSION_OK')
174                 wsc = self.peers[callee_id][0]
175                 print('Session from {!r} ({!r}) to {!r} ({!r})'
176                       ''.format(uid, raddr, callee_id, wsc.remote_address))
177                 # Register session
178                 self.peers[uid][2] = peer_status = 'session'
179                 self.sessions[uid] = callee_id
180                 self.peers[callee_id][2] = 'session'
181                 self.sessions[callee_id] = uid
182             # Requested joining or creation of a room
183             elif msg.startswith('ROOM'):
184                 print('{!r} command {!r}'.format(uid, msg))
185                 _, room_id = msg.split(maxsplit=1)
186                 # Room name cannot be 'session', empty, or contain whitespace
187                 if room_id == 'session' or room_id.split() != [room_id]:
188                     await ws.send('ERROR invalid room id {!r}'.format(room_id))
189                     continue
190                 if room_id in self.rooms:
191                     if uid in self.rooms[room_id]:
192                         raise AssertionError('How did we accept a ROOM command '
193                                              'despite already being in a room?')
194                 else:
195                     # Create room if required
196                     self.rooms[room_id] = set()
197                 room_peers = ' '.join([pid for pid in self.rooms[room_id]])
198                 await ws.send('ROOM_OK {}'.format(room_peers))
199                 # Enter room
200                 self.peers[uid][2] = peer_status = room_id
201                 self.rooms[room_id].add(uid)
202                 for pid in self.rooms[room_id]:
203                     if pid == uid:
204                         continue
205                     wsp, paddr, _ = self.peers[pid]
206                     msg = 'ROOM_PEER_JOINED {}'.format(uid)
207                     print('room {}: {} -> {}: {}'.format(room_id, uid, pid, msg))
208                     await wsp.send(msg)
209             else:
210                 print('Ignoring unknown message {!r} from {!r}'.format(msg, uid))
211
212     async def hello_peer(self, ws):
213         '''
214         Exchange hello, register peer
215         '''
216         raddr = ws.remote_address
217         hello = await ws.recv()
218         hello, uid = hello.split(maxsplit=1)
219         if hello != 'HELLO':
220             await ws.close(code=1002, reason='invalid protocol')
221             raise Exception("Invalid hello from {!r}".format(raddr))
222         if not uid or uid in self.peers or uid.split() != [uid]: # no whitespace
223             await ws.close(code=1002, reason='invalid peer uid')
224             raise Exception("Invalid uid {!r} from {!r}".format(uid, raddr))
225         # Send back a HELLO
226         await ws.send('HELLO')
227         return uid
228
229     def get_ssl_certs(self):
230         if 'letsencrypt' in self.cert_path:
231             chain_pem = os.path.join(self.cert_path, 'fullchain.pem')
232             key_pem = os.path.join(self.cert_path, 'privkey.pem')
233         else:
234             chain_pem = os.path.join(self.cert_path, 'cert.pem')
235             key_pem = os.path.join(self.cert_path, 'key.pem')
236         return chain_pem, key_pem
237
238     def get_ssl_ctx(self):
239         if self.disable_ssl:
240             return None
241         # Create an SSL context to be used by the websocket server
242         print('Using TLS with keys in {!r}'.format(self.cert_path))
243         chain_pem, key_pem = self.get_ssl_certs()
244         sslctx = ssl.create_default_context()
245         try:
246             sslctx.load_cert_chain(chain_pem, keyfile=key_pem)
247         except FileNotFoundError:
248             print("Certificates not found, did you run generate_cert.sh?")
249             sys.exit(1)
250         # FIXME
251         sslctx.check_hostname = False
252         sslctx.verify_mode = ssl.CERT_NONE
253         return sslctx
254
255     def run(self):
256         async def handler(ws, path):
257             '''
258             All incoming messages are handled here. @path is unused.
259             '''
260             raddr = ws.remote_address
261             print("Connected to {!r}".format(raddr))
262             peer_id = await self.hello_peer(ws)
263             try:
264                 await self.connection_handler(ws, peer_id)
265             except websockets.ConnectionClosed:
266                 print("Connection to peer {!r} closed, exiting handler".format(raddr))
267             finally:
268                 await self.remove_peer(peer_id)
269
270         sslctx = self.get_ssl_ctx()
271
272         print("Listening on https://{}:{}".format(self.addr, self.port))
273         # Websocket server
274         wsd = websockets.serve(handler, self.addr, self.port, ssl=sslctx, process_request=self.health_check if self.health_path else None,
275                                # Maximum number of messages that websockets will pop
276                                # off the asyncio and OS buffers per connection. See:
277                                # https://websockets.readthedocs.io/en/stable/api.html#websockets.protocol.WebSocketCommonProtocol
278                                max_queue=16)
279
280         # Setup logging
281         logger = logging.getLogger('websockets')
282         logger.setLevel(logging.INFO)
283         logger.addHandler(logging.StreamHandler())
284
285         # Run the server
286         self.server = self.loop.run_until_complete(wsd)
287         # Stop the server if certificate changes
288         self.loop.run_until_complete(self.check_server_needs_restart())
289
290     async def stop(self):
291         print('Stopping server... ', end='')
292         self.server.close()
293         await self.server.wait_closed()
294         self.loop.stop()
295         print('Stopped.')
296
297     def check_cert_changed(self):
298         chain_pem, key_pem = self.get_ssl_certs()
299         mtime = max(os.stat(key_pem).st_mtime, os.stat(chain_pem).st_mtime)
300         if self.cert_mtime < 0:
301             self.cert_mtime = mtime
302             return False
303         if mtime > self.cert_mtime:
304             self.cert_mtime = mtime
305             return True
306         return False
307
308     async def check_server_needs_restart(self):
309         "When the certificate changes, we need to restart the server"
310         if not self.cert_restart:
311             return
312         while True:
313             await asyncio.sleep(10)
314             if self.check_cert_changed():
315                 print('Certificate changed, stopping server...')
316                 await self.stop()
317                 return
318
319
320 def main():
321     parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
322     # See: host, port in https://docs.python.org/3/library/asyncio-eventloop.html#asyncio.loop.create_server
323     parser.add_argument('--addr', default='', help='Address to listen on (default: all interfaces, both ipv4 and ipv6)')
324     parser.add_argument('--port', default=8443, type=int, help='Port to listen on')
325     parser.add_argument('--keepalive-timeout', dest='keepalive_timeout', default=30, type=int, help='Timeout for keepalive (in seconds)')
326     parser.add_argument('--cert-path', default=os.path.dirname(__file__))
327     parser.add_argument('--disable-ssl', default=False, help='Disable ssl', action='store_true')
328     parser.add_argument('--health', default='/health', help='Health check route')
329     parser.add_argument('--restart-on-cert-change', default=False, dest='cert_restart', action='store_true', help='Automatically restart if the SSL certificate changes')
330
331     options = parser.parse_args(sys.argv[1:])
332
333     loop = asyncio.get_event_loop()
334
335     r = WebRTCSimpleServer(loop, options)
336
337     print('Starting server...')
338     while True:
339         r.run()
340         loop.run_forever()
341         print('Restarting server...')
342     print("Goodbye!")
343
344 if __name__ == "__main__":
345     main()