3 # Copyright 2012, Google Inc.
6 # Redistribution and use in source and binary forms, with or without
7 # modification, are permitted provided that the following conditions are
10 # * Redistributions of source code must retain the above copyright
11 # notice, this list of conditions and the following disclaimer.
12 # * Redistributions in binary form must reproduce the above
13 # copyright notice, this list of conditions and the following disclaimer
14 # in the documentation and/or other materials provided with the
16 # * Neither the name of Google Inc. nor the names of its
17 # contributors may be used to endorse or promote products derived from
18 # this software without specific prior written permission.
20 # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
21 # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
22 # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
23 # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
24 # OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
25 # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
26 # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
27 # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
28 # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
29 # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
30 # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
33 """WebSocket client utility for testing mux extension.
35 This code should be independent from mod_pywebsocket. See the comment of
36 client_for_testing.py.
38 NOTE: This code is far from robust like client_for_testing.py.
56 from mod_pywebsocket import util
58 from test import client_for_testing
61 _CONTROL_CHANNEL_ID = 0
62 _DEFAULT_CHANNEL_ID = 1
64 _MUX_OPCODE_ADD_CHANNEL_REQUEST = 0
65 _MUX_OPCODE_ADD_CHANNEL_RESPONSE = 1
66 _MUX_OPCODE_FLOW_CONTROL = 2
67 _MUX_OPCODE_DROP_CHANNEL = 3
68 _MUX_OPCODE_NEW_CHANNEL_SLOT = 4
72 def __init__(self, opcode):
76 def _parse_handshake_response(response):
77 status_line, header_lines = response.split('\r\n', 1)
79 words = status_line.split(' ')
81 raise ValueError('Bad Status-Line syntax %r' % status_line)
82 [version, response_code] = words[:2]
83 if version != 'HTTP/1.1':
84 raise ValueError('Bad response version %r' % version)
86 if response_code != '101':
87 raise ValueError('Bad response code %r ' % response_code)
88 headers = email.parser.Parser().parsestr(header_lines)
92 def _parse_channel_id(data, offset=0):
94 remaining = length - offset
97 raise Exception('No channel id found')
99 channel_id = ord(data[offset])
100 channel_id_length = 1
101 if channel_id & 0xe0 == 0xe0:
103 raise Exception('Invalid channel id format')
104 channel_id = struct.unpack('!L',
105 data[offset:offset+4])[0] & 0x1fffffff
106 channel_id_length = 4
107 elif channel_id & 0xc0 == 0xc0:
109 raise Exception('Invalid channel id format')
110 channel_id = (((channel_id & 0x1f) << 16) +
111 struct.unpack('!H', data[offset+1:offset+3])[0])
112 channel_id_length = 3
113 elif channel_id & 0x80 == 0x80:
115 raise Exception('Invalid channel id format')
116 channel_id = struct.unpack('!H', data[offset:offset+2])[0] & 0x3fff
117 channel_id_length = 2
119 return channel_id, channel_id_length
122 def _read_number(data, size_of_size, offset=0):
123 if size_of_size == 1:
124 return ord(data[offset])
125 elif size_of_size == 2:
126 return struct.unpack('!H', data[offset:offset+2])[0]
127 elif size_of_size == 3:
128 return ((ord(data[offset]) << 16)
129 + struct.unpack('!H', data[offset+1:offset+3])[0])
130 elif size_of_size == 4:
131 return struct.unpack('!L', data[offset:offset+4])[0]
133 raise Exception('Invalid "size of size" in control block')
136 def _parse_control_block_specific_data(data, size_of_size, offset=0):
137 remaining = len(data) - offset
138 if remaining < size_of_size:
139 raise Exception('Invalid control block received')
141 size = _read_number(data, size_of_size, offset)
143 start_position = offset + size_of_size
144 end_position = start_position + size
145 if len(data) < end_position:
146 raise Exception('Invalid size of control block (%d < %d)' % (
147 len(data), end_position))
148 return data[start_position:end_position], size_of_size + size
151 def _parse_control_blocks(data):
157 first_byte = ord(data[pos])
159 opcode = (first_byte >> 5) & 0x7
160 block = _ControlBlock(opcode)
162 # TODO(bashi): Support more opcode
163 if opcode == _MUX_OPCODE_ADD_CHANNEL_RESPONSE:
164 block.encode = (first_byte >> 2) & 3
165 block.rejected = (first_byte >> 4) & 1
167 channel_id, advance = _parse_channel_id(data, pos)
168 block.channel_id = channel_id
171 size_of_size = (first_byte & 3) + 1
172 encoded_handshake, advance = _parse_control_block_specific_data(
173 data, size_of_size, pos)
174 block.encoded_handshake = encoded_handshake
177 elif opcode == _MUX_OPCODE_DROP_CHANNEL:
178 block.mux_error = (first_byte >> 4) & 1
180 channel_id, channel_id_length = _parse_channel_id(data, pos)
181 block.channel_id = channel_id
182 pos += channel_id_length
184 size_of_size = first_byte & 3
185 reason, size = _parse_control_block_specific_data(
186 data, size_of_size, pos)
187 block.reason = reason
190 elif opcode == _MUX_OPCODE_FLOW_CONTROL:
191 channel_id, advance = _parse_channel_id(data, pos)
192 block.channel_id = channel_id
194 size_of_quota = (first_byte & 3) + 1
195 block.send_quota = _read_number(data, size_of_quota, pos)
198 elif opcode == _MUX_OPCODE_NEW_CHANNEL_SLOT:
199 size_of_slots = ((first_byte >> 2) & 3) + 1
200 size_of_quota = (first_byte & 3) + 1
201 block.slots = _read_number(data, size_of_slots, pos)
203 block.send_quota = _read_number(data, size_of_quota, pos)
208 'Unsupported mux opcode %d received' % opcode)
213 def _encode_channel_id(channel_id):
215 raise ValueError('Channel id %d must not be negative' % channel_id)
217 if channel_id < 2 ** 7:
218 return chr(channel_id)
219 if channel_id < 2 ** 14:
220 return struct.pack('!H', 0x8000 + channel_id)
221 if channel_id < 2 ** 21:
222 first = chr(0xc0 + (channel_id >> 16))
223 return first + struct.pack('!H', channel_id & 0xffff)
224 if channel_id < 2 ** 29:
225 return struct.pack('!L', 0xe0000000 + channel_id)
227 raise ValueError('Channel id %d is too large' % channel_id)
230 def _size_of_number_in_bytes_minus_1(number):
231 # Calculate the minimum number of bytes minus 1 that are required to store
234 raise ValueError('Invalid number: %d' % number)
235 elif number < 2 ** 8:
237 elif number < 2 ** 16:
239 elif number < 2 ** 24:
241 elif number < 2 ** 32:
244 raise ValueError('Invalid number %d' % number)
247 def _encode_number(number):
250 elif number < 2 ** 16:
251 return struct.pack('!H', number)
252 elif number < 2 ** 24:
253 return chr(number >> 16) + struct.pack('!H', number & 0xffff)
255 return struct.pack('!L', number)
258 def _create_add_channel_request(channel_id, encoded_handshake,
260 length = len(encoded_handshake)
261 size_of_length = _size_of_number_in_bytes_minus_1(length)
263 first_byte = ((_MUX_OPCODE_ADD_CHANNEL_REQUEST << 5) | (encoding << 2) |
265 encoded_length = _encode_number(length)
267 return (chr(first_byte) + _encode_channel_id(channel_id) +
268 encoded_length + encoded_handshake)
271 def _create_flow_control(channel_id, replenished_quota):
272 size_of_quota = _size_of_number_in_bytes_minus_1(replenished_quota)
273 first_byte = ((_MUX_OPCODE_FLOW_CONTROL << 5) | size_of_quota)
274 return (chr(first_byte) + _encode_channel_id(channel_id) +
275 _encode_number(replenished_quota))
278 class _MuxReaderThread(threading.Thread):
279 """Mux reader thread.
281 Reads frames and passes them to the mux client. This thread accesses
282 private functions/variables of the mux client.
285 def __init__(self, mux):
286 threading.Thread.__init__(self)
289 self._stop_requested = False
291 def _receive_message(self):
294 while not self._stop_requested:
295 fin, rsv1, rsv2, rsv3, opcode, payload_length = (
296 client_for_testing.read_frame_header(self._mux._socket))
299 if opcode == client_for_testing.OPCODE_TEXT:
300 raise Exception('Received a text message on physical '
302 if opcode == client_for_testing.OPCODE_CONTINUATION:
303 raise Exception('Received an intermediate frame but '
304 'fragmentation was not started')
305 if (opcode == client_for_testing.OPCODE_BINARY or
306 opcode == client_for_testing.OPCODE_PONG or
307 opcode == client_for_testing.OPCODE_PONG or
308 opcode == client_for_testing.OPCODE_CLOSE):
309 first_opcode = opcode
311 raise Exception('Received an undefined opcode frame: %d' %
314 elif opcode != client_for_testing.OPCODE_CONTINUATION:
315 raise Exception('Received a new opcode before '
316 'terminating fragmentation')
318 payload = client_for_testing.receive_bytes(
319 self._mux._socket, payload_length)
321 if self._mux._incoming_frame_filter is not None:
322 payload = self._mux._incoming_frame_filter.filter(payload)
324 pending_payload.append(payload)
329 if self._stop_requested:
332 message = ''.join(pending_payload)
333 return first_opcode, message
335 def request_stop(self):
336 self._stop_requested = True
340 while not self._stop_requested:
341 # opcode is OPCODE_BINARY or control opcodes when a message
342 # is succesfully received.
343 opcode, message = self._receive_message()
346 if opcode == client_for_testing.OPCODE_BINARY:
347 channel_id, advance = _parse_channel_id(message)
348 self._mux._dispatch_frame(channel_id, message[advance:])
350 self._mux._process_control_message(opcode, message)
352 self._mux._notify_reader_done()
355 class _InnerFrame(object):
356 def __init__(self, fin, rsv1, rsv2, rsv3, opcode, payload):
362 self.payload = payload
365 class _LogicalChannelData(object):
367 self.queue = Queue.Queue()
369 self.receive_quota = 0
372 class MuxClient(object):
373 """WebSocket mux client.
375 Note that this class is NOT thread-safe. Do not access an instance of this
376 class from multiple threads at a same time.
379 def __init__(self, options):
380 self._logger = util.get_class_logger(self)
382 self._options = options
383 self._options.enable_mux()
386 self._handshake = client_for_testing.WebSocketHandshake(self._options)
387 self._incoming_frame_filter = None
388 self._outgoing_frame_filter = None
390 self._is_active = False
391 self._read_thread = None
392 self._control_blocks_condition = threading.Condition()
393 self._control_blocks = []
394 self._channel_slots = collections.deque()
395 self._logical_channels_condition = threading.Condition();
396 self._logical_channels = {}
398 self._physical_connection_close_event = None
399 self._physical_connection_close_message = None
401 def _parse_inner_frame(self, data):
403 raise Exception('Invalid encapsulated frame received')
405 first_byte = ord(data[0])
406 fin = (first_byte << 7) & 1
407 rsv1 = (first_byte << 6) & 1
408 rsv2 = (first_byte << 5) & 1
409 rsv3 = (first_byte << 4) & 1
410 opcode = first_byte & 0xf
412 if self._outgoing_frame_filter:
413 payload = self._outgoing_frame_filter.filter(
418 return _InnerFrame(fin, rsv1, rsv2, rsv3, opcode, payload)
420 def _process_mux_control_blocks(self):
421 for block in self._control_blocks:
422 if block.opcode == _MUX_OPCODE_ADD_CHANNEL_RESPONSE:
423 # AddChannelResponse will be handled in add_channel().
425 elif block.opcode == _MUX_OPCODE_FLOW_CONTROL:
427 self._logical_channels_condition.acquire()
428 if not block.channel_id in self._logical_channels:
429 raise Exception('Invalid flow control received for '
430 'channel id %d' % block.channel_id)
431 self._logical_channels[block.channel_id].send_quota += (
433 self._logical_channels_condition.notify()
435 self._logical_channels_condition.release()
436 elif block.opcode == _MUX_OPCODE_NEW_CHANNEL_SLOT:
437 self._channel_slots.extend([block.send_quota] * block.slots)
439 def _dispatch_frame(self, channel_id, payload):
440 if channel_id == _CONTROL_CHANNEL_ID:
442 self._control_blocks_condition.acquire()
443 self._control_blocks += _parse_control_blocks(payload)
444 self._process_mux_control_blocks()
445 self._control_blocks_condition.notify()
447 self._control_blocks_condition.release()
450 self._logical_channels_condition.acquire()
451 if not channel_id in self._logical_channels:
452 raise Exception('Received logical frame on channel id '
453 '%d, which is not established' %
456 inner_frame = self._parse_inner_frame(payload)
457 self._logical_channels[channel_id].receive_quota -= (
458 len(inner_frame.payload))
459 if self._logical_channels[channel_id].receive_quota < 0:
460 raise Exception('The server violates quota on '
461 'channel id %d' % channel_id)
463 self._logical_channels_condition.release()
464 self._logical_channels[channel_id].queue.put(inner_frame)
466 def _process_control_message(self, opcode, message):
467 # Ping/Pong are not supported.
468 if opcode == client_for_testing.OPCODE_CLOSE:
469 self._physical_connection_close_message = message
471 self._stream.send_close(
472 code=client_for_testing.STATUS_NORMAL_CLOSURE, reason='')
473 self._read_thread.request_stop()
475 if self._physical_connection_close_event:
476 self._physical_connection_close_event.set()
478 def _notify_reader_done(self):
479 self._logger.debug('Read thread terminated.')
482 def _assert_channel_slot_available(self):
484 self._control_blocks_condition.acquire()
485 if len(self._channel_slots) == 0:
487 self._control_blocks_condition.wait(timeout=self._timeout)
489 self._control_blocks_condition.release()
491 if len(self._channel_slots) == 0:
492 raise Exception('Failed to receive NewChannelSlot')
494 def _assert_send_quota_available(self, channel_id):
496 self._logical_channels_condition.acquire()
497 if self._logical_channels[channel_id].send_quota == 0:
499 self._logical_channels_condition.wait(timeout=self._timeout)
501 self._logical_channels_condition.release()
503 if self._logical_channels[channel_id].send_quota == 0:
504 raise Exception('Failed to receive FlowControl for channel id %d' %
508 self._socket = socket.socket()
509 self._socket.settimeout(self._options.socket_timeout)
511 self._socket.connect((self._options.server_host,
512 self._options.server_port))
513 if self._options.use_tls:
514 self._socket = _TLSSocket(self._socket)
516 self._handshake.handshake(self._socket)
517 self._stream = client_for_testing.WebSocketStream(
518 self._socket, self._handshake)
520 self._logical_channels[_DEFAULT_CHANNEL_ID] = _LogicalChannelData()
522 self._read_thread = _MuxReaderThread(self)
523 self._read_thread.start()
525 self._assert_channel_slot_available()
526 self._assert_send_quota_available(_DEFAULT_CHANNEL_ID)
528 self._is_active = True
529 self._logger.info('Connection established')
531 def add_channel(self, channel_id, options):
532 if not self._is_active:
533 raise Exception('Mux client is not active')
535 if channel_id in self._logical_channels:
536 raise Exception('Channel id %d already exists' % channel_id)
539 send_quota = self._channel_slots.popleft()
540 except IndexError, e:
541 raise Exception('No channel slots: %r' % e)
543 # Create AddChannel request
544 request_line = 'GET %s HTTP/1.1\r\n' % options.resource
546 fields.append('Upgrade: websocket\r\n')
547 fields.append('Connection: Upgrade\r\n')
548 if options.server_port == client_for_testing.DEFAULT_PORT:
549 fields.append('Host: %s\r\n' % options.server_host.lower())
551 fields.append('Host: %s:%d\r\n' % (options.server_host.lower(),
552 options.server_port))
553 fields.append('Origin: %s\r\n' % options.origin.lower())
555 original_key = os.urandom(16)
556 key = base64.b64encode(original_key)
557 fields.append('Sec-WebSocket-Key: %s\r\n' % key)
559 fields.append('Sec-WebSocket-Version: 13\r\n')
561 if len(options.extensions) > 0:
562 fields.append('Sec-WebSocket-Extensions: %s\r\n' %
563 ', '.join(options.extensions))
565 handshake = request_line + ''.join(fields) + '\r\n'
566 add_channel_request = _create_add_channel_request(
567 channel_id, handshake)
568 payload = _encode_channel_id(_CONTROL_CHANNEL_ID) + add_channel_request
569 self._stream.send_binary(payload)
571 # Wait AddChannelResponse
572 self._logger.debug('Waiting AddChannelResponse for the request...')
575 self._control_blocks_condition.acquire()
577 for block in self._control_blocks:
578 if block.opcode != _MUX_OPCODE_ADD_CHANNEL_RESPONSE:
580 if block.channel_id == channel_id:
582 self._control_blocks.remove(response)
586 self._control_blocks_condition.wait(self._timeout)
587 if not self._is_active:
588 raise Exception('AddChannelRequest timed out')
590 self._control_blocks_condition.release()
592 # Validate AddChannelResponse
593 if response.rejected:
594 raise Exception('The server rejected AddChannelRequest')
596 fields = _parse_handshake_response(response.encoded_handshake)
598 if not 'upgrade' in fields:
599 raise Exception('No Upgrade header')
600 if fields['upgrade'] != 'websocket':
601 raise Exception('Wrong Upgrade header')
602 if not 'connection' in fields:
603 raise Exception('No Connection header')
604 if fields['connection'] != 'Upgrade':
605 raise Exception('Wrong Connection header')
606 if not 'sec-websocket-accept' in fields:
607 raise Exception('No Sec-WebSocket-Accept header')
609 accept = fields['sec-websocket-accept']
611 decoded_accept = base64.b64decode(accept)
614 'Illegal value for header Sec-WebSocket-Accept: ' + accept)
616 if len(decoded_accept) != 20:
618 'Decoded value of Sec-WebSocket-Accept is not 20-byte long')
620 original_expected_accept = util.sha1_hash(
621 key + client_for_testing.WEBSOCKET_ACCEPT_UUID).digest()
622 expected_accept = base64.b64encode(original_expected_accept)
624 if accept != expected_accept:
626 'Invalid Sec-WebSocket-Accept header: %r (expected) != %r '
627 '(actual)' % (accept, expected_accept))
629 self._logical_channels_condition.acquire()
630 self._logical_channels[channel_id] = _LogicalChannelData()
631 self._logical_channels[channel_id].send_quota = send_quota
632 self._logical_channels_condition.release()
634 self._logger.debug('Logical channel %d established' % channel_id)
636 def _check_logical_channel_is_opened(self, channel_id):
637 if not self._is_active:
638 raise Exception('Mux client is not active')
640 if not channel_id in self._logical_channels:
641 raise Exception('Logical channel %d is not established.')
643 def drop_channel(self, channel_id):
644 # TODO(bashi): Implement
647 def send_flow_control(self, channel_id, replenished_quota):
648 self._check_logical_channel_is_opened(channel_id)
649 flow_control = _create_flow_control(channel_id, replenished_quota)
650 payload = _encode_channel_id(_CONTROL_CHANNEL_ID) + flow_control
651 # Replenish receive quota
653 self._logical_channels_condition.acquire()
654 self._logical_channels[channel_id].receive_quota += (
657 self._logical_channels_condition.release()
658 self._stream.send_binary(payload)
660 def send_message(self, channel_id, message, end=True, binary=False):
661 self._check_logical_channel_is_opened(channel_id)
664 first_byte = (end << 7) | client_for_testing.OPCODE_BINARY
666 first_byte = (end << 7) | client_for_testing.OPCODE_TEXT
667 message = message.encode('utf-8')
670 self._logical_channels_condition.acquire()
671 if self._logical_channels[channel_id].send_quota < len(message):
672 raise Exception('Send quota violation: %d < %d' % (
673 self._logical_channels[channel_id].send_quota,
676 self._logical_channels[channel_id].send_quota -= len(message)
678 self._logical_channels_condition.release()
679 payload = _encode_channel_id(channel_id) + chr(first_byte) + message
680 self._stream.send_binary(payload)
682 def assert_receive(self, channel_id, payload, binary=False):
683 self._check_logical_channel_is_opened(channel_id)
686 inner_frame = self._logical_channels[channel_id].queue.get(
687 timeout=self._timeout)
688 except Queue.Empty, e:
689 raise Exception('Cannot receive message from channel id %d' %
693 opcode = client_for_testing.OPCODE_BINARY
695 opcode = client_for_testing.OPCODE_TEXT
697 if inner_frame.opcode != opcode:
698 raise Exception('Unexpected opcode received (%r != %r)' %
699 (expected_opcode, inner_frame.opcode))
701 if inner_frame.payload != payload:
702 raise Exception('Unexpected payload received')
704 def send_close(self, channel_id, code=None, reason=''):
705 self._check_logical_channel_is_opened(channel_id)
708 body = struct.pack('!H', code) + reason.encode('utf-8')
712 first_byte = (1 << 7) | client_for_testing.OPCODE_CLOSE
713 payload = _encode_channel_id(channel_id) + chr(first_byte) + body
714 self._stream.send_binary(payload)
716 def assert_receive_close(self, channel_id):
717 self._check_logical_channel_is_opened(channel_id)
720 inner_frame = self._logical_channels[channel_id].queue.get(
721 timeout=self._timeout)
722 except Queue.Empty, e:
723 raise Exception('Cannot receive message from channel id %d' %
725 if inner_frame.opcode != client_for_testing.OPCODE_CLOSE:
726 raise Exception('Didn\'t receive close frame')
728 def send_physical_connection_close(self, code=None, reason=''):
729 self._physical_connection_close_event = threading.Event()
730 self._stream.send_close(code, reason)
732 # This method can be used only after calling
733 # send_physical_connection_close().
734 def assert_physical_connection_receive_close(
735 self, code=client_for_testing.STATUS_NORMAL_CLOSURE, reason=''):
736 self._physical_connection_close_event.wait(timeout=self._timeout)
737 if (not self._physical_connection_close_event.isSet() or
738 not self._physical_connection_close_message):
739 raise Exception('Didn\'t receive closing handshake')
741 def close_socket(self):
742 self._is_active = False