sampy3/sampy/raknet/bitstream.py

300 lines
8.7 KiB
Python
Raw Normal View History

2023-02-24 23:56:09 +01:00
from __future__ import annotations
2023-03-15 06:06:58 +01:00
from typing import Any, List, Literal, Optional, Type, Union
2023-02-24 23:56:09 +01:00
# for seek()
SEEK_SET = 0
SEEK_CUR = 1
SEEK_END = 2
def get_bit_length(value: Union[bool, List[bool], bytes, Bitstream]) -> int:
2023-02-24 23:56:09 +01:00
if type(value) is bool:
return 1
elif type(value) is list and all(type(v) is bool for v in value):
return len(value)
elif type(value) is bytes:
2023-02-24 23:56:09 +01:00
return len(value) << 3
elif type(value) is Bitstream:
return len(value)
else:
raise TypeError("Invalid data type")
def bits_to_int(bits: List[bool], little_endian: bool = False) -> int:
num = 0
if little_endian:
bits = bits[::-1]
for bit in bits:
num = (num << 1) | bit
return num
2023-03-15 06:06:58 +01:00
def int_to_bits(
number: int, bit_length: int, little_endian: bool = False
) -> List[bool]:
bits = [bool(number & (0b1 << i)) for i in range(bit_length - 1, -1, -1)]
2023-02-24 23:56:09 +01:00
return bits[::-1] if little_endian else bits
class Bitstream:
def __init__(self, *values: Union[bool, List[bool], bytes, Bitstream]):
2023-02-24 23:56:09 +01:00
self._bytearray = bytearray()
self._bit_length = 0
self._offset = 0
for item in values:
if type(item) is bool:
self.append(item)
else:
self.extend(item)
@staticmethod
def from_bytes(value: bytes) -> Bitstream:
bitstream = Bitstream()
bitstream._bytearray += value
bitstream._bit_length = len(bitstream._bytearray) << 3
return bitstream
@staticmethod
def from_int(value: int, bit_length: int) -> Bitstream:
2023-03-15 06:06:58 +01:00
return Bitstream(
[bool(value & (0b1 << i)) for i in range(bit_length - 1, -1, -1)]
)
2023-02-24 23:56:09 +01:00
def __len__(self) -> int:
return self._bit_length
def __bytes__(self) -> bytes:
while len(self._bytearray) > ((len(self) + 7) >> 3):
self._bytearray.pop()
return bytes(self._bytearray)
def __int__(self) -> int:
num = 0
for bit in self:
num = (num << 1) | bit
return num
def __getitem__(self, bit_index: int) -> bool:
if bit_index < 0:
bit_index = len(self) + bit_index
if bit_index >= self._bit_length or bit_index < 0:
raise IndexError("bit index out of range")
mask = 1 << (7 - (bit_index % 8))
return bool(self._bytearray[bit_index >> 3] & mask)
2023-03-15 06:06:58 +01:00
def __setitem__(
self, bit_index: int, value: Union[bool, List[bool], bytes, Bitstream]
):
2023-02-24 23:56:09 +01:00
if bit_index >= self._bit_length:
raise IndexError("bit index out of range")
if type(value) is bytes:
self[bit_index] = Bitstream.from_bytes(value)
return
value_bit_length = get_bit_length(value)
if (bit_index + value_bit_length) > self._bit_length:
raise IndexError("Cannot write bits that extends the size of bitstream")
if type(value) is bool:
mask = 1 << (7 - (bit_index % 8))
if value:
self._bytearray[bit_index >> 3] |= mask
else:
self._bytearray[bit_index >> 3] &= ~mask
2023-03-15 06:06:58 +01:00
elif isinstance(value, Bitstream) or (
type(value) is list and all(type(v) is bool for v in value)
):
2023-02-24 23:56:09 +01:00
for i in range(value_bit_length):
self[bit_index + i] = value[i]
else:
raise TypeError("Expected bool, list[bool], bytes or Bitstream")
2023-02-24 23:56:09 +01:00
def __eq__(self, other: Any) -> bool:
if type(other) is bool:
other = [other]
if isinstance(other, Bitstream):
pass
elif type(other) is bytes:
pass
elif type(other) is list and all(type(v) is bool for v in other):
pass
else:
raise TypeError("Expected Bitstream, bytes or list[bool]")
2023-02-24 23:56:09 +01:00
if isinstance(other, Bitstream) or type(other) is list:
if len(self) != len(other):
return False
for left, right in zip(self, other):
if left != right:
return False
else:
if bytes(self) != other:
return False
return True
def __mul__(self, number: int) -> Bitstream:
copy = self.copy()
for _ in range(1, number):
copy.extend(self)
return copy
def __imul__(self, number: int) -> Bitstream:
copy = self.copy()
for _ in range(1, number):
self.extend(copy)
return self
def __add__(self, value: Union[bool, List[bool], bytes, Bitstream]) -> Bitstream:
2023-02-24 23:56:09 +01:00
copy = self.copy()
if type(value) is bool:
copy.append(value)
else:
copy.extend(value)
return copy
def __iadd__(self, value: Union[bool, List[bool], bytes, Bitstream]) -> Bitstream:
2023-02-24 23:56:09 +01:00
if type(value) is bool:
self.append(value)
else:
self.extend(value)
return self
def __iter__(self) -> Bitstream:
2023-03-15 03:33:24 +01:00
for i in range(len(self)):
yield self[i]
2023-02-24 23:56:09 +01:00
def __repr__(self) -> str:
return "<Bitstream addr:0x%012x offset:%d len:%d data:'%s'>" % (
2023-03-15 06:06:58 +01:00
id(self),
self._offset,
len(self),
bytes(self).hex(" "),
2023-02-24 23:56:09 +01:00
)
@property
def bits(self) -> List[bool]:
return [b for b in self]
def copy(self) -> Bitstream:
bitstream = Bitstream()
bitstream._bytearray = self._bytearray.copy()
bitstream._bit_length = self._bit_length
bitstream._offset = self._offset
return bitstream
def append(self, value: bool):
if type(value) is not bool:
raise TypeError("Expected bool")
while (len(self._bytearray) << 3) < (self._bit_length + 1):
self._bytearray.append(0)
prev_bit_length = self._bit_length
self._bit_length += 1
self[prev_bit_length] = value
def extend(self, value: Union[List[bool], bytes, Bitstream]):
if type(value) is bytes:
2023-02-24 23:56:09 +01:00
pass
elif isinstance(value, Bitstream):
pass
elif type(value) is list and all(type(v) is bool for v in value):
pass
else:
raise TypeError("Expected list[bool], bytes or Bitstream")
2023-02-24 23:56:09 +01:00
value_bit_length = get_bit_length(value)
if value_bit_length == 0:
return
while (len(self._bytearray) << 3) < (self._bit_length + value_bit_length):
self._bytearray.append(0)
prev_bit_length = self._bit_length
self._bit_length += value_bit_length
self[prev_bit_length] = value
def clear(self):
self._bytearray.clear()
self._bit_length = 0
self._offset = 0
def pop(self) -> bool:
value = self[-1]
self[-1] = False
self._bit_length -= 1
return value
def read(
2023-03-15 06:06:58 +01:00
self,
type: Type[Union[bool, bytes, Bitstream, int]],
bit_length: int,
bit_index: Optional[int] = None,
) -> Union[List[bool], bytes, Bitstream, int]:
2023-02-24 23:56:09 +01:00
start = self._offset if bit_index is None else bit_index
if (start + bit_length) > self._bit_length or (start < 0):
raise IndexError("bit index out of range")
bits = [self[i] for i in range(start, start + bit_length)]
if bit_index is None:
self._offset += bit_length
if type is bool:
return bits
bitstream = Bitstream(bits)
if type is Bitstream:
return bitstream
elif type is bytes:
return bytes(bitstream)
elif type is int:
return int(bitstream)
raise TypeError("Invalid data type")
2023-03-15 06:06:58 +01:00
def write(
self,
value: Union[bool, List[bool], bytes, Bitstream],
bit_index: Optional[int] = None,
):
2023-02-24 23:56:09 +01:00
start = self._offset if bit_index is None else bit_index
if start < 0:
raise IndexError("bit index out of range")
value_bit_length = get_bit_length(value)
if value_bit_length == 0:
return
while (len(self._bytearray) << 3) < (start + value_bit_length):
self._bytearray.append(0)
if bit_index is None:
self._offset += value_bit_length
_bit_length = self._offset
else:
_bit_length = start + value_bit_length
if _bit_length > self._bit_length:
self._bit_length = _bit_length
self[start] = value
def seek(self, position: int, whence: Literal[0, 1, 2] = 0):
if whence == 0:
self._offset = position
elif whence == 1:
self._offset += position
elif whence == 2:
self._offset = len(self) + position
else:
raise ValueError("Invalid whence")
self._offset = min(max(self._offset, 0), len(self))