Removed str type from bitstream (str is ambiguous)

This commit is contained in:
2023-03-15 03:43:18 +01:00
parent 404f71dddd
commit 446ab11a61
2 changed files with 36 additions and 57 deletions

View File

@@ -8,12 +8,12 @@ SEEK_CUR = 1
SEEK_END = 2
def get_bit_length(value: Union[bool, List[bool], bytes, str, Bitstream]) -> int:
def get_bit_length(value: Union[bool, List[bool], bytes, Bitstream]) -> int:
if type(value) is bool:
return 1
elif type(value) is list and all(type(v) is bool for v in value):
return len(value)
elif type(value) in (bytes, str):
elif type(value) is bytes:
return len(value) << 3
elif type(value) is Bitstream:
return len(value)
@@ -36,7 +36,7 @@ def int_to_bits(number: int, bit_length: int, little_endian: bool = False) -> Li
class Bitstream:
def __init__(self, *values: Union[bool, List[bool], bytes, str, Bitstream]):
def __init__(self, *values: Union[bool, List[bool], bytes, Bitstream]):
self._bytearray = bytearray()
self._bit_length = 0
self._offset = 0
@@ -66,9 +66,6 @@ class Bitstream:
self._bytearray.pop()
return bytes(self._bytearray)
def __str__(self) -> str:
return bytes(self).decode()
def __int__(self) -> int:
num = 0
for bit in self:
@@ -85,13 +82,10 @@ class Bitstream:
mask = 1 << (7 - (bit_index % 8))
return bool(self._bytearray[bit_index >> 3] & mask)
def __setitem__(self, bit_index: int, value: Union[bool, List[bool], bytes, str, Bitstream]):
def __setitem__(self, bit_index: int, value: Union[bool, List[bool], bytes, Bitstream]):
if bit_index >= self._bit_length:
raise IndexError("bit index out of range")
if type(value) is str:
value = value.encode()
if type(value) is bytes:
self[bit_index] = Bitstream.from_bytes(value)
return
@@ -111,13 +105,11 @@ class Bitstream:
for i in range(value_bit_length):
self[bit_index + i] = value[i]
else:
raise TypeError("Expected bool, list[bool], bytes, str or Bitstream")
raise TypeError("Expected bool, list[bool], bytes or Bitstream")
def __eq__(self, other: Any) -> bool:
if type(other) is bool:
other = [other]
if type(other) is str:
other = other.encode()
if isinstance(other, Bitstream):
pass
@@ -126,7 +118,7 @@ class Bitstream:
elif type(other) is list and all(type(v) is bool for v in other):
pass
else:
raise TypeError("Expected Bitstream, bytes, str or list[bool]")
raise TypeError("Expected Bitstream, bytes or list[bool]")
if isinstance(other, Bitstream) or type(other) is list:
if len(self) != len(other):
@@ -151,7 +143,7 @@ class Bitstream:
self.extend(copy)
return self
def __add__(self, value: Union[bool, List[bool], bytes, str, Bitstream]) -> Bitstream:
def __add__(self, value: Union[bool, List[bool], bytes, Bitstream]) -> Bitstream:
copy = self.copy()
if type(value) is bool:
copy.append(value)
@@ -159,7 +151,7 @@ class Bitstream:
copy.extend(value)
return copy
def __iadd__(self, value: Union[bool, List[bool], bytes, str, Bitstream]) -> Bitstream:
def __iadd__(self, value: Union[bool, List[bool], bytes, Bitstream]) -> Bitstream:
if type(value) is bool:
self.append(value)
else:
@@ -196,15 +188,15 @@ class Bitstream:
self._bit_length += 1
self[prev_bit_length] = value
def extend(self, value: Union[List[bool], bytes, str, Bitstream]):
if type(value) in (bytes, str):
def extend(self, value: Union[List[bool], bytes, Bitstream]):
if type(value) is bytes:
pass
elif isinstance(value, Bitstream):
pass
elif type(value) is list and all(type(v) is bool for v in value):
pass
else:
raise TypeError("Expected list[bool], bytes, str or Bitstream")
raise TypeError("Expected list[bool], bytes or Bitstream")
value_bit_length = get_bit_length(value)
if value_bit_length == 0:
@@ -229,10 +221,10 @@ class Bitstream:
def read(
self,
type: Type[Union[bool, bytes, str, Bitstream, int]],
type: Type[Union[bool, bytes, Bitstream, int]],
bit_length: int,
bit_index: Optional[int] = None
) -> Union[List[bool], bytes, str, Bitstream, int]:
) -> Union[List[bool], bytes, Bitstream, int]:
start = self._offset if bit_index is None else bit_index
if (start + bit_length) > self._bit_length or (start < 0):
@@ -251,14 +243,12 @@ class Bitstream:
return bitstream
elif type is bytes:
return bytes(bitstream)
elif type is str:
return str(bitstream)
elif type is int:
return int(bitstream)
raise TypeError("Invalid data type")
def write(self, value: Union[bool, List[bool], bytes, str, Bitstream], bit_index: Optional[int] = None):
def write(self, value: Union[bool, List[bool], bytes, Bitstream], bit_index: Optional[int] = None):
start = self._offset if bit_index is None else bit_index
if start < 0: