Removed str type from bitstream (str is ambiguous)
This commit is contained in:
parent
404f71dddd
commit
446ab11a61
|
@ -8,12 +8,12 @@ SEEK_CUR = 1
|
|||
SEEK_END = 2
|
||||
|
||||
|
||||
def get_bit_length(value: Union[bool, List[bool], bytes, str, Bitstream]) -> int:
|
||||
def get_bit_length(value: Union[bool, List[bool], bytes, Bitstream]) -> int:
|
||||
if type(value) is bool:
|
||||
return 1
|
||||
elif type(value) is list and all(type(v) is bool for v in value):
|
||||
return len(value)
|
||||
elif type(value) in (bytes, str):
|
||||
elif type(value) is bytes:
|
||||
return len(value) << 3
|
||||
elif type(value) is Bitstream:
|
||||
return len(value)
|
||||
|
@ -36,7 +36,7 @@ def int_to_bits(number: int, bit_length: int, little_endian: bool = False) -> Li
|
|||
|
||||
|
||||
class Bitstream:
|
||||
def __init__(self, *values: Union[bool, List[bool], bytes, str, Bitstream]):
|
||||
def __init__(self, *values: Union[bool, List[bool], bytes, Bitstream]):
|
||||
self._bytearray = bytearray()
|
||||
self._bit_length = 0
|
||||
self._offset = 0
|
||||
|
@ -66,9 +66,6 @@ class Bitstream:
|
|||
self._bytearray.pop()
|
||||
return bytes(self._bytearray)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return bytes(self).decode()
|
||||
|
||||
def __int__(self) -> int:
|
||||
num = 0
|
||||
for bit in self:
|
||||
|
@ -85,13 +82,10 @@ class Bitstream:
|
|||
mask = 1 << (7 - (bit_index % 8))
|
||||
return bool(self._bytearray[bit_index >> 3] & mask)
|
||||
|
||||
def __setitem__(self, bit_index: int, value: Union[bool, List[bool], bytes, str, Bitstream]):
|
||||
def __setitem__(self, bit_index: int, value: Union[bool, List[bool], bytes, Bitstream]):
|
||||
if bit_index >= self._bit_length:
|
||||
raise IndexError("bit index out of range")
|
||||
|
||||
if type(value) is str:
|
||||
value = value.encode()
|
||||
|
||||
if type(value) is bytes:
|
||||
self[bit_index] = Bitstream.from_bytes(value)
|
||||
return
|
||||
|
@ -111,13 +105,11 @@ class Bitstream:
|
|||
for i in range(value_bit_length):
|
||||
self[bit_index + i] = value[i]
|
||||
else:
|
||||
raise TypeError("Expected bool, list[bool], bytes, str or Bitstream")
|
||||
raise TypeError("Expected bool, list[bool], bytes or Bitstream")
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
if type(other) is bool:
|
||||
other = [other]
|
||||
if type(other) is str:
|
||||
other = other.encode()
|
||||
|
||||
if isinstance(other, Bitstream):
|
||||
pass
|
||||
|
@ -126,7 +118,7 @@ class Bitstream:
|
|||
elif type(other) is list and all(type(v) is bool for v in other):
|
||||
pass
|
||||
else:
|
||||
raise TypeError("Expected Bitstream, bytes, str or list[bool]")
|
||||
raise TypeError("Expected Bitstream, bytes or list[bool]")
|
||||
|
||||
if isinstance(other, Bitstream) or type(other) is list:
|
||||
if len(self) != len(other):
|
||||
|
@ -151,7 +143,7 @@ class Bitstream:
|
|||
self.extend(copy)
|
||||
return self
|
||||
|
||||
def __add__(self, value: Union[bool, List[bool], bytes, str, Bitstream]) -> Bitstream:
|
||||
def __add__(self, value: Union[bool, List[bool], bytes, Bitstream]) -> Bitstream:
|
||||
copy = self.copy()
|
||||
if type(value) is bool:
|
||||
copy.append(value)
|
||||
|
@ -159,7 +151,7 @@ class Bitstream:
|
|||
copy.extend(value)
|
||||
return copy
|
||||
|
||||
def __iadd__(self, value: Union[bool, List[bool], bytes, str, Bitstream]) -> Bitstream:
|
||||
def __iadd__(self, value: Union[bool, List[bool], bytes, Bitstream]) -> Bitstream:
|
||||
if type(value) is bool:
|
||||
self.append(value)
|
||||
else:
|
||||
|
@ -196,15 +188,15 @@ class Bitstream:
|
|||
self._bit_length += 1
|
||||
self[prev_bit_length] = value
|
||||
|
||||
def extend(self, value: Union[List[bool], bytes, str, Bitstream]):
|
||||
if type(value) in (bytes, str):
|
||||
def extend(self, value: Union[List[bool], bytes, Bitstream]):
|
||||
if type(value) is bytes:
|
||||
pass
|
||||
elif isinstance(value, Bitstream):
|
||||
pass
|
||||
elif type(value) is list and all(type(v) is bool for v in value):
|
||||
pass
|
||||
else:
|
||||
raise TypeError("Expected list[bool], bytes, str or Bitstream")
|
||||
raise TypeError("Expected list[bool], bytes or Bitstream")
|
||||
|
||||
value_bit_length = get_bit_length(value)
|
||||
if value_bit_length == 0:
|
||||
|
@ -229,10 +221,10 @@ class Bitstream:
|
|||
|
||||
def read(
|
||||
self,
|
||||
type: Type[Union[bool, bytes, str, Bitstream, int]],
|
||||
type: Type[Union[bool, bytes, Bitstream, int]],
|
||||
bit_length: int,
|
||||
bit_index: Optional[int] = None
|
||||
) -> Union[List[bool], bytes, str, Bitstream, int]:
|
||||
) -> Union[List[bool], bytes, Bitstream, int]:
|
||||
start = self._offset if bit_index is None else bit_index
|
||||
|
||||
if (start + bit_length) > self._bit_length or (start < 0):
|
||||
|
@ -251,14 +243,12 @@ class Bitstream:
|
|||
return bitstream
|
||||
elif type is bytes:
|
||||
return bytes(bitstream)
|
||||
elif type is str:
|
||||
return str(bitstream)
|
||||
elif type is int:
|
||||
return int(bitstream)
|
||||
|
||||
raise TypeError("Invalid data type")
|
||||
|
||||
def write(self, value: Union[bool, List[bool], bytes, str, Bitstream], bit_index: Optional[int] = None):
|
||||
def write(self, value: Union[bool, List[bool], bytes, Bitstream], bit_index: Optional[int] = None):
|
||||
start = self._offset if bit_index is None else bit_index
|
||||
|
||||
if start < 0:
|
||||
|
|
|
@ -38,11 +38,11 @@ def test_from_int(value: int, length: int, expected: int):
|
|||
[True, False, False, True, True],
|
||||
[[True], [False, True], [False, False]],
|
||||
[b"a", b"bc"],
|
||||
["aa", "b"],
|
||||
[b"aa", b"b"],
|
||||
[Bitstream.from_bytes(b"A"), Bitstream.from_int(0b1011, 4), Bitstream.from_bytes(b"B")],
|
||||
[True, b"A", "B", [False, True], Bitstream.from_int(0b10110, 5), False],
|
||||
[True, b"A", b"B", [False, True], Bitstream.from_int(0b10110, 5), False],
|
||||
])
|
||||
def test_init(values: List[Union[bool, List[bool], bytes, str, Bitstream]]):
|
||||
def test_init(values: List[Union[bool, List[bool], bytes, Bitstream]]):
|
||||
bitstream = Bitstream(*values)
|
||||
|
||||
bitstream.seek(0)
|
||||
|
@ -79,10 +79,9 @@ def test_getitem_index_error():
|
|||
True,
|
||||
[True, True],
|
||||
b"A",
|
||||
"B",
|
||||
Bitstream(True, b"C"),
|
||||
])
|
||||
def test_setitem(value: Union[bool, List[bool], bytes, str, Bitstream]):
|
||||
def test_setitem(value: Union[bool, List[bool], bytes, Bitstream]):
|
||||
bit_length = get_bit_length(value)
|
||||
read_type = bool if type(value) is list else type(value)
|
||||
for length in range(10):
|
||||
|
@ -111,13 +110,8 @@ def test_setitem(value: Union[bool, List[bool], bytes, str, Bitstream]):
|
|||
),
|
||||
(
|
||||
b"A",
|
||||
[b"A", "A", Bitstream.from_bytes(b"A"), Bitstream.from_bytes(b"A").bits],
|
||||
[[], [False], b"B", "B", Bitstream.from_bytes(b"B"), Bitstream.from_bytes(b"B").bits]
|
||||
),
|
||||
(
|
||||
"B",
|
||||
[b"B", "B", Bitstream.from_bytes(b"B"), Bitstream.from_bytes(b"B").bits],
|
||||
[[], [False], b"A", "A", Bitstream.from_bytes(b"A"), Bitstream.from_bytes(b"A").bits]
|
||||
[b"A", b"A", Bitstream.from_bytes(b"A"), Bitstream.from_bytes(b"A").bits],
|
||||
[[], [False], b"B", b"B", Bitstream.from_bytes(b"B"), Bitstream.from_bytes(b"B").bits]
|
||||
),
|
||||
(
|
||||
b"ABC",
|
||||
|
@ -126,9 +120,9 @@ def test_setitem(value: Union[bool, List[bool], bytes, str, Bitstream]):
|
|||
),
|
||||
])
|
||||
def test_eq(
|
||||
value: Union[bool, List[bool], bytes, str, Bitstream],
|
||||
equals: List[Union[bool, List[bool], bytes, str, Bitstream]],
|
||||
not_equals: List[Union[bool, List[bool], bytes, str, Bitstream]]):
|
||||
value: Union[bool, List[bool], bytes, Bitstream],
|
||||
equals: List[Union[bool, List[bool], bytes, Bitstream]],
|
||||
not_equals: List[Union[bool, List[bool], bytes, Bitstream]]):
|
||||
bitstream = Bitstream(value)
|
||||
for equal in equals:
|
||||
assert bitstream == equal
|
||||
|
@ -143,9 +137,9 @@ def test_eq(
|
|||
([False], 3, [False, False, False]),
|
||||
])
|
||||
def test_mul(
|
||||
values: Union[bool, List[bool], bytes, str, Bitstream],
|
||||
values: Union[bool, List[bool], bytes, Bitstream],
|
||||
number: int,
|
||||
expected: Union[List[bool], bytes, str, Bitstream]):
|
||||
expected: Union[List[bool], bytes, Bitstream]):
|
||||
bitstream = Bitstream(values)
|
||||
|
||||
assert (bitstream * number) == expected
|
||||
|
@ -161,9 +155,9 @@ def test_mul(
|
|||
(b"A", b"B", b"AB"),
|
||||
])
|
||||
def test_add(
|
||||
init_values: Union[bool, List[bool], bytes, str, Bitstream],
|
||||
add_value: Union[bool, List[bool], bytes, str, Bitstream],
|
||||
expected: Union[List[bool], bytes, str, Bitstream]):
|
||||
init_values: Union[bool, List[bool], bytes, Bitstream],
|
||||
add_value: Union[bool, List[bool], bytes, Bitstream],
|
||||
expected: Union[List[bool], bytes, Bitstream]):
|
||||
bitstream = Bitstream(init_values)
|
||||
|
||||
assert (bitstream + add_value) == expected
|
||||
|
@ -179,7 +173,7 @@ def test_add(
|
|||
[b"A"],
|
||||
[True, b"B"],
|
||||
])
|
||||
def test_iter(init_values: List[Union[bool, List[bool], bytes, str, Bitstream]]):
|
||||
def test_iter(init_values: List[Union[bool, List[bool], bytes, Bitstream]]):
|
||||
bitstream = Bitstream(*init_values)
|
||||
assert len(list(bitstream)) == len(bitstream)
|
||||
for index, bit in enumerate(bitstream):
|
||||
|
@ -194,7 +188,7 @@ def test_iter(init_values: List[Union[bool, List[bool], bytes, str, Bitstream]])
|
|||
([[False] * 7, True], b"\x01"),
|
||||
])
|
||||
def test_bytes(
|
||||
init_values: List[Union[bool, List[bool], bytes, str, Bitstream]],
|
||||
init_values: List[Union[bool, List[bool], bytes, Bitstream]],
|
||||
expected: bytes):
|
||||
assert bytes(Bitstream(*init_values)) == expected
|
||||
|
||||
|
@ -207,7 +201,7 @@ def test_bytes(
|
|||
([False] * 7 + [True], [False] * 7 + [True]),
|
||||
])
|
||||
def test_bits(
|
||||
init_values: List[Union[bool, List[bool], bytes, str, Bitstream]],
|
||||
init_values: List[Union[bool, List[bool], bytes, Bitstream]],
|
||||
expected: List[bool]):
|
||||
assert Bitstream(*init_values).bits == expected
|
||||
|
||||
|
@ -236,10 +230,10 @@ def test_append(lst: List[bool]):
|
|||
@pytest.mark.parametrize("extend_value", [
|
||||
[True, False, True],
|
||||
b"A",
|
||||
"B",
|
||||
b"BB",
|
||||
Bitstream(True, b"C"),
|
||||
])
|
||||
def test_extend(extend_value: Union[List[bool], bytes, str, Bitstream]):
|
||||
def test_extend(extend_value: Union[List[bool], bytes, Bitstream]):
|
||||
for i in range(10):
|
||||
bitstream = Bitstream([False] * i)
|
||||
bitstream.extend(extend_value)
|
||||
|
@ -304,11 +298,6 @@ def test_read():
|
|||
bytes, 8, 1,
|
||||
b"A"
|
||||
),
|
||||
(
|
||||
Bitstream(True, b"A", False, True, False),
|
||||
str, 8, 1,
|
||||
"A"
|
||||
),
|
||||
(
|
||||
Bitstream(True, b"A", False, True, False),
|
||||
Bitstream, 1 + 8 + 3, 0,
|
||||
|
@ -322,10 +311,10 @@ def test_read():
|
|||
])
|
||||
def test_read_index(
|
||||
bitstream: Bitstream,
|
||||
read_type: Type[Union[bool, bytes, str, Bitstream, int]],
|
||||
read_type: Type[Union[bool, bytes, Bitstream, int]],
|
||||
bit_length: int,
|
||||
bit_index: Optional[int],
|
||||
expected: Union[List[bool], bytes, str, Bitstream, int]):
|
||||
expected: Union[List[bool], bytes, Bitstream, int]):
|
||||
value = bitstream.read(read_type, bit_length, bit_index)
|
||||
|
||||
if read_type is bool:
|
||||
|
@ -360,7 +349,7 @@ def test_read_error():
|
|||
(Bitstream(False), True, 0, 0),
|
||||
(Bitstream(False), True, None, 0),
|
||||
])
|
||||
def test_write(bitstream: Bitstream, value: Union[bool, List[bool], bytes, str, Bitstream], bit_index: Optional[int], current_offset: Optional[int]):
|
||||
def test_write(bitstream: Bitstream, value: Union[bool, List[bool], bytes, Bitstream], bit_index: Optional[int], current_offset: Optional[int]):
|
||||
if current_offset is not None:
|
||||
bitstream.seek(current_offset)
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user