""" The :mod:`websockets.client` module defines a simple WebSocket client API. """ import asyncio import collections.abc import sys from .exceptions import ( InvalidHandshake, InvalidMessage, InvalidStatusCode, NegotiationError ) from .extensions.permessage_deflate import ClientPerMessageDeflateFactory from .handshake import build_request, check_response from .headers import ( build_basic_auth, build_extension_list, build_subprotocol_list, parse_extension_list, parse_subprotocol_list ) from .http import USER_AGENT, Headers, read_response from .protocol import WebSocketCommonProtocol from .uri import parse_uri __all__ = ['connect', 'WebSocketClientProtocol'] class WebSocketClientProtocol(WebSocketCommonProtocol): """ Complete WebSocket client implementation as an :class:`asyncio.Protocol`. This class inherits most of its methods from :class:`~websockets.protocol.WebSocketCommonProtocol`. """ is_client = True side = 'client' def __init__(self, *, origin=None, extensions=None, subprotocols=None, extra_headers=None, **kwds): self.origin = origin self.available_extensions = extensions self.available_subprotocols = subprotocols self.extra_headers = extra_headers super().__init__(**kwds) @asyncio.coroutine def write_http_request(self, path, headers): """ Write request line and headers to the HTTP request. """ self.path = path self.request_headers = headers # Since the path and headers only contain ASCII characters, # we can keep this simple. request = 'GET {path} HTTP/1.1\r\n'.format(path=path) request += str(headers) self.writer.write(request.encode()) @asyncio.coroutine def read_http_response(self): """ Read status line and headers from the HTTP response. Raise :exc:`~websockets.exceptions.InvalidMessage` if the HTTP message is malformed or isn't an HTTP/1.1 GET request. Don't attempt to read the response body because WebSocket handshake responses don't have one. If the response contains a body, it may be read from ``self.reader`` after this coroutine returns. """ try: status_code, headers = yield from read_response(self.reader) except ValueError as exc: raise InvalidMessage("Malformed HTTP message") from exc self.response_headers = headers return status_code, self.response_headers @staticmethod def process_extensions(headers, available_extensions): """ Handle the Sec-WebSocket-Extensions HTTP response header. Check that each extension is supported, as well as its parameters. Return the list of accepted extensions. Raise :exc:`~websockets.exceptions.InvalidHandshake` to abort the connection. :rfc:`6455` leaves the rules up to the specification of each :extension. To provide this level of flexibility, for each extension accepted by the server, we check for a match with each extension available in the client configuration. If no match is found, an exception is raised. If several variants of the same extension are accepted by the server, it may be configured severel times, which won't make sense in general. Extensions must implement their own requirements. For this purpose, the list of previously accepted extensions is provided. Other requirements, for example related to mandatory extensions or the order of extensions, may be implemented by overriding this method. """ accepted_extensions = [] header_values = headers.get_all('Sec-WebSocket-Extensions') if header_values: if available_extensions is None: raise InvalidHandshake("No extensions supported") parsed_header_values = sum([ parse_extension_list(header_value) for header_value in header_values ], []) for name, response_params in parsed_header_values: for extension_factory in available_extensions: # Skip non-matching extensions based on their name. if extension_factory.name != name: continue # Skip non-matching extensions based on their params. try: extension = extension_factory.process_response_params( response_params, accepted_extensions) except NegotiationError: continue # Add matching extension to the final list. accepted_extensions.append(extension) # Break out of the loop once we have a match. break # If we didn't break from the loop, no extension in our list # matched what the server sent. Fail the connection. else: raise NegotiationError( "Unsupported extension: name = {}, params = {}".format( name, response_params)) return accepted_extensions @staticmethod def process_subprotocol(headers, available_subprotocols): """ Handle the Sec-WebSocket-Protocol HTTP response header. Check that it contains exactly one supported subprotocol. Return the selected subprotocol. """ subprotocol = None header_values = headers.get_all('Sec-WebSocket-Protocol') if header_values: if available_subprotocols is None: raise InvalidHandshake("No subprotocols supported") parsed_header_values = sum([ parse_subprotocol_list(header_value) for header_value in header_values ], []) if len(parsed_header_values) > 1: raise InvalidHandshake( "Multiple subprotocols: {}".format( ', '.join(parsed_header_values))) subprotocol = parsed_header_values[0] if subprotocol not in available_subprotocols: raise NegotiationError( "Unsupported subprotocol: {}".format(subprotocol)) return subprotocol @asyncio.coroutine def handshake(self, wsuri, origin=None, available_extensions=None, available_subprotocols=None, extra_headers=None): """ Perform the client side of the opening handshake. If provided, ``origin`` sets the Origin HTTP header. If provided, ``available_extensions`` is a list of supported extensions in the order in which they should be used. If provided, ``available_subprotocols`` is a list of supported subprotocols in order of decreasing preference. If provided, ``extra_headers`` sets additional HTTP request headers. It must be a :class:`~websockets.http.Headers` instance, a :class:`~collections.abc.Mapping`, or an iterable of ``(name, value)`` pairs. Raise :exc:`~websockets.exceptions.InvalidHandshake` if the handshake fails. """ request_headers = Headers() if wsuri.port == (443 if wsuri.secure else 80): # pragma: no cover request_headers['Host'] = wsuri.host else: request_headers['Host'] = '{}:{}'.format(wsuri.host, wsuri.port) if wsuri.user_info: request_headers['Authorization'] = build_basic_auth( *wsuri.user_info) if origin is not None: request_headers['Origin'] = origin key = build_request(request_headers) if available_extensions is not None: extensions_header = build_extension_list([ ( extension_factory.name, extension_factory.get_request_params(), ) for extension_factory in available_extensions ]) request_headers['Sec-WebSocket-Extensions'] = extensions_header if available_subprotocols is not None: protocol_header = build_subprotocol_list(available_subprotocols) request_headers['Sec-WebSocket-Protocol'] = protocol_header if extra_headers is not None: if isinstance(extra_headers, Headers): extra_headers = extra_headers.raw_items() elif isinstance(extra_headers, collections.abc.Mapping): extra_headers = extra_headers.items() for name, value in extra_headers: request_headers[name] = value request_headers.setdefault('User-Agent', USER_AGENT) yield from self.write_http_request( wsuri.resource_name, request_headers) status_code, response_headers = yield from self.read_http_response() if status_code != 101: raise InvalidStatusCode(status_code) check_response(response_headers, key) self.extensions = self.process_extensions( response_headers, available_extensions) self.subprotocol = self.process_subprotocol( response_headers, available_subprotocols) self.connection_open() class Connect: """ Connect to the WebSocket server at the given ``uri``. :func:`connect` returns an awaitable. Awaiting it yields an instance of :class:`WebSocketClientProtocol` which can then be used to send and receive messages. On Python ≥ 3.5.1, :func:`connect` can be used as a asynchronous context manager. In that case, the connection is closed when exiting the context. :func:`connect` is a wrapper around the event loop's :meth:`~asyncio.BaseEventLoop.create_connection` method. Unknown keyword arguments are passed to :meth:`~asyncio.BaseEventLoop.create_connection`. For example, you can set the ``ssl`` keyword argument to a :class:`~ssl.SSLContext` to enforce some TLS settings. When connecting to a ``wss://`` URI, if this argument isn't provided explicitly, it's set to ``True``, which means Python's default :class:`~ssl.SSLContext` is used. The behavior of the ``timeout``, ``max_size``, and ``max_queue``, ``read_limit``, and ``write_limit`` optional arguments is described in the documentation of :class:`~websockets.protocol.WebSocketCommonProtocol`. The ``create_protocol`` parameter allows customizing the asyncio protocol that manages the connection. It should be a callable or class accepting the same arguments as :class:`WebSocketClientProtocol` and returning a :class:`WebSocketClientProtocol` instance. It defaults to :class:`WebSocketClientProtocol`. :func:`connect` also accepts the following optional arguments: * ``origin`` sets the Origin HTTP header * ``extensions`` is a list of supported extensions in order of decreasing preference * ``subprotocols`` is a list of supported subprotocols in order of decreasing preference * ``extra_headers`` sets additional HTTP request headers – it can be a :class:`~websockets.http.Headers` instance, a :class:`~collections.abc.Mapping`, or an iterable of ``(name, value)`` pairs * ``compression`` is a shortcut to configure compression extensions; by default it enables the "permessage-deflate" extension; set it to ``None`` to disable compression :func:`connect` raises :exc:`~websockets.uri.InvalidURI` if ``uri`` is invalid and :exc:`~websockets.handshake.InvalidHandshake` if the opening handshake fails. """ def __init__(self, uri, *, create_protocol=None, timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, legacy_recv=False, klass=None, origin=None, extensions=None, subprotocols=None, extra_headers=None, compression='deflate', **kwds): if loop is None: loop = asyncio.get_event_loop() # Backwards-compatibility: create_protocol used to be called klass. # In the unlikely event that both are specified, klass is ignored. if create_protocol is None: create_protocol = klass if create_protocol is None: create_protocol = WebSocketClientProtocol wsuri = parse_uri(uri) if wsuri.secure: kwds.setdefault('ssl', True) elif kwds.get('ssl') is not None: raise ValueError("connect() received a SSL context for a ws:// " "URI, use a wss:// URI to enable TLS") if compression == 'deflate': if extensions is None: extensions = [] if not any( extension_factory.name == ClientPerMessageDeflateFactory.name for extension_factory in extensions ): extensions.append(ClientPerMessageDeflateFactory( client_max_window_bits=True, )) elif compression is not None: raise ValueError("Unsupported compression: {}".format(compression)) factory = lambda: create_protocol( host=wsuri.host, port=wsuri.port, secure=wsuri.secure, timeout=timeout, max_size=max_size, max_queue=max_queue, read_limit=read_limit, write_limit=write_limit, loop=loop, legacy_recv=legacy_recv, origin=origin, extensions=extensions, subprotocols=subprotocols, extra_headers=extra_headers, ) if kwds.get('sock') is None: host, port = wsuri.host, wsuri.port else: # If sock is given, host and port mustn't be specified. host, port = None, None self._wsuri = wsuri self._origin = origin # This is a coroutine object. self._creating_connection = loop.create_connection( factory, host, port, **kwds) @asyncio.coroutine def __iter__(self): # pragma: no cover transport, protocol = yield from self._creating_connection try: yield from protocol.handshake( self._wsuri, origin=self._origin, available_extensions=protocol.available_extensions, available_subprotocols=protocol.available_subprotocols, extra_headers=protocol.extra_headers, ) except Exception: yield from protocol.fail_connection() raise self.ws_client = protocol return protocol # We can't define __await__ on Python < 3.5.1 because asyncio.ensure_future # didn't accept arbitrary awaitables until Python 3.5.1. We don't define # __aenter__ and __aexit__ either on Python < 3.5.1 to keep things simple. if sys.version_info[:3] <= (3, 5, 0): # pragma: no cover @asyncio.coroutine def connect(*args, **kwds): return Connect(*args, **kwds).__iter__() connect.__doc__ = Connect.__doc__ else: from .py35.client import __aenter__, __aexit__, __await__ Connect.__aenter__ = __aenter__ Connect.__aexit__ = __aexit__ Connect.__await__ = __await__ connect = Connect