from __future__ import annotations from typing import Any, List, Literal, Optional, Type, Union # for seek() SEEK_SET = 0 SEEK_CUR = 1 SEEK_END = 2 def get_bit_length(value: Union[bool, List[bool], bytes, Bitstream]) -> int: if type(value) is bool: return 1 elif type(value) is list and all(type(v) is bool for v in value): return len(value) elif type(value) is bytes: return len(value) << 3 elif type(value) is Bitstream: return len(value) else: raise TypeError("Invalid data type") def bits_to_int(bits: List[bool], little_endian: bool = False) -> int: num = 0 if little_endian: bits = bits[::-1] for bit in bits: num = (num << 1) | bit return num def int_to_bits( number: int, bit_length: int, little_endian: bool = False ) -> List[bool]: bits = [bool(number & (0b1 << i)) for i in range(bit_length - 1, -1, -1)] return bits[::-1] if little_endian else bits class Bitstream: def __init__(self, *values: Union[bool, List[bool], bytes, Bitstream]): self._bytearray = bytearray() self._bit_length = 0 self._offset = 0 for item in values: if type(item) is bool: self.append(item) else: self.extend(item) @staticmethod def from_bytes(value: bytes) -> Bitstream: bitstream = Bitstream() bitstream._bytearray += value bitstream._bit_length = len(bitstream._bytearray) << 3 return bitstream @staticmethod def from_int(value: int, bit_length: int) -> Bitstream: return Bitstream( [bool(value & (0b1 << i)) for i in range(bit_length - 1, -1, -1)] ) def __len__(self) -> int: return self._bit_length def __bytes__(self) -> bytes: while len(self._bytearray) > ((len(self) + 7) >> 3): self._bytearray.pop() return bytes(self._bytearray) def __int__(self) -> int: num = 0 for bit in self: num = (num << 1) | bit return num def __getitem__(self, bit_index: int) -> bool: if bit_index < 0: bit_index = len(self) + bit_index if bit_index >= self._bit_length or bit_index < 0: raise IndexError("bit index out of range") mask = 1 << (7 - (bit_index % 8)) return bool(self._bytearray[bit_index >> 3] & mask) def __setitem__( self, bit_index: int, value: Union[bool, List[bool], bytes, Bitstream] ): if bit_index >= self._bit_length: raise IndexError("bit index out of range") if type(value) is bytes: self[bit_index] = Bitstream.from_bytes(value) return value_bit_length = get_bit_length(value) if (bit_index + value_bit_length) > self._bit_length: raise IndexError("Cannot write bits that extends the size of bitstream") if type(value) is bool: mask = 1 << (7 - (bit_index % 8)) if value: self._bytearray[bit_index >> 3] |= mask else: self._bytearray[bit_index >> 3] &= ~mask elif isinstance(value, Bitstream) or ( type(value) is list and all(type(v) is bool for v in value) ): for i in range(value_bit_length): self[bit_index + i] = value[i] else: raise TypeError("Expected bool, list[bool], bytes or Bitstream") def __eq__(self, other: Any) -> bool: if type(other) is bool: other = [other] if isinstance(other, Bitstream): pass elif type(other) is bytes: pass elif type(other) is list and all(type(v) is bool for v in other): pass else: raise TypeError("Expected Bitstream, bytes or list[bool]") if isinstance(other, Bitstream) or type(other) is list: if len(self) != len(other): return False for left, right in zip(self, other): if left != right: return False else: if bytes(self) != other: return False return True def __mul__(self, number: int) -> Bitstream: copy = self.copy() for _ in range(1, number): copy.extend(self) return copy def __imul__(self, number: int) -> Bitstream: copy = self.copy() for _ in range(1, number): self.extend(copy) return self def __add__(self, value: Union[bool, List[bool], bytes, Bitstream]) -> Bitstream: copy = self.copy() if type(value) is bool: copy.append(value) else: copy.extend(value) return copy def __iadd__(self, value: Union[bool, List[bool], bytes, Bitstream]) -> Bitstream: if type(value) is bool: self.append(value) else: self.extend(value) return self def __iter__(self) -> Bitstream: for i in range(len(self)): yield self[i] def __repr__(self) -> str: return "" % ( id(self), self._offset, len(self), bytes(self).hex(" "), ) @property def bits(self) -> List[bool]: return [b for b in self] def copy(self) -> Bitstream: bitstream = Bitstream() bitstream._bytearray = self._bytearray.copy() bitstream._bit_length = self._bit_length bitstream._offset = self._offset return bitstream def append(self, value: bool): if type(value) is not bool: raise TypeError("Expected bool") while (len(self._bytearray) << 3) < (self._bit_length + 1): self._bytearray.append(0) prev_bit_length = self._bit_length self._bit_length += 1 self[prev_bit_length] = value def extend(self, value: Union[List[bool], bytes, Bitstream]): if type(value) is bytes: pass elif isinstance(value, Bitstream): pass elif type(value) is list and all(type(v) is bool for v in value): pass else: raise TypeError("Expected list[bool], bytes or Bitstream") value_bit_length = get_bit_length(value) if value_bit_length == 0: return while (len(self._bytearray) << 3) < (self._bit_length + value_bit_length): self._bytearray.append(0) prev_bit_length = self._bit_length self._bit_length += value_bit_length self[prev_bit_length] = value def clear(self): self._bytearray.clear() self._bit_length = 0 self._offset = 0 def pop(self) -> bool: value = self[-1] self[-1] = False self._bit_length -= 1 return value def read( self, type: Type[Union[bool, bytes, Bitstream, int]], bit_length: int, bit_index: Optional[int] = None, ) -> Union[List[bool], bytes, Bitstream, int]: start = self._offset if bit_index is None else bit_index if (start + bit_length) > self._bit_length or (start < 0): raise IndexError("bit index out of range") bits = [self[i] for i in range(start, start + bit_length)] if bit_index is None: self._offset += bit_length if type is bool: return bits bitstream = Bitstream(bits) if type is Bitstream: return bitstream elif type is bytes: return bytes(bitstream) elif type is int: return int(bitstream) raise TypeError("Invalid data type") def write( self, value: Union[bool, List[bool], bytes, Bitstream], bit_index: Optional[int] = None, ): start = self._offset if bit_index is None else bit_index if start < 0: raise IndexError("bit index out of range") value_bit_length = get_bit_length(value) if value_bit_length == 0: return while (len(self._bytearray) << 3) < (start + value_bit_length): self._bytearray.append(0) if bit_index is None: self._offset += value_bit_length _bit_length = self._offset else: _bit_length = start + value_bit_length if _bit_length > self._bit_length: self._bit_length = _bit_length self[start] = value def seek(self, position: int, whence: Literal[0, 1, 2] = 0): if whence == 0: self._offset = position elif whence == 1: self._offset += position elif whence == 2: self._offset = len(self) + position else: raise ValueError("Invalid whence") self._offset = min(max(self._offset, 0), len(self))