2023-02-24 23:56:09 +01:00
|
|
|
from __future__ import annotations
|
|
|
|
|
2023-03-15 06:06:58 +01:00
|
|
|
from typing import Any, List, Literal, Optional, Type, Union
|
2023-02-24 23:56:09 +01:00
|
|
|
|
|
|
|
# for seek()
|
|
|
|
SEEK_SET = 0
|
|
|
|
SEEK_CUR = 1
|
|
|
|
SEEK_END = 2
|
|
|
|
|
|
|
|
|
2023-03-15 03:43:18 +01:00
|
|
|
def get_bit_length(value: Union[bool, List[bool], bytes, Bitstream]) -> int:
|
2023-02-24 23:56:09 +01:00
|
|
|
if type(value) is bool:
|
|
|
|
return 1
|
|
|
|
elif type(value) is list and all(type(v) is bool for v in value):
|
|
|
|
return len(value)
|
2023-03-15 03:43:18 +01:00
|
|
|
elif type(value) is bytes:
|
2023-02-24 23:56:09 +01:00
|
|
|
return len(value) << 3
|
|
|
|
elif type(value) is Bitstream:
|
|
|
|
return len(value)
|
|
|
|
else:
|
|
|
|
raise TypeError("Invalid data type")
|
|
|
|
|
|
|
|
|
|
|
|
def bits_to_int(bits: List[bool], little_endian: bool = False) -> int:
|
|
|
|
num = 0
|
|
|
|
if little_endian:
|
|
|
|
bits = bits[::-1]
|
|
|
|
for bit in bits:
|
|
|
|
num = (num << 1) | bit
|
|
|
|
return num
|
|
|
|
|
|
|
|
|
2023-03-15 06:06:58 +01:00
|
|
|
def int_to_bits(
|
|
|
|
number: int, bit_length: int, little_endian: bool = False
|
|
|
|
) -> List[bool]:
|
|
|
|
bits = [bool(number & (0b1 << i)) for i in range(bit_length - 1, -1, -1)]
|
2023-02-24 23:56:09 +01:00
|
|
|
return bits[::-1] if little_endian else bits
|
|
|
|
|
|
|
|
|
|
|
|
class Bitstream:
|
2023-03-15 03:43:18 +01:00
|
|
|
def __init__(self, *values: Union[bool, List[bool], bytes, Bitstream]):
|
2023-02-24 23:56:09 +01:00
|
|
|
self._bytearray = bytearray()
|
|
|
|
self._bit_length = 0
|
|
|
|
self._offset = 0
|
|
|
|
|
|
|
|
for item in values:
|
|
|
|
if type(item) is bool:
|
|
|
|
self.append(item)
|
|
|
|
else:
|
|
|
|
self.extend(item)
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def from_bytes(value: bytes) -> Bitstream:
|
|
|
|
bitstream = Bitstream()
|
|
|
|
bitstream._bytearray += value
|
|
|
|
bitstream._bit_length = len(bitstream._bytearray) << 3
|
|
|
|
return bitstream
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def from_int(value: int, bit_length: int) -> Bitstream:
|
2023-03-15 06:06:58 +01:00
|
|
|
return Bitstream(
|
|
|
|
[bool(value & (0b1 << i)) for i in range(bit_length - 1, -1, -1)]
|
|
|
|
)
|
2023-02-24 23:56:09 +01:00
|
|
|
|
|
|
|
def __len__(self) -> int:
|
|
|
|
return self._bit_length
|
|
|
|
|
|
|
|
def __bytes__(self) -> bytes:
|
|
|
|
while len(self._bytearray) > ((len(self) + 7) >> 3):
|
|
|
|
self._bytearray.pop()
|
|
|
|
return bytes(self._bytearray)
|
|
|
|
|
|
|
|
def __int__(self) -> int:
|
|
|
|
num = 0
|
|
|
|
for bit in self:
|
|
|
|
num = (num << 1) | bit
|
|
|
|
return num
|
|
|
|
|
|
|
|
def __getitem__(self, bit_index: int) -> bool:
|
|
|
|
if bit_index < 0:
|
|
|
|
bit_index = len(self) + bit_index
|
|
|
|
|
|
|
|
if bit_index >= self._bit_length or bit_index < 0:
|
|
|
|
raise IndexError("bit index out of range")
|
|
|
|
|
|
|
|
mask = 1 << (7 - (bit_index % 8))
|
|
|
|
return bool(self._bytearray[bit_index >> 3] & mask)
|
|
|
|
|
2023-03-15 06:06:58 +01:00
|
|
|
def __setitem__(
|
|
|
|
self, bit_index: int, value: Union[bool, List[bool], bytes, Bitstream]
|
|
|
|
):
|
2023-02-24 23:56:09 +01:00
|
|
|
if bit_index >= self._bit_length:
|
|
|
|
raise IndexError("bit index out of range")
|
|
|
|
|
|
|
|
if type(value) is bytes:
|
|
|
|
self[bit_index] = Bitstream.from_bytes(value)
|
|
|
|
return
|
|
|
|
|
|
|
|
value_bit_length = get_bit_length(value)
|
|
|
|
|
|
|
|
if (bit_index + value_bit_length) > self._bit_length:
|
|
|
|
raise IndexError("Cannot write bits that extends the size of bitstream")
|
|
|
|
|
|
|
|
if type(value) is bool:
|
|
|
|
mask = 1 << (7 - (bit_index % 8))
|
|
|
|
if value:
|
|
|
|
self._bytearray[bit_index >> 3] |= mask
|
|
|
|
else:
|
|
|
|
self._bytearray[bit_index >> 3] &= ~mask
|
2023-03-15 06:06:58 +01:00
|
|
|
elif isinstance(value, Bitstream) or (
|
|
|
|
type(value) is list and all(type(v) is bool for v in value)
|
|
|
|
):
|
2023-02-24 23:56:09 +01:00
|
|
|
for i in range(value_bit_length):
|
|
|
|
self[bit_index + i] = value[i]
|
|
|
|
else:
|
2023-03-15 03:43:18 +01:00
|
|
|
raise TypeError("Expected bool, list[bool], bytes or Bitstream")
|
2023-02-24 23:56:09 +01:00
|
|
|
|
|
|
|
def __eq__(self, other: Any) -> bool:
|
|
|
|
if type(other) is bool:
|
|
|
|
other = [other]
|
|
|
|
|
|
|
|
if isinstance(other, Bitstream):
|
|
|
|
pass
|
|
|
|
elif type(other) is bytes:
|
|
|
|
pass
|
|
|
|
elif type(other) is list and all(type(v) is bool for v in other):
|
|
|
|
pass
|
|
|
|
else:
|
2023-03-15 03:43:18 +01:00
|
|
|
raise TypeError("Expected Bitstream, bytes or list[bool]")
|
2023-02-24 23:56:09 +01:00
|
|
|
|
|
|
|
if isinstance(other, Bitstream) or type(other) is list:
|
|
|
|
if len(self) != len(other):
|
|
|
|
return False
|
|
|
|
for left, right in zip(self, other):
|
|
|
|
if left != right:
|
|
|
|
return False
|
|
|
|
else:
|
|
|
|
if bytes(self) != other:
|
|
|
|
return False
|
|
|
|
return True
|
|
|
|
|
|
|
|
def __mul__(self, number: int) -> Bitstream:
|
|
|
|
copy = self.copy()
|
|
|
|
for _ in range(1, number):
|
|
|
|
copy.extend(self)
|
|
|
|
return copy
|
|
|
|
|
|
|
|
def __imul__(self, number: int) -> Bitstream:
|
|
|
|
copy = self.copy()
|
|
|
|
for _ in range(1, number):
|
|
|
|
self.extend(copy)
|
|
|
|
return self
|
|
|
|
|
2023-03-15 03:43:18 +01:00
|
|
|
def __add__(self, value: Union[bool, List[bool], bytes, Bitstream]) -> Bitstream:
|
2023-02-24 23:56:09 +01:00
|
|
|
copy = self.copy()
|
|
|
|
if type(value) is bool:
|
|
|
|
copy.append(value)
|
|
|
|
else:
|
|
|
|
copy.extend(value)
|
|
|
|
return copy
|
|
|
|
|
2023-03-15 03:43:18 +01:00
|
|
|
def __iadd__(self, value: Union[bool, List[bool], bytes, Bitstream]) -> Bitstream:
|
2023-02-24 23:56:09 +01:00
|
|
|
if type(value) is bool:
|
|
|
|
self.append(value)
|
|
|
|
else:
|
|
|
|
self.extend(value)
|
|
|
|
return self
|
|
|
|
|
|
|
|
def __iter__(self) -> Bitstream:
|
2023-03-15 03:33:24 +01:00
|
|
|
for i in range(len(self)):
|
|
|
|
yield self[i]
|
2023-02-24 23:56:09 +01:00
|
|
|
|
|
|
|
def __repr__(self) -> str:
|
|
|
|
return "<Bitstream addr:0x%012x offset:%d len:%d data:'%s'>" % (
|
2023-03-15 06:06:58 +01:00
|
|
|
id(self),
|
|
|
|
self._offset,
|
|
|
|
len(self),
|
|
|
|
bytes(self).hex(" "),
|
2023-02-24 23:56:09 +01:00
|
|
|
)
|
|
|
|
|
|
|
|
@property
|
|
|
|
def bits(self) -> List[bool]:
|
|
|
|
return [b for b in self]
|
|
|
|
|
|
|
|
def copy(self) -> Bitstream:
|
|
|
|
bitstream = Bitstream()
|
|
|
|
bitstream._bytearray = self._bytearray.copy()
|
|
|
|
bitstream._bit_length = self._bit_length
|
|
|
|
bitstream._offset = self._offset
|
|
|
|
return bitstream
|
|
|
|
|
|
|
|
def append(self, value: bool):
|
|
|
|
if type(value) is not bool:
|
|
|
|
raise TypeError("Expected bool")
|
|
|
|
|
|
|
|
while (len(self._bytearray) << 3) < (self._bit_length + 1):
|
|
|
|
self._bytearray.append(0)
|
|
|
|
prev_bit_length = self._bit_length
|
|
|
|
self._bit_length += 1
|
|
|
|
self[prev_bit_length] = value
|
|
|
|
|
2023-03-15 03:43:18 +01:00
|
|
|
def extend(self, value: Union[List[bool], bytes, Bitstream]):
|
|
|
|
if type(value) is bytes:
|
2023-02-24 23:56:09 +01:00
|
|
|
pass
|
|
|
|
elif isinstance(value, Bitstream):
|
|
|
|
pass
|
|
|
|
elif type(value) is list and all(type(v) is bool for v in value):
|
|
|
|
pass
|
|
|
|
else:
|
2023-03-15 03:43:18 +01:00
|
|
|
raise TypeError("Expected list[bool], bytes or Bitstream")
|
2023-02-24 23:56:09 +01:00
|
|
|
|
|
|
|
value_bit_length = get_bit_length(value)
|
|
|
|
if value_bit_length == 0:
|
|
|
|
return
|
|
|
|
|
|
|
|
while (len(self._bytearray) << 3) < (self._bit_length + value_bit_length):
|
|
|
|
self._bytearray.append(0)
|
|
|
|
prev_bit_length = self._bit_length
|
|
|
|
self._bit_length += value_bit_length
|
|
|
|
self[prev_bit_length] = value
|
|
|
|
|
|
|
|
def clear(self):
|
|
|
|
self._bytearray.clear()
|
|
|
|
self._bit_length = 0
|
|
|
|
self._offset = 0
|
|
|
|
|
|
|
|
def pop(self) -> bool:
|
|
|
|
value = self[-1]
|
|
|
|
self[-1] = False
|
|
|
|
self._bit_length -= 1
|
|
|
|
return value
|
|
|
|
|
|
|
|
def read(
|
2023-03-15 06:06:58 +01:00
|
|
|
self,
|
|
|
|
type: Type[Union[bool, bytes, Bitstream, int]],
|
|
|
|
bit_length: int,
|
|
|
|
bit_index: Optional[int] = None,
|
|
|
|
) -> Union[List[bool], bytes, Bitstream, int]:
|
2023-02-24 23:56:09 +01:00
|
|
|
start = self._offset if bit_index is None else bit_index
|
|
|
|
|
|
|
|
if (start + bit_length) > self._bit_length or (start < 0):
|
|
|
|
raise IndexError("bit index out of range")
|
|
|
|
|
|
|
|
bits = [self[i] for i in range(start, start + bit_length)]
|
|
|
|
|
|
|
|
if bit_index is None:
|
|
|
|
self._offset += bit_length
|
|
|
|
|
|
|
|
if type is bool:
|
|
|
|
return bits
|
|
|
|
|
|
|
|
bitstream = Bitstream(bits)
|
|
|
|
if type is Bitstream:
|
|
|
|
return bitstream
|
|
|
|
elif type is bytes:
|
|
|
|
return bytes(bitstream)
|
|
|
|
elif type is int:
|
|
|
|
return int(bitstream)
|
|
|
|
|
|
|
|
raise TypeError("Invalid data type")
|
|
|
|
|
2023-03-15 06:06:58 +01:00
|
|
|
def write(
|
|
|
|
self,
|
|
|
|
value: Union[bool, List[bool], bytes, Bitstream],
|
|
|
|
bit_index: Optional[int] = None,
|
|
|
|
):
|
2023-02-24 23:56:09 +01:00
|
|
|
start = self._offset if bit_index is None else bit_index
|
|
|
|
|
|
|
|
if start < 0:
|
|
|
|
raise IndexError("bit index out of range")
|
|
|
|
|
|
|
|
value_bit_length = get_bit_length(value)
|
|
|
|
if value_bit_length == 0:
|
|
|
|
return
|
|
|
|
|
|
|
|
while (len(self._bytearray) << 3) < (start + value_bit_length):
|
|
|
|
self._bytearray.append(0)
|
|
|
|
|
|
|
|
if bit_index is None:
|
|
|
|
self._offset += value_bit_length
|
|
|
|
_bit_length = self._offset
|
|
|
|
else:
|
|
|
|
_bit_length = start + value_bit_length
|
|
|
|
|
|
|
|
if _bit_length > self._bit_length:
|
|
|
|
self._bit_length = _bit_length
|
|
|
|
self[start] = value
|
|
|
|
|
|
|
|
def seek(self, position: int, whence: Literal[0, 1, 2] = 0):
|
|
|
|
if whence == 0:
|
|
|
|
self._offset = position
|
|
|
|
elif whence == 1:
|
|
|
|
self._offset += position
|
|
|
|
elif whence == 2:
|
|
|
|
self._offset = len(self) + position
|
|
|
|
else:
|
|
|
|
raise ValueError("Invalid whence")
|
|
|
|
|
|
|
|
self._offset = min(max(self._offset, 0), len(self))
|