import asyncio import contextlib import functools import logging import os import time import unittest import unittest.mock from .compatibility import asyncio_ensure_future from .exceptions import ConnectionClosed, InvalidState from .framing import * from .protocol import State, WebSocketCommonProtocol # Avoid displaying stack traces at the ERROR logging level. logging.basicConfig(level=logging.CRITICAL) # Unit for timeouts. May be increased on slow machines by setting the # WEBSOCKETS_TESTS_TIMEOUT_FACTOR environment variable. MS = 0.001 * int(os.environ.get('WEBSOCKETS_TESTS_TIMEOUT_FACTOR', 1)) # asyncio's debug mode has a 10x performance penalty for this test suite. if os.environ.get('PYTHONASYNCIODEBUG'): # pragma: no cover MS *= 10 # Ensure that timeouts are larger than the clock's resolution (for Windows). MS = max(MS, 2.5 * time.get_clock_info('monotonic').resolution) class TransportMock(unittest.mock.Mock): """ Transport mock to control the protocol's inputs and outputs in tests. It calls the protocol's connection_made and connection_lost methods like actual transports. It also calls the protocol's connection_open method to bypass the WebSocket handshake. To simulate incoming data, tests call the protocol's data_received and eof_received methods directly. They could also pause_writing and resume_writing to test flow control. """ # This should happen in __init__ but overriding Mock.__init__ is hard. def setup_mock(self, loop, protocol): self.loop = loop self.protocol = protocol self._eof = False self._closing = False # Simulate a successful TCP handshake. self.protocol.connection_made(self) # Simulate a successful WebSocket handshake. self.protocol.connection_open() def can_write_eof(self): return True def write_eof(self): # When the protocol half-closes the TCP connection, it expects the # other end to close it. Simulate that. if not self._eof: self.loop.call_soon(self.close) self._eof = True def is_closing(self): return self._closing def close(self): # Simulate how actual transports drop the connection. if not self._closing: self.loop.call_soon(self.protocol.connection_lost, None) self._closing = True def abort(self): # Change this to an `if` if tests call abort() multiple times. assert self.protocol.state is not State.CLOSED self.loop.call_soon(self.protocol.connection_lost, None) class CommonTests: """ Mixin that defines most tests but doesn't inherit unittest.TestCase. Tests are run by the ServerTests and ClientTests subclasses. """ def setUp(self): super().setUp() self.loop = asyncio.new_event_loop() asyncio.set_event_loop(self.loop) self.protocol = WebSocketCommonProtocol() self.transport = TransportMock() self.transport.setup_mock(self.loop, self.protocol) def tearDown(self): self.transport.close() self.loop.run_until_complete(self.protocol.close()) self.loop.close() super().tearDown() # Utilities for writing tests. def run_loop_once(self): # Process callbacks scheduled with call_soon by appending a callback # to stop the event loop then running it until it hits that callback. self.loop.call_soon(self.loop.stop) self.loop.run_forever() def make_drain_slow(self, delay=MS): # Process connection_made in order to initialize self.protocol.writer. self.run_loop_once() original_drain = self.protocol.writer.drain @asyncio.coroutine def delayed_drain(): yield from asyncio.sleep(delay, loop=self.loop) yield from original_drain() self.protocol.writer.drain = delayed_drain close_frame = Frame(True, OP_CLOSE, serialize_close(1000, 'close')) local_close = Frame(True, OP_CLOSE, serialize_close(1000, 'local')) remote_close = Frame(True, OP_CLOSE, serialize_close(1000, 'remote')) @property def ensure_future(self): return functools.partial(asyncio_ensure_future, loop=self.loop) def receive_frame(self, frame): """ Make the protocol receive a frame. """ writer = self.protocol.data_received mask = not self.protocol.is_client frame.write(writer, mask=mask) def receive_eof(self): """ Make the protocol receive the end of the data stream. Since ``WebSocketCommonProtocol.eof_received`` returns ``None``, an actual transport would close itself after calling it. This function emulates that behavior. """ self.protocol.eof_received() self.loop.call_soon(self.transport.close) def receive_eof_if_client(self): """ Like receive_eof, but only if this is the client side. Since the server is supposed to initiate the termination of the TCP connection, this method helps making tests work for both sides. """ if self.protocol.is_client: self.receive_eof() def close_connection(self, code=1000, reason='close'): """ Execute a closing handshake. This puts the connection in the CLOSED state. """ close_frame_data = serialize_close(code, reason) # Prepare the response to the closing handshake from the remote side. self.receive_frame(Frame(True, OP_CLOSE, close_frame_data)) self.receive_eof_if_client() # Trigger the closing handshake from the local side and complete it. self.loop.run_until_complete(self.protocol.close(code, reason)) # Empty the outgoing data stream so we can make assertions later on. self.assertOneFrameSent(True, OP_CLOSE, close_frame_data) assert self.protocol.state is State.CLOSED def half_close_connection_local(self, code=1000, reason='close'): """ Start a closing handshake but do not complete it. The main difference with `close_connection` is that the connection is left in the CLOSING state until the event loop runs again. The current implementation returns a task that must be awaited or cancelled, else asyncio complains about destroying a pending task. """ close_frame_data = serialize_close(code, reason) # Trigger the closing handshake from the local endpoint. close_task = self.ensure_future(self.protocol.close(code, reason)) self.run_loop_once() # wait_for executes self.run_loop_once() # write_frame executes # Empty the outgoing data stream so we can make assertions later on. self.assertOneFrameSent(True, OP_CLOSE, close_frame_data) assert self.protocol.state is State.CLOSING # Complete the closing sequence at 1ms intervals so the test can run # at each point even it goes back to the event loop several times. self.loop.call_later( MS, self.receive_frame, Frame(True, OP_CLOSE, close_frame_data)) self.loop.call_later(2 * MS, self.receive_eof_if_client) # This task must be awaited or cancelled by the caller. return close_task def half_close_connection_remote(self, code=1000, reason='close'): """ Receive a closing handshake but do not complete it. The main difference with `close_connection` is that the connection is left in the CLOSING state until the event loop runs again. """ # On the server side, websockets completes the closing handshake and # closes the TCP connection immediately. Yield to the event loop after # sending the close frame to run the test while the connection is in # the CLOSING state. if not self.protocol.is_client: self.make_drain_slow() close_frame_data = serialize_close(code, reason) # Trigger the closing handshake from the remote endpoint. self.receive_frame(Frame(True, OP_CLOSE, close_frame_data)) self.run_loop_once() # read_frame executes # Empty the outgoing data stream so we can make assertions later on. self.assertOneFrameSent(True, OP_CLOSE, close_frame_data) assert self.protocol.state is State.CLOSING # Complete the closing sequence at 1ms intervals so the test can run # at each point even it goes back to the event loop several times. self.loop.call_later(2 * MS, self.receive_eof_if_client) def process_invalid_frames(self): """ Make the protocol fail quickly after simulating invalid data. To achieve this, this function triggers the protocol's eof_received, which interrupts pending reads waiting for more data. """ self.run_loop_once() self.receive_eof() self.loop.run_until_complete(self.protocol.close_connection_task) def last_sent_frame(self): """ Read the last frame sent to the transport. This method assumes that at most one frame was sent. It raises an AssertionError otherwise. """ stream = asyncio.StreamReader(loop=self.loop) for (data,), kw in self.transport.write.call_args_list: stream.feed_data(data) self.transport.write.call_args_list = [] stream.feed_eof() if stream.at_eof(): frame = None else: frame = self.loop.run_until_complete(Frame.read( stream.readexactly, mask=self.protocol.is_client)) if not stream.at_eof(): # pragma: no cover data = self.loop.run_until_complete(stream.read()) raise AssertionError("Trailing data found: {!r}".format(data)) return frame def assertOneFrameSent(self, *args): self.assertEqual(self.last_sent_frame(), Frame(*args)) def assertNoFrameSent(self): self.assertIsNone(self.last_sent_frame()) def assertConnectionClosed(self, code, message): # The following line guarantees that connection_lost was called. self.assertEqual(self.protocol.state, State.CLOSED) # A close frame was received. self.assertEqual(self.protocol.close_code, code) self.assertEqual(self.protocol.close_reason, message) def assertConnectionFailed(self, code, message): # The following line guarantees that connection_lost was called. self.assertEqual(self.protocol.state, State.CLOSED) # No close frame was received. self.assertEqual(self.protocol.close_code, 1006) self.assertEqual(self.protocol.close_reason, '') # A close frame was sent -- unless the connection was already lost. if code == 1006: self.assertNoFrameSent() else: self.assertOneFrameSent( True, OP_CLOSE, serialize_close(code, message)) @contextlib.contextmanager def assertCompletesWithin(self, min_time, max_time): t0 = self.loop.time() yield t1 = self.loop.time() dt = t1 - t0 self.assertGreaterEqual( dt, min_time, "Too fast: {} < {}".format(dt, min_time)) self.assertLess( dt, max_time, "Too slow: {} >= {}".format(dt, max_time)) # Test public attributes. def test_local_address(self): get_extra_info = unittest.mock.Mock(return_value=('host', 4312)) self.transport.get_extra_info = get_extra_info self.assertEqual(self.protocol.local_address, ('host', 4312)) get_extra_info.assert_called_with('sockname', None) def test_local_address_before_connection(self): # Emulate the situation before connection_open() runs. self.protocol.writer, _writer = None, self.protocol.writer try: self.assertEqual(self.protocol.local_address, None) finally: self.protocol.writer = _writer def test_remote_address(self): get_extra_info = unittest.mock.Mock(return_value=('host', 4312)) self.transport.get_extra_info = get_extra_info self.assertEqual(self.protocol.remote_address, ('host', 4312)) get_extra_info.assert_called_with('peername', None) def test_remote_address_before_connection(self): # Emulate the situation before connection_open() runs. self.protocol.writer, _writer = None, self.protocol.writer try: self.assertEqual(self.protocol.remote_address, None) finally: self.protocol.writer = _writer def test_open(self): self.assertTrue(self.protocol.open) self.close_connection() self.assertFalse(self.protocol.open) def test_closed(self): self.assertFalse(self.protocol.closed) self.close_connection() self.assertTrue(self.protocol.closed) # Test the recv coroutine. def test_recv_text(self): self.receive_frame(Frame(True, OP_TEXT, 'café'.encode('utf-8'))) data = self.loop.run_until_complete(self.protocol.recv()) self.assertEqual(data, 'café') def test_recv_binary(self): self.receive_frame(Frame(True, OP_BINARY, b'tea')) data = self.loop.run_until_complete(self.protocol.recv()) self.assertEqual(data, b'tea') def test_recv_on_closing_connection_local(self): close_task = self.half_close_connection_local() with self.assertRaises(ConnectionClosed): self.loop.run_until_complete(self.protocol.recv()) self.loop.run_until_complete(close_task) # cleanup def test_recv_on_closing_connection_remote(self): self.half_close_connection_remote() with self.assertRaises(ConnectionClosed): self.loop.run_until_complete(self.protocol.recv()) def test_recv_on_closed_connection(self): self.close_connection() with self.assertRaises(ConnectionClosed): self.loop.run_until_complete(self.protocol.recv()) def test_recv_protocol_error(self): self.receive_frame(Frame(True, OP_CONT, 'café'.encode('utf-8'))) self.process_invalid_frames() self.assertConnectionFailed(1002, '') def test_recv_unicode_error(self): self.receive_frame(Frame(True, OP_TEXT, 'café'.encode('latin-1'))) self.process_invalid_frames() self.assertConnectionFailed(1007, '') def test_recv_text_payload_too_big(self): self.protocol.max_size = 1024 self.receive_frame(Frame(True, OP_TEXT, 'café'.encode('utf-8') * 205)) self.process_invalid_frames() self.assertConnectionFailed(1009, '') def test_recv_binary_payload_too_big(self): self.protocol.max_size = 1024 self.receive_frame(Frame(True, OP_BINARY, b'tea' * 342)) self.process_invalid_frames() self.assertConnectionFailed(1009, '') def test_recv_text_no_max_size(self): self.protocol.max_size = None # for test coverage self.receive_frame(Frame(True, OP_TEXT, 'café'.encode('utf-8') * 205)) data = self.loop.run_until_complete(self.protocol.recv()) self.assertEqual(data, 'café' * 205) def test_recv_binary_no_max_size(self): self.protocol.max_size = None # for test coverage self.receive_frame(Frame(True, OP_BINARY, b'tea' * 342)) data = self.loop.run_until_complete(self.protocol.recv()) self.assertEqual(data, b'tea' * 342) def test_recv_other_error(self): @asyncio.coroutine def read_message(): raise Exception("BOOM") self.protocol.read_message = read_message self.process_invalid_frames() self.assertConnectionFailed(1011, '') def test_recv_cancelled(self): recv = self.ensure_future(self.protocol.recv()) self.loop.call_soon(recv.cancel) with self.assertRaises(asyncio.CancelledError): self.loop.run_until_complete(recv) # The next frame doesn't disappear in a vacuum (it used to). self.receive_frame(Frame(True, OP_TEXT, 'café'.encode('utf-8'))) data = self.loop.run_until_complete(self.protocol.recv()) self.assertEqual(data, 'café') # Test the send coroutine. def test_send_text(self): self.loop.run_until_complete(self.protocol.send('café')) self.assertOneFrameSent(True, OP_TEXT, 'café'.encode('utf-8')) def test_send_binary(self): self.loop.run_until_complete(self.protocol.send(b'tea')) self.assertOneFrameSent(True, OP_BINARY, b'tea') def test_send_type_error(self): with self.assertRaises(TypeError): self.loop.run_until_complete(self.protocol.send(42)) self.assertNoFrameSent() def test_send_on_closing_connection_local(self): close_task = self.half_close_connection_local() with self.assertRaises(ConnectionClosed): self.loop.run_until_complete(self.protocol.send('foobar')) self.assertNoFrameSent() self.loop.run_until_complete(close_task) # cleanup def test_send_on_closing_connection_remote(self): self.half_close_connection_remote() with self.assertRaises(ConnectionClosed): self.loop.run_until_complete(self.protocol.send('foobar')) self.assertNoFrameSent() def test_send_on_closed_connection(self): self.close_connection() with self.assertRaises(ConnectionClosed): self.loop.run_until_complete(self.protocol.send('foobar')) self.assertNoFrameSent() # Test the ping coroutine. def test_ping_default(self): self.loop.run_until_complete(self.protocol.ping()) # With our testing tools, it's more convenient to extract the expected # ping data from the library's internals than from the frame sent. ping_data = next(iter(self.protocol.pings)) self.assertIsInstance(ping_data, bytes) self.assertEqual(len(ping_data), 4) self.assertOneFrameSent(True, OP_PING, ping_data) def test_ping_text(self): self.loop.run_until_complete(self.protocol.ping('café')) self.assertOneFrameSent(True, OP_PING, 'café'.encode('utf-8')) def test_ping_binary(self): self.loop.run_until_complete(self.protocol.ping(b'tea')) self.assertOneFrameSent(True, OP_PING, b'tea') def test_ping_type_error(self): with self.assertRaises(TypeError): self.loop.run_until_complete(self.protocol.ping(42)) self.assertNoFrameSent() def test_ping_on_closing_connection_local(self): close_task = self.half_close_connection_local() with self.assertRaises(ConnectionClosed): self.loop.run_until_complete(self.protocol.ping()) self.assertNoFrameSent() self.loop.run_until_complete(close_task) # cleanup def test_ping_on_closing_connection_remote(self): self.half_close_connection_remote() with self.assertRaises(ConnectionClosed): self.loop.run_until_complete(self.protocol.ping()) self.assertNoFrameSent() def test_ping_on_closed_connection(self): self.close_connection() with self.assertRaises(ConnectionClosed): self.loop.run_until_complete(self.protocol.ping()) self.assertNoFrameSent() # Test the pong coroutine. def test_pong_default(self): self.loop.run_until_complete(self.protocol.pong()) self.assertOneFrameSent(True, OP_PONG, b'') def test_pong_text(self): self.loop.run_until_complete(self.protocol.pong('café')) self.assertOneFrameSent(True, OP_PONG, 'café'.encode('utf-8')) def test_pong_binary(self): self.loop.run_until_complete(self.protocol.pong(b'tea')) self.assertOneFrameSent(True, OP_PONG, b'tea') def test_pong_type_error(self): with self.assertRaises(TypeError): self.loop.run_until_complete(self.protocol.pong(42)) self.assertNoFrameSent() def test_pong_on_closing_connection_local(self): close_task = self.half_close_connection_local() with self.assertRaises(ConnectionClosed): self.loop.run_until_complete(self.protocol.pong()) self.assertNoFrameSent() self.loop.run_until_complete(close_task) # cleanup def test_pong_on_closing_connection_remote(self): self.half_close_connection_remote() with self.assertRaises(ConnectionClosed): self.loop.run_until_complete(self.protocol.pong()) self.assertNoFrameSent() def test_pong_on_closed_connection(self): self.close_connection() with self.assertRaises(ConnectionClosed): self.loop.run_until_complete(self.protocol.pong()) self.assertNoFrameSent() # Test the protocol's logic for acknowledging pings with pongs. def test_answer_ping(self): self.receive_frame(Frame(True, OP_PING, b'test')) self.run_loop_once() self.assertOneFrameSent(True, OP_PONG, b'test') def test_ignore_pong(self): self.receive_frame(Frame(True, OP_PONG, b'test')) self.run_loop_once() self.assertNoFrameSent() def test_acknowledge_ping(self): ping = self.loop.run_until_complete(self.protocol.ping()) self.assertFalse(ping.done()) ping_frame = self.last_sent_frame() pong_frame = Frame(True, OP_PONG, ping_frame.data) self.receive_frame(pong_frame) self.run_loop_once() self.run_loop_once() self.assertTrue(ping.done()) def test_cancel_ping(self): ping = self.loop.run_until_complete(self.protocol.ping()) # Remove the frame from the buffer, else close_connection() complains. self.last_sent_frame() self.assertFalse(ping.cancelled()) self.close_connection() self.assertTrue(ping.cancelled()) def test_acknowledge_previous_pings(self): pings = [( self.loop.run_until_complete(self.protocol.ping()), self.last_sent_frame(), ) for i in range(3)] # Unsolicited pong doesn't acknowledge pings self.receive_frame(Frame(True, OP_PONG, b'')) self.run_loop_once() self.run_loop_once() self.assertFalse(pings[0][0].done()) self.assertFalse(pings[1][0].done()) self.assertFalse(pings[2][0].done()) # Pong acknowledges all previous pings self.receive_frame(Frame(True, OP_PONG, pings[1][1].data)) self.run_loop_once() self.run_loop_once() self.assertTrue(pings[0][0].done()) self.assertTrue(pings[1][0].done()) self.assertFalse(pings[2][0].done()) def test_cancelled_ping(self): ping = self.loop.run_until_complete(self.protocol.ping()) ping_frame = self.last_sent_frame() ping.cancel() pong_frame = Frame(True, OP_PONG, ping_frame.data) self.receive_frame(pong_frame) self.run_loop_once() self.run_loop_once() self.assertTrue(ping.cancelled()) def test_duplicate_ping(self): self.loop.run_until_complete(self.protocol.ping(b'foobar')) self.assertOneFrameSent(True, OP_PING, b'foobar') with self.assertRaises(ValueError): self.loop.run_until_complete(self.protocol.ping(b'foobar')) self.assertNoFrameSent() # Test the protocol's logic for rebuilding fragmented messages. def test_fragmented_text(self): self.receive_frame(Frame(False, OP_TEXT, 'ca'.encode('utf-8'))) self.receive_frame(Frame(True, OP_CONT, 'fé'.encode('utf-8'))) data = self.loop.run_until_complete(self.protocol.recv()) self.assertEqual(data, 'café') def test_fragmented_binary(self): self.receive_frame(Frame(False, OP_BINARY, b't')) self.receive_frame(Frame(False, OP_CONT, b'e')) self.receive_frame(Frame(True, OP_CONT, b'a')) data = self.loop.run_until_complete(self.protocol.recv()) self.assertEqual(data, b'tea') def test_fragmented_text_payload_too_big(self): self.protocol.max_size = 1024 self.receive_frame(Frame(False, OP_TEXT, 'café'.encode('utf-8') * 100)) self.receive_frame(Frame(True, OP_CONT, 'café'.encode('utf-8') * 105)) self.process_invalid_frames() self.assertConnectionFailed(1009, '') def test_fragmented_binary_payload_too_big(self): self.protocol.max_size = 1024 self.receive_frame(Frame(False, OP_BINARY, b'tea' * 171)) self.receive_frame(Frame(True, OP_CONT, b'tea' * 171)) self.process_invalid_frames() self.assertConnectionFailed(1009, '') def test_fragmented_text_no_max_size(self): self.protocol.max_size = None # for test coverage self.receive_frame(Frame(False, OP_TEXT, 'café'.encode('utf-8') * 100)) self.receive_frame(Frame(True, OP_CONT, 'café'.encode('utf-8') * 105)) data = self.loop.run_until_complete(self.protocol.recv()) self.assertEqual(data, 'café' * 205) def test_fragmented_binary_no_max_size(self): self.protocol.max_size = None # for test coverage self.receive_frame(Frame(False, OP_BINARY, b'tea' * 171)) self.receive_frame(Frame(True, OP_CONT, b'tea' * 171)) data = self.loop.run_until_complete(self.protocol.recv()) self.assertEqual(data, b'tea' * 342) def test_control_frame_within_fragmented_text(self): self.receive_frame(Frame(False, OP_TEXT, 'ca'.encode('utf-8'))) self.receive_frame(Frame(True, OP_PING, b'')) self.receive_frame(Frame(True, OP_CONT, 'fé'.encode('utf-8'))) data = self.loop.run_until_complete(self.protocol.recv()) self.assertEqual(data, 'café') self.assertOneFrameSent(True, OP_PONG, b'') def test_unterminated_fragmented_text(self): self.receive_frame(Frame(False, OP_TEXT, 'ca'.encode('utf-8'))) # Missing the second part of the fragmented frame. self.receive_frame(Frame(True, OP_BINARY, b'tea')) self.process_invalid_frames() self.assertConnectionFailed(1002, '') def test_close_handshake_in_fragmented_text(self): self.receive_frame(Frame(False, OP_TEXT, 'ca'.encode('utf-8'))) self.receive_frame(Frame(True, OP_CLOSE, b'')) self.process_invalid_frames() # The RFC may have overlooked this case: it says that control frames # can be interjected in the middle of a fragmented message and that a # close frame must be echoed. Even though there's an unterminated # message, technically, the closing handshake was successful. self.assertConnectionClosed(1005, '') def test_connection_close_in_fragmented_text(self): self.receive_frame(Frame(False, OP_TEXT, 'ca'.encode('utf-8'))) self.process_invalid_frames() self.assertConnectionFailed(1006, '') # Test miscellaneous code paths to ensure full coverage. def test_connection_lost(self): # Test calling connection_lost without going through close_connection. self.protocol.connection_lost(None) self.assertConnectionFailed(1006, '') def test_ensure_open_before_opening_handshake(self): # Simulate a bug by forcibly reverting the protocol state. self.protocol.state = State.CONNECTING with self.assertRaises(InvalidState): self.loop.run_until_complete(self.protocol.ensure_open()) def test_ensure_open_during_unclean_close(self): # Process connection_made in order to start transfer_data_task. self.run_loop_once() # Ensure the test terminates quickly. self.loop.call_later(MS, self.receive_eof_if_client) # Simulate the case when close() times out sending a close frame. self.protocol.fail_connection() with self.assertRaises(ConnectionClosed): self.loop.run_until_complete(self.protocol.ensure_open()) def test_legacy_recv(self): # By default legacy_recv in disabled. self.assertEqual(self.protocol.legacy_recv, False) self.close_connection() # Enable legacy_recv. self.protocol.legacy_recv = True # Now recv() returns None instead of raising ConnectionClosed. self.assertIsNone(self.loop.run_until_complete(self.protocol.recv())) def test_connection_closed_attributes(self): self.close_connection() with self.assertRaises(ConnectionClosed) as context: self.loop.run_until_complete(self.protocol.recv()) connection_closed_exc = context.exception self.assertEqual(connection_closed_exc.code, 1000) self.assertEqual(connection_closed_exc.reason, 'close') # Test the protocol logic for closing the connection. def test_local_close(self): # Emulate how the remote endpoint answers the closing handshake. self.loop.call_later(MS, self.receive_frame, self.close_frame) self.loop.call_later(MS, self.receive_eof_if_client) # Run the closing handshake. self.loop.run_until_complete(self.protocol.close(reason='close')) self.assertConnectionClosed(1000, 'close') self.assertOneFrameSent(*self.close_frame) # Closing the connection again is a no-op. self.loop.run_until_complete(self.protocol.close(reason='oh noes!')) self.assertConnectionClosed(1000, 'close') self.assertNoFrameSent() def test_remote_close(self): # Emulate how the remote endpoint initiates the closing handshake. self.loop.call_later(MS, self.receive_frame, self.close_frame) self.loop.call_later(MS, self.receive_eof_if_client) # Wait for some data in order to process the handshake. # After recv() raises ConnectionClosed, the connection is closed. with self.assertRaises(ConnectionClosed): self.loop.run_until_complete(self.protocol.recv()) self.assertConnectionClosed(1000, 'close') self.assertOneFrameSent(*self.close_frame) # Closing the connection again is a no-op. self.loop.run_until_complete(self.protocol.close(reason='oh noes!')) self.assertConnectionClosed(1000, 'close') self.assertNoFrameSent() def test_simultaneous_close(self): # Receive the incoming close frame right after self.protocol.close() # starts executing. This reproduces the error described in: # https://github.com/aaugustin/websockets/issues/339 self.loop.call_soon(self.receive_frame, self.remote_close) self.loop.call_soon(self.receive_eof_if_client) self.loop.run_until_complete(self.protocol.close(reason='local')) self.assertConnectionClosed(1000, 'remote') # The current implementation sends a close frame in response to the # close frame received from the remote end. It skips the close frame # that should be sent as a result of calling close(). self.assertOneFrameSent(*self.remote_close) def test_close_preserves_incoming_frames(self): self.receive_frame(Frame(True, OP_TEXT, b'hello')) self.loop.call_later(MS, self.receive_frame, self.close_frame) self.loop.call_later(MS, self.receive_eof_if_client) self.loop.run_until_complete(self.protocol.close(reason='close')) self.assertConnectionClosed(1000, 'close') self.assertOneFrameSent(*self.close_frame) next_message = self.loop.run_until_complete(self.protocol.recv()) self.assertEqual(next_message, 'hello') def test_close_protocol_error(self): invalid_close_frame = Frame(True, OP_CLOSE, b'\x00') self.receive_frame(invalid_close_frame) self.receive_eof_if_client() self.run_loop_once() self.loop.run_until_complete(self.protocol.close(reason='close')) self.assertConnectionFailed(1002, '') def test_close_connection_lost(self): self.receive_eof() self.run_loop_once() self.loop.run_until_complete(self.protocol.close(reason='close')) self.assertConnectionFailed(1006, '') def test_local_close_during_recv(self): recv = self.ensure_future(self.protocol.recv()) self.loop.call_later(MS, self.receive_frame, self.close_frame) self.loop.call_later(MS, self.receive_eof_if_client) self.loop.run_until_complete(self.protocol.close(reason='close')) with self.assertRaises(ConnectionClosed): self.loop.run_until_complete(recv) self.assertConnectionClosed(1000, 'close') # There is no test_remote_close_during_recv because it would be identical # to test_remote_close. def test_remote_close_during_send(self): self.make_drain_slow() send = self.ensure_future(self.protocol.send('hello')) self.receive_frame(self.close_frame) self.receive_eof() with self.assertRaises(ConnectionClosed): self.loop.run_until_complete(send) self.assertConnectionClosed(1000, 'close') # There is no test_local_close_during_send because this cannot really # happen, considering that writes are serialized. class ServerTests(CommonTests, unittest.TestCase): def setUp(self): super().setUp() self.protocol.is_client = False self.protocol.side = 'server' def test_local_close_send_close_frame_timeout(self): self.protocol.timeout = 10 * MS self.make_drain_slow(50 * MS) # If we can't send a close frame, time out in 10ms. # Check the timing within -1/+9ms for robustness. with self.assertCompletesWithin(9 * MS, 19 * MS): self.loop.run_until_complete(self.protocol.close(reason='close')) self.assertConnectionClosed(1006, '') def test_local_close_receive_close_frame_timeout(self): self.protocol.timeout = 10 * MS # If the client doesn't send a close frame, time out in 10ms. # Check the timing within -1/+9ms for robustness. with self.assertCompletesWithin(9 * MS, 19 * MS): self.loop.run_until_complete(self.protocol.close(reason='close')) self.assertConnectionClosed(1006, '') def test_local_close_connection_lost_timeout_after_write_eof(self): self.protocol.timeout = 10 * MS # If the client doesn't close its side of the TCP connection after we # half-close our side with write_eof(), time out in 10ms. # Check the timing within -1/+9ms for robustness. with self.assertCompletesWithin(9 * MS, 19 * MS): # HACK: disable write_eof => other end drops connection emulation. self.transport._eof = True self.receive_frame(self.close_frame) self.loop.run_until_complete(self.protocol.close(reason='close')) self.assertConnectionClosed(1000, 'close') def test_local_close_connection_lost_timeout_after_close(self): self.protocol.timeout = 10 * MS # If the client doesn't close its side of the TCP connection after we # half-close our side with write_eof() and close it with close(), time # out in 20ms. # Check the timing within -1/+9ms for robustness. with self.assertCompletesWithin(19 * MS, 29 * MS): # HACK: disable write_eof => other end drops connection emulation. self.transport._eof = True # HACK: disable close => other end drops connection emulation. self.transport._closing = True self.receive_frame(self.close_frame) self.loop.run_until_complete(self.protocol.close(reason='close')) self.assertConnectionClosed(1000, 'close') class ClientTests(CommonTests, unittest.TestCase): def setUp(self): super().setUp() self.protocol.is_client = True self.protocol.side = 'client' def test_local_close_send_close_frame_timeout(self): self.protocol.timeout = 10 * MS self.make_drain_slow(50 * MS) # If we can't send a close frame, time out in 20ms. # - 10ms waiting for sending a close frame # - 10ms waiting for receiving a half-close # Check the timing within -1/+9ms for robustness. with self.assertCompletesWithin(19 * MS, 29 * MS): self.loop.run_until_complete(self.protocol.close(reason='close')) self.assertConnectionClosed(1006, '') def test_local_close_receive_close_frame_timeout(self): self.protocol.timeout = 10 * MS # If the server doesn't send a close frame, time out in 20ms: # - 10ms waiting for receiving a close frame # - 10ms waiting for receiving a half-close # Check the timing within -1/+9ms for robustness. with self.assertCompletesWithin(19 * MS, 29 * MS): self.loop.run_until_complete(self.protocol.close(reason='close')) self.assertConnectionClosed(1006, '') def test_local_close_connection_lost_timeout_after_write_eof(self): self.protocol.timeout = 10 * MS # If the server doesn't half-close its side of the TCP connection # after we send a close frame, time out in 20ms: # - 10ms waiting for receiving a half-close # - 10ms waiting for receiving a close after write_eof # Check the timing within -1/+9ms for robustness. with self.assertCompletesWithin(19 * MS, 29 * MS): # HACK: disable write_eof => other end drops connection emulation. self.transport._eof = True self.receive_frame(self.close_frame) self.loop.run_until_complete(self.protocol.close(reason='close')) self.assertConnectionClosed(1000, 'close') def test_local_close_connection_lost_timeout_after_close(self): self.protocol.timeout = 10 * MS # If the client doesn't close its side of the TCP connection after we # half-close our side with write_eof() and close it with close(), time # out in 20ms. # - 10ms waiting for receiving a half-close # - 10ms waiting for receiving a close after write_eof # - 10ms waiting for receiving a close after close # Check the timing within -1/+9ms for robustness. with self.assertCompletesWithin(29 * MS, 39 * MS): # HACK: disable write_eof => other end drops connection emulation. self.transport._eof = True # HACK: disable close => other end drops connection emulation. self.transport._closing = True self.receive_frame(self.close_frame) self.loop.run_until_complete(self.protocol.close(reason='close')) self.assertConnectionClosed(1000, 'close')