Added Bitstream w/tests
This commit is contained in:
parent
0a0e8b3592
commit
73eaa0e89b
0
sampy/raknet/__init__.py
Normal file
0
sampy/raknet/__init__.py
Normal file
301
sampy/raknet/bitstream.py
Normal file
301
sampy/raknet/bitstream.py
Normal file
|
@ -0,0 +1,301 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
from typing import Any, List, Union, Optional, Literal, Type
|
||||||
|
|
||||||
|
|
||||||
|
# for seek()
|
||||||
|
SEEK_SET = 0
|
||||||
|
SEEK_CUR = 1
|
||||||
|
SEEK_END = 2
|
||||||
|
|
||||||
|
|
||||||
|
def get_bit_length(value: Union[bool, List[bool], bytes, str, Bitstream]) -> int:
|
||||||
|
if type(value) is bool:
|
||||||
|
return 1
|
||||||
|
elif type(value) is list and all(type(v) is bool for v in value):
|
||||||
|
return len(value)
|
||||||
|
elif type(value) in (bytes, str):
|
||||||
|
return len(value) << 3
|
||||||
|
elif type(value) is Bitstream:
|
||||||
|
return len(value)
|
||||||
|
else:
|
||||||
|
raise TypeError("Invalid data type")
|
||||||
|
|
||||||
|
|
||||||
|
def bits_to_int(bits: List[bool], little_endian: bool = False) -> int:
|
||||||
|
num = 0
|
||||||
|
if little_endian:
|
||||||
|
bits = bits[::-1]
|
||||||
|
for bit in bits:
|
||||||
|
num = (num << 1) | bit
|
||||||
|
return num
|
||||||
|
|
||||||
|
|
||||||
|
def int_to_bits(number: int, bit_length: int, little_endian: bool = False) -> List[bool]:
|
||||||
|
bits = [bool(number & (0b1 << i)) for i in range(bit_length-1, -1, -1)]
|
||||||
|
return bits[::-1] if little_endian else bits
|
||||||
|
|
||||||
|
|
||||||
|
class Bitstream:
|
||||||
|
def __init__(self, *values: Union[bool, List[bool], bytes, str, Bitstream]):
|
||||||
|
self._bytearray = bytearray()
|
||||||
|
self._bit_length = 0
|
||||||
|
self._offset = 0
|
||||||
|
|
||||||
|
for item in values:
|
||||||
|
if type(item) is bool:
|
||||||
|
self.append(item)
|
||||||
|
else:
|
||||||
|
self.extend(item)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_bytes(value: bytes) -> Bitstream:
|
||||||
|
bitstream = Bitstream()
|
||||||
|
bitstream._bytearray += value
|
||||||
|
bitstream._bit_length = len(bitstream._bytearray) << 3
|
||||||
|
return bitstream
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_int(value: int, bit_length: int) -> Bitstream:
|
||||||
|
return Bitstream([bool(value & (0b1 << i)) for i in range(bit_length-1, -1, -1)])
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return self._bit_length
|
||||||
|
|
||||||
|
def __bytes__(self) -> bytes:
|
||||||
|
while len(self._bytearray) > ((len(self) + 7) >> 3):
|
||||||
|
self._bytearray.pop()
|
||||||
|
return bytes(self._bytearray)
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return bytes(self).decode()
|
||||||
|
|
||||||
|
def __int__(self) -> int:
|
||||||
|
num = 0
|
||||||
|
for bit in self:
|
||||||
|
num = (num << 1) | bit
|
||||||
|
return num
|
||||||
|
|
||||||
|
def __getitem__(self, bit_index: int) -> bool:
|
||||||
|
if bit_index < 0:
|
||||||
|
bit_index = len(self) + bit_index
|
||||||
|
|
||||||
|
if bit_index >= self._bit_length or bit_index < 0:
|
||||||
|
raise IndexError("bit index out of range")
|
||||||
|
|
||||||
|
mask = 1 << (7 - (bit_index % 8))
|
||||||
|
return bool(self._bytearray[bit_index >> 3] & mask)
|
||||||
|
|
||||||
|
def __setitem__(self, bit_index: int, value: Union[bool, List[bool], bytes, str, Bitstream]):
|
||||||
|
if bit_index >= self._bit_length:
|
||||||
|
raise IndexError("bit index out of range")
|
||||||
|
|
||||||
|
if type(value) is str:
|
||||||
|
value = value.encode()
|
||||||
|
|
||||||
|
if type(value) is bytes:
|
||||||
|
self[bit_index] = Bitstream.from_bytes(value)
|
||||||
|
return
|
||||||
|
|
||||||
|
value_bit_length = get_bit_length(value)
|
||||||
|
|
||||||
|
if (bit_index + value_bit_length) > self._bit_length:
|
||||||
|
raise IndexError("Cannot write bits that extends the size of bitstream")
|
||||||
|
|
||||||
|
if type(value) is bool:
|
||||||
|
mask = 1 << (7 - (bit_index % 8))
|
||||||
|
if value:
|
||||||
|
self._bytearray[bit_index >> 3] |= mask
|
||||||
|
else:
|
||||||
|
self._bytearray[bit_index >> 3] &= ~mask
|
||||||
|
elif isinstance(value, Bitstream) or (type(value) is list and all(type(v) is bool for v in value)):
|
||||||
|
for i in range(value_bit_length):
|
||||||
|
self[bit_index + i] = value[i]
|
||||||
|
else:
|
||||||
|
raise TypeError("Expected bool, list[bool], bytes, str or Bitstream")
|
||||||
|
|
||||||
|
def __eq__(self, other: Any) -> bool:
|
||||||
|
if type(other) is bool:
|
||||||
|
other = [other]
|
||||||
|
if type(other) is str:
|
||||||
|
other = other.encode()
|
||||||
|
|
||||||
|
if isinstance(other, Bitstream):
|
||||||
|
pass
|
||||||
|
elif type(other) is bytes:
|
||||||
|
pass
|
||||||
|
elif type(other) is list and all(type(v) is bool for v in other):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
raise TypeError("Expected Bitstream, bytes, str or list[bool]")
|
||||||
|
|
||||||
|
if isinstance(other, Bitstream) or type(other) is list:
|
||||||
|
if len(self) != len(other):
|
||||||
|
return False
|
||||||
|
for left, right in zip(self, other):
|
||||||
|
if left != right:
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
if bytes(self) != other:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def __mul__(self, number: int) -> Bitstream:
|
||||||
|
copy = self.copy()
|
||||||
|
for _ in range(1, number):
|
||||||
|
copy.extend(self)
|
||||||
|
return copy
|
||||||
|
|
||||||
|
def __imul__(self, number: int) -> Bitstream:
|
||||||
|
copy = self.copy()
|
||||||
|
for _ in range(1, number):
|
||||||
|
self.extend(copy)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __add__(self, value: Union[bool, List[bool], bytes, str, Bitstream]) -> Bitstream:
|
||||||
|
copy = self.copy()
|
||||||
|
if type(value) is bool:
|
||||||
|
copy.append(value)
|
||||||
|
else:
|
||||||
|
copy.extend(value)
|
||||||
|
return copy
|
||||||
|
|
||||||
|
def __iadd__(self, value: Union[bool, List[bool], bytes, str, Bitstream]) -> Bitstream:
|
||||||
|
if type(value) is bool:
|
||||||
|
self.append(value)
|
||||||
|
else:
|
||||||
|
self.extend(value)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __iter__(self) -> Bitstream:
|
||||||
|
self._iter_index = 0
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __next__(self) -> bool:
|
||||||
|
if self._iter_index >= self._bit_length:
|
||||||
|
raise StopIteration
|
||||||
|
bit = self[self._iter_index]
|
||||||
|
self._iter_index += 1
|
||||||
|
return bit
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return "<Bitstream addr:0x%012x offset:%d len:%d data:'%s'>" % (
|
||||||
|
id(self), self._offset, len(self), bytes(self).hex(" ")
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def bits(self) -> List[bool]:
|
||||||
|
return [b for b in self]
|
||||||
|
|
||||||
|
def copy(self) -> Bitstream:
|
||||||
|
bitstream = Bitstream()
|
||||||
|
bitstream._bytearray = self._bytearray.copy()
|
||||||
|
bitstream._bit_length = self._bit_length
|
||||||
|
bitstream._offset = self._offset
|
||||||
|
return bitstream
|
||||||
|
|
||||||
|
def append(self, value: bool):
|
||||||
|
if type(value) is not bool:
|
||||||
|
raise TypeError("Expected bool")
|
||||||
|
|
||||||
|
while (len(self._bytearray) << 3) < (self._bit_length + 1):
|
||||||
|
self._bytearray.append(0)
|
||||||
|
prev_bit_length = self._bit_length
|
||||||
|
self._bit_length += 1
|
||||||
|
self[prev_bit_length] = value
|
||||||
|
|
||||||
|
def extend(self, value: Union[List[bool], bytes, str, Bitstream]):
|
||||||
|
if type(value) in (bytes, str):
|
||||||
|
pass
|
||||||
|
elif isinstance(value, Bitstream):
|
||||||
|
pass
|
||||||
|
elif type(value) is list and all(type(v) is bool for v in value):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
raise TypeError("Expected list[bool], bytes, str or Bitstream")
|
||||||
|
|
||||||
|
value_bit_length = get_bit_length(value)
|
||||||
|
if value_bit_length == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
while (len(self._bytearray) << 3) < (self._bit_length + value_bit_length):
|
||||||
|
self._bytearray.append(0)
|
||||||
|
prev_bit_length = self._bit_length
|
||||||
|
self._bit_length += value_bit_length
|
||||||
|
self[prev_bit_length] = value
|
||||||
|
|
||||||
|
def clear(self):
|
||||||
|
self._bytearray.clear()
|
||||||
|
self._bit_length = 0
|
||||||
|
self._offset = 0
|
||||||
|
|
||||||
|
def pop(self) -> bool:
|
||||||
|
value = self[-1]
|
||||||
|
self[-1] = False
|
||||||
|
self._bit_length -= 1
|
||||||
|
return value
|
||||||
|
|
||||||
|
def read(
|
||||||
|
self,
|
||||||
|
type: Type[Union[bool, bytes, str, Bitstream, int]],
|
||||||
|
bit_length: int,
|
||||||
|
bit_index: Optional[int] = None
|
||||||
|
) -> Union[List[bool], bytes, str, Bitstream, int]:
|
||||||
|
start = self._offset if bit_index is None else bit_index
|
||||||
|
|
||||||
|
if (start + bit_length) > self._bit_length or (start < 0):
|
||||||
|
raise IndexError("bit index out of range")
|
||||||
|
|
||||||
|
bits = [self[i] for i in range(start, start + bit_length)]
|
||||||
|
|
||||||
|
if bit_index is None:
|
||||||
|
self._offset += bit_length
|
||||||
|
|
||||||
|
if type is bool:
|
||||||
|
return bits
|
||||||
|
|
||||||
|
bitstream = Bitstream(bits)
|
||||||
|
if type is Bitstream:
|
||||||
|
return bitstream
|
||||||
|
elif type is bytes:
|
||||||
|
return bytes(bitstream)
|
||||||
|
elif type is str:
|
||||||
|
return str(bitstream)
|
||||||
|
elif type is int:
|
||||||
|
return int(bitstream)
|
||||||
|
|
||||||
|
raise TypeError("Invalid data type")
|
||||||
|
|
||||||
|
def write(self, value: Union[bool, List[bool], bytes, str, Bitstream], bit_index: Optional[int] = None):
|
||||||
|
start = self._offset if bit_index is None else bit_index
|
||||||
|
|
||||||
|
if start < 0:
|
||||||
|
raise IndexError("bit index out of range")
|
||||||
|
|
||||||
|
value_bit_length = get_bit_length(value)
|
||||||
|
if value_bit_length == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
while (len(self._bytearray) << 3) < (start + value_bit_length):
|
||||||
|
self._bytearray.append(0)
|
||||||
|
|
||||||
|
if bit_index is None:
|
||||||
|
self._offset += value_bit_length
|
||||||
|
_bit_length = self._offset
|
||||||
|
else:
|
||||||
|
_bit_length = start + value_bit_length
|
||||||
|
|
||||||
|
if _bit_length > self._bit_length:
|
||||||
|
self._bit_length = _bit_length
|
||||||
|
self[start] = value
|
||||||
|
|
||||||
|
def seek(self, position: int, whence: Literal[0, 1, 2] = 0):
|
||||||
|
if whence == 0:
|
||||||
|
self._offset = position
|
||||||
|
elif whence == 1:
|
||||||
|
self._offset += position
|
||||||
|
elif whence == 2:
|
||||||
|
self._offset = len(self) + position
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid whence")
|
||||||
|
|
||||||
|
self._offset = min(max(self._offset, 0), len(self))
|
389
tests/test_bitstream.py
Normal file
389
tests/test_bitstream.py
Normal file
|
@ -0,0 +1,389 @@
|
||||||
|
from typing import List, Union, Type, Optional
|
||||||
|
import pytest
|
||||||
|
from sampy.raknet.bitstream import Bitstream, get_bit_length
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("value", [
|
||||||
|
b"",
|
||||||
|
b"A",
|
||||||
|
b"AB",
|
||||||
|
b"\xFF\x00\xAA\x55",
|
||||||
|
])
|
||||||
|
def test_from_bytes(value: bytes):
|
||||||
|
bitstream = Bitstream.from_bytes(value)
|
||||||
|
assert len(bitstream) == (len(value) << 3)
|
||||||
|
assert bytes(bitstream) == value
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("value,length,expected", [
|
||||||
|
(0b0, 0, None),
|
||||||
|
(0b1, 1, None),
|
||||||
|
(0b0, 1, None),
|
||||||
|
(0b1011, 4, None),
|
||||||
|
(0b10111011, 8, None),
|
||||||
|
(0b10111011, 7, 0b0111011), # -> 0b0111011
|
||||||
|
])
|
||||||
|
def test_from_int(value: int, length: int, expected: int):
|
||||||
|
bitstream = Bitstream.from_int(value, length)
|
||||||
|
assert len(bitstream) == length
|
||||||
|
assert len(bytes(bitstream)) == ((length + 7) >> 3)
|
||||||
|
assert bitstream.bits == [bool(value & (1 << i)) for i in range(length-1, -1, -1)]
|
||||||
|
if expected is not None:
|
||||||
|
expected_bitstream = Bitstream.from_int(expected, length)
|
||||||
|
assert bitstream.bits == expected_bitstream.bits
|
||||||
|
assert bytes(bitstream) == bytes(expected_bitstream)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("values", [
|
||||||
|
[True, False, False, True, True],
|
||||||
|
[[True], [False, True], [False, False]],
|
||||||
|
[b"a", b"bc"],
|
||||||
|
["aa", "b"],
|
||||||
|
[Bitstream.from_bytes(b"A"), Bitstream.from_int(0b1011, 4), Bitstream.from_bytes(b"B")],
|
||||||
|
[True, b"A", "B", [False, True], Bitstream.from_int(0b10110, 5), False],
|
||||||
|
])
|
||||||
|
def test_init(values: List[Union[bool, List[bool], bytes, str, Bitstream]]):
|
||||||
|
bitstream = Bitstream(*values)
|
||||||
|
|
||||||
|
bitstream.seek(0)
|
||||||
|
for item in values:
|
||||||
|
bit_length = get_bit_length(item)
|
||||||
|
if type(item) is bool:
|
||||||
|
item = [item]
|
||||||
|
|
||||||
|
read_type = bool if type(item) is list else type(item)
|
||||||
|
value = bitstream.read(read_type, bit_length)
|
||||||
|
assert value == item
|
||||||
|
|
||||||
|
|
||||||
|
def test_getitem():
|
||||||
|
sample_values = [True, False, False, True, False, True, True, True, True, False]
|
||||||
|
|
||||||
|
for length in range(1, len(sample_values)):
|
||||||
|
values = sample_values[:length]
|
||||||
|
bitstream = Bitstream(values)
|
||||||
|
for i in range(length):
|
||||||
|
assert bitstream[i] == values[i]
|
||||||
|
assert bitstream[-i - 1] == values[-i - 1] # For inverse lookup
|
||||||
|
|
||||||
|
|
||||||
|
def test_getitem_index_error():
|
||||||
|
bitstream = Bitstream(True)
|
||||||
|
with pytest.raises(IndexError):
|
||||||
|
bitstream[1]
|
||||||
|
with pytest.raises(IndexError):
|
||||||
|
bitstream[-2] # Inverse IndexError
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("value", [
|
||||||
|
True,
|
||||||
|
[True, True],
|
||||||
|
b"A",
|
||||||
|
"B",
|
||||||
|
Bitstream(True, b"C"),
|
||||||
|
])
|
||||||
|
def test_setitem(value: Union[bool, List[bool], bytes, str, Bitstream]):
|
||||||
|
bit_length = get_bit_length(value)
|
||||||
|
read_type = bool if type(value) is list else type(value)
|
||||||
|
for length in range(10):
|
||||||
|
for i in range(length - bit_length + 1):
|
||||||
|
bitstream = Bitstream([False] * length)
|
||||||
|
bitstream[i] = value
|
||||||
|
assert bitstream.read(read_type, bit_length, i) == (value if type(value) is not bool else [value])
|
||||||
|
assert len(bitstream) == length
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("value,equals,not_equals", [
|
||||||
|
(
|
||||||
|
[],
|
||||||
|
[[]],
|
||||||
|
[[True], [False]]
|
||||||
|
),
|
||||||
|
(
|
||||||
|
True,
|
||||||
|
[[True]],
|
||||||
|
[[False], [True, False], b"A"]
|
||||||
|
),
|
||||||
|
(
|
||||||
|
[True, False],
|
||||||
|
[[True, False]],
|
||||||
|
[[False], [False, False], b"A"]
|
||||||
|
),
|
||||||
|
(
|
||||||
|
b"A",
|
||||||
|
[b"A", "A", Bitstream.from_bytes(b"A"), Bitstream.from_bytes(b"A").bits],
|
||||||
|
[[], [False], b"B", "B", Bitstream.from_bytes(b"B"), Bitstream.from_bytes(b"B").bits]
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"B",
|
||||||
|
[b"B", "B", Bitstream.from_bytes(b"B"), Bitstream.from_bytes(b"B").bits],
|
||||||
|
[[], [False], b"A", "A", Bitstream.from_bytes(b"A"), Bitstream.from_bytes(b"A").bits]
|
||||||
|
),
|
||||||
|
(
|
||||||
|
b"ABC",
|
||||||
|
[b"ABC"],
|
||||||
|
[[], [False], b"A"]
|
||||||
|
),
|
||||||
|
])
|
||||||
|
def test_eq(
|
||||||
|
value: Union[bool, List[bool], bytes, str, Bitstream],
|
||||||
|
equals: List[Union[bool, List[bool], bytes, str, Bitstream]],
|
||||||
|
not_equals: List[Union[bool, List[bool], bytes, str, Bitstream]]):
|
||||||
|
bitstream = Bitstream(value)
|
||||||
|
for equal in equals:
|
||||||
|
assert bitstream == equal
|
||||||
|
for not_equal in not_equals:
|
||||||
|
assert bitstream != not_equal
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("values,number,expected", [
|
||||||
|
([], 2, []),
|
||||||
|
([True, False], 2, [True, False, True, False]),
|
||||||
|
(b"A", 2, b"AA"),
|
||||||
|
([False], 3, [False, False, False]),
|
||||||
|
])
|
||||||
|
def test_mul(
|
||||||
|
values: Union[bool, List[bool], bytes, str, Bitstream],
|
||||||
|
number: int,
|
||||||
|
expected: Union[List[bool], bytes, str, Bitstream]):
|
||||||
|
bitstream = Bitstream(values)
|
||||||
|
|
||||||
|
assert (bitstream * number) == expected
|
||||||
|
assert bitstream == values
|
||||||
|
|
||||||
|
bitstream *= number
|
||||||
|
assert bitstream == expected
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("init_values,add_value,expected", [
|
||||||
|
([], True, [True]),
|
||||||
|
([True, False], False, [True, False, False]),
|
||||||
|
(b"A", b"B", b"AB"),
|
||||||
|
])
|
||||||
|
def test_add(
|
||||||
|
init_values: Union[bool, List[bool], bytes, str, Bitstream],
|
||||||
|
add_value: Union[bool, List[bool], bytes, str, Bitstream],
|
||||||
|
expected: Union[List[bool], bytes, str, Bitstream]):
|
||||||
|
bitstream = Bitstream(init_values)
|
||||||
|
|
||||||
|
assert (bitstream + add_value) == expected
|
||||||
|
assert len(bitstream) == get_bit_length(init_values)
|
||||||
|
|
||||||
|
bitstream += add_value
|
||||||
|
assert bitstream == expected
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("init_values", [
|
||||||
|
[],
|
||||||
|
[True, False],
|
||||||
|
[b"A"],
|
||||||
|
[True, b"B"],
|
||||||
|
])
|
||||||
|
def test_iter(init_values: List[Union[bool, List[bool], bytes, str, Bitstream]]):
|
||||||
|
bitstream = Bitstream(*init_values)
|
||||||
|
assert len(iter(bitstream)) == len(bitstream)
|
||||||
|
for index, bit in enumerate(bitstream):
|
||||||
|
assert bit == bitstream[index]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("init_values,expected", [
|
||||||
|
([], b""),
|
||||||
|
([b"A"], b"A"),
|
||||||
|
([False], b"\x00"),
|
||||||
|
([[False] * 6, True], b"\x02"),
|
||||||
|
([[False] * 7, True], b"\x01"),
|
||||||
|
])
|
||||||
|
def test_bytes(
|
||||||
|
init_values: List[Union[bool, List[bool], bytes, str, Bitstream]],
|
||||||
|
expected: bytes):
|
||||||
|
assert bytes(Bitstream(*init_values)) == expected
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("init_values,expected", [
|
||||||
|
([], []),
|
||||||
|
([b"\x01"], [False] * 7 + [True]),
|
||||||
|
([False], [False]),
|
||||||
|
([False] * 6 + [True], [False] * 6 + [True]),
|
||||||
|
([False] * 7 + [True], [False] * 7 + [True]),
|
||||||
|
])
|
||||||
|
def test_bits(
|
||||||
|
init_values: List[Union[bool, List[bool], bytes, str, Bitstream]],
|
||||||
|
expected: List[bool]):
|
||||||
|
assert Bitstream(*init_values).bits == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_copy():
|
||||||
|
bitstream = Bitstream(True, False)
|
||||||
|
bitstream_copy = bitstream.copy()
|
||||||
|
assert bitstream == bitstream_copy
|
||||||
|
assert id(bitstream) != id(bitstream.copy)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("lst", [
|
||||||
|
[False] * 10,
|
||||||
|
[True] * 10,
|
||||||
|
[True, False] * 5,
|
||||||
|
[False, True] * 5,
|
||||||
|
])
|
||||||
|
def test_append(lst: List[bool]):
|
||||||
|
bitstream = Bitstream()
|
||||||
|
for i in range(len(lst)):
|
||||||
|
bitstream.append(lst[i])
|
||||||
|
assert bitstream[i] == lst[i]
|
||||||
|
assert len(bitstream) == (i + 1)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("extend_value", [
|
||||||
|
[True, False, True],
|
||||||
|
b"A",
|
||||||
|
"B",
|
||||||
|
Bitstream(True, b"C"),
|
||||||
|
])
|
||||||
|
def test_extend(extend_value: Union[List[bool], bytes, str, Bitstream]):
|
||||||
|
for i in range(10):
|
||||||
|
bitstream = Bitstream([False] * i)
|
||||||
|
bitstream.extend(extend_value)
|
||||||
|
value = bitstream.read(bool if type(extend_value) is list else type(extend_value), get_bit_length(extend_value), i)
|
||||||
|
assert value == extend_value
|
||||||
|
|
||||||
|
|
||||||
|
def test_extend_error():
|
||||||
|
bitstream = Bitstream()
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
bitstream.extend(True)
|
||||||
|
|
||||||
|
|
||||||
|
def test_clear():
|
||||||
|
bitstream = Bitstream(b"A", [True, False, False, True])
|
||||||
|
bitstream._offset = 3
|
||||||
|
bitstream.clear()
|
||||||
|
assert len(bitstream) == 0
|
||||||
|
assert len(bytes(bitstream)) == 0
|
||||||
|
assert bitstream._offset == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("lst", [
|
||||||
|
[False] * 10,
|
||||||
|
[True] * 10,
|
||||||
|
[True, False] * 5,
|
||||||
|
[False, True] * 5,
|
||||||
|
])
|
||||||
|
def test_pop(lst: List[bool]):
|
||||||
|
bitstream = Bitstream(lst)
|
||||||
|
for i in range(len(bitstream)):
|
||||||
|
assert bitstream.pop() == lst[-i - 1]
|
||||||
|
|
||||||
|
with pytest.raises(IndexError):
|
||||||
|
assert bitstream.pop()
|
||||||
|
|
||||||
|
|
||||||
|
def test_read():
|
||||||
|
bitstream = Bitstream(True, b"A", False, True, False)
|
||||||
|
|
||||||
|
assert bitstream.read(bool, 1) == [True]
|
||||||
|
assert bitstream.read(bytes, 8) == b"A"
|
||||||
|
assert bitstream.read(int, 3) == 2
|
||||||
|
|
||||||
|
bitstream.seek(0)
|
||||||
|
assert bitstream.read(Bitstream, len(bitstream) - 1) == Bitstream(True, b"A", False, True)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("bitstream,read_type,bit_length,bit_index,expected", [
|
||||||
|
(
|
||||||
|
Bitstream(True, b"A", False, True, False),
|
||||||
|
bool, 1, 0,
|
||||||
|
[True]
|
||||||
|
),
|
||||||
|
(
|
||||||
|
Bitstream(True, b"A", False, True, False),
|
||||||
|
bool, 3, 8 + 1,
|
||||||
|
[False, True, False]
|
||||||
|
),
|
||||||
|
(
|
||||||
|
Bitstream(True, b"A", False, True, False),
|
||||||
|
bytes, 8, 1,
|
||||||
|
b"A"
|
||||||
|
),
|
||||||
|
(
|
||||||
|
Bitstream(True, b"A", False, True, False),
|
||||||
|
str, 8, 1,
|
||||||
|
"A"
|
||||||
|
),
|
||||||
|
(
|
||||||
|
Bitstream(True, b"A", False, True, False),
|
||||||
|
Bitstream, 1 + 8 + 3, 0,
|
||||||
|
Bitstream(True, b"A", False, True, False)
|
||||||
|
),
|
||||||
|
(
|
||||||
|
Bitstream(True, b"A", False, True, False),
|
||||||
|
int, 3, 1 + 8,
|
||||||
|
2
|
||||||
|
),
|
||||||
|
])
|
||||||
|
def test_read_index(
|
||||||
|
bitstream: Bitstream,
|
||||||
|
read_type: Type[Union[bool, bytes, str, Bitstream, int]],
|
||||||
|
bit_length: int,
|
||||||
|
bit_index: Optional[int],
|
||||||
|
expected: Union[List[bool], bytes, str, Bitstream, int]):
|
||||||
|
value = bitstream.read(read_type, bit_length, bit_index)
|
||||||
|
|
||||||
|
if read_type is bool:
|
||||||
|
assert type(value) == list
|
||||||
|
else:
|
||||||
|
assert type(value) == read_type
|
||||||
|
assert value == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_read_error():
|
||||||
|
bitstream = Bitstream(True, b"A", False, True, False)
|
||||||
|
|
||||||
|
with pytest.raises(IndexError):
|
||||||
|
bitstream.read(bool, 1, len(bitstream))
|
||||||
|
|
||||||
|
with pytest.raises(IndexError):
|
||||||
|
bitstream.read(bool, 1, -1)
|
||||||
|
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
bitstream.read(list, 1, 0)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("bitstream,value,bit_index,current_offset", [
|
||||||
|
(Bitstream(), True, None, None),
|
||||||
|
(Bitstream(), b"A", None, None),
|
||||||
|
(Bitstream(), True, 0, None),
|
||||||
|
(Bitstream(), True, 3, None),
|
||||||
|
(Bitstream(), True, 7, None),
|
||||||
|
(Bitstream(), True, 8, None),
|
||||||
|
(Bitstream(False), True, None, None),
|
||||||
|
(Bitstream(False), True, 0, None),
|
||||||
|
(Bitstream(False), True, 0, 0),
|
||||||
|
(Bitstream(False), True, None, 0),
|
||||||
|
])
|
||||||
|
def test_write(bitstream: Bitstream, value: Union[bool, List[bool], bytes, str, Bitstream], bit_index: Optional[int], current_offset: Optional[int]):
|
||||||
|
if current_offset is not None:
|
||||||
|
bitstream.seek(current_offset)
|
||||||
|
|
||||||
|
bit_length = get_bit_length(value)
|
||||||
|
bitstream.write(value, bit_index)
|
||||||
|
|
||||||
|
if bit_index is None:
|
||||||
|
bitstream.seek(-bit_length, 1)
|
||||||
|
read_value = bitstream.read(bool if type(value) is list else type(value), bit_length)
|
||||||
|
else:
|
||||||
|
read_value = bitstream.read(bool if type(value) is list else type(value), bit_length, bit_index)
|
||||||
|
|
||||||
|
assert read_value == ([value] if type(value) is bool else value)
|
||||||
|
|
||||||
|
|
||||||
|
def test_seek():
|
||||||
|
bitstream = Bitstream([True] * 8)
|
||||||
|
|
||||||
|
bitstream.seek(4, 0)
|
||||||
|
assert bitstream._offset == 4
|
||||||
|
|
||||||
|
bitstream.seek(-1, 1)
|
||||||
|
assert bitstream._offset == 3
|
||||||
|
|
||||||
|
bitstream.seek(-1, 2)
|
||||||
|
assert bitstream._offset == 7
|
Loading…
Reference in New Issue
Block a user