import asyncio import contextlib import functools import logging import pathlib import random import socket import ssl import sys import tempfile import unittest import unittest.mock import urllib.error import urllib.request from .client import * from .compatibility import FORBIDDEN, OK, UNAUTHORIZED from .exceptions import ( ConnectionClosed, InvalidHandshake, InvalidStatusCode, NegotiationError ) from .extensions.permessage_deflate import ( ClientPerMessageDeflateFactory, PerMessageDeflate, ServerPerMessageDeflateFactory ) from .handshake import build_response from .http import USER_AGENT, Headers, read_response from .protocol import State from .server import * from .test_protocol import MS # Avoid displaying stack traces at the ERROR logging level. logging.basicConfig(level=logging.CRITICAL) # Generate TLS certificate with: # $ openssl req -x509 -config test_localhost.cnf -days 15340 -newkey rsa:2048 \ # -out test_localhost.crt -keyout test_localhost.key # $ cat test_localhost.key test_localhost.crt > test_localhost.pem # $ rm test_localhost.key test_localhost.crt testcert = bytes(pathlib.Path(__file__).with_name('test_localhost.pem')) @asyncio.coroutine def handler(ws, path): if path == '/attributes': yield from ws.send(repr((ws.host, ws.port, ws.secure))) elif path == '/path': yield from ws.send(str(ws.path)) elif path == '/headers': yield from ws.send(repr(ws.request_headers)) yield from ws.send(repr(ws.response_headers)) elif path == '/extensions': yield from ws.send(repr(ws.extensions)) elif path == '/subprotocol': yield from ws.send(repr(ws.subprotocol)) elif path == '/slow_stop': try: yield from asyncio.sleep(1000 * MS) except asyncio.CancelledError: yield from asyncio.sleep(MS) raise else: yield from ws.send((yield from ws.recv())) @contextlib.contextmanager def temp_test_server(test, **kwds): test.start_server(**kwds) try: yield finally: test.stop_server() @contextlib.contextmanager def temp_test_client(test, *args, **kwds): test.start_client(*args, **kwds) try: yield finally: test.stop_client() def with_manager(manager, *args, **kwds): """ Return a decorator that wraps a function with a context manager. """ def decorate(func): @functools.wraps(func) def _decorate(self, *_args, **_kwds): with manager(self, *args, **kwds): return func(self, *_args, **_kwds) return _decorate return decorate def with_server(**kwds): """ Return a decorator for TestCase methods that starts and stops a server. """ return with_manager(temp_test_server, **kwds) def with_client(*args, **kwds): """ Return a decorator for TestCase methods that starts and stops a client. """ return with_manager(temp_test_client, *args, **kwds) def get_server_uri(server, secure=False, resource_name='/', user_info=None): """ Return a WebSocket URI for connecting to the given server. """ proto = 'wss' if secure else 'ws' user_info = ':'.join(user_info) + '@' if user_info else '' # Pick a random socket in order to test both IPv4 and IPv6 on systems # where both are available. Randomizing tests is usually a bad idea. If # needed, either use the first socket, or test separately IPv4 and IPv6. server_socket = random.choice(server.sockets) if server_socket.family == socket.AF_INET6: # pragma: no cover host, port = server_socket.getsockname()[:2] # (no IPv6 on CI) host = '[{}]'.format(host) elif server_socket.family == socket.AF_INET: host, port = server_socket.getsockname() elif server_socket.family == socket.AF_UNIX: # The host and port are ignored when connecting to a Unix socket. host, port = 'localhost', 0 else: # pragma: no cover raise ValueError("Expected an IPv6, IPv4, or Unix socket") return '{}://{}{}:{}{}'.format(proto, user_info, host, port, resource_name) class UnauthorizedServerProtocol(WebSocketServerProtocol): @asyncio.coroutine def process_request(self, path, request_headers): # Test returning headers as a Headers instance (1/3) return UNAUTHORIZED, Headers([('X-Access', 'denied')]), b'' class ForbiddenServerProtocol(WebSocketServerProtocol): @asyncio.coroutine def process_request(self, path, request_headers): # Test returning headers as a dict (2/3) return FORBIDDEN, {'X-Access': 'denied'}, b'' class HealthCheckServerProtocol(WebSocketServerProtocol): @asyncio.coroutine def process_request(self, path, request_headers): # Test returning headers as a list of pairs (3/3) if path == '/__health__/': return OK, [('X-Access', 'OK')], b'status = green\n' class FooClientProtocol(WebSocketClientProtocol): pass class BarClientProtocol(WebSocketClientProtocol): pass class ClientNoOpExtensionFactory: name = 'x-no-op' def get_request_params(self): return [] def process_response_params(self, params, accepted_extensions): if params: raise NegotiationError() return NoOpExtension() class ServerNoOpExtensionFactory: name = 'x-no-op' def __init__(self, params=None): self.params = params or [] def process_request_params(self, params, accepted_extensions): return self.params, NoOpExtension() class NoOpExtension: name = 'x-no-op' def __repr__(self): return 'NoOpExtension()' def decode(self, frame, *, max_size=None): return frame def encode(self, frame): return frame class ClientServerTests(unittest.TestCase): secure = False def setUp(self): self.loop = asyncio.new_event_loop() asyncio.set_event_loop(self.loop) def tearDown(self): self.loop.close() def run_loop_once(self): # Process callbacks scheduled with call_soon by appending a callback # to stop the event loop then running it until it hits that callback. self.loop.call_soon(self.loop.stop) self.loop.run_forever() def start_server(self, **kwds): # Don't enable compression by default in tests. kwds.setdefault('compression', None) start_server = serve(handler, 'localhost', 0, **kwds) self.server = self.loop.run_until_complete(start_server) def start_client(self, resource_name='/', user_info=None, **kwds): # Don't enable compression by default in tests. kwds.setdefault('compression', None) secure = kwds.get('ssl') is not None server_uri = get_server_uri( self.server, secure, resource_name, user_info) start_client = connect(server_uri, **kwds) self.client = self.loop.run_until_complete(start_client) def stop_client(self): try: self.loop.run_until_complete( asyncio.wait_for(self.client.close_connection_task, timeout=1)) except asyncio.TimeoutError: # pragma: no cover self.fail("Client failed to stop") def stop_server(self): self.server.close() try: self.loop.run_until_complete( asyncio.wait_for(self.server.wait_closed(), timeout=1)) except asyncio.TimeoutError: # pragma: no cover self.fail("Server failed to stop") @contextlib.contextmanager def temp_server(self, **kwds): with temp_test_server(self, **kwds): yield @contextlib.contextmanager def temp_client(self, *args, **kwds): with temp_test_client(self, *args, **kwds): yield @with_server() @with_client() def test_basic(self): self.loop.run_until_complete(self.client.send("Hello!")) reply = self.loop.run_until_complete(self.client.recv()) self.assertEqual(reply, "Hello!") def test_server_close_while_client_connected(self): with self.temp_server(loop=self.loop): # This endpoint waits just a bit when the connection is cancelled # in order to test that wait_closed() really waits for completion. self.start_client('/slow_stop') with self.assertRaises(ConnectionClosed): self.loop.run_until_complete(self.client.recv()) # Connection ends with 1001 going away. self.assertEqual(self.client.close_code, 1001) def test_explicit_event_loop(self): with self.temp_server(loop=self.loop): with self.temp_client(loop=self.loop): self.loop.run_until_complete(self.client.send("Hello!")) reply = self.loop.run_until_complete(self.client.recv()) self.assertEqual(reply, "Hello!") # The way the legacy SSL implementation wraps sockets makes it extremely # hard to write a test for Python 3.4. @unittest.skipIf( sys.version_info[:2] <= (3, 4), 'this test requires Python 3.5+') @with_server() def test_explicit_socket(self): class TrackedSocket(socket.socket): def __init__(self, *args, **kwargs): self.used_for_read = False self.used_for_write = False super().__init__(*args, **kwargs) def recv(self, *args, **kwargs): self.used_for_read = True return super().recv(*args, **kwargs) def send(self, *args, **kwargs): self.used_for_write = True return super().send(*args, **kwargs) server_socket = [ s for s in self.server.sockets if s.family == socket.AF_INET][0] client_socket = TrackedSocket(socket.AF_INET, socket.SOCK_STREAM) client_socket.connect(server_socket.getsockname()) try: self.assertFalse(client_socket.used_for_read) self.assertFalse(client_socket.used_for_write) with self.temp_client( sock=client_socket, # "You must set server_hostname when using ssl without a host" server_hostname='localhost' if self.secure else None, ): self.loop.run_until_complete(self.client.send("Hello!")) reply = self.loop.run_until_complete(self.client.recv()) self.assertEqual(reply, "Hello!") self.assertTrue(client_socket.used_for_read) self.assertTrue(client_socket.used_for_write) finally: client_socket.close() @unittest.skipUnless( hasattr(socket, 'AF_UNIX'), 'this test requires Unix sockets') def test_unix_socket(self): with tempfile.TemporaryDirectory() as temp_dir: path = bytes(pathlib.Path(temp_dir) / 'websockets') # Like self.start_server() but with unix_serve(). unix_server = unix_serve(handler, path) self.server = self.loop.run_until_complete(unix_server) client_socket = socket.socket(socket.AF_UNIX) client_socket.connect(path) try: with self.temp_client(sock=client_socket): self.loop.run_until_complete(self.client.send("Hello!")) reply = self.loop.run_until_complete(self.client.recv()) self.assertEqual(reply, "Hello!") finally: client_socket.close() self.stop_server() @with_server() @with_client('/attributes') def test_protocol_attributes(self): # The test could be connecting with IPv6 or IPv4. expected_client_attrs = [ server_socket.getsockname()[:2] + (self.secure,) for server_socket in self.server.sockets ] client_attrs = (self.client.host, self.client.port, self.client.secure) self.assertIn(client_attrs, expected_client_attrs) expected_server_attrs = ('localhost', 0, self.secure) server_attrs = self.loop.run_until_complete(self.client.recv()) self.assertEqual(server_attrs, repr(expected_server_attrs)) @with_server() @with_client('/path') def test_protocol_path(self): client_path = self.client.path self.assertEqual(client_path, '/path') server_path = self.loop.run_until_complete(self.client.recv()) self.assertEqual(server_path, '/path') @with_server() @with_client('/headers', user_info=('user', 'pass')) def test_protocol_basic_auth(self): self.assertEqual( self.client.request_headers['Authorization'], 'Basic dXNlcjpwYXNz', ) @with_server() @with_client('/headers') def test_protocol_headers(self): client_req = self.client.request_headers client_resp = self.client.response_headers self.assertEqual(client_req['User-Agent'], USER_AGENT) self.assertEqual(client_resp['Server'], USER_AGENT) server_req = self.loop.run_until_complete(self.client.recv()) server_resp = self.loop.run_until_complete(self.client.recv()) self.assertEqual(server_req, repr(client_req)) self.assertEqual(server_resp, repr(client_resp)) @with_server() @with_client('/headers', extra_headers=Headers({'X-Spam': 'Eggs'})) def test_protocol_custom_request_headers(self): req_headers = self.loop.run_until_complete(self.client.recv()) self.loop.run_until_complete(self.client.recv()) self.assertIn("('X-Spam', 'Eggs')", req_headers) @with_server() @with_client('/headers', extra_headers={'X-Spam': 'Eggs'}) def test_protocol_custom_request_headers_dict(self): req_headers = self.loop.run_until_complete(self.client.recv()) self.loop.run_until_complete(self.client.recv()) self.assertIn("('X-Spam', 'Eggs')", req_headers) @with_server() @with_client('/headers', extra_headers=[('X-Spam', 'Eggs')]) def test_protocol_custom_request_headers_list(self): req_headers = self.loop.run_until_complete(self.client.recv()) self.loop.run_until_complete(self.client.recv()) self.assertIn("('X-Spam', 'Eggs')", req_headers) @with_server() @with_client('/headers', extra_headers=[('User-Agent', 'Eggs')]) def test_protocol_custom_request_user_agent(self): req_headers = self.loop.run_until_complete(self.client.recv()) self.loop.run_until_complete(self.client.recv()) self.assertEqual(req_headers.count("User-Agent"), 1) self.assertIn("('User-Agent', 'Eggs')", req_headers) @with_server(extra_headers=lambda p, r: Headers({'X-Spam': 'Eggs'})) @with_client('/headers') def test_protocol_custom_response_headers_callable(self): self.loop.run_until_complete(self.client.recv()) resp_headers = self.loop.run_until_complete(self.client.recv()) self.assertIn("('X-Spam', 'Eggs')", resp_headers) @with_server(extra_headers=lambda p, r: {'X-Spam': 'Eggs'}) @with_client('/headers') def test_protocol_custom_response_headers_callable_dict(self): self.loop.run_until_complete(self.client.recv()) resp_headers = self.loop.run_until_complete(self.client.recv()) self.assertIn("('X-Spam', 'Eggs')", resp_headers) @with_server(extra_headers=lambda p, r: [('X-Spam', 'Eggs')]) @with_client('/headers') def test_protocol_custom_response_headers_callable_list(self): self.loop.run_until_complete(self.client.recv()) resp_headers = self.loop.run_until_complete(self.client.recv()) self.assertIn("('X-Spam', 'Eggs')", resp_headers) @with_server(extra_headers=Headers({'X-Spam': 'Eggs'})) @with_client('/headers') def test_protocol_custom_response_headers(self): self.loop.run_until_complete(self.client.recv()) resp_headers = self.loop.run_until_complete(self.client.recv()) self.assertIn("('X-Spam', 'Eggs')", resp_headers) @with_server(extra_headers={'X-Spam': 'Eggs'}) @with_client('/headers') def test_protocol_custom_response_headers_dict(self): self.loop.run_until_complete(self.client.recv()) resp_headers = self.loop.run_until_complete(self.client.recv()) self.assertIn("('X-Spam', 'Eggs')", resp_headers) @with_server(extra_headers=[('X-Spam', 'Eggs')]) @with_client('/headers') def test_protocol_custom_response_headers_list(self): self.loop.run_until_complete(self.client.recv()) resp_headers = self.loop.run_until_complete(self.client.recv()) self.assertIn("('X-Spam', 'Eggs')", resp_headers) @with_server(extra_headers=[('Server', 'Eggs')]) @with_client('/headers') def test_protocol_custom_response_user_agent(self): self.loop.run_until_complete(self.client.recv()) resp_headers = self.loop.run_until_complete(self.client.recv()) self.assertEqual(resp_headers.count("Server"), 1) self.assertIn("('Server', 'Eggs')", resp_headers) def make_http_request(self, path='/'): # Set url to 'https?://:'. url = get_server_uri( self.server, resource_name=path, secure=self.secure) url = url.replace('ws', 'http') if self.secure: open_health_check = functools.partial( urllib.request.urlopen, url, context=self.client_context) else: open_health_check = functools.partial( urllib.request.urlopen, url) return self.loop.run_in_executor(None, open_health_check) @with_server(create_protocol=HealthCheckServerProtocol) def test_http_request_http_endpoint(self): # Making a HTTP request to a HTTP endpoint succeeds. response = self.loop.run_until_complete( self.make_http_request('/__health__/')) with contextlib.closing(response): self.assertEqual(response.code, 200) self.assertEqual(response.read(), b'status = green\n') @with_server(create_protocol=HealthCheckServerProtocol) def test_http_request_ws_endpoint(self): # Making a HTTP request to a WS endpoint fails. with self.assertRaises(urllib.error.HTTPError) as raised: self.loop.run_until_complete(self.make_http_request()) self.assertEqual(raised.exception.code, 426) self.assertEqual(raised.exception.headers['Upgrade'], 'websocket') @with_server(create_protocol=HealthCheckServerProtocol) def test_ws_connection_http_endpoint(self): # Making a WS connection to a HTTP endpoint fails. with self.assertRaises(InvalidStatusCode) as raised: self.start_client('/__health__/') self.assertEqual(raised.exception.status_code, 200) @with_server(create_protocol=HealthCheckServerProtocol) def test_ws_connection_ws_endpoint(self): # Making a WS connection to a WS endpoint succeeds. self.start_client() self.loop.run_until_complete(self.client.send("Hello!")) self.loop.run_until_complete(self.client.recv()) self.stop_client() def assert_client_raises_code(self, status_code): with self.assertRaises(InvalidStatusCode) as raised: self.start_client() self.assertEqual(raised.exception.status_code, status_code) @with_server(create_protocol=UnauthorizedServerProtocol) def test_server_create_protocol(self): self.assert_client_raises_code(401) @with_server(create_protocol=(lambda *args, **kwargs: UnauthorizedServerProtocol(*args, **kwargs))) def test_server_create_protocol_function(self): self.assert_client_raises_code(401) @with_server(klass=UnauthorizedServerProtocol) def test_server_klass(self): self.assert_client_raises_code(401) @with_server(create_protocol=ForbiddenServerProtocol, klass=UnauthorizedServerProtocol) def test_server_create_protocol_over_klass(self): self.assert_client_raises_code(403) @with_server() @with_client('/path', create_protocol=FooClientProtocol) def test_client_create_protocol(self): self.assertIsInstance(self.client, FooClientProtocol) @with_server() @with_client('/path', create_protocol=( lambda *args, **kwargs: FooClientProtocol(*args, **kwargs))) def test_client_create_protocol_function(self): self.assertIsInstance(self.client, FooClientProtocol) @with_server() @with_client('/path', klass=FooClientProtocol) def test_client_klass(self): self.assertIsInstance(self.client, FooClientProtocol) @with_server() @with_client('/path', create_protocol=BarClientProtocol, klass=FooClientProtocol) def test_client_create_protocol_over_klass(self): self.assertIsInstance(self.client, BarClientProtocol) @with_server() @with_client('/extensions') def test_no_extension(self): server_extensions = self.loop.run_until_complete(self.client.recv()) self.assertEqual(server_extensions, repr([])) self.assertEqual(repr(self.client.extensions), repr([])) @with_server(extensions=[ServerNoOpExtensionFactory()]) @with_client('/extensions', extensions=[ClientNoOpExtensionFactory()]) def test_extension(self): server_extensions = self.loop.run_until_complete(self.client.recv()) self.assertEqual(server_extensions, repr([NoOpExtension()])) self.assertEqual(repr(self.client.extensions), repr([NoOpExtension()])) @with_server() @with_client('/extensions', extensions=[ClientNoOpExtensionFactory()]) def test_extension_not_accepted(self): server_extensions = self.loop.run_until_complete(self.client.recv()) self.assertEqual(server_extensions, repr([])) self.assertEqual(repr(self.client.extensions), repr([])) @with_server(extensions=[ServerNoOpExtensionFactory()]) @with_client('/extensions') def test_extension_not_requested(self): server_extensions = self.loop.run_until_complete(self.client.recv()) self.assertEqual(server_extensions, repr([])) self.assertEqual(repr(self.client.extensions), repr([])) @with_server(extensions=[ServerNoOpExtensionFactory([('foo', None)])]) def test_extension_client_rejection(self): with self.assertRaises(NegotiationError): self.start_client( '/extensions', extensions=[ClientNoOpExtensionFactory()], ) @with_server( extensions=[ # No match because the client doesn't send client_max_window_bits. ServerPerMessageDeflateFactory(client_max_window_bits=10), ServerPerMessageDeflateFactory(), ], ) @with_client( '/extensions', extensions=[ ClientPerMessageDeflateFactory(), ], ) def test_extension_no_match_then_match(self): # The order requested by the client has priority. server_extensions = self.loop.run_until_complete(self.client.recv()) self.assertEqual(server_extensions, repr([ PerMessageDeflate(False, False, 15, 15), ])) self.assertEqual(repr(self.client.extensions), repr([ PerMessageDeflate(False, False, 15, 15), ])) @with_server(extensions=[ServerPerMessageDeflateFactory()]) @with_client('/extensions', extensions=[ClientNoOpExtensionFactory()]) def test_extension_mismatch(self): server_extensions = self.loop.run_until_complete(self.client.recv()) self.assertEqual(server_extensions, repr([])) self.assertEqual(repr(self.client.extensions), repr([])) @with_server( extensions=[ ServerNoOpExtensionFactory(), ServerPerMessageDeflateFactory(), ], ) @with_client( '/extensions', extensions=[ ClientPerMessageDeflateFactory(), ClientNoOpExtensionFactory(), ], ) def test_extension_order(self): # The order requested by the client has priority. server_extensions = self.loop.run_until_complete(self.client.recv()) self.assertEqual(server_extensions, repr([ PerMessageDeflate(False, False, 15, 15), NoOpExtension(), ])) self.assertEqual(repr(self.client.extensions), repr([ PerMessageDeflate(False, False, 15, 15), NoOpExtension(), ])) @with_server(extensions=[ServerNoOpExtensionFactory()]) @unittest.mock.patch.object(WebSocketServerProtocol, 'process_extensions') def test_extensions_error(self, _process_extensions): _process_extensions.return_value = 'x-no-op', [NoOpExtension()] with self.assertRaises(NegotiationError): self.start_client( '/extensions', extensions=[ClientPerMessageDeflateFactory()], ) @with_server(extensions=[ServerNoOpExtensionFactory()]) @unittest.mock.patch.object(WebSocketServerProtocol, 'process_extensions') def test_extensions_error_no_extensions(self, _process_extensions): _process_extensions.return_value = 'x-no-op', [NoOpExtension()] with self.assertRaises(InvalidHandshake): self.start_client('/extensions') @with_server(compression='deflate') @with_client('/extensions', compression='deflate') def test_compression_deflate(self): server_extensions = self.loop.run_until_complete(self.client.recv()) self.assertEqual(server_extensions, repr([ PerMessageDeflate(False, False, 15, 15), ])) self.assertEqual(repr(self.client.extensions), repr([ PerMessageDeflate(False, False, 15, 15), ])) @with_server( extensions=[ ServerPerMessageDeflateFactory( client_no_context_takeover=True, server_max_window_bits=10, ), ], compression='deflate', # overridden by explicit config ) @with_client( '/extensions', extensions=[ ClientPerMessageDeflateFactory( server_no_context_takeover=True, client_max_window_bits=12, ), ], compression='deflate', # overridden by explicit config ) def test_compression_deflate_and_explicit_config(self): server_extensions = self.loop.run_until_complete(self.client.recv()) self.assertEqual(server_extensions, repr([ PerMessageDeflate(True, True, 12, 10), ])) self.assertEqual(repr(self.client.extensions), repr([ PerMessageDeflate(True, True, 10, 12), ])) def test_compression_unsupported_server(self): with self.assertRaises(ValueError): self.loop.run_until_complete(self.start_server(compression='xz')) @with_server() def test_compression_unsupported_client(self): with self.assertRaises(ValueError): self.loop.run_until_complete(self.start_client(compression='xz')) @with_server() @with_client('/subprotocol') def test_no_subprotocol(self): server_subprotocol = self.loop.run_until_complete(self.client.recv()) self.assertEqual(server_subprotocol, repr(None)) self.assertEqual(self.client.subprotocol, None) @with_server(subprotocols=['superchat', 'chat']) @with_client('/subprotocol', subprotocols=['otherchat', 'chat']) def test_subprotocol(self): server_subprotocol = self.loop.run_until_complete(self.client.recv()) self.assertEqual(server_subprotocol, repr('chat')) self.assertEqual(self.client.subprotocol, 'chat') @with_server(subprotocols=['superchat']) @with_client('/subprotocol', subprotocols=['otherchat']) def test_subprotocol_not_accepted(self): server_subprotocol = self.loop.run_until_complete(self.client.recv()) self.assertEqual(server_subprotocol, repr(None)) self.assertEqual(self.client.subprotocol, None) @with_server() @with_client('/subprotocol', subprotocols=['otherchat', 'chat']) def test_subprotocol_not_offered(self): server_subprotocol = self.loop.run_until_complete(self.client.recv()) self.assertEqual(server_subprotocol, repr(None)) self.assertEqual(self.client.subprotocol, None) @with_server(subprotocols=['superchat', 'chat']) @with_client('/subprotocol') def test_subprotocol_not_requested(self): server_subprotocol = self.loop.run_until_complete(self.client.recv()) self.assertEqual(server_subprotocol, repr(None)) self.assertEqual(self.client.subprotocol, None) @with_server(subprotocols=['superchat']) @unittest.mock.patch.object(WebSocketServerProtocol, 'process_subprotocol') def test_subprotocol_error(self, _process_subprotocol): _process_subprotocol.return_value = 'superchat' with self.assertRaises(NegotiationError): self.start_client('/subprotocol', subprotocols=['otherchat']) self.run_loop_once() @with_server(subprotocols=['superchat']) @unittest.mock.patch.object(WebSocketServerProtocol, 'process_subprotocol') def test_subprotocol_error_no_subprotocols(self, _process_subprotocol): _process_subprotocol.return_value = 'superchat' with self.assertRaises(InvalidHandshake): self.start_client('/subprotocol') self.run_loop_once() @with_server(subprotocols=['superchat', 'chat']) @unittest.mock.patch.object(WebSocketServerProtocol, 'process_subprotocol') def test_subprotocol_error_two_subprotocols(self, _process_subprotocol): _process_subprotocol.return_value = 'superchat, chat' with self.assertRaises(InvalidHandshake): self.start_client( '/subprotocol', subprotocols=['superchat', 'chat']) self.run_loop_once() @with_server() @unittest.mock.patch('websockets.server.read_request') def test_server_receives_malformed_request(self, _read_request): _read_request.side_effect = ValueError("read_request failed") with self.assertRaises(InvalidHandshake): self.start_client() @with_server() @unittest.mock.patch('websockets.client.read_response') def test_client_receives_malformed_response(self, _read_response): _read_response.side_effect = ValueError("read_response failed") with self.assertRaises(InvalidHandshake): self.start_client() self.run_loop_once() @with_server() @unittest.mock.patch('websockets.client.build_request') def test_client_sends_invalid_handshake_request(self, _build_request): def wrong_build_request(headers): return '42' _build_request.side_effect = wrong_build_request with self.assertRaises(InvalidHandshake): self.start_client() @with_server() @unittest.mock.patch('websockets.server.build_response') def test_server_sends_invalid_handshake_response(self, _build_response): def wrong_build_response(headers, key): return build_response(headers, '42') _build_response.side_effect = wrong_build_response with self.assertRaises(InvalidHandshake): self.start_client() @with_server() @unittest.mock.patch('websockets.client.read_response') def test_server_does_not_switch_protocols(self, _read_response): @asyncio.coroutine def wrong_read_response(stream): status_code, headers = yield from read_response(stream) return 400, headers _read_response.side_effect = wrong_read_response with self.assertRaises(InvalidStatusCode): self.start_client() self.run_loop_once() @with_server() @unittest.mock.patch( 'websockets.server.WebSocketServerProtocol.process_request') def test_server_error_in_handshake(self, _process_request): _process_request.side_effect = Exception("process_request crashed") with self.assertRaises(InvalidHandshake): self.start_client() @with_server() @unittest.mock.patch('websockets.server.WebSocketServerProtocol.send') def test_server_handler_crashes(self, send): send.side_effect = ValueError("send failed") with self.temp_client(): self.loop.run_until_complete(self.client.send("Hello!")) with self.assertRaises(ConnectionClosed): self.loop.run_until_complete(self.client.recv()) # Connection ends with an unexpected error. self.assertEqual(self.client.close_code, 1011) @with_server() @unittest.mock.patch('websockets.server.WebSocketServerProtocol.close') def test_server_close_crashes(self, close): close.side_effect = ValueError("close failed") with self.temp_client(): self.loop.run_until_complete(self.client.send("Hello!")) reply = self.loop.run_until_complete(self.client.recv()) self.assertEqual(reply, "Hello!") # Connection ends with an abnormal closure. self.assertEqual(self.client.close_code, 1006) @with_server() @with_client() @unittest.mock.patch.object(WebSocketClientProtocol, 'handshake') def test_client_closes_connection_before_handshake(self, handshake): # We have mocked the handshake() method to prevent the client from # performing the opening handshake. Force it to close the connection. self.client.writer.close() # The server should stop properly anyway. It used to hang because the # task handling the connection was waiting for the opening handshake. @with_server() @unittest.mock.patch('websockets.server.read_request') def test_server_shuts_down_during_opening_handshake(self, _read_request): _read_request.side_effect = asyncio.CancelledError self.server.closing = True with self.assertRaises(InvalidHandshake) as raised: self.start_client() # Opening handshake fails with 503 Service Unavailable self.assertEqual(str(raised.exception), "Status code not 101: 503") @with_server() def test_server_shuts_down_during_connection_handling(self): with self.temp_client(): self.server.close() with self.assertRaises(ConnectionClosed): self.loop.run_until_complete(self.client.recv()) # Websocket connection terminates with 1001 Going Away. self.assertEqual(self.client.close_code, 1001) @with_server() @unittest.mock.patch('websockets.server.WebSocketServerProtocol.close') def test_server_shuts_down_during_connection_close(self, _close): _close.side_effect = asyncio.CancelledError self.server.closing = True with self.temp_client(): self.loop.run_until_complete(self.client.send("Hello!")) reply = self.loop.run_until_complete(self.client.recv()) self.assertEqual(reply, "Hello!") # Websocket connection terminates abnormally. self.assertEqual(self.client.close_code, 1006) @with_server(create_protocol=ForbiddenServerProtocol) def test_invalid_status_error_during_client_connect(self): with self.assertRaises(InvalidStatusCode) as raised: self.start_client() exception = raised.exception self.assertEqual(str(exception), "Status code not 101: 403") self.assertEqual(exception.status_code, 403) @with_server() @unittest.mock.patch( 'websockets.server.WebSocketServerProtocol.write_http_response') @unittest.mock.patch( 'websockets.server.WebSocketServerProtocol.read_http_request') def test_connection_error_during_opening_handshake( self, _read_http_request, _write_http_response): _read_http_request.side_effect = ConnectionError # This exception is currently platform-dependent. It was observed to # be ConnectionResetError on Linux in the non-SSL case, and # InvalidMessage otherwise (including both Linux and macOS). This # doesn't matter though since this test is primarily for testing a # code path on the server side. with self.assertRaises(Exception): self.start_client() # No response must not be written if the network connection is broken. _write_http_response.assert_not_called() @with_server() @unittest.mock.patch('websockets.server.WebSocketServerProtocol.close') def test_connection_error_during_closing_handshake(self, close): close.side_effect = ConnectionError with self.temp_client(): self.loop.run_until_complete(self.client.send("Hello!")) reply = self.loop.run_until_complete(self.client.recv()) self.assertEqual(reply, "Hello!") # Connection ends with an abnormal closure. self.assertEqual(self.client.close_code, 1006) class SSLClientServerTests(ClientServerTests): secure = True @property def server_context(self): # Change to ssl.PROTOCOL_TLS_SERVER when dropping Python < 3.6. ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) ssl_context.load_cert_chain(testcert) return ssl_context @property def client_context(self): # Change to ssl.PROTOCOL_TLS_CLIENT when dropping Python < 3.6. # Then remove verify_mode and check_hostname below. ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) ssl_context.load_verify_locations(testcert) ssl_context.verify_mode = ssl.CERT_REQUIRED # ssl.match_hostname can't match IP addresses on Python < 3.5. # We're using IP addresses to enforce testing of IPv4 and IPv6. if sys.version_info[:2] >= (3, 5): # pragma: no cover ssl_context.check_hostname = True return ssl_context def start_server(self, **kwds): kwds.setdefault('ssl', self.server_context) super().start_server(**kwds) def start_client(self, path='/', **kwds): kwds.setdefault('ssl', self.client_context) super().start_client(path, **kwds) # TLS over Unix sockets doesn't make sense. test_unix_socket = None @with_server() def test_ws_uri_is_rejected(self): with self.assertRaises(ValueError): client = connect( get_server_uri(self.server, secure=False), ssl=self.client_context, ) # With Python ≥ 3.5, the exception is raised by connect() even # before awaiting. However, with Python 3.4 the exception is # raised only when awaiting. self.loop.run_until_complete(client) # pragma: no cover class ClientServerOriginTests(unittest.TestCase): def setUp(self): self.loop = asyncio.new_event_loop() asyncio.set_event_loop(self.loop) def tearDown(self): self.loop.close() def test_checking_origin_succeeds(self): server = self.loop.run_until_complete( serve(handler, 'localhost', 0, origins=['http://localhost'])) client = self.loop.run_until_complete( connect(get_server_uri(server), origin='http://localhost')) self.loop.run_until_complete(client.send("Hello!")) self.assertEqual(self.loop.run_until_complete(client.recv()), "Hello!") self.loop.run_until_complete(client.close()) server.close() self.loop.run_until_complete(server.wait_closed()) def test_checking_origin_fails(self): server = self.loop.run_until_complete( serve(handler, 'localhost', 0, origins=['http://localhost'])) with self.assertRaisesRegex(InvalidHandshake, "Status code not 101: 403"): self.loop.run_until_complete( connect(get_server_uri(server), origin='http://otherhost')) server.close() self.loop.run_until_complete(server.wait_closed()) def test_checking_lack_of_origin_succeeds(self): server = self.loop.run_until_complete( serve(handler, 'localhost', 0, origins=[''])) client = self.loop.run_until_complete(connect(get_server_uri(server))) self.loop.run_until_complete(client.send("Hello!")) self.assertEqual(self.loop.run_until_complete(client.recv()), "Hello!") self.loop.run_until_complete(client.close()) server.close() self.loop.run_until_complete(server.wait_closed()) class YieldFromTests(unittest.TestCase): def setUp(self): self.loop = asyncio.new_event_loop() asyncio.set_event_loop(self.loop) def tearDown(self): self.loop.close() def test_client(self): start_server = serve(handler, 'localhost', 0) server = self.loop.run_until_complete(start_server) @asyncio.coroutine def run_client(): # Yield from connect. client = yield from connect(get_server_uri(server)) self.assertEqual(client.state, State.OPEN) yield from client.close() self.assertEqual(client.state, State.CLOSED) self.loop.run_until_complete(run_client()) server.close() self.loop.run_until_complete(server.wait_closed()) def test_server(self): @asyncio.coroutine def run_server(): # Yield from serve. server = yield from serve(handler, 'localhost', 0) self.assertTrue(server.sockets) server.close() yield from server.wait_closed() self.assertFalse(server.sockets) self.loop.run_until_complete(run_server()) if sys.version_info[:2] >= (3, 5): # pragma: no cover from .py35._test_client_server import AsyncAwaitTests # noqa from .py35._test_client_server import ContextManagerTests # noqa if sys.version_info[:2] >= (3, 6): # pragma: no cover from .py36._test_client_server import AsyncIteratorTests # noqa