Changed formatting

This commit is contained in:
Emily 2023-03-15 06:06:58 +01:00
parent b3fedb8214
commit 75824c306f
2 changed files with 252 additions and 182 deletions

View File

@ -1,6 +1,6 @@
from __future__ import annotations from __future__ import annotations
from typing import Any, List, Union, Optional, Literal, Type
from typing import Any, List, Literal, Optional, Type, Union
# for seek() # for seek()
SEEK_SET = 0 SEEK_SET = 0
@ -30,8 +30,10 @@ def bits_to_int(bits: List[bool], little_endian: bool = False) -> int:
return num return num
def int_to_bits(number: int, bit_length: int, little_endian: bool = False) -> List[bool]: def int_to_bits(
bits = [bool(number & (0b1 << i)) for i in range(bit_length-1, -1, -1)] number: int, bit_length: int, little_endian: bool = False
) -> List[bool]:
bits = [bool(number & (0b1 << i)) for i in range(bit_length - 1, -1, -1)]
return bits[::-1] if little_endian else bits return bits[::-1] if little_endian else bits
@ -56,7 +58,9 @@ class Bitstream:
@staticmethod @staticmethod
def from_int(value: int, bit_length: int) -> Bitstream: def from_int(value: int, bit_length: int) -> Bitstream:
return Bitstream([bool(value & (0b1 << i)) for i in range(bit_length-1, -1, -1)]) return Bitstream(
[bool(value & (0b1 << i)) for i in range(bit_length - 1, -1, -1)]
)
def __len__(self) -> int: def __len__(self) -> int:
return self._bit_length return self._bit_length
@ -82,7 +86,9 @@ class Bitstream:
mask = 1 << (7 - (bit_index % 8)) mask = 1 << (7 - (bit_index % 8))
return bool(self._bytearray[bit_index >> 3] & mask) return bool(self._bytearray[bit_index >> 3] & mask)
def __setitem__(self, bit_index: int, value: Union[bool, List[bool], bytes, Bitstream]): def __setitem__(
self, bit_index: int, value: Union[bool, List[bool], bytes, Bitstream]
):
if bit_index >= self._bit_length: if bit_index >= self._bit_length:
raise IndexError("bit index out of range") raise IndexError("bit index out of range")
@ -101,7 +107,9 @@ class Bitstream:
self._bytearray[bit_index >> 3] |= mask self._bytearray[bit_index >> 3] |= mask
else: else:
self._bytearray[bit_index >> 3] &= ~mask self._bytearray[bit_index >> 3] &= ~mask
elif isinstance(value, Bitstream) or (type(value) is list and all(type(v) is bool for v in value)): elif isinstance(value, Bitstream) or (
type(value) is list and all(type(v) is bool for v in value)
):
for i in range(value_bit_length): for i in range(value_bit_length):
self[bit_index + i] = value[i] self[bit_index + i] = value[i]
else: else:
@ -164,7 +172,10 @@ class Bitstream:
def __repr__(self) -> str: def __repr__(self) -> str:
return "<Bitstream addr:0x%012x offset:%d len:%d data:'%s'>" % ( return "<Bitstream addr:0x%012x offset:%d len:%d data:'%s'>" % (
id(self), self._offset, len(self), bytes(self).hex(" ") id(self),
self._offset,
len(self),
bytes(self).hex(" "),
) )
@property @property
@ -223,7 +234,7 @@ class Bitstream:
self, self,
type: Type[Union[bool, bytes, Bitstream, int]], type: Type[Union[bool, bytes, Bitstream, int]],
bit_length: int, bit_length: int,
bit_index: Optional[int] = None bit_index: Optional[int] = None,
) -> Union[List[bool], bytes, Bitstream, int]: ) -> Union[List[bool], bytes, Bitstream, int]:
start = self._offset if bit_index is None else bit_index start = self._offset if bit_index is None else bit_index
@ -248,7 +259,11 @@ class Bitstream:
raise TypeError("Invalid data type") raise TypeError("Invalid data type")
def write(self, value: Union[bool, List[bool], bytes, Bitstream], bit_index: Optional[int] = None): def write(
self,
value: Union[bool, List[bool], bytes, Bitstream],
bit_index: Optional[int] = None,
):
start = self._offset if bit_index is None else bit_index start = self._offset if bit_index is None else bit_index
if start < 0: if start < 0:

View File

@ -1,47 +1,62 @@
from typing import List, Union, Type, Optional from typing import List, Optional, Type, Union
import pytest import pytest
from sampy.raknet.bitstream import Bitstream, get_bit_length from sampy.raknet.bitstream import Bitstream, get_bit_length
@pytest.mark.parametrize("value", [ @pytest.mark.parametrize(
"value",
[
b"", b"",
b"A", b"A",
b"AB", b"AB",
b"\xFF\x00\xAA\x55", b"\xFF\x00\xAA\x55",
]) ],
)
def test_from_bytes(value: bytes): def test_from_bytes(value: bytes):
bitstream = Bitstream.from_bytes(value) bitstream = Bitstream.from_bytes(value)
assert len(bitstream) == (len(value) << 3) assert len(bitstream) == (len(value) << 3)
assert bytes(bitstream) == value assert bytes(bitstream) == value
@pytest.mark.parametrize("value,length,expected", [ @pytest.mark.parametrize(
"value,length,expected",
[
(0b0, 0, None), (0b0, 0, None),
(0b1, 1, None), (0b1, 1, None),
(0b0, 1, None), (0b0, 1, None),
(0b1011, 4, None), (0b1011, 4, None),
(0b10111011, 8, None), (0b10111011, 8, None),
(0b10111011, 7, 0b0111011), # -> 0b0111011 (0b10111011, 7, 0b0111011), # -> 0b0111011
]) ],
)
def test_from_int(value: int, length: int, expected: int): def test_from_int(value: int, length: int, expected: int):
bitstream = Bitstream.from_int(value, length) bitstream = Bitstream.from_int(value, length)
assert len(bitstream) == length assert len(bitstream) == length
assert len(bytes(bitstream)) == ((length + 7) >> 3) assert len(bytes(bitstream)) == ((length + 7) >> 3)
assert bitstream.bits == [bool(value & (1 << i)) for i in range(length-1, -1, -1)] assert bitstream.bits == [bool(value & (1 << i)) for i in range(length - 1, -1, -1)]
if expected is not None: if expected is not None:
expected_bitstream = Bitstream.from_int(expected, length) expected_bitstream = Bitstream.from_int(expected, length)
assert bitstream.bits == expected_bitstream.bits assert bitstream.bits == expected_bitstream.bits
assert bytes(bitstream) == bytes(expected_bitstream) assert bytes(bitstream) == bytes(expected_bitstream)
@pytest.mark.parametrize("values", [ @pytest.mark.parametrize(
"values",
[
[True, False, False, True, True], [True, False, False, True, True],
[[True], [False, True], [False, False]], [[True], [False, True], [False, False]],
[b"a", b"bc"], [b"a", b"bc"],
[b"aa", b"b"], [b"aa", b"b"],
[Bitstream.from_bytes(b"A"), Bitstream.from_int(0b1011, 4), Bitstream.from_bytes(b"B")], [
Bitstream.from_bytes(b"A"),
Bitstream.from_int(0b1011, 4),
Bitstream.from_bytes(b"B"),
],
[True, b"A", b"B", [False, True], Bitstream.from_int(0b10110, 5), False], [True, b"A", b"B", [False, True], Bitstream.from_int(0b10110, 5), False],
]) ],
)
def test_init(values: List[Union[bool, List[bool], bytes, Bitstream]]): def test_init(values: List[Union[bool, List[bool], bytes, Bitstream]]):
bitstream = Bitstream(*values) bitstream = Bitstream(*values)
@ -75,12 +90,15 @@ def test_getitem_index_error():
bitstream[-2] # Inverse IndexError bitstream[-2] # Inverse IndexError
@pytest.mark.parametrize("value", [ @pytest.mark.parametrize(
"value",
[
True, True,
[True, True], [True, True],
b"A", b"A",
Bitstream(True, b"C"), Bitstream(True, b"C"),
]) ],
)
def test_setitem(value: Union[bool, List[bool], bytes, Bitstream]): def test_setitem(value: Union[bool, List[bool], bytes, Bitstream]):
bit_length = get_bit_length(value) bit_length = get_bit_length(value)
read_type = bool if type(value) is list else type(value) read_type = bool if type(value) is list else type(value)
@ -88,41 +106,38 @@ def test_setitem(value: Union[bool, List[bool], bytes, Bitstream]):
for i in range(length - bit_length + 1): for i in range(length - bit_length + 1):
bitstream = Bitstream([False] * length) bitstream = Bitstream([False] * length)
bitstream[i] = value bitstream[i] = value
assert bitstream.read(read_type, bit_length, i) == (value if type(value) is not bool else [value]) assert bitstream.read(read_type, bit_length, i) == (
value if type(value) is not bool else [value]
)
assert len(bitstream) == length assert len(bitstream) == length
@pytest.mark.parametrize("value,equals,not_equals", [ @pytest.mark.parametrize(
( "value,equals,not_equals",
[], [
[[]], ([], [[]], [[True], [False]]),
[[True], [False]] (True, [[True]], [[False], [True, False], b"A"]),
), ([True, False], [[True, False]], [[False], [False, False], b"A"]),
(
True,
[[True]],
[[False], [True, False], b"A"]
),
(
[True, False],
[[True, False]],
[[False], [False, False], b"A"]
),
( (
b"A", b"A",
[b"A", b"A", Bitstream.from_bytes(b"A"), Bitstream.from_bytes(b"A").bits], [b"A", b"A", Bitstream.from_bytes(b"A"), Bitstream.from_bytes(b"A").bits],
[[], [False], b"B", b"B", Bitstream.from_bytes(b"B"), Bitstream.from_bytes(b"B").bits] [
[],
[False],
b"B",
b"B",
Bitstream.from_bytes(b"B"),
Bitstream.from_bytes(b"B").bits,
],
), ),
( (b"ABC", [b"ABC"], [[], [False], b"A"]),
b"ABC", ],
[b"ABC"], )
[[], [False], b"A"]
),
])
def test_eq( def test_eq(
value: Union[bool, List[bool], bytes, Bitstream], value: Union[bool, List[bool], bytes, Bitstream],
equals: List[Union[bool, List[bool], bytes, Bitstream]], equals: List[Union[bool, List[bool], bytes, Bitstream]],
not_equals: List[Union[bool, List[bool], bytes, Bitstream]]): not_equals: List[Union[bool, List[bool], bytes, Bitstream]],
):
bitstream = Bitstream(value) bitstream = Bitstream(value)
for equal in equals: for equal in equals:
assert bitstream == equal assert bitstream == equal
@ -130,16 +145,20 @@ def test_eq(
assert bitstream != not_equal assert bitstream != not_equal
@pytest.mark.parametrize("values,number,expected", [ @pytest.mark.parametrize(
"values,number,expected",
[
([], 2, []), ([], 2, []),
([True, False], 2, [True, False, True, False]), ([True, False], 2, [True, False, True, False]),
(b"A", 2, b"AA"), (b"A", 2, b"AA"),
([False], 3, [False, False, False]), ([False], 3, [False, False, False]),
]) ],
)
def test_mul( def test_mul(
values: Union[bool, List[bool], bytes, Bitstream], values: Union[bool, List[bool], bytes, Bitstream],
number: int, number: int,
expected: Union[List[bool], bytes, Bitstream]): expected: Union[List[bool], bytes, Bitstream],
):
bitstream = Bitstream(values) bitstream = Bitstream(values)
assert (bitstream * number) == expected assert (bitstream * number) == expected
@ -149,15 +168,19 @@ def test_mul(
assert bitstream == expected assert bitstream == expected
@pytest.mark.parametrize("init_values,add_value,expected", [ @pytest.mark.parametrize(
"init_values,add_value,expected",
[
([], True, [True]), ([], True, [True]),
([True, False], False, [True, False, False]), ([True, False], False, [True, False, False]),
(b"A", b"B", b"AB"), (b"A", b"B", b"AB"),
]) ],
)
def test_add( def test_add(
init_values: Union[bool, List[bool], bytes, Bitstream], init_values: Union[bool, List[bool], bytes, Bitstream],
add_value: Union[bool, List[bool], bytes, Bitstream], add_value: Union[bool, List[bool], bytes, Bitstream],
expected: Union[List[bool], bytes, Bitstream]): expected: Union[List[bool], bytes, Bitstream],
):
bitstream = Bitstream(init_values) bitstream = Bitstream(init_values)
assert (bitstream + add_value) == expected assert (bitstream + add_value) == expected
@ -167,12 +190,15 @@ def test_add(
assert bitstream == expected assert bitstream == expected
@pytest.mark.parametrize("init_values", [ @pytest.mark.parametrize(
"init_values",
[
[], [],
[True, False], [True, False],
[b"A"], [b"A"],
[True, b"B"], [True, b"B"],
]) ],
)
def test_iter(init_values: List[Union[bool, List[bool], bytes, Bitstream]]): def test_iter(init_values: List[Union[bool, List[bool], bytes, Bitstream]]):
bitstream = Bitstream(*init_values) bitstream = Bitstream(*init_values)
assert len(list(bitstream)) == len(bitstream) assert len(list(bitstream)) == len(bitstream)
@ -180,29 +206,35 @@ def test_iter(init_values: List[Union[bool, List[bool], bytes, Bitstream]]):
assert bit == bitstream[index] assert bit == bitstream[index]
@pytest.mark.parametrize("init_values,expected", [ @pytest.mark.parametrize(
"init_values,expected",
[
([], b""), ([], b""),
([b"A"], b"A"), ([b"A"], b"A"),
([False], b"\x00"), ([False], b"\x00"),
([[False] * 6, True], b"\x02"), ([[False] * 6, True], b"\x02"),
([[False] * 7, True], b"\x01"), ([[False] * 7, True], b"\x01"),
]) ],
)
def test_bytes( def test_bytes(
init_values: List[Union[bool, List[bool], bytes, Bitstream]], init_values: List[Union[bool, List[bool], bytes, Bitstream]], expected: bytes
expected: bytes): ):
assert bytes(Bitstream(*init_values)) == expected assert bytes(Bitstream(*init_values)) == expected
@pytest.mark.parametrize("init_values,expected", [ @pytest.mark.parametrize(
"init_values,expected",
[
([], []), ([], []),
([b"\x01"], [False] * 7 + [True]), ([b"\x01"], [False] * 7 + [True]),
([False], [False]), ([False], [False]),
([False] * 6 + [True], [False] * 6 + [True]), ([False] * 6 + [True], [False] * 6 + [True]),
([False] * 7 + [True], [False] * 7 + [True]), ([False] * 7 + [True], [False] * 7 + [True]),
]) ],
)
def test_bits( def test_bits(
init_values: List[Union[bool, List[bool], bytes, Bitstream]], init_values: List[Union[bool, List[bool], bytes, Bitstream]], expected: List[bool]
expected: List[bool]): ):
assert Bitstream(*init_values).bits == expected assert Bitstream(*init_values).bits == expected
@ -213,12 +245,15 @@ def test_copy():
assert id(bitstream) != id(bitstream.copy) assert id(bitstream) != id(bitstream.copy)
@pytest.mark.parametrize("lst", [ @pytest.mark.parametrize(
"lst",
[
[False] * 10, [False] * 10,
[True] * 10, [True] * 10,
[True, False] * 5, [True, False] * 5,
[False, True] * 5, [False, True] * 5,
]) ],
)
def test_append(lst: List[bool]): def test_append(lst: List[bool]):
bitstream = Bitstream() bitstream = Bitstream()
for i in range(len(lst)): for i in range(len(lst)):
@ -227,17 +262,24 @@ def test_append(lst: List[bool]):
assert len(bitstream) == (i + 1) assert len(bitstream) == (i + 1)
@pytest.mark.parametrize("extend_value", [ @pytest.mark.parametrize(
"extend_value",
[
[True, False, True], [True, False, True],
b"A", b"A",
b"BB", b"BB",
Bitstream(True, b"C"), Bitstream(True, b"C"),
]) ],
)
def test_extend(extend_value: Union[List[bool], bytes, Bitstream]): def test_extend(extend_value: Union[List[bool], bytes, Bitstream]):
for i in range(10): for i in range(10):
bitstream = Bitstream([False] * i) bitstream = Bitstream([False] * i)
bitstream.extend(extend_value) bitstream.extend(extend_value)
value = bitstream.read(bool if type(extend_value) is list else type(extend_value), get_bit_length(extend_value), i) value = bitstream.read(
bool if type(extend_value) is list else type(extend_value),
get_bit_length(extend_value),
i,
)
assert value == extend_value assert value == extend_value
@ -256,12 +298,15 @@ def test_clear():
assert bitstream._offset == 0 assert bitstream._offset == 0
@pytest.mark.parametrize("lst", [ @pytest.mark.parametrize(
"lst",
[
[False] * 10, [False] * 10,
[True] * 10, [True] * 10,
[True, False] * 5, [True, False] * 5,
[False, True] * 5, [False, True] * 5,
]) ],
)
def test_pop(lst: List[bool]): def test_pop(lst: List[bool]):
bitstream = Bitstream(lst) bitstream = Bitstream(lst)
for i in range(len(bitstream)): for i in range(len(bitstream)):
@ -279,42 +324,40 @@ def test_read():
assert bitstream.read(int, 3) == 2 assert bitstream.read(int, 3) == 2
bitstream.seek(0) bitstream.seek(0)
assert bitstream.read(Bitstream, len(bitstream) - 1) == Bitstream(True, b"A", False, True) assert bitstream.read(Bitstream, len(bitstream) - 1) == Bitstream(
True, b"A", False, True
)
@pytest.mark.parametrize("bitstream,read_type,bit_length,bit_index,expected", [ @pytest.mark.parametrize(
"bitstream,read_type,bit_length,bit_index,expected",
[
(Bitstream(True, b"A", False, True, False), bool, 1, 0, [True]),
( (
Bitstream(True, b"A", False, True, False), Bitstream(True, b"A", False, True, False),
bool, 1, 0, bool,
[True] 3,
8 + 1,
[False, True, False],
), ),
(Bitstream(True, b"A", False, True, False), bytes, 8, 1, b"A"),
( (
Bitstream(True, b"A", False, True, False), Bitstream(True, b"A", False, True, False),
bool, 3, 8 + 1, Bitstream,
[False, True, False] 1 + 8 + 3,
), 0,
(
Bitstream(True, b"A", False, True, False), Bitstream(True, b"A", False, True, False),
bytes, 8, 1,
b"A"
), ),
( (Bitstream(True, b"A", False, True, False), int, 3, 1 + 8, 2),
Bitstream(True, b"A", False, True, False), ],
Bitstream, 1 + 8 + 3, 0, )
Bitstream(True, b"A", False, True, False)
),
(
Bitstream(True, b"A", False, True, False),
int, 3, 1 + 8,
2
),
])
def test_read_index( def test_read_index(
bitstream: Bitstream, bitstream: Bitstream,
read_type: Type[Union[bool, bytes, Bitstream, int]], read_type: Type[Union[bool, bytes, Bitstream, int]],
bit_length: int, bit_length: int,
bit_index: Optional[int], bit_index: Optional[int],
expected: Union[List[bool], bytes, Bitstream, int]): expected: Union[List[bool], bytes, Bitstream, int],
):
value = bitstream.read(read_type, bit_length, bit_index) value = bitstream.read(read_type, bit_length, bit_index)
if read_type is bool: if read_type is bool:
@ -337,7 +380,9 @@ def test_read_error():
bitstream.read(list, 1, 0) bitstream.read(list, 1, 0)
@pytest.mark.parametrize("bitstream,value,bit_index,current_offset", [ @pytest.mark.parametrize(
"bitstream,value,bit_index,current_offset",
[
(Bitstream(), True, None, None), (Bitstream(), True, None, None),
(Bitstream(), b"A", None, None), (Bitstream(), b"A", None, None),
(Bitstream(), True, 0, None), (Bitstream(), True, 0, None),
@ -348,8 +393,14 @@ def test_read_error():
(Bitstream(False), True, 0, None), (Bitstream(False), True, 0, None),
(Bitstream(False), True, 0, 0), (Bitstream(False), True, 0, 0),
(Bitstream(False), True, None, 0), (Bitstream(False), True, None, 0),
]) ],
def test_write(bitstream: Bitstream, value: Union[bool, List[bool], bytes, Bitstream], bit_index: Optional[int], current_offset: Optional[int]): )
def test_write(
bitstream: Bitstream,
value: Union[bool, List[bool], bytes, Bitstream],
bit_index: Optional[int],
current_offset: Optional[int],
):
if current_offset is not None: if current_offset is not None:
bitstream.seek(current_offset) bitstream.seek(current_offset)
@ -358,9 +409,13 @@ def test_write(bitstream: Bitstream, value: Union[bool, List[bool], bytes, Bitst
if bit_index is None: if bit_index is None:
bitstream.seek(-bit_length, 1) bitstream.seek(-bit_length, 1)
read_value = bitstream.read(bool if type(value) is list else type(value), bit_length) read_value = bitstream.read(
bool if type(value) is list else type(value), bit_length
)
else: else:
read_value = bitstream.read(bool if type(value) is list else type(value), bit_length, bit_index) read_value = bitstream.read(
bool if type(value) is list else type(value), bit_length, bit_index
)
assert read_value == ([value] if type(value) is bool else value) assert read_value == ([value] if type(value) is bool else value)