420 lines
15 KiB
Python
420 lines
15 KiB
Python
"""
|
||
The :mod:`websockets.client` module defines a simple WebSocket client API.
|
||
|
||
"""
|
||
|
||
import asyncio
|
||
import collections.abc
|
||
import sys
|
||
|
||
from .exceptions import (
|
||
InvalidHandshake, InvalidMessage, InvalidStatusCode, NegotiationError
|
||
)
|
||
from .extensions.permessage_deflate import ClientPerMessageDeflateFactory
|
||
from .handshake import build_request, check_response
|
||
from .headers import (
|
||
build_basic_auth, build_extension_list, build_subprotocol_list,
|
||
parse_extension_list, parse_subprotocol_list
|
||
)
|
||
from .http import USER_AGENT, Headers, read_response
|
||
from .protocol import WebSocketCommonProtocol
|
||
from .uri import parse_uri
|
||
|
||
|
||
__all__ = ['connect', 'WebSocketClientProtocol']
|
||
|
||
|
||
class WebSocketClientProtocol(WebSocketCommonProtocol):
|
||
"""
|
||
Complete WebSocket client implementation as an :class:`asyncio.Protocol`.
|
||
|
||
This class inherits most of its methods from
|
||
:class:`~websockets.protocol.WebSocketCommonProtocol`.
|
||
|
||
"""
|
||
is_client = True
|
||
side = 'client'
|
||
|
||
def __init__(self, *,
|
||
origin=None, extensions=None, subprotocols=None,
|
||
extra_headers=None, **kwds):
|
||
self.origin = origin
|
||
self.available_extensions = extensions
|
||
self.available_subprotocols = subprotocols
|
||
self.extra_headers = extra_headers
|
||
super().__init__(**kwds)
|
||
|
||
@asyncio.coroutine
|
||
def write_http_request(self, path, headers):
|
||
"""
|
||
Write request line and headers to the HTTP request.
|
||
|
||
"""
|
||
self.path = path
|
||
self.request_headers = headers
|
||
|
||
# Since the path and headers only contain ASCII characters,
|
||
# we can keep this simple.
|
||
request = 'GET {path} HTTP/1.1\r\n'.format(path=path)
|
||
request += str(headers)
|
||
|
||
self.writer.write(request.encode())
|
||
|
||
@asyncio.coroutine
|
||
def read_http_response(self):
|
||
"""
|
||
Read status line and headers from the HTTP response.
|
||
|
||
Raise :exc:`~websockets.exceptions.InvalidMessage` if the HTTP message
|
||
is malformed or isn't an HTTP/1.1 GET request.
|
||
|
||
Don't attempt to read the response body because WebSocket handshake
|
||
responses don't have one. If the response contains a body, it may be
|
||
read from ``self.reader`` after this coroutine returns.
|
||
|
||
"""
|
||
try:
|
||
status_code, headers = yield from read_response(self.reader)
|
||
except ValueError as exc:
|
||
raise InvalidMessage("Malformed HTTP message") from exc
|
||
|
||
self.response_headers = headers
|
||
|
||
return status_code, self.response_headers
|
||
|
||
@staticmethod
|
||
def process_extensions(headers, available_extensions):
|
||
"""
|
||
Handle the Sec-WebSocket-Extensions HTTP response header.
|
||
|
||
Check that each extension is supported, as well as its parameters.
|
||
|
||
Return the list of accepted extensions.
|
||
|
||
Raise :exc:`~websockets.exceptions.InvalidHandshake` to abort the
|
||
connection.
|
||
|
||
:rfc:`6455` leaves the rules up to the specification of each
|
||
:extension.
|
||
|
||
To provide this level of flexibility, for each extension accepted by
|
||
the server, we check for a match with each extension available in the
|
||
client configuration. If no match is found, an exception is raised.
|
||
|
||
If several variants of the same extension are accepted by the server,
|
||
it may be configured severel times, which won't make sense in general.
|
||
Extensions must implement their own requirements. For this purpose,
|
||
the list of previously accepted extensions is provided.
|
||
|
||
Other requirements, for example related to mandatory extensions or the
|
||
order of extensions, may be implemented by overriding this method.
|
||
|
||
"""
|
||
accepted_extensions = []
|
||
|
||
header_values = headers.get_all('Sec-WebSocket-Extensions')
|
||
|
||
if header_values:
|
||
|
||
if available_extensions is None:
|
||
raise InvalidHandshake("No extensions supported")
|
||
|
||
parsed_header_values = sum([
|
||
parse_extension_list(header_value)
|
||
for header_value in header_values
|
||
], [])
|
||
|
||
for name, response_params in parsed_header_values:
|
||
|
||
for extension_factory in available_extensions:
|
||
|
||
# Skip non-matching extensions based on their name.
|
||
if extension_factory.name != name:
|
||
continue
|
||
|
||
# Skip non-matching extensions based on their params.
|
||
try:
|
||
extension = extension_factory.process_response_params(
|
||
response_params, accepted_extensions)
|
||
except NegotiationError:
|
||
continue
|
||
|
||
# Add matching extension to the final list.
|
||
accepted_extensions.append(extension)
|
||
|
||
# Break out of the loop once we have a match.
|
||
break
|
||
|
||
# If we didn't break from the loop, no extension in our list
|
||
# matched what the server sent. Fail the connection.
|
||
else:
|
||
raise NegotiationError(
|
||
"Unsupported extension: name = {}, params = {}".format(
|
||
name, response_params))
|
||
|
||
return accepted_extensions
|
||
|
||
@staticmethod
|
||
def process_subprotocol(headers, available_subprotocols):
|
||
"""
|
||
Handle the Sec-WebSocket-Protocol HTTP response header.
|
||
|
||
Check that it contains exactly one supported subprotocol.
|
||
|
||
Return the selected subprotocol.
|
||
|
||
"""
|
||
subprotocol = None
|
||
|
||
header_values = headers.get_all('Sec-WebSocket-Protocol')
|
||
|
||
if header_values:
|
||
|
||
if available_subprotocols is None:
|
||
raise InvalidHandshake("No subprotocols supported")
|
||
|
||
parsed_header_values = sum([
|
||
parse_subprotocol_list(header_value)
|
||
for header_value in header_values
|
||
], [])
|
||
|
||
if len(parsed_header_values) > 1:
|
||
raise InvalidHandshake(
|
||
"Multiple subprotocols: {}".format(
|
||
', '.join(parsed_header_values)))
|
||
|
||
subprotocol = parsed_header_values[0]
|
||
|
||
if subprotocol not in available_subprotocols:
|
||
raise NegotiationError(
|
||
"Unsupported subprotocol: {}".format(subprotocol))
|
||
|
||
return subprotocol
|
||
|
||
@asyncio.coroutine
|
||
def handshake(self, wsuri, origin=None, available_extensions=None,
|
||
available_subprotocols=None, extra_headers=None):
|
||
"""
|
||
Perform the client side of the opening handshake.
|
||
|
||
If provided, ``origin`` sets the Origin HTTP header.
|
||
|
||
If provided, ``available_extensions`` is a list of supported
|
||
extensions in the order in which they should be used.
|
||
|
||
If provided, ``available_subprotocols`` is a list of supported
|
||
subprotocols in order of decreasing preference.
|
||
|
||
If provided, ``extra_headers`` sets additional HTTP request headers.
|
||
It must be a :class:`~websockets.http.Headers` instance, a
|
||
:class:`~collections.abc.Mapping`, or an iterable of ``(name, value)``
|
||
pairs.
|
||
|
||
Raise :exc:`~websockets.exceptions.InvalidHandshake` if the handshake
|
||
fails.
|
||
|
||
"""
|
||
request_headers = Headers()
|
||
|
||
if wsuri.port == (443 if wsuri.secure else 80): # pragma: no cover
|
||
request_headers['Host'] = wsuri.host
|
||
else:
|
||
request_headers['Host'] = '{}:{}'.format(wsuri.host, wsuri.port)
|
||
|
||
if wsuri.user_info:
|
||
request_headers['Authorization'] = build_basic_auth(
|
||
*wsuri.user_info)
|
||
|
||
if origin is not None:
|
||
request_headers['Origin'] = origin
|
||
|
||
key = build_request(request_headers)
|
||
|
||
if available_extensions is not None:
|
||
extensions_header = build_extension_list([
|
||
(
|
||
extension_factory.name,
|
||
extension_factory.get_request_params(),
|
||
)
|
||
for extension_factory in available_extensions
|
||
])
|
||
request_headers['Sec-WebSocket-Extensions'] = extensions_header
|
||
|
||
if available_subprotocols is not None:
|
||
protocol_header = build_subprotocol_list(available_subprotocols)
|
||
request_headers['Sec-WebSocket-Protocol'] = protocol_header
|
||
|
||
if extra_headers is not None:
|
||
if isinstance(extra_headers, Headers):
|
||
extra_headers = extra_headers.raw_items()
|
||
elif isinstance(extra_headers, collections.abc.Mapping):
|
||
extra_headers = extra_headers.items()
|
||
for name, value in extra_headers:
|
||
request_headers[name] = value
|
||
|
||
request_headers.setdefault('User-Agent', USER_AGENT)
|
||
|
||
yield from self.write_http_request(
|
||
wsuri.resource_name, request_headers)
|
||
|
||
status_code, response_headers = yield from self.read_http_response()
|
||
|
||
if status_code != 101:
|
||
raise InvalidStatusCode(status_code)
|
||
|
||
check_response(response_headers, key)
|
||
|
||
self.extensions = self.process_extensions(
|
||
response_headers, available_extensions)
|
||
|
||
self.subprotocol = self.process_subprotocol(
|
||
response_headers, available_subprotocols)
|
||
|
||
self.connection_open()
|
||
|
||
|
||
class Connect:
|
||
"""
|
||
Connect to the WebSocket server at the given ``uri``.
|
||
|
||
:func:`connect` returns an awaitable. Awaiting it yields an instance of
|
||
:class:`WebSocketClientProtocol` which can then be used to send and
|
||
receive messages.
|
||
|
||
On Python ≥ 3.5.1, :func:`connect` can be used as a asynchronous context
|
||
manager. In that case, the connection is closed when exiting the context.
|
||
|
||
:func:`connect` is a wrapper around the event loop's
|
||
:meth:`~asyncio.BaseEventLoop.create_connection` method. Unknown keyword
|
||
arguments are passed to :meth:`~asyncio.BaseEventLoop.create_connection`.
|
||
|
||
For example, you can set the ``ssl`` keyword argument to a
|
||
:class:`~ssl.SSLContext` to enforce some TLS settings. When connecting to
|
||
a ``wss://`` URI, if this argument isn't provided explicitly, it's set to
|
||
``True``, which means Python's default :class:`~ssl.SSLContext` is used.
|
||
|
||
The behavior of the ``timeout``, ``max_size``, and ``max_queue``,
|
||
``read_limit``, and ``write_limit`` optional arguments is described in the
|
||
documentation of :class:`~websockets.protocol.WebSocketCommonProtocol`.
|
||
|
||
The ``create_protocol`` parameter allows customizing the asyncio protocol
|
||
that manages the connection. It should be a callable or class accepting
|
||
the same arguments as :class:`WebSocketClientProtocol` and returning a
|
||
:class:`WebSocketClientProtocol` instance. It defaults to
|
||
:class:`WebSocketClientProtocol`.
|
||
|
||
:func:`connect` also accepts the following optional arguments:
|
||
|
||
* ``origin`` sets the Origin HTTP header
|
||
* ``extensions`` is a list of supported extensions in order of
|
||
decreasing preference
|
||
* ``subprotocols`` is a list of supported subprotocols in order of
|
||
decreasing preference
|
||
* ``extra_headers`` sets additional HTTP request headers – it can be a
|
||
:class:`~websockets.http.Headers` instance, a
|
||
:class:`~collections.abc.Mapping`, or an iterable of ``(name, value)``
|
||
pairs
|
||
* ``compression`` is a shortcut to configure compression extensions;
|
||
by default it enables the "permessage-deflate" extension; set it to
|
||
``None`` to disable compression
|
||
|
||
:func:`connect` raises :exc:`~websockets.uri.InvalidURI` if ``uri`` is
|
||
invalid and :exc:`~websockets.handshake.InvalidHandshake` if the opening
|
||
handshake fails.
|
||
|
||
"""
|
||
|
||
def __init__(self, uri, *,
|
||
create_protocol=None,
|
||
timeout=10, max_size=2 ** 20, max_queue=2 ** 5,
|
||
read_limit=2 ** 16, write_limit=2 ** 16,
|
||
loop=None, legacy_recv=False, klass=None,
|
||
origin=None, extensions=None, subprotocols=None,
|
||
extra_headers=None, compression='deflate', **kwds):
|
||
if loop is None:
|
||
loop = asyncio.get_event_loop()
|
||
|
||
# Backwards-compatibility: create_protocol used to be called klass.
|
||
# In the unlikely event that both are specified, klass is ignored.
|
||
if create_protocol is None:
|
||
create_protocol = klass
|
||
|
||
if create_protocol is None:
|
||
create_protocol = WebSocketClientProtocol
|
||
|
||
wsuri = parse_uri(uri)
|
||
if wsuri.secure:
|
||
kwds.setdefault('ssl', True)
|
||
elif kwds.get('ssl') is not None:
|
||
raise ValueError("connect() received a SSL context for a ws:// "
|
||
"URI, use a wss:// URI to enable TLS")
|
||
|
||
if compression == 'deflate':
|
||
if extensions is None:
|
||
extensions = []
|
||
if not any(
|
||
extension_factory.name == ClientPerMessageDeflateFactory.name
|
||
for extension_factory in extensions
|
||
):
|
||
extensions.append(ClientPerMessageDeflateFactory(
|
||
client_max_window_bits=True,
|
||
))
|
||
elif compression is not None:
|
||
raise ValueError("Unsupported compression: {}".format(compression))
|
||
|
||
factory = lambda: create_protocol(
|
||
host=wsuri.host, port=wsuri.port, secure=wsuri.secure,
|
||
timeout=timeout, max_size=max_size, max_queue=max_queue,
|
||
read_limit=read_limit, write_limit=write_limit,
|
||
loop=loop, legacy_recv=legacy_recv,
|
||
origin=origin, extensions=extensions, subprotocols=subprotocols,
|
||
extra_headers=extra_headers,
|
||
)
|
||
|
||
if kwds.get('sock') is None:
|
||
host, port = wsuri.host, wsuri.port
|
||
else:
|
||
# If sock is given, host and port mustn't be specified.
|
||
host, port = None, None
|
||
|
||
self._wsuri = wsuri
|
||
self._origin = origin
|
||
|
||
# This is a coroutine object.
|
||
self._creating_connection = loop.create_connection(
|
||
factory, host, port, **kwds)
|
||
|
||
@asyncio.coroutine
|
||
def __iter__(self): # pragma: no cover
|
||
transport, protocol = yield from self._creating_connection
|
||
|
||
try:
|
||
yield from protocol.handshake(
|
||
self._wsuri, origin=self._origin,
|
||
available_extensions=protocol.available_extensions,
|
||
available_subprotocols=protocol.available_subprotocols,
|
||
extra_headers=protocol.extra_headers,
|
||
)
|
||
except Exception:
|
||
yield from protocol.fail_connection()
|
||
raise
|
||
|
||
self.ws_client = protocol
|
||
return protocol
|
||
|
||
|
||
# We can't define __await__ on Python < 3.5.1 because asyncio.ensure_future
|
||
# didn't accept arbitrary awaitables until Python 3.5.1. We don't define
|
||
# __aenter__ and __aexit__ either on Python < 3.5.1 to keep things simple.
|
||
if sys.version_info[:3] <= (3, 5, 0): # pragma: no cover
|
||
@asyncio.coroutine
|
||
def connect(*args, **kwds):
|
||
return Connect(*args, **kwds).__iter__()
|
||
connect.__doc__ = Connect.__doc__
|
||
|
||
else:
|
||
from .py35.client import __aenter__, __aexit__, __await__
|
||
Connect.__aenter__ = __aenter__
|
||
Connect.__aexit__ = __aexit__
|
||
Connect.__await__ = __await__
|
||
connect = Connect
|