From 73eaa0e89b0c17ef97f120574b6f62000a306e5c Mon Sep 17 00:00:00 2001 From: Emily Sunpy Date: Fri, 24 Feb 2023 23:56:09 +0100 Subject: [PATCH] Added Bitstream w/tests --- sampy/raknet/__init__.py | 0 sampy/raknet/bitstream.py | 301 +++++++++++++++++++++++++++++ tests/test_bitstream.py | 389 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 690 insertions(+) create mode 100644 sampy/raknet/__init__.py create mode 100644 sampy/raknet/bitstream.py create mode 100644 tests/test_bitstream.py diff --git a/sampy/raknet/__init__.py b/sampy/raknet/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/sampy/raknet/bitstream.py b/sampy/raknet/bitstream.py new file mode 100644 index 0000000..0fcde01 --- /dev/null +++ b/sampy/raknet/bitstream.py @@ -0,0 +1,301 @@ +from __future__ import annotations +from typing import Any, List, Union, Optional, Literal, Type + + +# for seek() +SEEK_SET = 0 +SEEK_CUR = 1 +SEEK_END = 2 + + +def get_bit_length(value: Union[bool, List[bool], bytes, str, Bitstream]) -> int: + if type(value) is bool: + return 1 + elif type(value) is list and all(type(v) is bool for v in value): + return len(value) + elif type(value) in (bytes, str): + return len(value) << 3 + elif type(value) is Bitstream: + return len(value) + else: + raise TypeError("Invalid data type") + + +def bits_to_int(bits: List[bool], little_endian: bool = False) -> int: + num = 0 + if little_endian: + bits = bits[::-1] + for bit in bits: + num = (num << 1) | bit + return num + + +def int_to_bits(number: int, bit_length: int, little_endian: bool = False) -> List[bool]: + bits = [bool(number & (0b1 << i)) for i in range(bit_length-1, -1, -1)] + return bits[::-1] if little_endian else bits + + +class Bitstream: + def __init__(self, *values: Union[bool, List[bool], bytes, str, Bitstream]): + self._bytearray = bytearray() + self._bit_length = 0 + self._offset = 0 + + for item in values: + if type(item) is bool: + self.append(item) + else: + self.extend(item) + + @staticmethod + def from_bytes(value: bytes) -> Bitstream: + bitstream = Bitstream() + bitstream._bytearray += value + bitstream._bit_length = len(bitstream._bytearray) << 3 + return bitstream + + @staticmethod + def from_int(value: int, bit_length: int) -> Bitstream: + return Bitstream([bool(value & (0b1 << i)) for i in range(bit_length-1, -1, -1)]) + + def __len__(self) -> int: + return self._bit_length + + def __bytes__(self) -> bytes: + while len(self._bytearray) > ((len(self) + 7) >> 3): + self._bytearray.pop() + return bytes(self._bytearray) + + def __str__(self) -> str: + return bytes(self).decode() + + def __int__(self) -> int: + num = 0 + for bit in self: + num = (num << 1) | bit + return num + + def __getitem__(self, bit_index: int) -> bool: + if bit_index < 0: + bit_index = len(self) + bit_index + + if bit_index >= self._bit_length or bit_index < 0: + raise IndexError("bit index out of range") + + mask = 1 << (7 - (bit_index % 8)) + return bool(self._bytearray[bit_index >> 3] & mask) + + def __setitem__(self, bit_index: int, value: Union[bool, List[bool], bytes, str, Bitstream]): + if bit_index >= self._bit_length: + raise IndexError("bit index out of range") + + if type(value) is str: + value = value.encode() + + if type(value) is bytes: + self[bit_index] = Bitstream.from_bytes(value) + return + + value_bit_length = get_bit_length(value) + + if (bit_index + value_bit_length) > self._bit_length: + raise IndexError("Cannot write bits that extends the size of bitstream") + + if type(value) is bool: + mask = 1 << (7 - (bit_index % 8)) + if value: + self._bytearray[bit_index >> 3] |= mask + else: + self._bytearray[bit_index >> 3] &= ~mask + elif isinstance(value, Bitstream) or (type(value) is list and all(type(v) is bool for v in value)): + for i in range(value_bit_length): + self[bit_index + i] = value[i] + else: + raise TypeError("Expected bool, list[bool], bytes, str or Bitstream") + + def __eq__(self, other: Any) -> bool: + if type(other) is bool: + other = [other] + if type(other) is str: + other = other.encode() + + if isinstance(other, Bitstream): + pass + elif type(other) is bytes: + pass + elif type(other) is list and all(type(v) is bool for v in other): + pass + else: + raise TypeError("Expected Bitstream, bytes, str or list[bool]") + + if isinstance(other, Bitstream) or type(other) is list: + if len(self) != len(other): + return False + for left, right in zip(self, other): + if left != right: + return False + else: + if bytes(self) != other: + return False + return True + + def __mul__(self, number: int) -> Bitstream: + copy = self.copy() + for _ in range(1, number): + copy.extend(self) + return copy + + def __imul__(self, number: int) -> Bitstream: + copy = self.copy() + for _ in range(1, number): + self.extend(copy) + return self + + def __add__(self, value: Union[bool, List[bool], bytes, str, Bitstream]) -> Bitstream: + copy = self.copy() + if type(value) is bool: + copy.append(value) + else: + copy.extend(value) + return copy + + def __iadd__(self, value: Union[bool, List[bool], bytes, str, Bitstream]) -> Bitstream: + if type(value) is bool: + self.append(value) + else: + self.extend(value) + return self + + def __iter__(self) -> Bitstream: + self._iter_index = 0 + return self + + def __next__(self) -> bool: + if self._iter_index >= self._bit_length: + raise StopIteration + bit = self[self._iter_index] + self._iter_index += 1 + return bit + + def __repr__(self) -> str: + return "" % ( + id(self), self._offset, len(self), bytes(self).hex(" ") + ) + + @property + def bits(self) -> List[bool]: + return [b for b in self] + + def copy(self) -> Bitstream: + bitstream = Bitstream() + bitstream._bytearray = self._bytearray.copy() + bitstream._bit_length = self._bit_length + bitstream._offset = self._offset + return bitstream + + def append(self, value: bool): + if type(value) is not bool: + raise TypeError("Expected bool") + + while (len(self._bytearray) << 3) < (self._bit_length + 1): + self._bytearray.append(0) + prev_bit_length = self._bit_length + self._bit_length += 1 + self[prev_bit_length] = value + + def extend(self, value: Union[List[bool], bytes, str, Bitstream]): + if type(value) in (bytes, str): + pass + elif isinstance(value, Bitstream): + pass + elif type(value) is list and all(type(v) is bool for v in value): + pass + else: + raise TypeError("Expected list[bool], bytes, str or Bitstream") + + value_bit_length = get_bit_length(value) + if value_bit_length == 0: + return + + while (len(self._bytearray) << 3) < (self._bit_length + value_bit_length): + self._bytearray.append(0) + prev_bit_length = self._bit_length + self._bit_length += value_bit_length + self[prev_bit_length] = value + + def clear(self): + self._bytearray.clear() + self._bit_length = 0 + self._offset = 0 + + def pop(self) -> bool: + value = self[-1] + self[-1] = False + self._bit_length -= 1 + return value + + def read( + self, + type: Type[Union[bool, bytes, str, Bitstream, int]], + bit_length: int, + bit_index: Optional[int] = None + ) -> Union[List[bool], bytes, str, Bitstream, int]: + start = self._offset if bit_index is None else bit_index + + if (start + bit_length) > self._bit_length or (start < 0): + raise IndexError("bit index out of range") + + bits = [self[i] for i in range(start, start + bit_length)] + + if bit_index is None: + self._offset += bit_length + + if type is bool: + return bits + + bitstream = Bitstream(bits) + if type is Bitstream: + return bitstream + elif type is bytes: + return bytes(bitstream) + elif type is str: + return str(bitstream) + elif type is int: + return int(bitstream) + + raise TypeError("Invalid data type") + + def write(self, value: Union[bool, List[bool], bytes, str, Bitstream], bit_index: Optional[int] = None): + start = self._offset if bit_index is None else bit_index + + if start < 0: + raise IndexError("bit index out of range") + + value_bit_length = get_bit_length(value) + if value_bit_length == 0: + return + + while (len(self._bytearray) << 3) < (start + value_bit_length): + self._bytearray.append(0) + + if bit_index is None: + self._offset += value_bit_length + _bit_length = self._offset + else: + _bit_length = start + value_bit_length + + if _bit_length > self._bit_length: + self._bit_length = _bit_length + self[start] = value + + def seek(self, position: int, whence: Literal[0, 1, 2] = 0): + if whence == 0: + self._offset = position + elif whence == 1: + self._offset += position + elif whence == 2: + self._offset = len(self) + position + else: + raise ValueError("Invalid whence") + + self._offset = min(max(self._offset, 0), len(self)) diff --git a/tests/test_bitstream.py b/tests/test_bitstream.py new file mode 100644 index 0000000..b6750d2 --- /dev/null +++ b/tests/test_bitstream.py @@ -0,0 +1,389 @@ +from typing import List, Union, Type, Optional +import pytest +from sampy.raknet.bitstream import Bitstream, get_bit_length + + +@pytest.mark.parametrize("value", [ + b"", + b"A", + b"AB", + b"\xFF\x00\xAA\x55", +]) +def test_from_bytes(value: bytes): + bitstream = Bitstream.from_bytes(value) + assert len(bitstream) == (len(value) << 3) + assert bytes(bitstream) == value + + +@pytest.mark.parametrize("value,length,expected", [ + (0b0, 0, None), + (0b1, 1, None), + (0b0, 1, None), + (0b1011, 4, None), + (0b10111011, 8, None), + (0b10111011, 7, 0b0111011), # -> 0b0111011 +]) +def test_from_int(value: int, length: int, expected: int): + bitstream = Bitstream.from_int(value, length) + assert len(bitstream) == length + assert len(bytes(bitstream)) == ((length + 7) >> 3) + assert bitstream.bits == [bool(value & (1 << i)) for i in range(length-1, -1, -1)] + if expected is not None: + expected_bitstream = Bitstream.from_int(expected, length) + assert bitstream.bits == expected_bitstream.bits + assert bytes(bitstream) == bytes(expected_bitstream) + + +@pytest.mark.parametrize("values", [ + [True, False, False, True, True], + [[True], [False, True], [False, False]], + [b"a", b"bc"], + ["aa", "b"], + [Bitstream.from_bytes(b"A"), Bitstream.from_int(0b1011, 4), Bitstream.from_bytes(b"B")], + [True, b"A", "B", [False, True], Bitstream.from_int(0b10110, 5), False], +]) +def test_init(values: List[Union[bool, List[bool], bytes, str, Bitstream]]): + bitstream = Bitstream(*values) + + bitstream.seek(0) + for item in values: + bit_length = get_bit_length(item) + if type(item) is bool: + item = [item] + + read_type = bool if type(item) is list else type(item) + value = bitstream.read(read_type, bit_length) + assert value == item + + +def test_getitem(): + sample_values = [True, False, False, True, False, True, True, True, True, False] + + for length in range(1, len(sample_values)): + values = sample_values[:length] + bitstream = Bitstream(values) + for i in range(length): + assert bitstream[i] == values[i] + assert bitstream[-i - 1] == values[-i - 1] # For inverse lookup + + +def test_getitem_index_error(): + bitstream = Bitstream(True) + with pytest.raises(IndexError): + bitstream[1] + with pytest.raises(IndexError): + bitstream[-2] # Inverse IndexError + + +@pytest.mark.parametrize("value", [ + True, + [True, True], + b"A", + "B", + Bitstream(True, b"C"), +]) +def test_setitem(value: Union[bool, List[bool], bytes, str, Bitstream]): + bit_length = get_bit_length(value) + read_type = bool if type(value) is list else type(value) + for length in range(10): + for i in range(length - bit_length + 1): + bitstream = Bitstream([False] * length) + bitstream[i] = value + assert bitstream.read(read_type, bit_length, i) == (value if type(value) is not bool else [value]) + assert len(bitstream) == length + + +@pytest.mark.parametrize("value,equals,not_equals", [ + ( + [], + [[]], + [[True], [False]] + ), + ( + True, + [[True]], + [[False], [True, False], b"A"] + ), + ( + [True, False], + [[True, False]], + [[False], [False, False], b"A"] + ), + ( + b"A", + [b"A", "A", Bitstream.from_bytes(b"A"), Bitstream.from_bytes(b"A").bits], + [[], [False], b"B", "B", Bitstream.from_bytes(b"B"), Bitstream.from_bytes(b"B").bits] + ), + ( + "B", + [b"B", "B", Bitstream.from_bytes(b"B"), Bitstream.from_bytes(b"B").bits], + [[], [False], b"A", "A", Bitstream.from_bytes(b"A"), Bitstream.from_bytes(b"A").bits] + ), + ( + b"ABC", + [b"ABC"], + [[], [False], b"A"] + ), +]) +def test_eq( + value: Union[bool, List[bool], bytes, str, Bitstream], + equals: List[Union[bool, List[bool], bytes, str, Bitstream]], + not_equals: List[Union[bool, List[bool], bytes, str, Bitstream]]): + bitstream = Bitstream(value) + for equal in equals: + assert bitstream == equal + for not_equal in not_equals: + assert bitstream != not_equal + + +@pytest.mark.parametrize("values,number,expected", [ + ([], 2, []), + ([True, False], 2, [True, False, True, False]), + (b"A", 2, b"AA"), + ([False], 3, [False, False, False]), +]) +def test_mul( + values: Union[bool, List[bool], bytes, str, Bitstream], + number: int, + expected: Union[List[bool], bytes, str, Bitstream]): + bitstream = Bitstream(values) + + assert (bitstream * number) == expected + assert bitstream == values + + bitstream *= number + assert bitstream == expected + + +@pytest.mark.parametrize("init_values,add_value,expected", [ + ([], True, [True]), + ([True, False], False, [True, False, False]), + (b"A", b"B", b"AB"), +]) +def test_add( + init_values: Union[bool, List[bool], bytes, str, Bitstream], + add_value: Union[bool, List[bool], bytes, str, Bitstream], + expected: Union[List[bool], bytes, str, Bitstream]): + bitstream = Bitstream(init_values) + + assert (bitstream + add_value) == expected + assert len(bitstream) == get_bit_length(init_values) + + bitstream += add_value + assert bitstream == expected + + +@pytest.mark.parametrize("init_values", [ + [], + [True, False], + [b"A"], + [True, b"B"], +]) +def test_iter(init_values: List[Union[bool, List[bool], bytes, str, Bitstream]]): + bitstream = Bitstream(*init_values) + assert len(iter(bitstream)) == len(bitstream) + for index, bit in enumerate(bitstream): + assert bit == bitstream[index] + + +@pytest.mark.parametrize("init_values,expected", [ + ([], b""), + ([b"A"], b"A"), + ([False], b"\x00"), + ([[False] * 6, True], b"\x02"), + ([[False] * 7, True], b"\x01"), +]) +def test_bytes( + init_values: List[Union[bool, List[bool], bytes, str, Bitstream]], + expected: bytes): + assert bytes(Bitstream(*init_values)) == expected + + +@pytest.mark.parametrize("init_values,expected", [ + ([], []), + ([b"\x01"], [False] * 7 + [True]), + ([False], [False]), + ([False] * 6 + [True], [False] * 6 + [True]), + ([False] * 7 + [True], [False] * 7 + [True]), +]) +def test_bits( + init_values: List[Union[bool, List[bool], bytes, str, Bitstream]], + expected: List[bool]): + assert Bitstream(*init_values).bits == expected + + +def test_copy(): + bitstream = Bitstream(True, False) + bitstream_copy = bitstream.copy() + assert bitstream == bitstream_copy + assert id(bitstream) != id(bitstream.copy) + + +@pytest.mark.parametrize("lst", [ + [False] * 10, + [True] * 10, + [True, False] * 5, + [False, True] * 5, +]) +def test_append(lst: List[bool]): + bitstream = Bitstream() + for i in range(len(lst)): + bitstream.append(lst[i]) + assert bitstream[i] == lst[i] + assert len(bitstream) == (i + 1) + + +@pytest.mark.parametrize("extend_value", [ + [True, False, True], + b"A", + "B", + Bitstream(True, b"C"), +]) +def test_extend(extend_value: Union[List[bool], bytes, str, Bitstream]): + for i in range(10): + bitstream = Bitstream([False] * i) + bitstream.extend(extend_value) + value = bitstream.read(bool if type(extend_value) is list else type(extend_value), get_bit_length(extend_value), i) + assert value == extend_value + + +def test_extend_error(): + bitstream = Bitstream() + with pytest.raises(TypeError): + bitstream.extend(True) + + +def test_clear(): + bitstream = Bitstream(b"A", [True, False, False, True]) + bitstream._offset = 3 + bitstream.clear() + assert len(bitstream) == 0 + assert len(bytes(bitstream)) == 0 + assert bitstream._offset == 0 + + +@pytest.mark.parametrize("lst", [ + [False] * 10, + [True] * 10, + [True, False] * 5, + [False, True] * 5, +]) +def test_pop(lst: List[bool]): + bitstream = Bitstream(lst) + for i in range(len(bitstream)): + assert bitstream.pop() == lst[-i - 1] + + with pytest.raises(IndexError): + assert bitstream.pop() + + +def test_read(): + bitstream = Bitstream(True, b"A", False, True, False) + + assert bitstream.read(bool, 1) == [True] + assert bitstream.read(bytes, 8) == b"A" + assert bitstream.read(int, 3) == 2 + + bitstream.seek(0) + assert bitstream.read(Bitstream, len(bitstream) - 1) == Bitstream(True, b"A", False, True) + + +@pytest.mark.parametrize("bitstream,read_type,bit_length,bit_index,expected", [ + ( + Bitstream(True, b"A", False, True, False), + bool, 1, 0, + [True] + ), + ( + Bitstream(True, b"A", False, True, False), + bool, 3, 8 + 1, + [False, True, False] + ), + ( + Bitstream(True, b"A", False, True, False), + bytes, 8, 1, + b"A" + ), + ( + Bitstream(True, b"A", False, True, False), + str, 8, 1, + "A" + ), + ( + Bitstream(True, b"A", False, True, False), + Bitstream, 1 + 8 + 3, 0, + Bitstream(True, b"A", False, True, False) + ), + ( + Bitstream(True, b"A", False, True, False), + int, 3, 1 + 8, + 2 + ), +]) +def test_read_index( + bitstream: Bitstream, + read_type: Type[Union[bool, bytes, str, Bitstream, int]], + bit_length: int, + bit_index: Optional[int], + expected: Union[List[bool], bytes, str, Bitstream, int]): + value = bitstream.read(read_type, bit_length, bit_index) + + if read_type is bool: + assert type(value) == list + else: + assert type(value) == read_type + assert value == expected + + +def test_read_error(): + bitstream = Bitstream(True, b"A", False, True, False) + + with pytest.raises(IndexError): + bitstream.read(bool, 1, len(bitstream)) + + with pytest.raises(IndexError): + bitstream.read(bool, 1, -1) + + with pytest.raises(TypeError): + bitstream.read(list, 1, 0) + + +@pytest.mark.parametrize("bitstream,value,bit_index,current_offset", [ + (Bitstream(), True, None, None), + (Bitstream(), b"A", None, None), + (Bitstream(), True, 0, None), + (Bitstream(), True, 3, None), + (Bitstream(), True, 7, None), + (Bitstream(), True, 8, None), + (Bitstream(False), True, None, None), + (Bitstream(False), True, 0, None), + (Bitstream(False), True, 0, 0), + (Bitstream(False), True, None, 0), +]) +def test_write(bitstream: Bitstream, value: Union[bool, List[bool], bytes, str, Bitstream], bit_index: Optional[int], current_offset: Optional[int]): + if current_offset is not None: + bitstream.seek(current_offset) + + bit_length = get_bit_length(value) + bitstream.write(value, bit_index) + + if bit_index is None: + bitstream.seek(-bit_length, 1) + read_value = bitstream.read(bool if type(value) is list else type(value), bit_length) + else: + read_value = bitstream.read(bool if type(value) is list else type(value), bit_length, bit_index) + + assert read_value == ([value] if type(value) is bool else value) + + +def test_seek(): + bitstream = Bitstream([True] * 8) + + bitstream.seek(4, 0) + assert bitstream._offset == 4 + + bitstream.seek(-1, 1) + assert bitstream._offset == 3 + + bitstream.seek(-1, 2) + assert bitstream._offset == 7