Added Bitstream w/tests
This commit is contained in:
0
sampy/raknet/__init__.py
Normal file
0
sampy/raknet/__init__.py
Normal file
301
sampy/raknet/bitstream.py
Normal file
301
sampy/raknet/bitstream.py
Normal file
@@ -0,0 +1,301 @@
|
||||
from __future__ import annotations
|
||||
from typing import Any, List, Union, Optional, Literal, Type
|
||||
|
||||
|
||||
# for seek()
|
||||
SEEK_SET = 0
|
||||
SEEK_CUR = 1
|
||||
SEEK_END = 2
|
||||
|
||||
|
||||
def get_bit_length(value: Union[bool, List[bool], bytes, str, Bitstream]) -> int:
|
||||
if type(value) is bool:
|
||||
return 1
|
||||
elif type(value) is list and all(type(v) is bool for v in value):
|
||||
return len(value)
|
||||
elif type(value) in (bytes, str):
|
||||
return len(value) << 3
|
||||
elif type(value) is Bitstream:
|
||||
return len(value)
|
||||
else:
|
||||
raise TypeError("Invalid data type")
|
||||
|
||||
|
||||
def bits_to_int(bits: List[bool], little_endian: bool = False) -> int:
|
||||
num = 0
|
||||
if little_endian:
|
||||
bits = bits[::-1]
|
||||
for bit in bits:
|
||||
num = (num << 1) | bit
|
||||
return num
|
||||
|
||||
|
||||
def int_to_bits(number: int, bit_length: int, little_endian: bool = False) -> List[bool]:
|
||||
bits = [bool(number & (0b1 << i)) for i in range(bit_length-1, -1, -1)]
|
||||
return bits[::-1] if little_endian else bits
|
||||
|
||||
|
||||
class Bitstream:
|
||||
def __init__(self, *values: Union[bool, List[bool], bytes, str, Bitstream]):
|
||||
self._bytearray = bytearray()
|
||||
self._bit_length = 0
|
||||
self._offset = 0
|
||||
|
||||
for item in values:
|
||||
if type(item) is bool:
|
||||
self.append(item)
|
||||
else:
|
||||
self.extend(item)
|
||||
|
||||
@staticmethod
|
||||
def from_bytes(value: bytes) -> Bitstream:
|
||||
bitstream = Bitstream()
|
||||
bitstream._bytearray += value
|
||||
bitstream._bit_length = len(bitstream._bytearray) << 3
|
||||
return bitstream
|
||||
|
||||
@staticmethod
|
||||
def from_int(value: int, bit_length: int) -> Bitstream:
|
||||
return Bitstream([bool(value & (0b1 << i)) for i in range(bit_length-1, -1, -1)])
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self._bit_length
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
while len(self._bytearray) > ((len(self) + 7) >> 3):
|
||||
self._bytearray.pop()
|
||||
return bytes(self._bytearray)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return bytes(self).decode()
|
||||
|
||||
def __int__(self) -> int:
|
||||
num = 0
|
||||
for bit in self:
|
||||
num = (num << 1) | bit
|
||||
return num
|
||||
|
||||
def __getitem__(self, bit_index: int) -> bool:
|
||||
if bit_index < 0:
|
||||
bit_index = len(self) + bit_index
|
||||
|
||||
if bit_index >= self._bit_length or bit_index < 0:
|
||||
raise IndexError("bit index out of range")
|
||||
|
||||
mask = 1 << (7 - (bit_index % 8))
|
||||
return bool(self._bytearray[bit_index >> 3] & mask)
|
||||
|
||||
def __setitem__(self, bit_index: int, value: Union[bool, List[bool], bytes, str, Bitstream]):
|
||||
if bit_index >= self._bit_length:
|
||||
raise IndexError("bit index out of range")
|
||||
|
||||
if type(value) is str:
|
||||
value = value.encode()
|
||||
|
||||
if type(value) is bytes:
|
||||
self[bit_index] = Bitstream.from_bytes(value)
|
||||
return
|
||||
|
||||
value_bit_length = get_bit_length(value)
|
||||
|
||||
if (bit_index + value_bit_length) > self._bit_length:
|
||||
raise IndexError("Cannot write bits that extends the size of bitstream")
|
||||
|
||||
if type(value) is bool:
|
||||
mask = 1 << (7 - (bit_index % 8))
|
||||
if value:
|
||||
self._bytearray[bit_index >> 3] |= mask
|
||||
else:
|
||||
self._bytearray[bit_index >> 3] &= ~mask
|
||||
elif isinstance(value, Bitstream) or (type(value) is list and all(type(v) is bool for v in value)):
|
||||
for i in range(value_bit_length):
|
||||
self[bit_index + i] = value[i]
|
||||
else:
|
||||
raise TypeError("Expected bool, list[bool], bytes, str or Bitstream")
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
if type(other) is bool:
|
||||
other = [other]
|
||||
if type(other) is str:
|
||||
other = other.encode()
|
||||
|
||||
if isinstance(other, Bitstream):
|
||||
pass
|
||||
elif type(other) is bytes:
|
||||
pass
|
||||
elif type(other) is list and all(type(v) is bool for v in other):
|
||||
pass
|
||||
else:
|
||||
raise TypeError("Expected Bitstream, bytes, str or list[bool]")
|
||||
|
||||
if isinstance(other, Bitstream) or type(other) is list:
|
||||
if len(self) != len(other):
|
||||
return False
|
||||
for left, right in zip(self, other):
|
||||
if left != right:
|
||||
return False
|
||||
else:
|
||||
if bytes(self) != other:
|
||||
return False
|
||||
return True
|
||||
|
||||
def __mul__(self, number: int) -> Bitstream:
|
||||
copy = self.copy()
|
||||
for _ in range(1, number):
|
||||
copy.extend(self)
|
||||
return copy
|
||||
|
||||
def __imul__(self, number: int) -> Bitstream:
|
||||
copy = self.copy()
|
||||
for _ in range(1, number):
|
||||
self.extend(copy)
|
||||
return self
|
||||
|
||||
def __add__(self, value: Union[bool, List[bool], bytes, str, Bitstream]) -> Bitstream:
|
||||
copy = self.copy()
|
||||
if type(value) is bool:
|
||||
copy.append(value)
|
||||
else:
|
||||
copy.extend(value)
|
||||
return copy
|
||||
|
||||
def __iadd__(self, value: Union[bool, List[bool], bytes, str, Bitstream]) -> Bitstream:
|
||||
if type(value) is bool:
|
||||
self.append(value)
|
||||
else:
|
||||
self.extend(value)
|
||||
return self
|
||||
|
||||
def __iter__(self) -> Bitstream:
|
||||
self._iter_index = 0
|
||||
return self
|
||||
|
||||
def __next__(self) -> bool:
|
||||
if self._iter_index >= self._bit_length:
|
||||
raise StopIteration
|
||||
bit = self[self._iter_index]
|
||||
self._iter_index += 1
|
||||
return bit
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "<Bitstream addr:0x%012x offset:%d len:%d data:'%s'>" % (
|
||||
id(self), self._offset, len(self), bytes(self).hex(" ")
|
||||
)
|
||||
|
||||
@property
|
||||
def bits(self) -> List[bool]:
|
||||
return [b for b in self]
|
||||
|
||||
def copy(self) -> Bitstream:
|
||||
bitstream = Bitstream()
|
||||
bitstream._bytearray = self._bytearray.copy()
|
||||
bitstream._bit_length = self._bit_length
|
||||
bitstream._offset = self._offset
|
||||
return bitstream
|
||||
|
||||
def append(self, value: bool):
|
||||
if type(value) is not bool:
|
||||
raise TypeError("Expected bool")
|
||||
|
||||
while (len(self._bytearray) << 3) < (self._bit_length + 1):
|
||||
self._bytearray.append(0)
|
||||
prev_bit_length = self._bit_length
|
||||
self._bit_length += 1
|
||||
self[prev_bit_length] = value
|
||||
|
||||
def extend(self, value: Union[List[bool], bytes, str, Bitstream]):
|
||||
if type(value) in (bytes, str):
|
||||
pass
|
||||
elif isinstance(value, Bitstream):
|
||||
pass
|
||||
elif type(value) is list and all(type(v) is bool for v in value):
|
||||
pass
|
||||
else:
|
||||
raise TypeError("Expected list[bool], bytes, str or Bitstream")
|
||||
|
||||
value_bit_length = get_bit_length(value)
|
||||
if value_bit_length == 0:
|
||||
return
|
||||
|
||||
while (len(self._bytearray) << 3) < (self._bit_length + value_bit_length):
|
||||
self._bytearray.append(0)
|
||||
prev_bit_length = self._bit_length
|
||||
self._bit_length += value_bit_length
|
||||
self[prev_bit_length] = value
|
||||
|
||||
def clear(self):
|
||||
self._bytearray.clear()
|
||||
self._bit_length = 0
|
||||
self._offset = 0
|
||||
|
||||
def pop(self) -> bool:
|
||||
value = self[-1]
|
||||
self[-1] = False
|
||||
self._bit_length -= 1
|
||||
return value
|
||||
|
||||
def read(
|
||||
self,
|
||||
type: Type[Union[bool, bytes, str, Bitstream, int]],
|
||||
bit_length: int,
|
||||
bit_index: Optional[int] = None
|
||||
) -> Union[List[bool], bytes, str, Bitstream, int]:
|
||||
start = self._offset if bit_index is None else bit_index
|
||||
|
||||
if (start + bit_length) > self._bit_length or (start < 0):
|
||||
raise IndexError("bit index out of range")
|
||||
|
||||
bits = [self[i] for i in range(start, start + bit_length)]
|
||||
|
||||
if bit_index is None:
|
||||
self._offset += bit_length
|
||||
|
||||
if type is bool:
|
||||
return bits
|
||||
|
||||
bitstream = Bitstream(bits)
|
||||
if type is Bitstream:
|
||||
return bitstream
|
||||
elif type is bytes:
|
||||
return bytes(bitstream)
|
||||
elif type is str:
|
||||
return str(bitstream)
|
||||
elif type is int:
|
||||
return int(bitstream)
|
||||
|
||||
raise TypeError("Invalid data type")
|
||||
|
||||
def write(self, value: Union[bool, List[bool], bytes, str, Bitstream], bit_index: Optional[int] = None):
|
||||
start = self._offset if bit_index is None else bit_index
|
||||
|
||||
if start < 0:
|
||||
raise IndexError("bit index out of range")
|
||||
|
||||
value_bit_length = get_bit_length(value)
|
||||
if value_bit_length == 0:
|
||||
return
|
||||
|
||||
while (len(self._bytearray) << 3) < (start + value_bit_length):
|
||||
self._bytearray.append(0)
|
||||
|
||||
if bit_index is None:
|
||||
self._offset += value_bit_length
|
||||
_bit_length = self._offset
|
||||
else:
|
||||
_bit_length = start + value_bit_length
|
||||
|
||||
if _bit_length > self._bit_length:
|
||||
self._bit_length = _bit_length
|
||||
self[start] = value
|
||||
|
||||
def seek(self, position: int, whence: Literal[0, 1, 2] = 0):
|
||||
if whence == 0:
|
||||
self._offset = position
|
||||
elif whence == 1:
|
||||
self._offset += position
|
||||
elif whence == 2:
|
||||
self._offset = len(self) + position
|
||||
else:
|
||||
raise ValueError("Invalid whence")
|
||||
|
||||
self._offset = min(max(self._offset, 0), len(self))
|
||||
Reference in New Issue
Block a user