118 lines
4 KiB
Python
118 lines
4 KiB
Python
|
import contextlib
|
||
|
import unittest
|
||
|
|
||
|
from .exceptions import InvalidHandshake
|
||
|
from .handshake import *
|
||
|
from .handshake import accept # private API
|
||
|
|
||
|
|
||
|
class HandshakeTests(unittest.TestCase):
|
||
|
|
||
|
def test_accept(self):
|
||
|
# Test vector from RFC 6455
|
||
|
key = "dGhlIHNhbXBsZSBub25jZQ=="
|
||
|
acc = "s3pPLMBiTxaQ9kYGzzhZRbK+xOo="
|
||
|
self.assertEqual(accept(key), acc)
|
||
|
|
||
|
def test_round_trip(self):
|
||
|
request_headers = {}
|
||
|
request_key = build_request(request_headers)
|
||
|
response_key = check_request(request_headers)
|
||
|
self.assertEqual(request_key, response_key)
|
||
|
response_headers = {}
|
||
|
build_response(response_headers, response_key)
|
||
|
check_response(response_headers, request_key)
|
||
|
|
||
|
@contextlib.contextmanager
|
||
|
def assertInvalidRequestHeaders(self):
|
||
|
"""
|
||
|
Provide request headers for corruption.
|
||
|
|
||
|
Assert that the transformation made them invalid.
|
||
|
|
||
|
"""
|
||
|
headers = {}
|
||
|
build_request(headers)
|
||
|
yield headers
|
||
|
with self.assertRaises(InvalidHandshake):
|
||
|
check_request(headers)
|
||
|
|
||
|
def test_request_invalid_upgrade(self):
|
||
|
with self.assertInvalidRequestHeaders() as headers:
|
||
|
headers['Upgrade'] = 'socketweb'
|
||
|
|
||
|
def test_request_missing_upgrade(self):
|
||
|
with self.assertInvalidRequestHeaders() as headers:
|
||
|
del headers['Upgrade']
|
||
|
|
||
|
def test_request_invalid_connection(self):
|
||
|
with self.assertInvalidRequestHeaders() as headers:
|
||
|
headers['Connection'] = 'Downgrade'
|
||
|
|
||
|
def test_request_missing_connection(self):
|
||
|
with self.assertInvalidRequestHeaders() as headers:
|
||
|
del headers['Connection']
|
||
|
|
||
|
def test_request_invalid_key_not_base64(self):
|
||
|
with self.assertInvalidRequestHeaders() as headers:
|
||
|
headers['Sec-WebSocket-Key'] = "!@#$%^&*()"
|
||
|
|
||
|
def test_request_invalid_key_not_well_padded(self):
|
||
|
with self.assertInvalidRequestHeaders() as headers:
|
||
|
headers['Sec-WebSocket-Key'] = "CSIRmL8dWYxeAdr/XpEHRw"
|
||
|
|
||
|
def test_request_invalid_key_not_16_bytes_long(self):
|
||
|
with self.assertInvalidRequestHeaders() as headers:
|
||
|
headers['Sec-WebSocket-Key'] = "ZLpprpvK4PE="
|
||
|
|
||
|
def test_request_missing_key(self):
|
||
|
with self.assertInvalidRequestHeaders() as headers:
|
||
|
del headers['Sec-WebSocket-Key']
|
||
|
|
||
|
def test_request_invalid_version(self):
|
||
|
with self.assertInvalidRequestHeaders() as headers:
|
||
|
headers['Sec-WebSocket-Version'] = '42'
|
||
|
|
||
|
def test_request_missing_version(self):
|
||
|
with self.assertInvalidRequestHeaders() as headers:
|
||
|
del headers['Sec-WebSocket-Version']
|
||
|
|
||
|
@contextlib.contextmanager
|
||
|
def assertInvalidResponseHeaders(self, key='CSIRmL8dWYxeAdr/XpEHRw=='):
|
||
|
"""
|
||
|
Provide response headers for corruption.
|
||
|
|
||
|
Assert that the transformation made them invalid.
|
||
|
|
||
|
"""
|
||
|
headers = {}
|
||
|
build_response(headers, key)
|
||
|
yield headers
|
||
|
with self.assertRaises(InvalidHandshake):
|
||
|
check_response(headers, key)
|
||
|
|
||
|
def test_response_invalid_upgrade(self):
|
||
|
with self.assertInvalidResponseHeaders() as headers:
|
||
|
headers['Upgrade'] = 'socketweb'
|
||
|
|
||
|
def test_response_missing_upgrade(self):
|
||
|
with self.assertInvalidResponseHeaders() as headers:
|
||
|
del headers['Upgrade']
|
||
|
|
||
|
def test_response_invalid_connection(self):
|
||
|
with self.assertInvalidResponseHeaders() as headers:
|
||
|
headers['Connection'] = 'Downgrade'
|
||
|
|
||
|
def test_response_missing_connection(self):
|
||
|
with self.assertInvalidResponseHeaders() as headers:
|
||
|
del headers['Connection']
|
||
|
|
||
|
def test_response_invalid_accept(self):
|
||
|
with self.assertInvalidResponseHeaders() as headers:
|
||
|
other_key = "1Eq4UDEFQYg3YspNgqxv5g=="
|
||
|
headers['Sec-WebSocket-Accept'] = accept(other_key)
|
||
|
|
||
|
def test_response_missing_accept(self):
|
||
|
with self.assertInvalidResponseHeaders() as headers:
|
||
|
del headers['Sec-WebSocket-Accept']
|