examples/webrtc/signalling: Fix compatibility with Python 3.10
[platform/upstream/gstreamer.git] / subprojects / gst-examples / webrtc / signalling / simple_server.py
1 #!/usr/bin/env python3
2 #
3 # Example 1-1 call signalling server
4 #
5 # Copyright (C) 2017 Centricular Ltd.
6 #
7 #  Author: Nirbheek Chauhan <nirbheek@centricular.com>
8 #
9
10 import os
11 import sys
12 import ssl
13 import logging
14 import asyncio
15 import websockets
16 import argparse
17 import http
18 import concurrent
19
20
21 class WebRTCSimpleServer(object):
22
23     def __init__(self, options):
24         ############### Global data ###############
25
26         # Format: {uid: (Peer WebSocketServerProtocol,
27         #                remote_address,
28         #                <'session'|room_id|None>)}
29         self.peers = dict()
30         # Format: {caller_uid: callee_uid,
31         #          callee_uid: caller_uid}
32         # Bidirectional mapping between the two peers
33         self.sessions = dict()
34         # Format: {room_id: {peer1_id, peer2_id, peer3_id, ...}}
35         # Room dict with a set of peers in each room
36         self.rooms = dict()
37
38         # Options
39         self.addr = options.addr
40         self.port = options.port
41         self.keepalive_timeout = options.keepalive_timeout
42         self.cert_restart = options.cert_restart
43         self.cert_path = options.cert_path
44         self.disable_ssl = options.disable_ssl
45         self.health_path = options.health
46
47         # Certificate mtime, used to detect when to restart the server
48         self.cert_mtime = -1
49
50     ############### Helper functions ###############
51
52     async def health_check(self, path, request_headers):
53         if path == self.health_path:
54             return http.HTTPStatus.OK, [], b"OK\n"
55         return None
56
57     async def recv_msg_ping(self, ws, raddr):
58         '''
59         Wait for a message forever, and send a regular ping to prevent bad routers
60         from closing the connection.
61         '''
62         msg = None
63         while msg is None:
64             try:
65                 msg = await asyncio.wait_for(ws.recv(), self.keepalive_timeout)
66             except (asyncio.TimeoutError, concurrent.futures._base.TimeoutError):
67                 print('Sending keepalive ping to {!r} in recv'.format(raddr))
68                 await ws.ping()
69         return msg
70
71     async def cleanup_session(self, uid):
72         if uid in self.sessions:
73             other_id = self.sessions[uid]
74             del self.sessions[uid]
75             print("Cleaned up {} session".format(uid))
76             if other_id in self.sessions:
77                 del self.sessions[other_id]
78                 print("Also cleaned up {} session".format(other_id))
79                 # If there was a session with this peer, also
80                 # close the connection to reset its state.
81                 if other_id in self.peers:
82                     print("Closing connection to {}".format(other_id))
83                     wso, oaddr, _ = self.peers[other_id]
84                     del self.peers[other_id]
85                     await wso.close()
86
87     async def cleanup_room(self, uid, room_id):
88         room_peers = self.rooms[room_id]
89         if uid not in room_peers:
90             return
91         room_peers.remove(uid)
92         for pid in room_peers:
93             wsp, paddr, _ = self.peers[pid]
94             msg = 'ROOM_PEER_LEFT {}'.format(uid)
95             print('room {}: {} -> {}: {}'.format(room_id, uid, pid, msg))
96             await wsp.send(msg)
97
98     async def remove_peer(self, uid):
99         await self.cleanup_session(uid)
100         if uid in self.peers:
101             ws, raddr, status = self.peers[uid]
102             if status and status != 'session':
103                 await self.cleanup_room(uid, status)
104             del self.peers[uid]
105             await ws.close()
106             print("Disconnected from peer {!r} at {!r}".format(uid, raddr))
107
108     ############### Handler functions ###############
109
110     async def connection_handler(self, ws, uid):
111         raddr = ws.remote_address
112         peer_status = None
113         self.peers[uid] = [ws, raddr, peer_status]
114         print("Registered peer {!r} at {!r}".format(uid, raddr))
115         while True:
116             # Receive command, wait forever if necessary
117             msg = await self.recv_msg_ping(ws, raddr)
118             # Update current status
119             peer_status = self.peers[uid][2]
120             # We are in a session or a room, messages must be relayed
121             if peer_status is not None:
122                 # We're in a session, route message to connected peer
123                 if peer_status == 'session':
124                     other_id = self.sessions[uid]
125                     wso, oaddr, status = self.peers[other_id]
126                     assert(status == 'session')
127                     print("{} -> {}: {}".format(uid, other_id, msg))
128                     await wso.send(msg)
129                 # We're in a room, accept room-specific commands
130                 elif peer_status:
131                     # ROOM_PEER_MSG peer_id MSG
132                     if msg.startswith('ROOM_PEER_MSG'):
133                         _, other_id, msg = msg.split(maxsplit=2)
134                         if other_id not in self.peers:
135                             await ws.send('ERROR peer {!r} not found'
136                                           ''.format(other_id))
137                             continue
138                         wso, oaddr, status = self.peers[other_id]
139                         if status != room_id:
140                             await ws.send('ERROR peer {!r} is not in the room'
141                                           ''.format(other_id))
142                             continue
143                         msg = 'ROOM_PEER_MSG {} {}'.format(uid, msg)
144                         print('room {}: {} -> {}: {}'.format(room_id, uid, other_id, msg))
145                         await wso.send(msg)
146                     elif msg == 'ROOM_PEER_LIST':
147                         room_id = self.peers[peer_id][2]
148                         room_peers = ' '.join([pid for pid in self.rooms[room_id] if pid != peer_id])
149                         msg = 'ROOM_PEER_LIST {}'.format(room_peers)
150                         print('room {}: -> {}: {}'.format(room_id, uid, msg))
151                         await ws.send(msg)
152                     else:
153                         await ws.send('ERROR invalid msg, already in room')
154                         continue
155                 else:
156                     raise AssertionError('Unknown peer status {!r}'.format(peer_status))
157             # Requested a session with a specific peer
158             elif msg.startswith('SESSION'):
159                 print("{!r} command {!r}".format(uid, msg))
160                 _, callee_id = msg.split(maxsplit=1)
161                 if callee_id not in self.peers:
162                     await ws.send('ERROR peer {!r} not found'.format(callee_id))
163                     continue
164                 if peer_status is not None:
165                     await ws.send('ERROR peer {!r} busy'.format(callee_id))
166                     continue
167                 await ws.send('SESSION_OK')
168                 wsc = self.peers[callee_id][0]
169                 print('Session from {!r} ({!r}) to {!r} ({!r})'
170                       ''.format(uid, raddr, callee_id, wsc.remote_address))
171                 # Register session
172                 self.peers[uid][2] = peer_status = 'session'
173                 self.sessions[uid] = callee_id
174                 self.peers[callee_id][2] = 'session'
175                 self.sessions[callee_id] = uid
176             # Requested joining or creation of a room
177             elif msg.startswith('ROOM'):
178                 print('{!r} command {!r}'.format(uid, msg))
179                 _, room_id = msg.split(maxsplit=1)
180                 # Room name cannot be 'session', empty, or contain whitespace
181                 if room_id == 'session' or room_id.split() != [room_id]:
182                     await ws.send('ERROR invalid room id {!r}'.format(room_id))
183                     continue
184                 if room_id in self.rooms:
185                     if uid in self.rooms[room_id]:
186                         raise AssertionError('How did we accept a ROOM command '
187                                              'despite already being in a room?')
188                 else:
189                     # Create room if required
190                     self.rooms[room_id] = set()
191                 room_peers = ' '.join([pid for pid in self.rooms[room_id]])
192                 await ws.send('ROOM_OK {}'.format(room_peers))
193                 # Enter room
194                 self.peers[uid][2] = peer_status = room_id
195                 self.rooms[room_id].add(uid)
196                 for pid in self.rooms[room_id]:
197                     if pid == uid:
198                         continue
199                     wsp, paddr, _ = self.peers[pid]
200                     msg = 'ROOM_PEER_JOINED {}'.format(uid)
201                     print('room {}: {} -> {}: {}'.format(room_id, uid, pid, msg))
202                     await wsp.send(msg)
203             else:
204                 print('Ignoring unknown message {!r} from {!r}'.format(msg, uid))
205
206     async def hello_peer(self, ws):
207         '''
208         Exchange hello, register peer
209         '''
210         raddr = ws.remote_address
211         hello = await ws.recv()
212         hello, uid = hello.split(maxsplit=1)
213         if hello != 'HELLO':
214             await ws.close(code=1002, reason='invalid protocol')
215             raise Exception("Invalid hello from {!r}".format(raddr))
216         if not uid or uid in self.peers or uid.split() != [uid]:  # no whitespace
217             await ws.close(code=1002, reason='invalid peer uid')
218             raise Exception("Invalid uid {!r} from {!r}".format(uid, raddr))
219         # Send back a HELLO
220         await ws.send('HELLO')
221         return uid
222
223     def get_ssl_certs(self):
224         if 'letsencrypt' in self.cert_path:
225             chain_pem = os.path.join(self.cert_path, 'fullchain.pem')
226             key_pem = os.path.join(self.cert_path, 'privkey.pem')
227         else:
228             chain_pem = os.path.join(self.cert_path, 'cert.pem')
229             key_pem = os.path.join(self.cert_path, 'key.pem')
230         return chain_pem, key_pem
231
232     def get_ssl_ctx(self):
233         if self.disable_ssl:
234             return None
235         # Create an SSL context to be used by the websocket server
236         print('Using TLS with keys in {!r}'.format(self.cert_path))
237         chain_pem, key_pem = self.get_ssl_certs()
238         sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
239         try:
240             sslctx.load_cert_chain(chain_pem, keyfile=key_pem)
241         except FileNotFoundError:
242             print("Certificates not found, did you run generate_cert.sh?")
243             sys.exit(1)
244         # FIXME
245         sslctx.check_hostname = False
246         sslctx.verify_mode = ssl.CERT_NONE
247         return sslctx
248
249     async def run(self):
250         async def handler(ws, path):
251             '''
252             All incoming messages are handled here. @path is unused.
253             '''
254             raddr = ws.remote_address
255             print("Connected to {!r}".format(raddr))
256             peer_id = await self.hello_peer(ws)
257             try:
258                 await self.connection_handler(ws, peer_id)
259             except websockets.ConnectionClosed:
260                 print("Connection to peer {!r} closed, exiting handler".format(raddr))
261             finally:
262                 await self.remove_peer(peer_id)
263
264         sslctx = self.get_ssl_ctx()
265
266         print("Listening on https://{}:{}".format(self.addr, self.port))
267         # Websocket server
268         wsd = websockets.serve(handler, self.addr, self.port, ssl=sslctx, process_request=self.health_check if self.health_path else None,
269                                # Maximum number of messages that websockets will pop
270                                # off the asyncio and OS buffers per connection. See:
271                                # https://websockets.readthedocs.io/en/stable/api.html#websockets.protocol.WebSocketCommonProtocol
272                                max_queue=16)
273
274         logger = logging.getLogger('websockets')
275         logger.setLevel(logging.INFO)
276         handler = logging.StreamHandler()
277         logger.addHandler(handler)
278
279         try:
280             self.exit_future = asyncio.Future()
281             task = asyncio.create_task(self.check_server_needs_restart())
282
283             # Run the server
284             async with wsd:
285                 await self.exit_future
286                 self.exit_future = None
287             print('Stopped.')
288         finally:
289             logger.removeHandler(handler)
290             self.peers = dict()
291             self.sessions = dict()
292             self.rooms = dict()
293
294     def stop(self):
295         if self.exit_future:
296             print('Stopping server... ', end='')
297             self.exit_future.set_result(None)
298
299     def check_cert_changed(self):
300         chain_pem, key_pem = self.get_ssl_certs()
301         mtime = max(os.stat(key_pem).st_mtime, os.stat(chain_pem).st_mtime)
302         if self.cert_mtime < 0:
303             self.cert_mtime = mtime
304             return False
305         if mtime > self.cert_mtime:
306             self.cert_mtime = mtime
307             return True
308         return False
309
310     async def check_server_needs_restart(self):
311         "When the certificate changes, we need to restart the server"
312         if not self.cert_restart:
313             return
314         while True:
315             await asyncio.sleep(10)
316             if self.check_cert_changed():
317                 print('Certificate changed, stopping server...')
318                 self.stop()
319                 return
320
321
322 def main():
323     parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
324     # See: host, port in https://docs.python.org/3/library/asyncio-eventloop.html#asyncio.loop.create_server
325     parser.add_argument('--addr', default='', help='Address to listen on (default: all interfaces, both ipv4 and ipv6)')
326     parser.add_argument('--port', default=8443, type=int, help='Port to listen on')
327     parser.add_argument('--keepalive-timeout', dest='keepalive_timeout', default=30, type=int, help='Timeout for keepalive (in seconds)')
328     parser.add_argument('--cert-path', default=os.path.dirname(__file__))
329     parser.add_argument('--disable-ssl', default=False, help='Disable ssl', action='store_true')
330     parser.add_argument('--health', default='/health', help='Health check route')
331     parser.add_argument('--restart-on-cert-change', default=False, dest='cert_restart', action='store_true', help='Automatically restart if the SSL certificate changes')
332
333     options = parser.parse_args(sys.argv[1:])
334
335     print('Starting server...')
336     while True:
337         r = WebRTCSimpleServer(options)
338         asyncio.run(r.run())
339         print('Restarting server...')
340
341     print("Goodbye!")
342
343
344 if __name__ == "__main__":
345     main()