diff --git a/sampy/raknet/bitstream.py b/sampy/raknet/bitstream.py index ebf14d1..85c3d45 100644 --- a/sampy/raknet/bitstream.py +++ b/sampy/raknet/bitstream.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, List, Union, Optional, Literal, Type +from typing import Any, List, Literal, Optional, Type, Union # for seek() SEEK_SET = 0 @@ -30,8 +30,10 @@ def bits_to_int(bits: List[bool], little_endian: bool = False) -> int: return num -def int_to_bits(number: int, bit_length: int, little_endian: bool = False) -> List[bool]: - bits = [bool(number & (0b1 << i)) for i in range(bit_length-1, -1, -1)] +def int_to_bits( + number: int, bit_length: int, little_endian: bool = False +) -> List[bool]: + bits = [bool(number & (0b1 << i)) for i in range(bit_length - 1, -1, -1)] return bits[::-1] if little_endian else bits @@ -56,7 +58,9 @@ class Bitstream: @staticmethod def from_int(value: int, bit_length: int) -> Bitstream: - return Bitstream([bool(value & (0b1 << i)) for i in range(bit_length-1, -1, -1)]) + return Bitstream( + [bool(value & (0b1 << i)) for i in range(bit_length - 1, -1, -1)] + ) def __len__(self) -> int: return self._bit_length @@ -82,7 +86,9 @@ class Bitstream: mask = 1 << (7 - (bit_index % 8)) return bool(self._bytearray[bit_index >> 3] & mask) - def __setitem__(self, bit_index: int, value: Union[bool, List[bool], bytes, Bitstream]): + def __setitem__( + self, bit_index: int, value: Union[bool, List[bool], bytes, Bitstream] + ): if bit_index >= self._bit_length: raise IndexError("bit index out of range") @@ -101,7 +107,9 @@ class Bitstream: self._bytearray[bit_index >> 3] |= mask else: self._bytearray[bit_index >> 3] &= ~mask - elif isinstance(value, Bitstream) or (type(value) is list and all(type(v) is bool for v in value)): + elif isinstance(value, Bitstream) or ( + type(value) is list and all(type(v) is bool for v in value) + ): for i in range(value_bit_length): self[bit_index + i] = value[i] else: @@ -164,7 +172,10 @@ class Bitstream: def __repr__(self) -> str: return "" % ( - id(self), self._offset, len(self), bytes(self).hex(" ") + id(self), + self._offset, + len(self), + bytes(self).hex(" "), ) @property @@ -220,11 +231,11 @@ class Bitstream: return value def read( - self, - type: Type[Union[bool, bytes, Bitstream, int]], - bit_length: int, - bit_index: Optional[int] = None - ) -> Union[List[bool], bytes, Bitstream, int]: + self, + type: Type[Union[bool, bytes, Bitstream, int]], + bit_length: int, + bit_index: Optional[int] = None, + ) -> Union[List[bool], bytes, Bitstream, int]: start = self._offset if bit_index is None else bit_index if (start + bit_length) > self._bit_length or (start < 0): @@ -248,7 +259,11 @@ class Bitstream: raise TypeError("Invalid data type") - def write(self, value: Union[bool, List[bool], bytes, Bitstream], bit_index: Optional[int] = None): + def write( + self, + value: Union[bool, List[bool], bytes, Bitstream], + bit_index: Optional[int] = None, + ): start = self._offset if bit_index is None else bit_index if start < 0: diff --git a/tests/test_bitstream.py b/tests/test_bitstream.py index 27b8cb0..fe11af2 100644 --- a/tests/test_bitstream.py +++ b/tests/test_bitstream.py @@ -1,47 +1,62 @@ -from typing import List, Union, Type, Optional +from typing import List, Optional, Type, Union + import pytest + from sampy.raknet.bitstream import Bitstream, get_bit_length -@pytest.mark.parametrize("value", [ - b"", - b"A", - b"AB", - b"\xFF\x00\xAA\x55", -]) +@pytest.mark.parametrize( + "value", + [ + b"", + b"A", + b"AB", + b"\xFF\x00\xAA\x55", + ], +) def test_from_bytes(value: bytes): bitstream = Bitstream.from_bytes(value) assert len(bitstream) == (len(value) << 3) assert bytes(bitstream) == value -@pytest.mark.parametrize("value,length,expected", [ - (0b0, 0, None), - (0b1, 1, None), - (0b0, 1, None), - (0b1011, 4, None), - (0b10111011, 8, None), - (0b10111011, 7, 0b0111011), # -> 0b0111011 -]) +@pytest.mark.parametrize( + "value,length,expected", + [ + (0b0, 0, None), + (0b1, 1, None), + (0b0, 1, None), + (0b1011, 4, None), + (0b10111011, 8, None), + (0b10111011, 7, 0b0111011), # -> 0b0111011 + ], +) def test_from_int(value: int, length: int, expected: int): bitstream = Bitstream.from_int(value, length) assert len(bitstream) == length assert len(bytes(bitstream)) == ((length + 7) >> 3) - assert bitstream.bits == [bool(value & (1 << i)) for i in range(length-1, -1, -1)] + assert bitstream.bits == [bool(value & (1 << i)) for i in range(length - 1, -1, -1)] if expected is not None: expected_bitstream = Bitstream.from_int(expected, length) assert bitstream.bits == expected_bitstream.bits assert bytes(bitstream) == bytes(expected_bitstream) -@pytest.mark.parametrize("values", [ - [True, False, False, True, True], - [[True], [False, True], [False, False]], - [b"a", b"bc"], - [b"aa", b"b"], - [Bitstream.from_bytes(b"A"), Bitstream.from_int(0b1011, 4), Bitstream.from_bytes(b"B")], - [True, b"A", b"B", [False, True], Bitstream.from_int(0b10110, 5), False], -]) +@pytest.mark.parametrize( + "values", + [ + [True, False, False, True, True], + [[True], [False, True], [False, False]], + [b"a", b"bc"], + [b"aa", b"b"], + [ + Bitstream.from_bytes(b"A"), + Bitstream.from_int(0b1011, 4), + Bitstream.from_bytes(b"B"), + ], + [True, b"A", b"B", [False, True], Bitstream.from_int(0b10110, 5), False], + ], +) def test_init(values: List[Union[bool, List[bool], bytes, Bitstream]]): bitstream = Bitstream(*values) @@ -75,12 +90,15 @@ def test_getitem_index_error(): bitstream[-2] # Inverse IndexError -@pytest.mark.parametrize("value", [ - True, - [True, True], - b"A", - Bitstream(True, b"C"), -]) +@pytest.mark.parametrize( + "value", + [ + True, + [True, True], + b"A", + Bitstream(True, b"C"), + ], +) def test_setitem(value: Union[bool, List[bool], bytes, Bitstream]): bit_length = get_bit_length(value) read_type = bool if type(value) is list else type(value) @@ -88,41 +106,38 @@ def test_setitem(value: Union[bool, List[bool], bytes, Bitstream]): for i in range(length - bit_length + 1): bitstream = Bitstream([False] * length) bitstream[i] = value - assert bitstream.read(read_type, bit_length, i) == (value if type(value) is not bool else [value]) + assert bitstream.read(read_type, bit_length, i) == ( + value if type(value) is not bool else [value] + ) assert len(bitstream) == length -@pytest.mark.parametrize("value,equals,not_equals", [ - ( - [], - [[]], - [[True], [False]] - ), - ( - True, - [[True]], - [[False], [True, False], b"A"] - ), - ( - [True, False], - [[True, False]], - [[False], [False, False], b"A"] - ), - ( - b"A", - [b"A", b"A", Bitstream.from_bytes(b"A"), Bitstream.from_bytes(b"A").bits], - [[], [False], b"B", b"B", Bitstream.from_bytes(b"B"), Bitstream.from_bytes(b"B").bits] - ), - ( - b"ABC", - [b"ABC"], - [[], [False], b"A"] - ), -]) +@pytest.mark.parametrize( + "value,equals,not_equals", + [ + ([], [[]], [[True], [False]]), + (True, [[True]], [[False], [True, False], b"A"]), + ([True, False], [[True, False]], [[False], [False, False], b"A"]), + ( + b"A", + [b"A", b"A", Bitstream.from_bytes(b"A"), Bitstream.from_bytes(b"A").bits], + [ + [], + [False], + b"B", + b"B", + Bitstream.from_bytes(b"B"), + Bitstream.from_bytes(b"B").bits, + ], + ), + (b"ABC", [b"ABC"], [[], [False], b"A"]), + ], +) def test_eq( - value: Union[bool, List[bool], bytes, Bitstream], - equals: List[Union[bool, List[bool], bytes, Bitstream]], - not_equals: List[Union[bool, List[bool], bytes, Bitstream]]): + value: Union[bool, List[bool], bytes, Bitstream], + equals: List[Union[bool, List[bool], bytes, Bitstream]], + not_equals: List[Union[bool, List[bool], bytes, Bitstream]], +): bitstream = Bitstream(value) for equal in equals: assert bitstream == equal @@ -130,16 +145,20 @@ def test_eq( assert bitstream != not_equal -@pytest.mark.parametrize("values,number,expected", [ - ([], 2, []), - ([True, False], 2, [True, False, True, False]), - (b"A", 2, b"AA"), - ([False], 3, [False, False, False]), -]) +@pytest.mark.parametrize( + "values,number,expected", + [ + ([], 2, []), + ([True, False], 2, [True, False, True, False]), + (b"A", 2, b"AA"), + ([False], 3, [False, False, False]), + ], +) def test_mul( - values: Union[bool, List[bool], bytes, Bitstream], - number: int, - expected: Union[List[bool], bytes, Bitstream]): + values: Union[bool, List[bool], bytes, Bitstream], + number: int, + expected: Union[List[bool], bytes, Bitstream], +): bitstream = Bitstream(values) assert (bitstream * number) == expected @@ -149,15 +168,19 @@ def test_mul( assert bitstream == expected -@pytest.mark.parametrize("init_values,add_value,expected", [ - ([], True, [True]), - ([True, False], False, [True, False, False]), - (b"A", b"B", b"AB"), -]) +@pytest.mark.parametrize( + "init_values,add_value,expected", + [ + ([], True, [True]), + ([True, False], False, [True, False, False]), + (b"A", b"B", b"AB"), + ], +) def test_add( - init_values: Union[bool, List[bool], bytes, Bitstream], - add_value: Union[bool, List[bool], bytes, Bitstream], - expected: Union[List[bool], bytes, Bitstream]): + init_values: Union[bool, List[bool], bytes, Bitstream], + add_value: Union[bool, List[bool], bytes, Bitstream], + expected: Union[List[bool], bytes, Bitstream], +): bitstream = Bitstream(init_values) assert (bitstream + add_value) == expected @@ -167,12 +190,15 @@ def test_add( assert bitstream == expected -@pytest.mark.parametrize("init_values", [ - [], - [True, False], - [b"A"], - [True, b"B"], -]) +@pytest.mark.parametrize( + "init_values", + [ + [], + [True, False], + [b"A"], + [True, b"B"], + ], +) def test_iter(init_values: List[Union[bool, List[bool], bytes, Bitstream]]): bitstream = Bitstream(*init_values) assert len(list(bitstream)) == len(bitstream) @@ -180,29 +206,35 @@ def test_iter(init_values: List[Union[bool, List[bool], bytes, Bitstream]]): assert bit == bitstream[index] -@pytest.mark.parametrize("init_values,expected", [ - ([], b""), - ([b"A"], b"A"), - ([False], b"\x00"), - ([[False] * 6, True], b"\x02"), - ([[False] * 7, True], b"\x01"), -]) +@pytest.mark.parametrize( + "init_values,expected", + [ + ([], b""), + ([b"A"], b"A"), + ([False], b"\x00"), + ([[False] * 6, True], b"\x02"), + ([[False] * 7, True], b"\x01"), + ], +) def test_bytes( - init_values: List[Union[bool, List[bool], bytes, Bitstream]], - expected: bytes): + init_values: List[Union[bool, List[bool], bytes, Bitstream]], expected: bytes +): assert bytes(Bitstream(*init_values)) == expected -@pytest.mark.parametrize("init_values,expected", [ - ([], []), - ([b"\x01"], [False] * 7 + [True]), - ([False], [False]), - ([False] * 6 + [True], [False] * 6 + [True]), - ([False] * 7 + [True], [False] * 7 + [True]), -]) +@pytest.mark.parametrize( + "init_values,expected", + [ + ([], []), + ([b"\x01"], [False] * 7 + [True]), + ([False], [False]), + ([False] * 6 + [True], [False] * 6 + [True]), + ([False] * 7 + [True], [False] * 7 + [True]), + ], +) def test_bits( - init_values: List[Union[bool, List[bool], bytes, Bitstream]], - expected: List[bool]): + init_values: List[Union[bool, List[bool], bytes, Bitstream]], expected: List[bool] +): assert Bitstream(*init_values).bits == expected @@ -213,12 +245,15 @@ def test_copy(): assert id(bitstream) != id(bitstream.copy) -@pytest.mark.parametrize("lst", [ - [False] * 10, - [True] * 10, - [True, False] * 5, - [False, True] * 5, -]) +@pytest.mark.parametrize( + "lst", + [ + [False] * 10, + [True] * 10, + [True, False] * 5, + [False, True] * 5, + ], +) def test_append(lst: List[bool]): bitstream = Bitstream() for i in range(len(lst)): @@ -227,17 +262,24 @@ def test_append(lst: List[bool]): assert len(bitstream) == (i + 1) -@pytest.mark.parametrize("extend_value", [ - [True, False, True], - b"A", - b"BB", - Bitstream(True, b"C"), -]) +@pytest.mark.parametrize( + "extend_value", + [ + [True, False, True], + b"A", + b"BB", + Bitstream(True, b"C"), + ], +) def test_extend(extend_value: Union[List[bool], bytes, Bitstream]): for i in range(10): bitstream = Bitstream([False] * i) bitstream.extend(extend_value) - value = bitstream.read(bool if type(extend_value) is list else type(extend_value), get_bit_length(extend_value), i) + value = bitstream.read( + bool if type(extend_value) is list else type(extend_value), + get_bit_length(extend_value), + i, + ) assert value == extend_value @@ -256,12 +298,15 @@ def test_clear(): assert bitstream._offset == 0 -@pytest.mark.parametrize("lst", [ - [False] * 10, - [True] * 10, - [True, False] * 5, - [False, True] * 5, -]) +@pytest.mark.parametrize( + "lst", + [ + [False] * 10, + [True] * 10, + [True, False] * 5, + [False, True] * 5, + ], +) def test_pop(lst: List[bool]): bitstream = Bitstream(lst) for i in range(len(bitstream)): @@ -279,42 +324,40 @@ def test_read(): assert bitstream.read(int, 3) == 2 bitstream.seek(0) - assert bitstream.read(Bitstream, len(bitstream) - 1) == Bitstream(True, b"A", False, True) + assert bitstream.read(Bitstream, len(bitstream) - 1) == Bitstream( + True, b"A", False, True + ) -@pytest.mark.parametrize("bitstream,read_type,bit_length,bit_index,expected", [ - ( - Bitstream(True, b"A", False, True, False), - bool, 1, 0, - [True] - ), - ( - Bitstream(True, b"A", False, True, False), - bool, 3, 8 + 1, - [False, True, False] - ), - ( - Bitstream(True, b"A", False, True, False), - bytes, 8, 1, - b"A" - ), - ( - Bitstream(True, b"A", False, True, False), - Bitstream, 1 + 8 + 3, 0, - Bitstream(True, b"A", False, True, False) - ), - ( - Bitstream(True, b"A", False, True, False), - int, 3, 1 + 8, - 2 - ), -]) +@pytest.mark.parametrize( + "bitstream,read_type,bit_length,bit_index,expected", + [ + (Bitstream(True, b"A", False, True, False), bool, 1, 0, [True]), + ( + Bitstream(True, b"A", False, True, False), + bool, + 3, + 8 + 1, + [False, True, False], + ), + (Bitstream(True, b"A", False, True, False), bytes, 8, 1, b"A"), + ( + Bitstream(True, b"A", False, True, False), + Bitstream, + 1 + 8 + 3, + 0, + Bitstream(True, b"A", False, True, False), + ), + (Bitstream(True, b"A", False, True, False), int, 3, 1 + 8, 2), + ], +) def test_read_index( - bitstream: Bitstream, - read_type: Type[Union[bool, bytes, Bitstream, int]], - bit_length: int, - bit_index: Optional[int], - expected: Union[List[bool], bytes, Bitstream, int]): + bitstream: Bitstream, + read_type: Type[Union[bool, bytes, Bitstream, int]], + bit_length: int, + bit_index: Optional[int], + expected: Union[List[bool], bytes, Bitstream, int], +): value = bitstream.read(read_type, bit_length, bit_index) if read_type is bool: @@ -337,19 +380,27 @@ def test_read_error(): bitstream.read(list, 1, 0) -@pytest.mark.parametrize("bitstream,value,bit_index,current_offset", [ - (Bitstream(), True, None, None), - (Bitstream(), b"A", None, None), - (Bitstream(), True, 0, None), - (Bitstream(), True, 3, None), - (Bitstream(), True, 7, None), - (Bitstream(), True, 8, None), - (Bitstream(False), True, None, None), - (Bitstream(False), True, 0, None), - (Bitstream(False), True, 0, 0), - (Bitstream(False), True, None, 0), -]) -def test_write(bitstream: Bitstream, value: Union[bool, List[bool], bytes, Bitstream], bit_index: Optional[int], current_offset: Optional[int]): +@pytest.mark.parametrize( + "bitstream,value,bit_index,current_offset", + [ + (Bitstream(), True, None, None), + (Bitstream(), b"A", None, None), + (Bitstream(), True, 0, None), + (Bitstream(), True, 3, None), + (Bitstream(), True, 7, None), + (Bitstream(), True, 8, None), + (Bitstream(False), True, None, None), + (Bitstream(False), True, 0, None), + (Bitstream(False), True, 0, 0), + (Bitstream(False), True, None, 0), + ], +) +def test_write( + bitstream: Bitstream, + value: Union[bool, List[bool], bytes, Bitstream], + bit_index: Optional[int], + current_offset: Optional[int], +): if current_offset is not None: bitstream.seek(current_offset) @@ -358,9 +409,13 @@ def test_write(bitstream: Bitstream, value: Union[bool, List[bool], bytes, Bitst if bit_index is None: bitstream.seek(-bit_length, 1) - read_value = bitstream.read(bool if type(value) is list else type(value), bit_length) + read_value = bitstream.read( + bool if type(value) is list else type(value), bit_length + ) else: - read_value = bitstream.read(bool if type(value) is list else type(value), bit_length, bit_index) + read_value = bitstream.read( + bool if type(value) is list else type(value), bit_length, bit_index + ) assert read_value == ([value] if type(value) is bool else value)