"""
The :mod:`websockets.extensions.permessage_deflate` module implements the
Compression Extensions for WebSocket as specified in :rfc:`7692`.

"""

import zlib

from ..exceptions import (
    DuplicateParameter, InvalidParameterName, InvalidParameterValue,
    NegotiationError, PayloadTooBig
)
from ..framing import CTRL_OPCODES, OP_CONT


__all__ = [
    'ClientPerMessageDeflateFactory',
    'ServerPerMessageDeflateFactory',
    'PerMessageDeflate',
]

_EMPTY_UNCOMPRESSED_BLOCK = b'\x00\x00\xff\xff'

_MAX_WINDOW_BITS_VALUES = [str(bits) for bits in range(8, 16)]


def _build_parameters(
    server_no_context_takeover,
    client_no_context_takeover,
    server_max_window_bits,
    client_max_window_bits,
):
    """
    Build a list of ``(name, value)`` pairs for some compression parameters.

    """
    params = []
    if server_no_context_takeover:
        params.append(('server_no_context_takeover', None))
    if client_no_context_takeover:
        params.append(('client_no_context_takeover', None))
    if server_max_window_bits:
        params.append(('server_max_window_bits', str(server_max_window_bits)))
    if client_max_window_bits is True:          # only in handshake requests
        params.append(('client_max_window_bits', None))
    elif client_max_window_bits:
        params.append(('client_max_window_bits', str(client_max_window_bits)))
    return params


def _extract_parameters(params, *, is_server):
    """
    Extract compression parameters from a list of ``(name, value)`` pairs.

    If ``is_server`` is ``True``, ``client_max_window_bits`` may be provided
    without a value. This is only allow in handshake requests.

    """
    server_no_context_takeover = False
    client_no_context_takeover = False
    server_max_window_bits = None
    client_max_window_bits = None

    for name, value in params:

        if name == 'server_no_context_takeover':
            if server_no_context_takeover:
                raise DuplicateParameter(name)
            if value is None:
                server_no_context_takeover = True
            else:
                raise InvalidParameterValue(name, value)

        elif name == 'client_no_context_takeover':
            if client_no_context_takeover:
                raise DuplicateParameter(name)
            if value is None:
                client_no_context_takeover = True
            else:
                raise InvalidParameterValue(name, value)

        elif name == 'server_max_window_bits':
            if server_max_window_bits is not None:
                raise DuplicateParameter(name)
            if value in _MAX_WINDOW_BITS_VALUES:
                server_max_window_bits = int(value)
            else:
                raise InvalidParameterValue(name, value)

        elif name == 'client_max_window_bits':
            if client_max_window_bits is not None:
                raise DuplicateParameter(name)
            if is_server and value is None:     # only in handshake requests
                client_max_window_bits = True
            elif value in _MAX_WINDOW_BITS_VALUES:
                client_max_window_bits = int(value)
            else:
                raise InvalidParameterValue(name, value)

        else:
            raise InvalidParameterName(name)

    return (
        server_no_context_takeover,
        client_no_context_takeover,
        server_max_window_bits,
        client_max_window_bits,
    )


class ClientPerMessageDeflateFactory:
    """
    Client-side extension factory for permessage-deflate extension.

    """
    name = 'permessage-deflate'

    def __init__(
        self,
        server_no_context_takeover=False,
        client_no_context_takeover=False,
        server_max_window_bits=None,
        client_max_window_bits=None,
        compress_settings=None,
    ):
        """
        Configure permessage-deflate extension factory.

        See https://tools.ietf.org/html/rfc7692#section-7.1.

        """
        if not (server_max_window_bits is None or
                8 <= server_max_window_bits <= 15):
            raise ValueError("server_max_window_bits must be between 8 and 15")
        if not (client_max_window_bits is None or
                client_max_window_bits is True or
                8 <= client_max_window_bits <= 15):
            raise ValueError("client_max_window_bits must be between 8 and 15")
        if compress_settings is not None and 'wbits' in compress_settings:
            raise ValueError("compress_settings must not include wbits, "
                             "set client_max_window_bits instead")

        self.server_no_context_takeover = server_no_context_takeover
        self.client_no_context_takeover = client_no_context_takeover
        self.server_max_window_bits = server_max_window_bits
        self.client_max_window_bits = client_max_window_bits
        self.compress_settings = compress_settings

    def get_request_params(self):
        """
        Build request parameters.

        """
        return _build_parameters(
            self.server_no_context_takeover, self.client_no_context_takeover,
            self.server_max_window_bits, self.client_max_window_bits,
        )

    def process_response_params(self, params, accepted_extensions):
        """"
        Process response parameters.

        Return an extension instance.

        """
        if any(other.name == self.name for other in accepted_extensions):
            raise NegotiationError("Received duplicate {}".format(self.name))

        # Request parameters are available in instance variables.

        # Load response parameters in local variables.
        (
            server_no_context_takeover,
            client_no_context_takeover,
            server_max_window_bits,
            client_max_window_bits,
        ) = _extract_parameters(params, is_server=False)

        # After comparing the request and the response, the final
        # configuration must be available in the local variables.

        # server_no_context_takeover
        #
        #   Req.    Resp.   Result
        #   ------  ------  --------------------------------------------------
        #   False   False   False
        #   False   True    True
        #   True    False   Error!
        #   True    True    True

        if self.server_no_context_takeover:
            if not server_no_context_takeover:
                raise NegotiationError("Expected server_no_context_takeover")

        # client_no_context_takeover
        #
        #   Req.    Resp.   Result
        #   ------  ------  --------------------------------------------------
        #   False   False   False
        #   False   True    True
        #   True    False   True - must change value
        #   True    True    True

        if self.client_no_context_takeover:
            if not client_no_context_takeover:
                client_no_context_takeover = True

        # server_max_window_bits

        #   Req.    Resp.   Result
        #   ------  ------  --------------------------------------------------
        #   None    None    None
        #   None    8≤M≤15  M
        #   8≤N≤15  None    Error!
        #   8≤N≤15  8≤M≤N   M
        #   8≤N≤15  N<M≤15  Error!

        if self.server_max_window_bits is None:
            pass

        else:
            if server_max_window_bits is None:
                raise NegotiationError("Expected server_max_window_bits")
            elif server_max_window_bits > self.server_max_window_bits:
                raise NegotiationError("Unsupported server_max_window_bits")

        # client_max_window_bits

        #   Req.    Resp.   Result
        #   ------  ------  --------------------------------------------------
        #   None    None    None
        #   None    8≤M≤15  Error!
        #   True    None    None
        #   True    8≤M≤15  M
        #   8≤N≤15  None    N - must change value
        #   8≤N≤15  8≤M≤N   M
        #   8≤N≤15  N<M≤15  Error!

        if self.client_max_window_bits is None:
            if client_max_window_bits is not None:
                raise NegotiationError("Unexpected client_max_window_bits")

        elif self.client_max_window_bits is True:
            pass

        else:
            if client_max_window_bits is None:
                client_max_window_bits = self.client_max_window_bits
            elif client_max_window_bits > self.client_max_window_bits:
                raise NegotiationError("Unsupported client_max_window_bits")

        return PerMessageDeflate(
            server_no_context_takeover,     # remote_no_context_takeover
            client_no_context_takeover,     # local_no_context_takeover
            server_max_window_bits or 15,   # remote_max_window_bits
            client_max_window_bits or 15,   # local_max_window_bits
            self.compress_settings,
        )


class ServerPerMessageDeflateFactory:
    """
    Server-side extension factory for permessage-deflate extension.

    """
    name = 'permessage-deflate'

    def __init__(
        self,
        server_no_context_takeover=False,
        client_no_context_takeover=False,
        server_max_window_bits=None,
        client_max_window_bits=None,
        compress_settings=None,
    ):
        """
        Configure permessage-deflate extension factory.

        See https://tools.ietf.org/html/rfc7692#section-7.1.

        """
        if not (server_max_window_bits is None or
                8 <= server_max_window_bits <= 15):
            raise ValueError("server_max_window_bits must be between 8 and 15")
        if not (client_max_window_bits is None or
                8 <= client_max_window_bits <= 15):
            raise ValueError("client_max_window_bits must be between 8 and 15")
        if compress_settings is not None and 'wbits' in compress_settings:
            raise ValueError("compress_settings must not include wbits, "
                             "set server_max_window_bits instead")

        self.server_no_context_takeover = server_no_context_takeover
        self.client_no_context_takeover = client_no_context_takeover
        self.server_max_window_bits = server_max_window_bits
        self.client_max_window_bits = client_max_window_bits
        self.compress_settings = compress_settings

    def process_request_params(self, params, accepted_extensions):
        """"
        Process request parameters.

        Return response params and an extension instance.

        """
        if any(other.name == self.name for other in accepted_extensions):
            raise NegotiationError("Skipped duplicate {}".format(self.name))

        # Load request parameters in local variables.
        (
            server_no_context_takeover,
            client_no_context_takeover,
            server_max_window_bits,
            client_max_window_bits,
        ) = _extract_parameters(params, is_server=True)

        # Configuration parameters are available in instance variables.

        # After comparing the request and the configuration, the response must
        # be available in the local variables.

        # server_no_context_takeover
        #
        #   Config  Req.    Resp.
        #   ------  ------  --------------------------------------------------
        #   False   False   False
        #   False   True    True
        #   True    False   True - must change value to True
        #   True    True    True

        if self.server_no_context_takeover:
            if not server_no_context_takeover:
                server_no_context_takeover = True

        # client_no_context_takeover
        #
        #   Config  Req.    Resp.
        #   ------  ------  --------------------------------------------------
        #   False   False   False
        #   False   True    True (or False)
        #   True    False   True - must change value to True
        #   True    True    True (or False)

        if self.client_no_context_takeover:
            if not client_no_context_takeover:
                client_no_context_takeover = True

        # server_max_window_bits

        #   Config  Req.    Resp.
        #   ------  ------  --------------------------------------------------
        #   None    None    None
        #   None    8≤M≤15  M
        #   8≤N≤15  None    N - must change value
        #   8≤N≤15  8≤M≤N   M
        #   8≤N≤15  N<M≤15  N - must change value

        if self.server_max_window_bits is None:
            pass

        else:
            if server_max_window_bits is None:
                server_max_window_bits = self.server_max_window_bits
            elif server_max_window_bits > self.server_max_window_bits:
                server_max_window_bits = self.server_max_window_bits

        # client_max_window_bits

        #   Config  Req.    Resp.
        #   ------  ------  --------------------------------------------------
        #   None    None    None
        #   None    True    None - must change value
        #   None    8≤M≤15  M (or None)
        #   8≤N≤15  None    Error!
        #   8≤N≤15  True    N - must change value
        #   8≤N≤15  8≤M≤N   M (or None)
        #   8≤N≤15  N<M≤15  N

        if self.client_max_window_bits is None:
            if client_max_window_bits is True:
                client_max_window_bits = self.client_max_window_bits

        else:
            if client_max_window_bits is None:
                raise NegotiationError("Required client_max_window_bits")
            elif client_max_window_bits is True:
                client_max_window_bits = self.client_max_window_bits
            elif self.client_max_window_bits < client_max_window_bits:
                client_max_window_bits = self.client_max_window_bits

        return (
            _build_parameters(
                server_no_context_takeover, client_no_context_takeover,
                server_max_window_bits, client_max_window_bits,
            ),
            PerMessageDeflate(
                client_no_context_takeover,     # remote_no_context_takeover
                server_no_context_takeover,     # local_no_context_takeover
                client_max_window_bits or 15,   # remote_max_window_bits
                server_max_window_bits or 15,   # local_max_window_bits
                self.compress_settings,
            )
        )


class PerMessageDeflate:
    """
    permessage-deflate extension.

    """
    name = 'permessage-deflate'

    def __init__(
        self,
        remote_no_context_takeover,
        local_no_context_takeover,
        remote_max_window_bits,
        local_max_window_bits,
        compress_settings=None,
    ):
        """
        Configure permessage-deflate extension.

        """
        if compress_settings is None:
            compress_settings = {}

        assert remote_no_context_takeover in [False, True]
        assert local_no_context_takeover in [False, True]
        assert 8 <= remote_max_window_bits <= 15
        assert 8 <= local_max_window_bits <= 15
        assert 'wbits' not in compress_settings

        self.remote_no_context_takeover = remote_no_context_takeover
        self.local_no_context_takeover = local_no_context_takeover
        self.remote_max_window_bits = remote_max_window_bits
        self.local_max_window_bits = local_max_window_bits
        self.compress_settings = compress_settings

        if not self.remote_no_context_takeover:
            self.decoder = zlib.decompressobj(
                wbits=-self.remote_max_window_bits)

        if not self.local_no_context_takeover:
            self.encoder = zlib.compressobj(
                wbits=-self.local_max_window_bits,
                **self.compress_settings)

        # To handle continuation frames properly, we must keep track of
        # whether that initial frame was encoded.
        self.decode_cont_data = False
        # There's no need for self.encode_cont_data because we always encode
        # outgoing frames, so it would always be True.

    def __repr__(self):
        return 'PerMessageDeflate({})'.format(', '.join([
            'remote_no_context_takeover={}'.format(
                self.remote_no_context_takeover),
            'local_no_context_takeover={}'.format(
                self.local_no_context_takeover),
            'remote_max_window_bits={}'.format(
                self.remote_max_window_bits),
            'local_max_window_bits={}'.format(
                self.local_max_window_bits),
        ]))

    def decode(self, frame, *, max_size=None):
        """
        Decode an incoming frame.

        """
        # Skip control frames.
        if frame.opcode in CTRL_OPCODES:
            return frame

        # Handle continuation data frames:
        # - skip if the initial data frame wasn't encoded
        # - reset "decode continuation data" flag if it's a final frame
        if frame.opcode == OP_CONT:
            if not self.decode_cont_data:
                return frame
            if frame.fin:
                self.decode_cont_data = False

        # Handle text and binary data frames:
        # - skip if the frame isn't encoded
        # - set "decode continuation data" flag if it's a non-final frame
        else:
            if not frame.rsv1:
                return frame
            if not frame.fin:  # frame.rsv1 is True at this point
                self.decode_cont_data = True

            # Re-initialize per-message decoder.
            if self.remote_no_context_takeover:
                self.decoder = zlib.decompressobj(
                    wbits=-self.remote_max_window_bits)

        # Uncompress compressed frames. Protect against zip bombs by
        # preventing zlib from decompressing more than max_length bytes
        # (except when the limit is disabled with max_size = None).
        data = frame.data
        if frame.fin:
            data += _EMPTY_UNCOMPRESSED_BLOCK
        max_length = 0 if max_size is None else max_size
        data = self.decoder.decompress(data, max_length)
        if self.decoder.unconsumed_tail:
            raise PayloadTooBig(
                "Uncompressed payload length exceeds size limit (? > {} bytes)"
                .format(max_size))

        # Allow garbage collection of the decoder if it won't be reused.
        if frame.fin and self.remote_no_context_takeover:
            self.decoder = None

        return frame._replace(data=data, rsv1=False)

    def encode(self, frame):
        """
        Encode an outgoing frame.

        """
        # Skip control frames.
        if frame.opcode in CTRL_OPCODES:
            return frame

        # Since we always encode and never fragment messages, there's no logic
        # similar to decode() here at this time.

        if frame.opcode != OP_CONT:
            # Re-initialize per-message decoder.
            if self.local_no_context_takeover:
                self.encoder = zlib.compressobj(
                    wbits=-self.local_max_window_bits,
                    **self.compress_settings)

        # Compress data frames.
        data = (
            self.encoder.compress(frame.data) +
            self.encoder.flush(zlib.Z_SYNC_FLUSH)
        )
        if frame.fin and data.endswith(_EMPTY_UNCOMPRESSED_BLOCK):
            data = data[:-4]

        # Allow garbage collection of the encoder if it won't be reused.
        if frame.fin and self.local_no_context_takeover:
            self.encoder = None

        return frame._replace(data=data, rsv1=True)