sampy3/sampy/raknet/bitstream.py

302 lines
9.1 KiB
Python

from __future__ import annotations
from typing import Any, List, Union, Optional, Literal, Type
# for seek()
SEEK_SET = 0
SEEK_CUR = 1
SEEK_END = 2
def get_bit_length(value: Union[bool, List[bool], bytes, str, Bitstream]) -> int:
if type(value) is bool:
return 1
elif type(value) is list and all(type(v) is bool for v in value):
return len(value)
elif type(value) in (bytes, str):
return len(value) << 3
elif type(value) is Bitstream:
return len(value)
else:
raise TypeError("Invalid data type")
def bits_to_int(bits: List[bool], little_endian: bool = False) -> int:
num = 0
if little_endian:
bits = bits[::-1]
for bit in bits:
num = (num << 1) | bit
return num
def int_to_bits(number: int, bit_length: int, little_endian: bool = False) -> List[bool]:
bits = [bool(number & (0b1 << i)) for i in range(bit_length-1, -1, -1)]
return bits[::-1] if little_endian else bits
class Bitstream:
def __init__(self, *values: Union[bool, List[bool], bytes, str, Bitstream]):
self._bytearray = bytearray()
self._bit_length = 0
self._offset = 0
for item in values:
if type(item) is bool:
self.append(item)
else:
self.extend(item)
@staticmethod
def from_bytes(value: bytes) -> Bitstream:
bitstream = Bitstream()
bitstream._bytearray += value
bitstream._bit_length = len(bitstream._bytearray) << 3
return bitstream
@staticmethod
def from_int(value: int, bit_length: int) -> Bitstream:
return Bitstream([bool(value & (0b1 << i)) for i in range(bit_length-1, -1, -1)])
def __len__(self) -> int:
return self._bit_length
def __bytes__(self) -> bytes:
while len(self._bytearray) > ((len(self) + 7) >> 3):
self._bytearray.pop()
return bytes(self._bytearray)
def __str__(self) -> str:
return bytes(self).decode()
def __int__(self) -> int:
num = 0
for bit in self:
num = (num << 1) | bit
return num
def __getitem__(self, bit_index: int) -> bool:
if bit_index < 0:
bit_index = len(self) + bit_index
if bit_index >= self._bit_length or bit_index < 0:
raise IndexError("bit index out of range")
mask = 1 << (7 - (bit_index % 8))
return bool(self._bytearray[bit_index >> 3] & mask)
def __setitem__(self, bit_index: int, value: Union[bool, List[bool], bytes, str, Bitstream]):
if bit_index >= self._bit_length:
raise IndexError("bit index out of range")
if type(value) is str:
value = value.encode()
if type(value) is bytes:
self[bit_index] = Bitstream.from_bytes(value)
return
value_bit_length = get_bit_length(value)
if (bit_index + value_bit_length) > self._bit_length:
raise IndexError("Cannot write bits that extends the size of bitstream")
if type(value) is bool:
mask = 1 << (7 - (bit_index % 8))
if value:
self._bytearray[bit_index >> 3] |= mask
else:
self._bytearray[bit_index >> 3] &= ~mask
elif isinstance(value, Bitstream) or (type(value) is list and all(type(v) is bool for v in value)):
for i in range(value_bit_length):
self[bit_index + i] = value[i]
else:
raise TypeError("Expected bool, list[bool], bytes, str or Bitstream")
def __eq__(self, other: Any) -> bool:
if type(other) is bool:
other = [other]
if type(other) is str:
other = other.encode()
if isinstance(other, Bitstream):
pass
elif type(other) is bytes:
pass
elif type(other) is list and all(type(v) is bool for v in other):
pass
else:
raise TypeError("Expected Bitstream, bytes, str or list[bool]")
if isinstance(other, Bitstream) or type(other) is list:
if len(self) != len(other):
return False
for left, right in zip(self, other):
if left != right:
return False
else:
if bytes(self) != other:
return False
return True
def __mul__(self, number: int) -> Bitstream:
copy = self.copy()
for _ in range(1, number):
copy.extend(self)
return copy
def __imul__(self, number: int) -> Bitstream:
copy = self.copy()
for _ in range(1, number):
self.extend(copy)
return self
def __add__(self, value: Union[bool, List[bool], bytes, str, Bitstream]) -> Bitstream:
copy = self.copy()
if type(value) is bool:
copy.append(value)
else:
copy.extend(value)
return copy
def __iadd__(self, value: Union[bool, List[bool], bytes, str, Bitstream]) -> Bitstream:
if type(value) is bool:
self.append(value)
else:
self.extend(value)
return self
def __iter__(self) -> Bitstream:
self._iter_index = 0
return self
def __next__(self) -> bool:
if self._iter_index >= self._bit_length:
raise StopIteration
bit = self[self._iter_index]
self._iter_index += 1
return bit
def __repr__(self) -> str:
return "<Bitstream addr:0x%012x offset:%d len:%d data:'%s'>" % (
id(self), self._offset, len(self), bytes(self).hex(" ")
)
@property
def bits(self) -> List[bool]:
return [b for b in self]
def copy(self) -> Bitstream:
bitstream = Bitstream()
bitstream._bytearray = self._bytearray.copy()
bitstream._bit_length = self._bit_length
bitstream._offset = self._offset
return bitstream
def append(self, value: bool):
if type(value) is not bool:
raise TypeError("Expected bool")
while (len(self._bytearray) << 3) < (self._bit_length + 1):
self._bytearray.append(0)
prev_bit_length = self._bit_length
self._bit_length += 1
self[prev_bit_length] = value
def extend(self, value: Union[List[bool], bytes, str, Bitstream]):
if type(value) in (bytes, str):
pass
elif isinstance(value, Bitstream):
pass
elif type(value) is list and all(type(v) is bool for v in value):
pass
else:
raise TypeError("Expected list[bool], bytes, str or Bitstream")
value_bit_length = get_bit_length(value)
if value_bit_length == 0:
return
while (len(self._bytearray) << 3) < (self._bit_length + value_bit_length):
self._bytearray.append(0)
prev_bit_length = self._bit_length
self._bit_length += value_bit_length
self[prev_bit_length] = value
def clear(self):
self._bytearray.clear()
self._bit_length = 0
self._offset = 0
def pop(self) -> bool:
value = self[-1]
self[-1] = False
self._bit_length -= 1
return value
def read(
self,
type: Type[Union[bool, bytes, str, Bitstream, int]],
bit_length: int,
bit_index: Optional[int] = None
) -> Union[List[bool], bytes, str, Bitstream, int]:
start = self._offset if bit_index is None else bit_index
if (start + bit_length) > self._bit_length or (start < 0):
raise IndexError("bit index out of range")
bits = [self[i] for i in range(start, start + bit_length)]
if bit_index is None:
self._offset += bit_length
if type is bool:
return bits
bitstream = Bitstream(bits)
if type is Bitstream:
return bitstream
elif type is bytes:
return bytes(bitstream)
elif type is str:
return str(bitstream)
elif type is int:
return int(bitstream)
raise TypeError("Invalid data type")
def write(self, value: Union[bool, List[bool], bytes, str, Bitstream], bit_index: Optional[int] = None):
start = self._offset if bit_index is None else bit_index
if start < 0:
raise IndexError("bit index out of range")
value_bit_length = get_bit_length(value)
if value_bit_length == 0:
return
while (len(self._bytearray) << 3) < (start + value_bit_length):
self._bytearray.append(0)
if bit_index is None:
self._offset += value_bit_length
_bit_length = self._offset
else:
_bit_length = start + value_bit_length
if _bit_length > self._bit_length:
self._bit_length = _bit_length
self[start] = value
def seek(self, position: int, whence: Literal[0, 1, 2] = 0):
if whence == 0:
self._offset = position
elif whence == 1:
self._offset += position
elif whence == 2:
self._offset = len(self) + position
else:
raise ValueError("Invalid whence")
self._offset = min(max(self._offset, 0), len(self))