""" The :mod:`websockets.extensions.permessage_deflate` module implements the Compression Extensions for WebSocket as specified in :rfc:`7692`. """ import zlib from ..exceptions import ( DuplicateParameter, InvalidParameterName, InvalidParameterValue, NegotiationError, PayloadTooBig ) from ..framing import CTRL_OPCODES, OP_CONT __all__ = [ 'ClientPerMessageDeflateFactory', 'ServerPerMessageDeflateFactory', 'PerMessageDeflate', ] _EMPTY_UNCOMPRESSED_BLOCK = b'\x00\x00\xff\xff' _MAX_WINDOW_BITS_VALUES = [str(bits) for bits in range(8, 16)] def _build_parameters( server_no_context_takeover, client_no_context_takeover, server_max_window_bits, client_max_window_bits, ): """ Build a list of ``(name, value)`` pairs for some compression parameters. """ params = [] if server_no_context_takeover: params.append(('server_no_context_takeover', None)) if client_no_context_takeover: params.append(('client_no_context_takeover', None)) if server_max_window_bits: params.append(('server_max_window_bits', str(server_max_window_bits))) if client_max_window_bits is True: # only in handshake requests params.append(('client_max_window_bits', None)) elif client_max_window_bits: params.append(('client_max_window_bits', str(client_max_window_bits))) return params def _extract_parameters(params, *, is_server): """ Extract compression parameters from a list of ``(name, value)`` pairs. If ``is_server`` is ``True``, ``client_max_window_bits`` may be provided without a value. This is only allow in handshake requests. """ server_no_context_takeover = False client_no_context_takeover = False server_max_window_bits = None client_max_window_bits = None for name, value in params: if name == 'server_no_context_takeover': if server_no_context_takeover: raise DuplicateParameter(name) if value is None: server_no_context_takeover = True else: raise InvalidParameterValue(name, value) elif name == 'client_no_context_takeover': if client_no_context_takeover: raise DuplicateParameter(name) if value is None: client_no_context_takeover = True else: raise InvalidParameterValue(name, value) elif name == 'server_max_window_bits': if server_max_window_bits is not None: raise DuplicateParameter(name) if value in _MAX_WINDOW_BITS_VALUES: server_max_window_bits = int(value) else: raise InvalidParameterValue(name, value) elif name == 'client_max_window_bits': if client_max_window_bits is not None: raise DuplicateParameter(name) if is_server and value is None: # only in handshake requests client_max_window_bits = True elif value in _MAX_WINDOW_BITS_VALUES: client_max_window_bits = int(value) else: raise InvalidParameterValue(name, value) else: raise InvalidParameterName(name) return ( server_no_context_takeover, client_no_context_takeover, server_max_window_bits, client_max_window_bits, ) class ClientPerMessageDeflateFactory: """ Client-side extension factory for permessage-deflate extension. """ name = 'permessage-deflate' def __init__( self, server_no_context_takeover=False, client_no_context_takeover=False, server_max_window_bits=None, client_max_window_bits=None, compress_settings=None, ): """ Configure permessage-deflate extension factory. See https://tools.ietf.org/html/rfc7692#section-7.1. """ if not (server_max_window_bits is None or 8 <= server_max_window_bits <= 15): raise ValueError("server_max_window_bits must be between 8 and 15") if not (client_max_window_bits is None or client_max_window_bits is True or 8 <= client_max_window_bits <= 15): raise ValueError("client_max_window_bits must be between 8 and 15") if compress_settings is not None and 'wbits' in compress_settings: raise ValueError("compress_settings must not include wbits, " "set client_max_window_bits instead") self.server_no_context_takeover = server_no_context_takeover self.client_no_context_takeover = client_no_context_takeover self.server_max_window_bits = server_max_window_bits self.client_max_window_bits = client_max_window_bits self.compress_settings = compress_settings def get_request_params(self): """ Build request parameters. """ return _build_parameters( self.server_no_context_takeover, self.client_no_context_takeover, self.server_max_window_bits, self.client_max_window_bits, ) def process_response_params(self, params, accepted_extensions): """" Process response parameters. Return an extension instance. """ if any(other.name == self.name for other in accepted_extensions): raise NegotiationError("Received duplicate {}".format(self.name)) # Request parameters are available in instance variables. # Load response parameters in local variables. ( server_no_context_takeover, client_no_context_takeover, server_max_window_bits, client_max_window_bits, ) = _extract_parameters(params, is_server=False) # After comparing the request and the response, the final # configuration must be available in the local variables. # server_no_context_takeover # # Req. Resp. Result # ------ ------ -------------------------------------------------- # False False False # False True True # True False Error! # True True True if self.server_no_context_takeover: if not server_no_context_takeover: raise NegotiationError("Expected server_no_context_takeover") # client_no_context_takeover # # Req. Resp. Result # ------ ------ -------------------------------------------------- # False False False # False True True # True False True - must change value # True True True if self.client_no_context_takeover: if not client_no_context_takeover: client_no_context_takeover = True # server_max_window_bits # Req. Resp. Result # ------ ------ -------------------------------------------------- # None None None # None 8≤M≤15 M # 8≤N≤15 None Error! # 8≤N≤15 8≤M≤N M # 8≤N≤15 N self.server_max_window_bits: raise NegotiationError("Unsupported server_max_window_bits") # client_max_window_bits # Req. Resp. Result # ------ ------ -------------------------------------------------- # None None None # None 8≤M≤15 Error! # True None None # True 8≤M≤15 M # 8≤N≤15 None N - must change value # 8≤N≤15 8≤M≤N M # 8≤N≤15 N self.client_max_window_bits: raise NegotiationError("Unsupported client_max_window_bits") return PerMessageDeflate( server_no_context_takeover, # remote_no_context_takeover client_no_context_takeover, # local_no_context_takeover server_max_window_bits or 15, # remote_max_window_bits client_max_window_bits or 15, # local_max_window_bits self.compress_settings, ) class ServerPerMessageDeflateFactory: """ Server-side extension factory for permessage-deflate extension. """ name = 'permessage-deflate' def __init__( self, server_no_context_takeover=False, client_no_context_takeover=False, server_max_window_bits=None, client_max_window_bits=None, compress_settings=None, ): """ Configure permessage-deflate extension factory. See https://tools.ietf.org/html/rfc7692#section-7.1. """ if not (server_max_window_bits is None or 8 <= server_max_window_bits <= 15): raise ValueError("server_max_window_bits must be between 8 and 15") if not (client_max_window_bits is None or 8 <= client_max_window_bits <= 15): raise ValueError("client_max_window_bits must be between 8 and 15") if compress_settings is not None and 'wbits' in compress_settings: raise ValueError("compress_settings must not include wbits, " "set server_max_window_bits instead") self.server_no_context_takeover = server_no_context_takeover self.client_no_context_takeover = client_no_context_takeover self.server_max_window_bits = server_max_window_bits self.client_max_window_bits = client_max_window_bits self.compress_settings = compress_settings def process_request_params(self, params, accepted_extensions): """" Process request parameters. Return response params and an extension instance. """ if any(other.name == self.name for other in accepted_extensions): raise NegotiationError("Skipped duplicate {}".format(self.name)) # Load request parameters in local variables. ( server_no_context_takeover, client_no_context_takeover, server_max_window_bits, client_max_window_bits, ) = _extract_parameters(params, is_server=True) # Configuration parameters are available in instance variables. # After comparing the request and the configuration, the response must # be available in the local variables. # server_no_context_takeover # # Config Req. Resp. # ------ ------ -------------------------------------------------- # False False False # False True True # True False True - must change value to True # True True True if self.server_no_context_takeover: if not server_no_context_takeover: server_no_context_takeover = True # client_no_context_takeover # # Config Req. Resp. # ------ ------ -------------------------------------------------- # False False False # False True True (or False) # True False True - must change value to True # True True True (or False) if self.client_no_context_takeover: if not client_no_context_takeover: client_no_context_takeover = True # server_max_window_bits # Config Req. Resp. # ------ ------ -------------------------------------------------- # None None None # None 8≤M≤15 M # 8≤N≤15 None N - must change value # 8≤N≤15 8≤M≤N M # 8≤N≤15 N self.server_max_window_bits: server_max_window_bits = self.server_max_window_bits # client_max_window_bits # Config Req. Resp. # ------ ------ -------------------------------------------------- # None None None # None True None - must change value # None 8≤M≤15 M (or None) # 8≤N≤15 None Error! # 8≤N≤15 True N - must change value # 8≤N≤15 8≤M≤N M (or None) # 8≤N≤15 N {} bytes)" .format(max_size)) # Allow garbage collection of the decoder if it won't be reused. if frame.fin and self.remote_no_context_takeover: self.decoder = None return frame._replace(data=data, rsv1=False) def encode(self, frame): """ Encode an outgoing frame. """ # Skip control frames. if frame.opcode in CTRL_OPCODES: return frame # Since we always encode and never fragment messages, there's no logic # similar to decode() here at this time. if frame.opcode != OP_CONT: # Re-initialize per-message decoder. if self.local_no_context_takeover: self.encoder = zlib.compressobj( wbits=-self.local_max_window_bits, **self.compress_settings) # Compress data frames. data = ( self.encoder.compress(frame.data) + self.encoder.flush(zlib.Z_SYNC_FLUSH) ) if frame.fin and data.endswith(_EMPTY_UNCOMPRESSED_BLOCK): data = data[:-4] # Allow garbage collection of the encoder if it won't be reused. if frame.fin and self.local_no_context_takeover: self.encoder = None return frame._replace(data=data, rsv1=True)