tuxbot-bot/venv/lib/python3.7/site-packages/websockets/test_protocol.py

998 lines
38 KiB
Python
Raw Normal View History

2019-12-16 18:12:10 +01:00
import asyncio
import contextlib
import functools
import logging
import os
import time
import unittest
import unittest.mock
from .compatibility import asyncio_ensure_future
from .exceptions import ConnectionClosed, InvalidState
from .framing import *
from .protocol import State, WebSocketCommonProtocol
# Avoid displaying stack traces at the ERROR logging level.
logging.basicConfig(level=logging.CRITICAL)
# Unit for timeouts. May be increased on slow machines by setting the
# WEBSOCKETS_TESTS_TIMEOUT_FACTOR environment variable.
MS = 0.001 * int(os.environ.get('WEBSOCKETS_TESTS_TIMEOUT_FACTOR', 1))
# asyncio's debug mode has a 10x performance penalty for this test suite.
if os.environ.get('PYTHONASYNCIODEBUG'): # pragma: no cover
MS *= 10
# Ensure that timeouts are larger than the clock's resolution (for Windows).
MS = max(MS, 2.5 * time.get_clock_info('monotonic').resolution)
class TransportMock(unittest.mock.Mock):
"""
Transport mock to control the protocol's inputs and outputs in tests.
It calls the protocol's connection_made and connection_lost methods like
actual transports.
It also calls the protocol's connection_open method to bypass the
WebSocket handshake.
To simulate incoming data, tests call the protocol's data_received and
eof_received methods directly.
They could also pause_writing and resume_writing to test flow control.
"""
# This should happen in __init__ but overriding Mock.__init__ is hard.
def setup_mock(self, loop, protocol):
self.loop = loop
self.protocol = protocol
self._eof = False
self._closing = False
# Simulate a successful TCP handshake.
self.protocol.connection_made(self)
# Simulate a successful WebSocket handshake.
self.protocol.connection_open()
def can_write_eof(self):
return True
def write_eof(self):
# When the protocol half-closes the TCP connection, it expects the
# other end to close it. Simulate that.
if not self._eof:
self.loop.call_soon(self.close)
self._eof = True
def is_closing(self):
return self._closing
def close(self):
# Simulate how actual transports drop the connection.
if not self._closing:
self.loop.call_soon(self.protocol.connection_lost, None)
self._closing = True
def abort(self):
# Change this to an `if` if tests call abort() multiple times.
assert self.protocol.state is not State.CLOSED
self.loop.call_soon(self.protocol.connection_lost, None)
class CommonTests:
"""
Mixin that defines most tests but doesn't inherit unittest.TestCase.
Tests are run by the ServerTests and ClientTests subclasses.
"""
def setUp(self):
super().setUp()
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
self.protocol = WebSocketCommonProtocol()
self.transport = TransportMock()
self.transport.setup_mock(self.loop, self.protocol)
def tearDown(self):
self.transport.close()
self.loop.run_until_complete(self.protocol.close())
self.loop.close()
super().tearDown()
# Utilities for writing tests.
def run_loop_once(self):
# Process callbacks scheduled with call_soon by appending a callback
# to stop the event loop then running it until it hits that callback.
self.loop.call_soon(self.loop.stop)
self.loop.run_forever()
def make_drain_slow(self, delay=MS):
# Process connection_made in order to initialize self.protocol.writer.
self.run_loop_once()
original_drain = self.protocol.writer.drain
@asyncio.coroutine
def delayed_drain():
yield from asyncio.sleep(delay, loop=self.loop)
yield from original_drain()
self.protocol.writer.drain = delayed_drain
close_frame = Frame(True, OP_CLOSE, serialize_close(1000, 'close'))
local_close = Frame(True, OP_CLOSE, serialize_close(1000, 'local'))
remote_close = Frame(True, OP_CLOSE, serialize_close(1000, 'remote'))
@property
def ensure_future(self):
return functools.partial(asyncio_ensure_future, loop=self.loop)
def receive_frame(self, frame):
"""
Make the protocol receive a frame.
"""
writer = self.protocol.data_received
mask = not self.protocol.is_client
frame.write(writer, mask=mask)
def receive_eof(self):
"""
Make the protocol receive the end of the data stream.
Since ``WebSocketCommonProtocol.eof_received`` returns ``None``, an
actual transport would close itself after calling it. This function
emulates that behavior.
"""
self.protocol.eof_received()
self.loop.call_soon(self.transport.close)
def receive_eof_if_client(self):
"""
Like receive_eof, but only if this is the client side.
Since the server is supposed to initiate the termination of the TCP
connection, this method helps making tests work for both sides.
"""
if self.protocol.is_client:
self.receive_eof()
def close_connection(self, code=1000, reason='close'):
"""
Execute a closing handshake.
This puts the connection in the CLOSED state.
"""
close_frame_data = serialize_close(code, reason)
# Prepare the response to the closing handshake from the remote side.
self.receive_frame(Frame(True, OP_CLOSE, close_frame_data))
self.receive_eof_if_client()
# Trigger the closing handshake from the local side and complete it.
self.loop.run_until_complete(self.protocol.close(code, reason))
# Empty the outgoing data stream so we can make assertions later on.
self.assertOneFrameSent(True, OP_CLOSE, close_frame_data)
assert self.protocol.state is State.CLOSED
def half_close_connection_local(self, code=1000, reason='close'):
"""
Start a closing handshake but do not complete it.
The main difference with `close_connection` is that the connection is
left in the CLOSING state until the event loop runs again.
The current implementation returns a task that must be awaited or
cancelled, else asyncio complains about destroying a pending task.
"""
close_frame_data = serialize_close(code, reason)
# Trigger the closing handshake from the local endpoint.
close_task = self.ensure_future(self.protocol.close(code, reason))
self.run_loop_once() # wait_for executes
self.run_loop_once() # write_frame executes
# Empty the outgoing data stream so we can make assertions later on.
self.assertOneFrameSent(True, OP_CLOSE, close_frame_data)
assert self.protocol.state is State.CLOSING
# Complete the closing sequence at 1ms intervals so the test can run
# at each point even it goes back to the event loop several times.
self.loop.call_later(
MS, self.receive_frame, Frame(True, OP_CLOSE, close_frame_data))
self.loop.call_later(2 * MS, self.receive_eof_if_client)
# This task must be awaited or cancelled by the caller.
return close_task
def half_close_connection_remote(self, code=1000, reason='close'):
"""
Receive a closing handshake but do not complete it.
The main difference with `close_connection` is that the connection is
left in the CLOSING state until the event loop runs again.
"""
# On the server side, websockets completes the closing handshake and
# closes the TCP connection immediately. Yield to the event loop after
# sending the close frame to run the test while the connection is in
# the CLOSING state.
if not self.protocol.is_client:
self.make_drain_slow()
close_frame_data = serialize_close(code, reason)
# Trigger the closing handshake from the remote endpoint.
self.receive_frame(Frame(True, OP_CLOSE, close_frame_data))
self.run_loop_once() # read_frame executes
# Empty the outgoing data stream so we can make assertions later on.
self.assertOneFrameSent(True, OP_CLOSE, close_frame_data)
assert self.protocol.state is State.CLOSING
# Complete the closing sequence at 1ms intervals so the test can run
# at each point even it goes back to the event loop several times.
self.loop.call_later(2 * MS, self.receive_eof_if_client)
def process_invalid_frames(self):
"""
Make the protocol fail quickly after simulating invalid data.
To achieve this, this function triggers the protocol's eof_received,
which interrupts pending reads waiting for more data.
"""
self.run_loop_once()
self.receive_eof()
self.loop.run_until_complete(self.protocol.close_connection_task)
def last_sent_frame(self):
"""
Read the last frame sent to the transport.
This method assumes that at most one frame was sent. It raises an
AssertionError otherwise.
"""
stream = asyncio.StreamReader(loop=self.loop)
for (data,), kw in self.transport.write.call_args_list:
stream.feed_data(data)
self.transport.write.call_args_list = []
stream.feed_eof()
if stream.at_eof():
frame = None
else:
frame = self.loop.run_until_complete(Frame.read(
stream.readexactly, mask=self.protocol.is_client))
if not stream.at_eof(): # pragma: no cover
data = self.loop.run_until_complete(stream.read())
raise AssertionError("Trailing data found: {!r}".format(data))
return frame
def assertOneFrameSent(self, *args):
self.assertEqual(self.last_sent_frame(), Frame(*args))
def assertNoFrameSent(self):
self.assertIsNone(self.last_sent_frame())
def assertConnectionClosed(self, code, message):
# The following line guarantees that connection_lost was called.
self.assertEqual(self.protocol.state, State.CLOSED)
# A close frame was received.
self.assertEqual(self.protocol.close_code, code)
self.assertEqual(self.protocol.close_reason, message)
def assertConnectionFailed(self, code, message):
# The following line guarantees that connection_lost was called.
self.assertEqual(self.protocol.state, State.CLOSED)
# No close frame was received.
self.assertEqual(self.protocol.close_code, 1006)
self.assertEqual(self.protocol.close_reason, '')
# A close frame was sent -- unless the connection was already lost.
if code == 1006:
self.assertNoFrameSent()
else:
self.assertOneFrameSent(
True, OP_CLOSE, serialize_close(code, message))
@contextlib.contextmanager
def assertCompletesWithin(self, min_time, max_time):
t0 = self.loop.time()
yield
t1 = self.loop.time()
dt = t1 - t0
self.assertGreaterEqual(
dt, min_time, "Too fast: {} < {}".format(dt, min_time))
self.assertLess(
dt, max_time, "Too slow: {} >= {}".format(dt, max_time))
# Test public attributes.
def test_local_address(self):
get_extra_info = unittest.mock.Mock(return_value=('host', 4312))
self.transport.get_extra_info = get_extra_info
self.assertEqual(self.protocol.local_address, ('host', 4312))
get_extra_info.assert_called_with('sockname', None)
def test_local_address_before_connection(self):
# Emulate the situation before connection_open() runs.
self.protocol.writer, _writer = None, self.protocol.writer
try:
self.assertEqual(self.protocol.local_address, None)
finally:
self.protocol.writer = _writer
def test_remote_address(self):
get_extra_info = unittest.mock.Mock(return_value=('host', 4312))
self.transport.get_extra_info = get_extra_info
self.assertEqual(self.protocol.remote_address, ('host', 4312))
get_extra_info.assert_called_with('peername', None)
def test_remote_address_before_connection(self):
# Emulate the situation before connection_open() runs.
self.protocol.writer, _writer = None, self.protocol.writer
try:
self.assertEqual(self.protocol.remote_address, None)
finally:
self.protocol.writer = _writer
def test_open(self):
self.assertTrue(self.protocol.open)
self.close_connection()
self.assertFalse(self.protocol.open)
def test_closed(self):
self.assertFalse(self.protocol.closed)
self.close_connection()
self.assertTrue(self.protocol.closed)
# Test the recv coroutine.
def test_recv_text(self):
self.receive_frame(Frame(True, OP_TEXT, 'café'.encode('utf-8')))
data = self.loop.run_until_complete(self.protocol.recv())
self.assertEqual(data, 'café')
def test_recv_binary(self):
self.receive_frame(Frame(True, OP_BINARY, b'tea'))
data = self.loop.run_until_complete(self.protocol.recv())
self.assertEqual(data, b'tea')
def test_recv_on_closing_connection_local(self):
close_task = self.half_close_connection_local()
with self.assertRaises(ConnectionClosed):
self.loop.run_until_complete(self.protocol.recv())
self.loop.run_until_complete(close_task) # cleanup
def test_recv_on_closing_connection_remote(self):
self.half_close_connection_remote()
with self.assertRaises(ConnectionClosed):
self.loop.run_until_complete(self.protocol.recv())
def test_recv_on_closed_connection(self):
self.close_connection()
with self.assertRaises(ConnectionClosed):
self.loop.run_until_complete(self.protocol.recv())
def test_recv_protocol_error(self):
self.receive_frame(Frame(True, OP_CONT, 'café'.encode('utf-8')))
self.process_invalid_frames()
self.assertConnectionFailed(1002, '')
def test_recv_unicode_error(self):
self.receive_frame(Frame(True, OP_TEXT, 'café'.encode('latin-1')))
self.process_invalid_frames()
self.assertConnectionFailed(1007, '')
def test_recv_text_payload_too_big(self):
self.protocol.max_size = 1024
self.receive_frame(Frame(True, OP_TEXT, 'café'.encode('utf-8') * 205))
self.process_invalid_frames()
self.assertConnectionFailed(1009, '')
def test_recv_binary_payload_too_big(self):
self.protocol.max_size = 1024
self.receive_frame(Frame(True, OP_BINARY, b'tea' * 342))
self.process_invalid_frames()
self.assertConnectionFailed(1009, '')
def test_recv_text_no_max_size(self):
self.protocol.max_size = None # for test coverage
self.receive_frame(Frame(True, OP_TEXT, 'café'.encode('utf-8') * 205))
data = self.loop.run_until_complete(self.protocol.recv())
self.assertEqual(data, 'café' * 205)
def test_recv_binary_no_max_size(self):
self.protocol.max_size = None # for test coverage
self.receive_frame(Frame(True, OP_BINARY, b'tea' * 342))
data = self.loop.run_until_complete(self.protocol.recv())
self.assertEqual(data, b'tea' * 342)
def test_recv_other_error(self):
@asyncio.coroutine
def read_message():
raise Exception("BOOM")
self.protocol.read_message = read_message
self.process_invalid_frames()
self.assertConnectionFailed(1011, '')
def test_recv_cancelled(self):
recv = self.ensure_future(self.protocol.recv())
self.loop.call_soon(recv.cancel)
with self.assertRaises(asyncio.CancelledError):
self.loop.run_until_complete(recv)
# The next frame doesn't disappear in a vacuum (it used to).
self.receive_frame(Frame(True, OP_TEXT, 'café'.encode('utf-8')))
data = self.loop.run_until_complete(self.protocol.recv())
self.assertEqual(data, 'café')
# Test the send coroutine.
def test_send_text(self):
self.loop.run_until_complete(self.protocol.send('café'))
self.assertOneFrameSent(True, OP_TEXT, 'café'.encode('utf-8'))
def test_send_binary(self):
self.loop.run_until_complete(self.protocol.send(b'tea'))
self.assertOneFrameSent(True, OP_BINARY, b'tea')
def test_send_type_error(self):
with self.assertRaises(TypeError):
self.loop.run_until_complete(self.protocol.send(42))
self.assertNoFrameSent()
def test_send_on_closing_connection_local(self):
close_task = self.half_close_connection_local()
with self.assertRaises(ConnectionClosed):
self.loop.run_until_complete(self.protocol.send('foobar'))
self.assertNoFrameSent()
self.loop.run_until_complete(close_task) # cleanup
def test_send_on_closing_connection_remote(self):
self.half_close_connection_remote()
with self.assertRaises(ConnectionClosed):
self.loop.run_until_complete(self.protocol.send('foobar'))
self.assertNoFrameSent()
def test_send_on_closed_connection(self):
self.close_connection()
with self.assertRaises(ConnectionClosed):
self.loop.run_until_complete(self.protocol.send('foobar'))
self.assertNoFrameSent()
# Test the ping coroutine.
def test_ping_default(self):
self.loop.run_until_complete(self.protocol.ping())
# With our testing tools, it's more convenient to extract the expected
# ping data from the library's internals than from the frame sent.
ping_data = next(iter(self.protocol.pings))
self.assertIsInstance(ping_data, bytes)
self.assertEqual(len(ping_data), 4)
self.assertOneFrameSent(True, OP_PING, ping_data)
def test_ping_text(self):
self.loop.run_until_complete(self.protocol.ping('café'))
self.assertOneFrameSent(True, OP_PING, 'café'.encode('utf-8'))
def test_ping_binary(self):
self.loop.run_until_complete(self.protocol.ping(b'tea'))
self.assertOneFrameSent(True, OP_PING, b'tea')
def test_ping_type_error(self):
with self.assertRaises(TypeError):
self.loop.run_until_complete(self.protocol.ping(42))
self.assertNoFrameSent()
def test_ping_on_closing_connection_local(self):
close_task = self.half_close_connection_local()
with self.assertRaises(ConnectionClosed):
self.loop.run_until_complete(self.protocol.ping())
self.assertNoFrameSent()
self.loop.run_until_complete(close_task) # cleanup
def test_ping_on_closing_connection_remote(self):
self.half_close_connection_remote()
with self.assertRaises(ConnectionClosed):
self.loop.run_until_complete(self.protocol.ping())
self.assertNoFrameSent()
def test_ping_on_closed_connection(self):
self.close_connection()
with self.assertRaises(ConnectionClosed):
self.loop.run_until_complete(self.protocol.ping())
self.assertNoFrameSent()
# Test the pong coroutine.
def test_pong_default(self):
self.loop.run_until_complete(self.protocol.pong())
self.assertOneFrameSent(True, OP_PONG, b'')
def test_pong_text(self):
self.loop.run_until_complete(self.protocol.pong('café'))
self.assertOneFrameSent(True, OP_PONG, 'café'.encode('utf-8'))
def test_pong_binary(self):
self.loop.run_until_complete(self.protocol.pong(b'tea'))
self.assertOneFrameSent(True, OP_PONG, b'tea')
def test_pong_type_error(self):
with self.assertRaises(TypeError):
self.loop.run_until_complete(self.protocol.pong(42))
self.assertNoFrameSent()
def test_pong_on_closing_connection_local(self):
close_task = self.half_close_connection_local()
with self.assertRaises(ConnectionClosed):
self.loop.run_until_complete(self.protocol.pong())
self.assertNoFrameSent()
self.loop.run_until_complete(close_task) # cleanup
def test_pong_on_closing_connection_remote(self):
self.half_close_connection_remote()
with self.assertRaises(ConnectionClosed):
self.loop.run_until_complete(self.protocol.pong())
self.assertNoFrameSent()
def test_pong_on_closed_connection(self):
self.close_connection()
with self.assertRaises(ConnectionClosed):
self.loop.run_until_complete(self.protocol.pong())
self.assertNoFrameSent()
# Test the protocol's logic for acknowledging pings with pongs.
def test_answer_ping(self):
self.receive_frame(Frame(True, OP_PING, b'test'))
self.run_loop_once()
self.assertOneFrameSent(True, OP_PONG, b'test')
def test_ignore_pong(self):
self.receive_frame(Frame(True, OP_PONG, b'test'))
self.run_loop_once()
self.assertNoFrameSent()
def test_acknowledge_ping(self):
ping = self.loop.run_until_complete(self.protocol.ping())
self.assertFalse(ping.done())
ping_frame = self.last_sent_frame()
pong_frame = Frame(True, OP_PONG, ping_frame.data)
self.receive_frame(pong_frame)
self.run_loop_once()
self.run_loop_once()
self.assertTrue(ping.done())
def test_cancel_ping(self):
ping = self.loop.run_until_complete(self.protocol.ping())
# Remove the frame from the buffer, else close_connection() complains.
self.last_sent_frame()
self.assertFalse(ping.cancelled())
self.close_connection()
self.assertTrue(ping.cancelled())
def test_acknowledge_previous_pings(self):
pings = [(
self.loop.run_until_complete(self.protocol.ping()),
self.last_sent_frame(),
) for i in range(3)]
# Unsolicited pong doesn't acknowledge pings
self.receive_frame(Frame(True, OP_PONG, b''))
self.run_loop_once()
self.run_loop_once()
self.assertFalse(pings[0][0].done())
self.assertFalse(pings[1][0].done())
self.assertFalse(pings[2][0].done())
# Pong acknowledges all previous pings
self.receive_frame(Frame(True, OP_PONG, pings[1][1].data))
self.run_loop_once()
self.run_loop_once()
self.assertTrue(pings[0][0].done())
self.assertTrue(pings[1][0].done())
self.assertFalse(pings[2][0].done())
def test_cancelled_ping(self):
ping = self.loop.run_until_complete(self.protocol.ping())
ping_frame = self.last_sent_frame()
ping.cancel()
pong_frame = Frame(True, OP_PONG, ping_frame.data)
self.receive_frame(pong_frame)
self.run_loop_once()
self.run_loop_once()
self.assertTrue(ping.cancelled())
def test_duplicate_ping(self):
self.loop.run_until_complete(self.protocol.ping(b'foobar'))
self.assertOneFrameSent(True, OP_PING, b'foobar')
with self.assertRaises(ValueError):
self.loop.run_until_complete(self.protocol.ping(b'foobar'))
self.assertNoFrameSent()
# Test the protocol's logic for rebuilding fragmented messages.
def test_fragmented_text(self):
self.receive_frame(Frame(False, OP_TEXT, 'ca'.encode('utf-8')))
self.receive_frame(Frame(True, OP_CONT, ''.encode('utf-8')))
data = self.loop.run_until_complete(self.protocol.recv())
self.assertEqual(data, 'café')
def test_fragmented_binary(self):
self.receive_frame(Frame(False, OP_BINARY, b't'))
self.receive_frame(Frame(False, OP_CONT, b'e'))
self.receive_frame(Frame(True, OP_CONT, b'a'))
data = self.loop.run_until_complete(self.protocol.recv())
self.assertEqual(data, b'tea')
def test_fragmented_text_payload_too_big(self):
self.protocol.max_size = 1024
self.receive_frame(Frame(False, OP_TEXT, 'café'.encode('utf-8') * 100))
self.receive_frame(Frame(True, OP_CONT, 'café'.encode('utf-8') * 105))
self.process_invalid_frames()
self.assertConnectionFailed(1009, '')
def test_fragmented_binary_payload_too_big(self):
self.protocol.max_size = 1024
self.receive_frame(Frame(False, OP_BINARY, b'tea' * 171))
self.receive_frame(Frame(True, OP_CONT, b'tea' * 171))
self.process_invalid_frames()
self.assertConnectionFailed(1009, '')
def test_fragmented_text_no_max_size(self):
self.protocol.max_size = None # for test coverage
self.receive_frame(Frame(False, OP_TEXT, 'café'.encode('utf-8') * 100))
self.receive_frame(Frame(True, OP_CONT, 'café'.encode('utf-8') * 105))
data = self.loop.run_until_complete(self.protocol.recv())
self.assertEqual(data, 'café' * 205)
def test_fragmented_binary_no_max_size(self):
self.protocol.max_size = None # for test coverage
self.receive_frame(Frame(False, OP_BINARY, b'tea' * 171))
self.receive_frame(Frame(True, OP_CONT, b'tea' * 171))
data = self.loop.run_until_complete(self.protocol.recv())
self.assertEqual(data, b'tea' * 342)
def test_control_frame_within_fragmented_text(self):
self.receive_frame(Frame(False, OP_TEXT, 'ca'.encode('utf-8')))
self.receive_frame(Frame(True, OP_PING, b''))
self.receive_frame(Frame(True, OP_CONT, ''.encode('utf-8')))
data = self.loop.run_until_complete(self.protocol.recv())
self.assertEqual(data, 'café')
self.assertOneFrameSent(True, OP_PONG, b'')
def test_unterminated_fragmented_text(self):
self.receive_frame(Frame(False, OP_TEXT, 'ca'.encode('utf-8')))
# Missing the second part of the fragmented frame.
self.receive_frame(Frame(True, OP_BINARY, b'tea'))
self.process_invalid_frames()
self.assertConnectionFailed(1002, '')
def test_close_handshake_in_fragmented_text(self):
self.receive_frame(Frame(False, OP_TEXT, 'ca'.encode('utf-8')))
self.receive_frame(Frame(True, OP_CLOSE, b''))
self.process_invalid_frames()
# The RFC may have overlooked this case: it says that control frames
# can be interjected in the middle of a fragmented message and that a
# close frame must be echoed. Even though there's an unterminated
# message, technically, the closing handshake was successful.
self.assertConnectionClosed(1005, '')
def test_connection_close_in_fragmented_text(self):
self.receive_frame(Frame(False, OP_TEXT, 'ca'.encode('utf-8')))
self.process_invalid_frames()
self.assertConnectionFailed(1006, '')
# Test miscellaneous code paths to ensure full coverage.
def test_connection_lost(self):
# Test calling connection_lost without going through close_connection.
self.protocol.connection_lost(None)
self.assertConnectionFailed(1006, '')
def test_ensure_open_before_opening_handshake(self):
# Simulate a bug by forcibly reverting the protocol state.
self.protocol.state = State.CONNECTING
with self.assertRaises(InvalidState):
self.loop.run_until_complete(self.protocol.ensure_open())
def test_ensure_open_during_unclean_close(self):
# Process connection_made in order to start transfer_data_task.
self.run_loop_once()
# Ensure the test terminates quickly.
self.loop.call_later(MS, self.receive_eof_if_client)
# Simulate the case when close() times out sending a close frame.
self.protocol.fail_connection()
with self.assertRaises(ConnectionClosed):
self.loop.run_until_complete(self.protocol.ensure_open())
def test_legacy_recv(self):
# By default legacy_recv in disabled.
self.assertEqual(self.protocol.legacy_recv, False)
self.close_connection()
# Enable legacy_recv.
self.protocol.legacy_recv = True
# Now recv() returns None instead of raising ConnectionClosed.
self.assertIsNone(self.loop.run_until_complete(self.protocol.recv()))
def test_connection_closed_attributes(self):
self.close_connection()
with self.assertRaises(ConnectionClosed) as context:
self.loop.run_until_complete(self.protocol.recv())
connection_closed_exc = context.exception
self.assertEqual(connection_closed_exc.code, 1000)
self.assertEqual(connection_closed_exc.reason, 'close')
# Test the protocol logic for closing the connection.
def test_local_close(self):
# Emulate how the remote endpoint answers the closing handshake.
self.loop.call_later(MS, self.receive_frame, self.close_frame)
self.loop.call_later(MS, self.receive_eof_if_client)
# Run the closing handshake.
self.loop.run_until_complete(self.protocol.close(reason='close'))
self.assertConnectionClosed(1000, 'close')
self.assertOneFrameSent(*self.close_frame)
# Closing the connection again is a no-op.
self.loop.run_until_complete(self.protocol.close(reason='oh noes!'))
self.assertConnectionClosed(1000, 'close')
self.assertNoFrameSent()
def test_remote_close(self):
# Emulate how the remote endpoint initiates the closing handshake.
self.loop.call_later(MS, self.receive_frame, self.close_frame)
self.loop.call_later(MS, self.receive_eof_if_client)
# Wait for some data in order to process the handshake.
# After recv() raises ConnectionClosed, the connection is closed.
with self.assertRaises(ConnectionClosed):
self.loop.run_until_complete(self.protocol.recv())
self.assertConnectionClosed(1000, 'close')
self.assertOneFrameSent(*self.close_frame)
# Closing the connection again is a no-op.
self.loop.run_until_complete(self.protocol.close(reason='oh noes!'))
self.assertConnectionClosed(1000, 'close')
self.assertNoFrameSent()
def test_simultaneous_close(self):
# Receive the incoming close frame right after self.protocol.close()
# starts executing. This reproduces the error described in:
# https://github.com/aaugustin/websockets/issues/339
self.loop.call_soon(self.receive_frame, self.remote_close)
self.loop.call_soon(self.receive_eof_if_client)
self.loop.run_until_complete(self.protocol.close(reason='local'))
self.assertConnectionClosed(1000, 'remote')
# The current implementation sends a close frame in response to the
# close frame received from the remote end. It skips the close frame
# that should be sent as a result of calling close().
self.assertOneFrameSent(*self.remote_close)
def test_close_preserves_incoming_frames(self):
self.receive_frame(Frame(True, OP_TEXT, b'hello'))
self.loop.call_later(MS, self.receive_frame, self.close_frame)
self.loop.call_later(MS, self.receive_eof_if_client)
self.loop.run_until_complete(self.protocol.close(reason='close'))
self.assertConnectionClosed(1000, 'close')
self.assertOneFrameSent(*self.close_frame)
next_message = self.loop.run_until_complete(self.protocol.recv())
self.assertEqual(next_message, 'hello')
def test_close_protocol_error(self):
invalid_close_frame = Frame(True, OP_CLOSE, b'\x00')
self.receive_frame(invalid_close_frame)
self.receive_eof_if_client()
self.run_loop_once()
self.loop.run_until_complete(self.protocol.close(reason='close'))
self.assertConnectionFailed(1002, '')
def test_close_connection_lost(self):
self.receive_eof()
self.run_loop_once()
self.loop.run_until_complete(self.protocol.close(reason='close'))
self.assertConnectionFailed(1006, '')
def test_local_close_during_recv(self):
recv = self.ensure_future(self.protocol.recv())
self.loop.call_later(MS, self.receive_frame, self.close_frame)
self.loop.call_later(MS, self.receive_eof_if_client)
self.loop.run_until_complete(self.protocol.close(reason='close'))
with self.assertRaises(ConnectionClosed):
self.loop.run_until_complete(recv)
self.assertConnectionClosed(1000, 'close')
# There is no test_remote_close_during_recv because it would be identical
# to test_remote_close.
def test_remote_close_during_send(self):
self.make_drain_slow()
send = self.ensure_future(self.protocol.send('hello'))
self.receive_frame(self.close_frame)
self.receive_eof()
with self.assertRaises(ConnectionClosed):
self.loop.run_until_complete(send)
self.assertConnectionClosed(1000, 'close')
# There is no test_local_close_during_send because this cannot really
# happen, considering that writes are serialized.
class ServerTests(CommonTests, unittest.TestCase):
def setUp(self):
super().setUp()
self.protocol.is_client = False
self.protocol.side = 'server'
def test_local_close_send_close_frame_timeout(self):
self.protocol.timeout = 10 * MS
self.make_drain_slow(50 * MS)
# If we can't send a close frame, time out in 10ms.
# Check the timing within -1/+9ms for robustness.
with self.assertCompletesWithin(9 * MS, 19 * MS):
self.loop.run_until_complete(self.protocol.close(reason='close'))
self.assertConnectionClosed(1006, '')
def test_local_close_receive_close_frame_timeout(self):
self.protocol.timeout = 10 * MS
# If the client doesn't send a close frame, time out in 10ms.
# Check the timing within -1/+9ms for robustness.
with self.assertCompletesWithin(9 * MS, 19 * MS):
self.loop.run_until_complete(self.protocol.close(reason='close'))
self.assertConnectionClosed(1006, '')
def test_local_close_connection_lost_timeout_after_write_eof(self):
self.protocol.timeout = 10 * MS
# If the client doesn't close its side of the TCP connection after we
# half-close our side with write_eof(), time out in 10ms.
# Check the timing within -1/+9ms for robustness.
with self.assertCompletesWithin(9 * MS, 19 * MS):
# HACK: disable write_eof => other end drops connection emulation.
self.transport._eof = True
self.receive_frame(self.close_frame)
self.loop.run_until_complete(self.protocol.close(reason='close'))
self.assertConnectionClosed(1000, 'close')
def test_local_close_connection_lost_timeout_after_close(self):
self.protocol.timeout = 10 * MS
# If the client doesn't close its side of the TCP connection after we
# half-close our side with write_eof() and close it with close(), time
# out in 20ms.
# Check the timing within -1/+9ms for robustness.
with self.assertCompletesWithin(19 * MS, 29 * MS):
# HACK: disable write_eof => other end drops connection emulation.
self.transport._eof = True
# HACK: disable close => other end drops connection emulation.
self.transport._closing = True
self.receive_frame(self.close_frame)
self.loop.run_until_complete(self.protocol.close(reason='close'))
self.assertConnectionClosed(1000, 'close')
class ClientTests(CommonTests, unittest.TestCase):
def setUp(self):
super().setUp()
self.protocol.is_client = True
self.protocol.side = 'client'
def test_local_close_send_close_frame_timeout(self):
self.protocol.timeout = 10 * MS
self.make_drain_slow(50 * MS)
# If we can't send a close frame, time out in 20ms.
# - 10ms waiting for sending a close frame
# - 10ms waiting for receiving a half-close
# Check the timing within -1/+9ms for robustness.
with self.assertCompletesWithin(19 * MS, 29 * MS):
self.loop.run_until_complete(self.protocol.close(reason='close'))
self.assertConnectionClosed(1006, '')
def test_local_close_receive_close_frame_timeout(self):
self.protocol.timeout = 10 * MS
# If the server doesn't send a close frame, time out in 20ms:
# - 10ms waiting for receiving a close frame
# - 10ms waiting for receiving a half-close
# Check the timing within -1/+9ms for robustness.
with self.assertCompletesWithin(19 * MS, 29 * MS):
self.loop.run_until_complete(self.protocol.close(reason='close'))
self.assertConnectionClosed(1006, '')
def test_local_close_connection_lost_timeout_after_write_eof(self):
self.protocol.timeout = 10 * MS
# If the server doesn't half-close its side of the TCP connection
# after we send a close frame, time out in 20ms:
# - 10ms waiting for receiving a half-close
# - 10ms waiting for receiving a close after write_eof
# Check the timing within -1/+9ms for robustness.
with self.assertCompletesWithin(19 * MS, 29 * MS):
# HACK: disable write_eof => other end drops connection emulation.
self.transport._eof = True
self.receive_frame(self.close_frame)
self.loop.run_until_complete(self.protocol.close(reason='close'))
self.assertConnectionClosed(1000, 'close')
def test_local_close_connection_lost_timeout_after_close(self):
self.protocol.timeout = 10 * MS
# If the client doesn't close its side of the TCP connection after we
# half-close our side with write_eof() and close it with close(), time
# out in 20ms.
# - 10ms waiting for receiving a half-close
# - 10ms waiting for receiving a close after write_eof
# - 10ms waiting for receiving a close after close
# Check the timing within -1/+9ms for robustness.
with self.assertCompletesWithin(29 * MS, 39 * MS):
# HACK: disable write_eof => other end drops connection emulation.
self.transport._eof = True
# HACK: disable close => other end drops connection emulation.
self.transport._closing = True
self.receive_frame(self.close_frame)
self.loop.run_until_complete(self.protocol.close(reason='close'))
self.assertConnectionClosed(1000, 'close')