sampy3/tests/test_bitstream.py

390 lines
11 KiB
Python

from typing import List, Union, Type, Optional
import pytest
from sampy.raknet.bitstream import Bitstream, get_bit_length
@pytest.mark.parametrize("value", [
b"",
b"A",
b"AB",
b"\xFF\x00\xAA\x55",
])
def test_from_bytes(value: bytes):
bitstream = Bitstream.from_bytes(value)
assert len(bitstream) == (len(value) << 3)
assert bytes(bitstream) == value
@pytest.mark.parametrize("value,length,expected", [
(0b0, 0, None),
(0b1, 1, None),
(0b0, 1, None),
(0b1011, 4, None),
(0b10111011, 8, None),
(0b10111011, 7, 0b0111011), # -> 0b0111011
])
def test_from_int(value: int, length: int, expected: int):
bitstream = Bitstream.from_int(value, length)
assert len(bitstream) == length
assert len(bytes(bitstream)) == ((length + 7) >> 3)
assert bitstream.bits == [bool(value & (1 << i)) for i in range(length-1, -1, -1)]
if expected is not None:
expected_bitstream = Bitstream.from_int(expected, length)
assert bitstream.bits == expected_bitstream.bits
assert bytes(bitstream) == bytes(expected_bitstream)
@pytest.mark.parametrize("values", [
[True, False, False, True, True],
[[True], [False, True], [False, False]],
[b"a", b"bc"],
["aa", "b"],
[Bitstream.from_bytes(b"A"), Bitstream.from_int(0b1011, 4), Bitstream.from_bytes(b"B")],
[True, b"A", "B", [False, True], Bitstream.from_int(0b10110, 5), False],
])
def test_init(values: List[Union[bool, List[bool], bytes, str, Bitstream]]):
bitstream = Bitstream(*values)
bitstream.seek(0)
for item in values:
bit_length = get_bit_length(item)
if type(item) is bool:
item = [item]
read_type = bool if type(item) is list else type(item)
value = bitstream.read(read_type, bit_length)
assert value == item
def test_getitem():
sample_values = [True, False, False, True, False, True, True, True, True, False]
for length in range(1, len(sample_values)):
values = sample_values[:length]
bitstream = Bitstream(values)
for i in range(length):
assert bitstream[i] == values[i]
assert bitstream[-i - 1] == values[-i - 1] # For inverse lookup
def test_getitem_index_error():
bitstream = Bitstream(True)
with pytest.raises(IndexError):
bitstream[1]
with pytest.raises(IndexError):
bitstream[-2] # Inverse IndexError
@pytest.mark.parametrize("value", [
True,
[True, True],
b"A",
"B",
Bitstream(True, b"C"),
])
def test_setitem(value: Union[bool, List[bool], bytes, str, Bitstream]):
bit_length = get_bit_length(value)
read_type = bool if type(value) is list else type(value)
for length in range(10):
for i in range(length - bit_length + 1):
bitstream = Bitstream([False] * length)
bitstream[i] = value
assert bitstream.read(read_type, bit_length, i) == (value if type(value) is not bool else [value])
assert len(bitstream) == length
@pytest.mark.parametrize("value,equals,not_equals", [
(
[],
[[]],
[[True], [False]]
),
(
True,
[[True]],
[[False], [True, False], b"A"]
),
(
[True, False],
[[True, False]],
[[False], [False, False], b"A"]
),
(
b"A",
[b"A", "A", Bitstream.from_bytes(b"A"), Bitstream.from_bytes(b"A").bits],
[[], [False], b"B", "B", Bitstream.from_bytes(b"B"), Bitstream.from_bytes(b"B").bits]
),
(
"B",
[b"B", "B", Bitstream.from_bytes(b"B"), Bitstream.from_bytes(b"B").bits],
[[], [False], b"A", "A", Bitstream.from_bytes(b"A"), Bitstream.from_bytes(b"A").bits]
),
(
b"ABC",
[b"ABC"],
[[], [False], b"A"]
),
])
def test_eq(
value: Union[bool, List[bool], bytes, str, Bitstream],
equals: List[Union[bool, List[bool], bytes, str, Bitstream]],
not_equals: List[Union[bool, List[bool], bytes, str, Bitstream]]):
bitstream = Bitstream(value)
for equal in equals:
assert bitstream == equal
for not_equal in not_equals:
assert bitstream != not_equal
@pytest.mark.parametrize("values,number,expected", [
([], 2, []),
([True, False], 2, [True, False, True, False]),
(b"A", 2, b"AA"),
([False], 3, [False, False, False]),
])
def test_mul(
values: Union[bool, List[bool], bytes, str, Bitstream],
number: int,
expected: Union[List[bool], bytes, str, Bitstream]):
bitstream = Bitstream(values)
assert (bitstream * number) == expected
assert bitstream == values
bitstream *= number
assert bitstream == expected
@pytest.mark.parametrize("init_values,add_value,expected", [
([], True, [True]),
([True, False], False, [True, False, False]),
(b"A", b"B", b"AB"),
])
def test_add(
init_values: Union[bool, List[bool], bytes, str, Bitstream],
add_value: Union[bool, List[bool], bytes, str, Bitstream],
expected: Union[List[bool], bytes, str, Bitstream]):
bitstream = Bitstream(init_values)
assert (bitstream + add_value) == expected
assert len(bitstream) == get_bit_length(init_values)
bitstream += add_value
assert bitstream == expected
@pytest.mark.parametrize("init_values", [
[],
[True, False],
[b"A"],
[True, b"B"],
])
def test_iter(init_values: List[Union[bool, List[bool], bytes, str, Bitstream]]):
bitstream = Bitstream(*init_values)
assert len(iter(bitstream)) == len(bitstream)
for index, bit in enumerate(bitstream):
assert bit == bitstream[index]
@pytest.mark.parametrize("init_values,expected", [
([], b""),
([b"A"], b"A"),
([False], b"\x00"),
([[False] * 6, True], b"\x02"),
([[False] * 7, True], b"\x01"),
])
def test_bytes(
init_values: List[Union[bool, List[bool], bytes, str, Bitstream]],
expected: bytes):
assert bytes(Bitstream(*init_values)) == expected
@pytest.mark.parametrize("init_values,expected", [
([], []),
([b"\x01"], [False] * 7 + [True]),
([False], [False]),
([False] * 6 + [True], [False] * 6 + [True]),
([False] * 7 + [True], [False] * 7 + [True]),
])
def test_bits(
init_values: List[Union[bool, List[bool], bytes, str, Bitstream]],
expected: List[bool]):
assert Bitstream(*init_values).bits == expected
def test_copy():
bitstream = Bitstream(True, False)
bitstream_copy = bitstream.copy()
assert bitstream == bitstream_copy
assert id(bitstream) != id(bitstream.copy)
@pytest.mark.parametrize("lst", [
[False] * 10,
[True] * 10,
[True, False] * 5,
[False, True] * 5,
])
def test_append(lst: List[bool]):
bitstream = Bitstream()
for i in range(len(lst)):
bitstream.append(lst[i])
assert bitstream[i] == lst[i]
assert len(bitstream) == (i + 1)
@pytest.mark.parametrize("extend_value", [
[True, False, True],
b"A",
"B",
Bitstream(True, b"C"),
])
def test_extend(extend_value: Union[List[bool], bytes, str, Bitstream]):
for i in range(10):
bitstream = Bitstream([False] * i)
bitstream.extend(extend_value)
value = bitstream.read(bool if type(extend_value) is list else type(extend_value), get_bit_length(extend_value), i)
assert value == extend_value
def test_extend_error():
bitstream = Bitstream()
with pytest.raises(TypeError):
bitstream.extend(True)
def test_clear():
bitstream = Bitstream(b"A", [True, False, False, True])
bitstream._offset = 3
bitstream.clear()
assert len(bitstream) == 0
assert len(bytes(bitstream)) == 0
assert bitstream._offset == 0
@pytest.mark.parametrize("lst", [
[False] * 10,
[True] * 10,
[True, False] * 5,
[False, True] * 5,
])
def test_pop(lst: List[bool]):
bitstream = Bitstream(lst)
for i in range(len(bitstream)):
assert bitstream.pop() == lst[-i - 1]
with pytest.raises(IndexError):
assert bitstream.pop()
def test_read():
bitstream = Bitstream(True, b"A", False, True, False)
assert bitstream.read(bool, 1) == [True]
assert bitstream.read(bytes, 8) == b"A"
assert bitstream.read(int, 3) == 2
bitstream.seek(0)
assert bitstream.read(Bitstream, len(bitstream) - 1) == Bitstream(True, b"A", False, True)
@pytest.mark.parametrize("bitstream,read_type,bit_length,bit_index,expected", [
(
Bitstream(True, b"A", False, True, False),
bool, 1, 0,
[True]
),
(
Bitstream(True, b"A", False, True, False),
bool, 3, 8 + 1,
[False, True, False]
),
(
Bitstream(True, b"A", False, True, False),
bytes, 8, 1,
b"A"
),
(
Bitstream(True, b"A", False, True, False),
str, 8, 1,
"A"
),
(
Bitstream(True, b"A", False, True, False),
Bitstream, 1 + 8 + 3, 0,
Bitstream(True, b"A", False, True, False)
),
(
Bitstream(True, b"A", False, True, False),
int, 3, 1 + 8,
2
),
])
def test_read_index(
bitstream: Bitstream,
read_type: Type[Union[bool, bytes, str, Bitstream, int]],
bit_length: int,
bit_index: Optional[int],
expected: Union[List[bool], bytes, str, Bitstream, int]):
value = bitstream.read(read_type, bit_length, bit_index)
if read_type is bool:
assert type(value) == list
else:
assert type(value) == read_type
assert value == expected
def test_read_error():
bitstream = Bitstream(True, b"A", False, True, False)
with pytest.raises(IndexError):
bitstream.read(bool, 1, len(bitstream))
with pytest.raises(IndexError):
bitstream.read(bool, 1, -1)
with pytest.raises(TypeError):
bitstream.read(list, 1, 0)
@pytest.mark.parametrize("bitstream,value,bit_index,current_offset", [
(Bitstream(), True, None, None),
(Bitstream(), b"A", None, None),
(Bitstream(), True, 0, None),
(Bitstream(), True, 3, None),
(Bitstream(), True, 7, None),
(Bitstream(), True, 8, None),
(Bitstream(False), True, None, None),
(Bitstream(False), True, 0, None),
(Bitstream(False), True, 0, 0),
(Bitstream(False), True, None, 0),
])
def test_write(bitstream: Bitstream, value: Union[bool, List[bool], bytes, str, Bitstream], bit_index: Optional[int], current_offset: Optional[int]):
if current_offset is not None:
bitstream.seek(current_offset)
bit_length = get_bit_length(value)
bitstream.write(value, bit_index)
if bit_index is None:
bitstream.seek(-bit_length, 1)
read_value = bitstream.read(bool if type(value) is list else type(value), bit_length)
else:
read_value = bitstream.read(bool if type(value) is list else type(value), bit_length, bit_index)
assert read_value == ([value] if type(value) is bool else value)
def test_seek():
bitstream = Bitstream([True] * 8)
bitstream.seek(4, 0)
assert bitstream._offset == 4
bitstream.seek(-1, 1)
assert bitstream._offset == 3
bitstream.seek(-1, 2)
assert bitstream._offset == 7