229 lines
7.5 KiB
Python
229 lines
7.5 KiB
Python
|
import asyncio
|
||
|
import codecs
|
||
|
import unittest
|
||
|
import unittest.mock
|
||
|
|
||
|
from .exceptions import PayloadTooBig, WebSocketProtocolError
|
||
|
from .framing import *
|
||
|
|
||
|
|
||
|
class FramingTests(unittest.TestCase):
|
||
|
|
||
|
def setUp(self):
|
||
|
self.loop = asyncio.new_event_loop()
|
||
|
asyncio.set_event_loop(self.loop)
|
||
|
|
||
|
def tearDown(self):
|
||
|
self.loop.close()
|
||
|
|
||
|
def decode(self, message, mask=False, max_size=None, extensions=None):
|
||
|
self.stream = asyncio.StreamReader(loop=self.loop)
|
||
|
self.stream.feed_data(message)
|
||
|
self.stream.feed_eof()
|
||
|
frame = self.loop.run_until_complete(Frame.read(
|
||
|
self.stream.readexactly, mask=mask,
|
||
|
max_size=max_size, extensions=extensions,
|
||
|
))
|
||
|
# Make sure all the data was consumed.
|
||
|
self.assertTrue(self.stream.at_eof())
|
||
|
return frame
|
||
|
|
||
|
def encode(self, frame, mask=False, extensions=None):
|
||
|
writer = unittest.mock.Mock()
|
||
|
frame.write(writer, mask=mask, extensions=extensions)
|
||
|
# Ensure the entire frame is sent with a single call to writer().
|
||
|
# Multiple calls cause TCP fragmentation and degrade performance.
|
||
|
self.assertEqual(writer.call_count, 1)
|
||
|
# The frame data is the single positional argument of that call.
|
||
|
self.assertEqual(len(writer.call_args[0]), 1)
|
||
|
self.assertEqual(len(writer.call_args[1]), 0)
|
||
|
return writer.call_args[0][0]
|
||
|
|
||
|
def round_trip(self, message, expected, mask=False, extensions=None):
|
||
|
decoded = self.decode(message, mask, extensions=extensions)
|
||
|
self.assertEqual(decoded, expected)
|
||
|
encoded = self.encode(decoded, mask, extensions=extensions)
|
||
|
if mask: # non-deterministic encoding
|
||
|
decoded = self.decode(encoded, mask, extensions=extensions)
|
||
|
self.assertEqual(decoded, expected)
|
||
|
else: # deterministic encoding
|
||
|
self.assertEqual(encoded, message)
|
||
|
|
||
|
def round_trip_close(self, data, code, reason):
|
||
|
parsed = parse_close(data)
|
||
|
self.assertEqual(parsed, (code, reason))
|
||
|
serialized = serialize_close(code, reason)
|
||
|
self.assertEqual(serialized, data)
|
||
|
|
||
|
def test_text(self):
|
||
|
self.round_trip(
|
||
|
b'\x81\x04Spam',
|
||
|
Frame(True, OP_TEXT, b'Spam'),
|
||
|
)
|
||
|
|
||
|
def test_text_masked(self):
|
||
|
self.round_trip(
|
||
|
b'\x81\x84\x5b\xfb\xe1\xa8\x08\x8b\x80\xc5',
|
||
|
Frame(True, OP_TEXT, b'Spam'),
|
||
|
mask=True,
|
||
|
)
|
||
|
|
||
|
def test_binary(self):
|
||
|
self.round_trip(
|
||
|
b'\x82\x04Eggs',
|
||
|
Frame(True, OP_BINARY, b'Eggs'),
|
||
|
)
|
||
|
|
||
|
def test_binary_masked(self):
|
||
|
self.round_trip(
|
||
|
b'\x82\x84\x53\xcd\xe2\x89\x16\xaa\x85\xfa',
|
||
|
Frame(True, OP_BINARY, b'Eggs'),
|
||
|
mask=True,
|
||
|
)
|
||
|
|
||
|
def test_non_ascii_text(self):
|
||
|
self.round_trip(
|
||
|
b'\x81\x05caf\xc3\xa9',
|
||
|
Frame(True, OP_TEXT, 'café'.encode('utf-8')),
|
||
|
)
|
||
|
|
||
|
def test_non_ascii_text_masked(self):
|
||
|
self.round_trip(
|
||
|
b'\x81\x85\x64\xbe\xee\x7e\x07\xdf\x88\xbd\xcd',
|
||
|
Frame(True, OP_TEXT, 'café'.encode('utf-8')),
|
||
|
mask=True,
|
||
|
)
|
||
|
|
||
|
def test_close(self):
|
||
|
self.round_trip(
|
||
|
b'\x88\x00',
|
||
|
Frame(True, OP_CLOSE, b''),
|
||
|
)
|
||
|
|
||
|
def test_ping(self):
|
||
|
self.round_trip(
|
||
|
b'\x89\x04ping',
|
||
|
Frame(True, OP_PING, b'ping'),
|
||
|
)
|
||
|
|
||
|
def test_pong(self):
|
||
|
self.round_trip(
|
||
|
b'\x8a\x04pong',
|
||
|
Frame(True, OP_PONG, b'pong'),
|
||
|
)
|
||
|
|
||
|
def test_long(self):
|
||
|
self.round_trip(
|
||
|
b'\x82\x7e\x00\x7e' + 126 * b'a',
|
||
|
Frame(True, OP_BINARY, 126 * b'a'),
|
||
|
)
|
||
|
|
||
|
def test_very_long(self):
|
||
|
self.round_trip(
|
||
|
b'\x82\x7f\x00\x00\x00\x00\x00\x01\x00\x00' + 65536 * b'a',
|
||
|
Frame(True, OP_BINARY, 65536 * b'a'),
|
||
|
)
|
||
|
|
||
|
def test_payload_too_big(self):
|
||
|
with self.assertRaises(PayloadTooBig):
|
||
|
self.decode(
|
||
|
b'\x82\x7e\x04\x01' + 1025 * b'a',
|
||
|
max_size=1024,
|
||
|
)
|
||
|
|
||
|
def test_bad_reserved_bits(self):
|
||
|
for encoded in [b'\xc0\x00', b'\xa0\x00', b'\x90\x00']:
|
||
|
with self.subTest(encoded=encoded):
|
||
|
with self.assertRaises(WebSocketProtocolError):
|
||
|
self.decode(encoded)
|
||
|
|
||
|
def test_good_opcode(self):
|
||
|
for opcode in list(range(0x00, 0x03)) + list(range(0x08, 0x0b)):
|
||
|
encoded = bytes([0x80 | opcode, 0])
|
||
|
with self.subTest(encoded=encoded):
|
||
|
self.decode(encoded) # does not raise an exception
|
||
|
|
||
|
def test_bad_opcode(self):
|
||
|
for opcode in list(range(0x03, 0x08)) + list(range(0x0b, 0x10)):
|
||
|
encoded = bytes([0x80 | opcode, 0])
|
||
|
with self.subTest(encoded=encoded):
|
||
|
with self.assertRaises(WebSocketProtocolError):
|
||
|
self.decode(encoded)
|
||
|
|
||
|
def test_mask_flag(self):
|
||
|
# Mask flag correctly set.
|
||
|
self.decode(b'\x80\x80\x00\x00\x00\x00', mask=True)
|
||
|
# Mask flag incorrectly unset.
|
||
|
with self.assertRaises(WebSocketProtocolError):
|
||
|
self.decode(b'\x80\x80\x00\x00\x00\x00')
|
||
|
# Mask flag correctly unset.
|
||
|
self.decode(b'\x80\x00')
|
||
|
# Mask flag incorrectly set.
|
||
|
with self.assertRaises(WebSocketProtocolError):
|
||
|
self.decode(b'\x80\x00', mask=True)
|
||
|
|
||
|
def test_control_frame_max_length(self):
|
||
|
# At maximum allowed length.
|
||
|
self.decode(b'\x88\x7e\x00\x7d' + 125 * b'a')
|
||
|
# Above maximum allowed length.
|
||
|
with self.assertRaises(WebSocketProtocolError):
|
||
|
self.decode(b'\x88\x7e\x00\x7e' + 126 * b'a')
|
||
|
|
||
|
def test_encode_data_str(self):
|
||
|
self.assertEqual(encode_data('café'), b'caf\xc3\xa9')
|
||
|
|
||
|
def test_encode_data_bytes(self):
|
||
|
self.assertEqual(encode_data(b'tea'), b'tea')
|
||
|
|
||
|
def test_encode_data_other(self):
|
||
|
with self.assertRaises(TypeError):
|
||
|
encode_data(None)
|
||
|
|
||
|
def test_fragmented_control_frame(self):
|
||
|
# Fin bit correctly set.
|
||
|
self.decode(b'\x88\x00')
|
||
|
# Fin bit incorrectly unset.
|
||
|
with self.assertRaises(WebSocketProtocolError):
|
||
|
self.decode(b'\x08\x00')
|
||
|
|
||
|
def test_parse_close_and_serialize_close(self):
|
||
|
self.round_trip_close(b'\x03\xe8', 1000, '')
|
||
|
self.round_trip_close(b'\x03\xe8OK', 1000, 'OK')
|
||
|
|
||
|
def test_parse_close_empty(self):
|
||
|
self.assertEqual(parse_close(b''), (1005, ''))
|
||
|
|
||
|
def test_parse_close_errors(self):
|
||
|
with self.assertRaises(WebSocketProtocolError):
|
||
|
parse_close(b'\x03')
|
||
|
with self.assertRaises(WebSocketProtocolError):
|
||
|
parse_close(b'\x03\xe7')
|
||
|
with self.assertRaises(UnicodeDecodeError):
|
||
|
parse_close(b'\x03\xe8\xff\xff')
|
||
|
|
||
|
def test_serialize_close_errors(self):
|
||
|
with self.assertRaises(WebSocketProtocolError):
|
||
|
serialize_close(999, '')
|
||
|
|
||
|
def test_extensions(self):
|
||
|
|
||
|
class Rot13:
|
||
|
|
||
|
@staticmethod
|
||
|
def encode(frame):
|
||
|
assert frame.opcode == OP_TEXT
|
||
|
text = frame.data.decode()
|
||
|
data = codecs.encode(text, 'rot13').encode()
|
||
|
return frame._replace(data=data)
|
||
|
|
||
|
# This extensions is symmetrical.
|
||
|
@staticmethod
|
||
|
def decode(frame, *, max_size=None):
|
||
|
return Rot13.encode(frame)
|
||
|
|
||
|
self.round_trip(
|
||
|
b'\x81\x05uryyb',
|
||
|
Frame(True, OP_TEXT, b'hello'),
|
||
|
extensions=[Rot13()],
|
||
|
)
|