548 lines
19 KiB
Python
548 lines
19 KiB
Python
"""
|
|
The :mod:`websockets.extensions.permessage_deflate` module implements the
|
|
Compression Extensions for WebSocket as specified in :rfc:`7692`.
|
|
|
|
"""
|
|
|
|
import zlib
|
|
|
|
from ..exceptions import (
|
|
DuplicateParameter, InvalidParameterName, InvalidParameterValue,
|
|
NegotiationError, PayloadTooBig
|
|
)
|
|
from ..framing import CTRL_OPCODES, OP_CONT
|
|
|
|
|
|
__all__ = [
|
|
'ClientPerMessageDeflateFactory',
|
|
'ServerPerMessageDeflateFactory',
|
|
'PerMessageDeflate',
|
|
]
|
|
|
|
_EMPTY_UNCOMPRESSED_BLOCK = b'\x00\x00\xff\xff'
|
|
|
|
_MAX_WINDOW_BITS_VALUES = [str(bits) for bits in range(8, 16)]
|
|
|
|
|
|
def _build_parameters(
|
|
server_no_context_takeover,
|
|
client_no_context_takeover,
|
|
server_max_window_bits,
|
|
client_max_window_bits,
|
|
):
|
|
"""
|
|
Build a list of ``(name, value)`` pairs for some compression parameters.
|
|
|
|
"""
|
|
params = []
|
|
if server_no_context_takeover:
|
|
params.append(('server_no_context_takeover', None))
|
|
if client_no_context_takeover:
|
|
params.append(('client_no_context_takeover', None))
|
|
if server_max_window_bits:
|
|
params.append(('server_max_window_bits', str(server_max_window_bits)))
|
|
if client_max_window_bits is True: # only in handshake requests
|
|
params.append(('client_max_window_bits', None))
|
|
elif client_max_window_bits:
|
|
params.append(('client_max_window_bits', str(client_max_window_bits)))
|
|
return params
|
|
|
|
|
|
def _extract_parameters(params, *, is_server):
|
|
"""
|
|
Extract compression parameters from a list of ``(name, value)`` pairs.
|
|
|
|
If ``is_server`` is ``True``, ``client_max_window_bits`` may be provided
|
|
without a value. This is only allow in handshake requests.
|
|
|
|
"""
|
|
server_no_context_takeover = False
|
|
client_no_context_takeover = False
|
|
server_max_window_bits = None
|
|
client_max_window_bits = None
|
|
|
|
for name, value in params:
|
|
|
|
if name == 'server_no_context_takeover':
|
|
if server_no_context_takeover:
|
|
raise DuplicateParameter(name)
|
|
if value is None:
|
|
server_no_context_takeover = True
|
|
else:
|
|
raise InvalidParameterValue(name, value)
|
|
|
|
elif name == 'client_no_context_takeover':
|
|
if client_no_context_takeover:
|
|
raise DuplicateParameter(name)
|
|
if value is None:
|
|
client_no_context_takeover = True
|
|
else:
|
|
raise InvalidParameterValue(name, value)
|
|
|
|
elif name == 'server_max_window_bits':
|
|
if server_max_window_bits is not None:
|
|
raise DuplicateParameter(name)
|
|
if value in _MAX_WINDOW_BITS_VALUES:
|
|
server_max_window_bits = int(value)
|
|
else:
|
|
raise InvalidParameterValue(name, value)
|
|
|
|
elif name == 'client_max_window_bits':
|
|
if client_max_window_bits is not None:
|
|
raise DuplicateParameter(name)
|
|
if is_server and value is None: # only in handshake requests
|
|
client_max_window_bits = True
|
|
elif value in _MAX_WINDOW_BITS_VALUES:
|
|
client_max_window_bits = int(value)
|
|
else:
|
|
raise InvalidParameterValue(name, value)
|
|
|
|
else:
|
|
raise InvalidParameterName(name)
|
|
|
|
return (
|
|
server_no_context_takeover,
|
|
client_no_context_takeover,
|
|
server_max_window_bits,
|
|
client_max_window_bits,
|
|
)
|
|
|
|
|
|
class ClientPerMessageDeflateFactory:
|
|
"""
|
|
Client-side extension factory for permessage-deflate extension.
|
|
|
|
"""
|
|
name = 'permessage-deflate'
|
|
|
|
def __init__(
|
|
self,
|
|
server_no_context_takeover=False,
|
|
client_no_context_takeover=False,
|
|
server_max_window_bits=None,
|
|
client_max_window_bits=None,
|
|
compress_settings=None,
|
|
):
|
|
"""
|
|
Configure permessage-deflate extension factory.
|
|
|
|
See https://tools.ietf.org/html/rfc7692#section-7.1.
|
|
|
|
"""
|
|
if not (server_max_window_bits is None or
|
|
8 <= server_max_window_bits <= 15):
|
|
raise ValueError("server_max_window_bits must be between 8 and 15")
|
|
if not (client_max_window_bits is None or
|
|
client_max_window_bits is True or
|
|
8 <= client_max_window_bits <= 15):
|
|
raise ValueError("client_max_window_bits must be between 8 and 15")
|
|
if compress_settings is not None and 'wbits' in compress_settings:
|
|
raise ValueError("compress_settings must not include wbits, "
|
|
"set client_max_window_bits instead")
|
|
|
|
self.server_no_context_takeover = server_no_context_takeover
|
|
self.client_no_context_takeover = client_no_context_takeover
|
|
self.server_max_window_bits = server_max_window_bits
|
|
self.client_max_window_bits = client_max_window_bits
|
|
self.compress_settings = compress_settings
|
|
|
|
def get_request_params(self):
|
|
"""
|
|
Build request parameters.
|
|
|
|
"""
|
|
return _build_parameters(
|
|
self.server_no_context_takeover, self.client_no_context_takeover,
|
|
self.server_max_window_bits, self.client_max_window_bits,
|
|
)
|
|
|
|
def process_response_params(self, params, accepted_extensions):
|
|
""""
|
|
Process response parameters.
|
|
|
|
Return an extension instance.
|
|
|
|
"""
|
|
if any(other.name == self.name for other in accepted_extensions):
|
|
raise NegotiationError("Received duplicate {}".format(self.name))
|
|
|
|
# Request parameters are available in instance variables.
|
|
|
|
# Load response parameters in local variables.
|
|
(
|
|
server_no_context_takeover,
|
|
client_no_context_takeover,
|
|
server_max_window_bits,
|
|
client_max_window_bits,
|
|
) = _extract_parameters(params, is_server=False)
|
|
|
|
# After comparing the request and the response, the final
|
|
# configuration must be available in the local variables.
|
|
|
|
# server_no_context_takeover
|
|
#
|
|
# Req. Resp. Result
|
|
# ------ ------ --------------------------------------------------
|
|
# False False False
|
|
# False True True
|
|
# True False Error!
|
|
# True True True
|
|
|
|
if self.server_no_context_takeover:
|
|
if not server_no_context_takeover:
|
|
raise NegotiationError("Expected server_no_context_takeover")
|
|
|
|
# client_no_context_takeover
|
|
#
|
|
# Req. Resp. Result
|
|
# ------ ------ --------------------------------------------------
|
|
# False False False
|
|
# False True True
|
|
# True False True - must change value
|
|
# True True True
|
|
|
|
if self.client_no_context_takeover:
|
|
if not client_no_context_takeover:
|
|
client_no_context_takeover = True
|
|
|
|
# server_max_window_bits
|
|
|
|
# Req. Resp. Result
|
|
# ------ ------ --------------------------------------------------
|
|
# None None None
|
|
# None 8≤M≤15 M
|
|
# 8≤N≤15 None Error!
|
|
# 8≤N≤15 8≤M≤N M
|
|
# 8≤N≤15 N<M≤15 Error!
|
|
|
|
if self.server_max_window_bits is None:
|
|
pass
|
|
|
|
else:
|
|
if server_max_window_bits is None:
|
|
raise NegotiationError("Expected server_max_window_bits")
|
|
elif server_max_window_bits > self.server_max_window_bits:
|
|
raise NegotiationError("Unsupported server_max_window_bits")
|
|
|
|
# client_max_window_bits
|
|
|
|
# Req. Resp. Result
|
|
# ------ ------ --------------------------------------------------
|
|
# None None None
|
|
# None 8≤M≤15 Error!
|
|
# True None None
|
|
# True 8≤M≤15 M
|
|
# 8≤N≤15 None N - must change value
|
|
# 8≤N≤15 8≤M≤N M
|
|
# 8≤N≤15 N<M≤15 Error!
|
|
|
|
if self.client_max_window_bits is None:
|
|
if client_max_window_bits is not None:
|
|
raise NegotiationError("Unexpected client_max_window_bits")
|
|
|
|
elif self.client_max_window_bits is True:
|
|
pass
|
|
|
|
else:
|
|
if client_max_window_bits is None:
|
|
client_max_window_bits = self.client_max_window_bits
|
|
elif client_max_window_bits > self.client_max_window_bits:
|
|
raise NegotiationError("Unsupported client_max_window_bits")
|
|
|
|
return PerMessageDeflate(
|
|
server_no_context_takeover, # remote_no_context_takeover
|
|
client_no_context_takeover, # local_no_context_takeover
|
|
server_max_window_bits or 15, # remote_max_window_bits
|
|
client_max_window_bits or 15, # local_max_window_bits
|
|
self.compress_settings,
|
|
)
|
|
|
|
|
|
class ServerPerMessageDeflateFactory:
|
|
"""
|
|
Server-side extension factory for permessage-deflate extension.
|
|
|
|
"""
|
|
name = 'permessage-deflate'
|
|
|
|
def __init__(
|
|
self,
|
|
server_no_context_takeover=False,
|
|
client_no_context_takeover=False,
|
|
server_max_window_bits=None,
|
|
client_max_window_bits=None,
|
|
compress_settings=None,
|
|
):
|
|
"""
|
|
Configure permessage-deflate extension factory.
|
|
|
|
See https://tools.ietf.org/html/rfc7692#section-7.1.
|
|
|
|
"""
|
|
if not (server_max_window_bits is None or
|
|
8 <= server_max_window_bits <= 15):
|
|
raise ValueError("server_max_window_bits must be between 8 and 15")
|
|
if not (client_max_window_bits is None or
|
|
8 <= client_max_window_bits <= 15):
|
|
raise ValueError("client_max_window_bits must be between 8 and 15")
|
|
if compress_settings is not None and 'wbits' in compress_settings:
|
|
raise ValueError("compress_settings must not include wbits, "
|
|
"set server_max_window_bits instead")
|
|
|
|
self.server_no_context_takeover = server_no_context_takeover
|
|
self.client_no_context_takeover = client_no_context_takeover
|
|
self.server_max_window_bits = server_max_window_bits
|
|
self.client_max_window_bits = client_max_window_bits
|
|
self.compress_settings = compress_settings
|
|
|
|
def process_request_params(self, params, accepted_extensions):
|
|
""""
|
|
Process request parameters.
|
|
|
|
Return response params and an extension instance.
|
|
|
|
"""
|
|
if any(other.name == self.name for other in accepted_extensions):
|
|
raise NegotiationError("Skipped duplicate {}".format(self.name))
|
|
|
|
# Load request parameters in local variables.
|
|
(
|
|
server_no_context_takeover,
|
|
client_no_context_takeover,
|
|
server_max_window_bits,
|
|
client_max_window_bits,
|
|
) = _extract_parameters(params, is_server=True)
|
|
|
|
# Configuration parameters are available in instance variables.
|
|
|
|
# After comparing the request and the configuration, the response must
|
|
# be available in the local variables.
|
|
|
|
# server_no_context_takeover
|
|
#
|
|
# Config Req. Resp.
|
|
# ------ ------ --------------------------------------------------
|
|
# False False False
|
|
# False True True
|
|
# True False True - must change value to True
|
|
# True True True
|
|
|
|
if self.server_no_context_takeover:
|
|
if not server_no_context_takeover:
|
|
server_no_context_takeover = True
|
|
|
|
# client_no_context_takeover
|
|
#
|
|
# Config Req. Resp.
|
|
# ------ ------ --------------------------------------------------
|
|
# False False False
|
|
# False True True (or False)
|
|
# True False True - must change value to True
|
|
# True True True (or False)
|
|
|
|
if self.client_no_context_takeover:
|
|
if not client_no_context_takeover:
|
|
client_no_context_takeover = True
|
|
|
|
# server_max_window_bits
|
|
|
|
# Config Req. Resp.
|
|
# ------ ------ --------------------------------------------------
|
|
# None None None
|
|
# None 8≤M≤15 M
|
|
# 8≤N≤15 None N - must change value
|
|
# 8≤N≤15 8≤M≤N M
|
|
# 8≤N≤15 N<M≤15 N - must change value
|
|
|
|
if self.server_max_window_bits is None:
|
|
pass
|
|
|
|
else:
|
|
if server_max_window_bits is None:
|
|
server_max_window_bits = self.server_max_window_bits
|
|
elif server_max_window_bits > self.server_max_window_bits:
|
|
server_max_window_bits = self.server_max_window_bits
|
|
|
|
# client_max_window_bits
|
|
|
|
# Config Req. Resp.
|
|
# ------ ------ --------------------------------------------------
|
|
# None None None
|
|
# None True None - must change value
|
|
# None 8≤M≤15 M (or None)
|
|
# 8≤N≤15 None Error!
|
|
# 8≤N≤15 True N - must change value
|
|
# 8≤N≤15 8≤M≤N M (or None)
|
|
# 8≤N≤15 N<M≤15 N
|
|
|
|
if self.client_max_window_bits is None:
|
|
if client_max_window_bits is True:
|
|
client_max_window_bits = self.client_max_window_bits
|
|
|
|
else:
|
|
if client_max_window_bits is None:
|
|
raise NegotiationError("Required client_max_window_bits")
|
|
elif client_max_window_bits is True:
|
|
client_max_window_bits = self.client_max_window_bits
|
|
elif self.client_max_window_bits < client_max_window_bits:
|
|
client_max_window_bits = self.client_max_window_bits
|
|
|
|
return (
|
|
_build_parameters(
|
|
server_no_context_takeover, client_no_context_takeover,
|
|
server_max_window_bits, client_max_window_bits,
|
|
),
|
|
PerMessageDeflate(
|
|
client_no_context_takeover, # remote_no_context_takeover
|
|
server_no_context_takeover, # local_no_context_takeover
|
|
client_max_window_bits or 15, # remote_max_window_bits
|
|
server_max_window_bits or 15, # local_max_window_bits
|
|
self.compress_settings,
|
|
)
|
|
)
|
|
|
|
|
|
class PerMessageDeflate:
|
|
"""
|
|
permessage-deflate extension.
|
|
|
|
"""
|
|
name = 'permessage-deflate'
|
|
|
|
def __init__(
|
|
self,
|
|
remote_no_context_takeover,
|
|
local_no_context_takeover,
|
|
remote_max_window_bits,
|
|
local_max_window_bits,
|
|
compress_settings=None,
|
|
):
|
|
"""
|
|
Configure permessage-deflate extension.
|
|
|
|
"""
|
|
if compress_settings is None:
|
|
compress_settings = {}
|
|
|
|
assert remote_no_context_takeover in [False, True]
|
|
assert local_no_context_takeover in [False, True]
|
|
assert 8 <= remote_max_window_bits <= 15
|
|
assert 8 <= local_max_window_bits <= 15
|
|
assert 'wbits' not in compress_settings
|
|
|
|
self.remote_no_context_takeover = remote_no_context_takeover
|
|
self.local_no_context_takeover = local_no_context_takeover
|
|
self.remote_max_window_bits = remote_max_window_bits
|
|
self.local_max_window_bits = local_max_window_bits
|
|
self.compress_settings = compress_settings
|
|
|
|
if not self.remote_no_context_takeover:
|
|
self.decoder = zlib.decompressobj(
|
|
wbits=-self.remote_max_window_bits)
|
|
|
|
if not self.local_no_context_takeover:
|
|
self.encoder = zlib.compressobj(
|
|
wbits=-self.local_max_window_bits,
|
|
**self.compress_settings)
|
|
|
|
# To handle continuation frames properly, we must keep track of
|
|
# whether that initial frame was encoded.
|
|
self.decode_cont_data = False
|
|
# There's no need for self.encode_cont_data because we always encode
|
|
# outgoing frames, so it would always be True.
|
|
|
|
def __repr__(self):
|
|
return 'PerMessageDeflate({})'.format(', '.join([
|
|
'remote_no_context_takeover={}'.format(
|
|
self.remote_no_context_takeover),
|
|
'local_no_context_takeover={}'.format(
|
|
self.local_no_context_takeover),
|
|
'remote_max_window_bits={}'.format(
|
|
self.remote_max_window_bits),
|
|
'local_max_window_bits={}'.format(
|
|
self.local_max_window_bits),
|
|
]))
|
|
|
|
def decode(self, frame, *, max_size=None):
|
|
"""
|
|
Decode an incoming frame.
|
|
|
|
"""
|
|
# Skip control frames.
|
|
if frame.opcode in CTRL_OPCODES:
|
|
return frame
|
|
|
|
# Handle continuation data frames:
|
|
# - skip if the initial data frame wasn't encoded
|
|
# - reset "decode continuation data" flag if it's a final frame
|
|
if frame.opcode == OP_CONT:
|
|
if not self.decode_cont_data:
|
|
return frame
|
|
if frame.fin:
|
|
self.decode_cont_data = False
|
|
|
|
# Handle text and binary data frames:
|
|
# - skip if the frame isn't encoded
|
|
# - set "decode continuation data" flag if it's a non-final frame
|
|
else:
|
|
if not frame.rsv1:
|
|
return frame
|
|
if not frame.fin: # frame.rsv1 is True at this point
|
|
self.decode_cont_data = True
|
|
|
|
# Re-initialize per-message decoder.
|
|
if self.remote_no_context_takeover:
|
|
self.decoder = zlib.decompressobj(
|
|
wbits=-self.remote_max_window_bits)
|
|
|
|
# Uncompress compressed frames. Protect against zip bombs by
|
|
# preventing zlib from decompressing more than max_length bytes
|
|
# (except when the limit is disabled with max_size = None).
|
|
data = frame.data
|
|
if frame.fin:
|
|
data += _EMPTY_UNCOMPRESSED_BLOCK
|
|
max_length = 0 if max_size is None else max_size
|
|
data = self.decoder.decompress(data, max_length)
|
|
if self.decoder.unconsumed_tail:
|
|
raise PayloadTooBig(
|
|
"Uncompressed payload length exceeds size limit (? > {} bytes)"
|
|
.format(max_size))
|
|
|
|
# Allow garbage collection of the decoder if it won't be reused.
|
|
if frame.fin and self.remote_no_context_takeover:
|
|
self.decoder = None
|
|
|
|
return frame._replace(data=data, rsv1=False)
|
|
|
|
def encode(self, frame):
|
|
"""
|
|
Encode an outgoing frame.
|
|
|
|
"""
|
|
# Skip control frames.
|
|
if frame.opcode in CTRL_OPCODES:
|
|
return frame
|
|
|
|
# Since we always encode and never fragment messages, there's no logic
|
|
# similar to decode() here at this time.
|
|
|
|
if frame.opcode != OP_CONT:
|
|
# Re-initialize per-message decoder.
|
|
if self.local_no_context_takeover:
|
|
self.encoder = zlib.compressobj(
|
|
wbits=-self.local_max_window_bits,
|
|
**self.compress_settings)
|
|
|
|
# Compress data frames.
|
|
data = (
|
|
self.encoder.compress(frame.data) +
|
|
self.encoder.flush(zlib.Z_SYNC_FLUSH)
|
|
)
|
|
if frame.fin and data.endswith(_EMPTY_UNCOMPRESSED_BLOCK):
|
|
data = data[:-4]
|
|
|
|
# Allow garbage collection of the encoder if it won't be reused.
|
|
if frame.fin and self.local_no_context_takeover:
|
|
self.encoder = None
|
|
|
|
return frame._replace(data=data, rsv1=True)
|