389 lines
12 KiB
Python
389 lines
12 KiB
Python
from __future__ import annotations
|
|
|
|
import ctypes as c
|
|
|
|
from array import array
|
|
|
|
from typing import Tuple, Union
|
|
|
|
from ..helpers import logging
|
|
logger = logging.getLogger(__name__)
|
|
|
|
DEBUG = (logger.root.level & logging.DEBUG) == 10
|
|
|
|
if DEBUG:
|
|
import inspect
|
|
|
|
BITS_TO_BYTES = lambda x: (((x)+7)>>3)
|
|
BYTES_TO_BITS = lambda x: ((x)<<3)
|
|
|
|
debug_wrapper_depth = 0
|
|
ignored_bits = []
|
|
|
|
class Bitstream:
|
|
def __init__(self, data: bytes = b""):
|
|
self._offset = 0
|
|
self._buffer = bytearray(data) # array("B", data)
|
|
|
|
def debug_wrapper(func):
|
|
def wrapper(self, *args, **kwargs):
|
|
if not DEBUG:
|
|
return func(self, *args, **kwargs)
|
|
|
|
global debug_wrapper_depth
|
|
#global ignored_bits
|
|
ignored_bits.clear()
|
|
|
|
curframe = inspect.currentframe()
|
|
calframe = inspect.getouterframes(curframe, 2)
|
|
|
|
#print(calframe)
|
|
|
|
copy = self.clone()
|
|
|
|
header_func = "%s%s" % (func.__name__, (*args, *kwargs))
|
|
header = " {BLUE}{header_func}{SPACE_FUNC}{CYAN}{caller}{SPACE_CALLER}{GREEN}{file}{YELLOW}@L{lineno}{RESET}".format(
|
|
header_func = header_func,
|
|
caller = calframe[1].function,
|
|
file = calframe[1].filename.split("\\")[-1],
|
|
lineno = calframe[1].lineno,
|
|
CYAN = logging.color(logging.COLOR_CYAN),
|
|
BLUE = logging.color(logging.COLOR_BLUE),
|
|
GREEN = logging.color(logging.COLOR_GREEN),
|
|
YELLOW = logging.color(logging.COLOR_YELLOW),
|
|
RESET = logging.color(logging.COLOR_RESET),
|
|
SPACE_FUNC = " " * (28 - len(header_func)),
|
|
SPACE_CALLER = " " * (32 - len(calframe[1].function)),
|
|
)
|
|
|
|
prefix = "%s%s%s" % (
|
|
logging.color(logging.COLOR_RED),
|
|
">" * debug_wrapper_depth,
|
|
logging.color(logging.COLOR_RESET)
|
|
)
|
|
logger.debug(prefix + header)
|
|
|
|
debug_wrapper_depth += 1
|
|
|
|
result = func(self, *args, **kwargs)
|
|
edits = self.debug_edits(copy)
|
|
|
|
debug_wrapper_depth -= 1
|
|
|
|
logger.debug(edits)
|
|
return result
|
|
return wrapper
|
|
|
|
@property
|
|
def offset(self) -> int:
|
|
return self._offset
|
|
|
|
@property
|
|
def length(self) -> int:
|
|
return len(self._buffer) << 3
|
|
|
|
def clone(self) -> Bitstream:
|
|
copy = Bitstream(self._buffer)
|
|
copy._offset = self._offset
|
|
return copy
|
|
|
|
def _can_access_bits(self, bits_to_access: int) -> bool:
|
|
return self._offset + bits_to_access <= len(self._buffer) << 3
|
|
|
|
def add_bits(self, bit_length: int):
|
|
if bit_length <= 0:
|
|
return
|
|
self._buffer.extend(b"\x00" * BITS_TO_BYTES(bit_length))
|
|
|
|
def align_read_to_byte_boundary(self):
|
|
if DEBUG:
|
|
ignored_from = self._offset
|
|
if self._offset:
|
|
self._offset += 8 - (((self._offset - 1) & 7) + 1)
|
|
if DEBUG:
|
|
ignored_bits.append((ignored_from, self._offset))
|
|
|
|
@debug_wrapper
|
|
def write_bit(self, value: bool) -> bool:
|
|
if not self._can_access_bits(1):
|
|
return False
|
|
|
|
mask = (1 << (7 - (self._offset % 8)))
|
|
if value:
|
|
self._buffer[self._offset >> 3] |= mask
|
|
else:
|
|
self._buffer[self._offset >> 3] &= ~mask
|
|
|
|
self._offset += 1
|
|
return True
|
|
|
|
@debug_wrapper
|
|
def read_bit(self) -> Tuple[bool, bool]:
|
|
if not self._can_access_bits(1):
|
|
return False, False
|
|
|
|
mask = (1 << (7 - (self._offset % 8)))
|
|
value = self._buffer[self._offset >> 3] & mask
|
|
|
|
self._offset += 1
|
|
return True, value > 0
|
|
|
|
@debug_wrapper
|
|
def write(self, value: bytes, bit_length: int) -> bool:
|
|
if not self._can_access_bits(bit_length):
|
|
return False
|
|
|
|
if len(value) << 3 < bit_length:
|
|
return False
|
|
|
|
byte_from = self._offset >> 3
|
|
byte_to = (self._offset + bit_length + 7) >> 3
|
|
|
|
bits_written = 0
|
|
|
|
for byte_index in range(byte_from, byte_to, 1):
|
|
byte = self._buffer[byte_index]
|
|
for bit_index in range(self._offset % 8, 8, 1):
|
|
mask = 1 << (7 - bit_index)
|
|
bit = value[bits_written >> 3] & (1 << ((7 - bits_written) % 8))
|
|
if bit:
|
|
byte |= mask
|
|
else:
|
|
byte &= ~mask
|
|
|
|
self._offset += 1
|
|
bits_written += 1
|
|
|
|
if bits_written >= bit_length:
|
|
break
|
|
|
|
self._buffer[byte_index] = byte
|
|
|
|
return True
|
|
|
|
@debug_wrapper
|
|
def read_int(self, bit_length: int) -> Tuple[bool, int]:
|
|
if not self._can_access_bits(bit_length):
|
|
return False, 0
|
|
|
|
byte_from = self._offset >> 3
|
|
byte_to = (self._offset + bit_length + 7) >> 3
|
|
|
|
# Since python doesnt have a max number size,
|
|
# we can read the bits we need and store them into a number,
|
|
# then convert the number to bytes (as they are all in the same for python)
|
|
value = 0
|
|
|
|
for byte_index in range(byte_from, byte_to, 1):
|
|
for bit_index in range(self._offset % 8, 8, 1):
|
|
mask = 1 << (7 - bit_index)
|
|
bit = self._buffer[self._offset >> 3] & (1 << ((7 - self._offset) % 8))
|
|
value = (value << 1) + bit
|
|
|
|
self._offset += 1
|
|
|
|
if self._offset >= bit_length:
|
|
break
|
|
|
|
#if self._offset % 8 == 0:
|
|
#value +=
|
|
|
|
return True, value
|
|
|
|
def read(self, bit_length: int) -> Tuple[bool, bytes]:
|
|
success, value = self.read_int(bit_length)
|
|
return success, value.to_bytes((bit_length + 7) >> 3, "big") # bytes(value)
|
|
|
|
@debug_wrapper
|
|
def write_compressed(self, data: bytes, bit_length: int, unsigned: bool = True) -> bool:
|
|
current_byte = (bit_length >> 3) - 1
|
|
byte_match = 0 if unsigned else 0xFF
|
|
|
|
while current_byte > 0:
|
|
if data[current_byte] == byte_match:
|
|
self.write_bit(True)
|
|
else:
|
|
self.write_bit(False)
|
|
self.write_bits(data, (current_byte + 1) << 3, True)
|
|
return True
|
|
|
|
if (unsigned and data[current_byte] & 0xF0 == 0x00) or (not unsigned and data[current_byte] & 0xF0 == 0xF0):
|
|
self.write_bit(True)
|
|
self.write_bits(data[current_byte:], 4, True)
|
|
else:
|
|
self.write_bit(False)
|
|
self.write_bits(data[current_byte:], 8, True)
|
|
|
|
return True
|
|
|
|
@debug_wrapper
|
|
def read_compressed(self, bit_length: int, unsigned: bool = True) -> Tuple[bool, bytes]:
|
|
current_byte = (bit_length >> 3) - 1
|
|
|
|
out = bytearray(current_byte + 1)
|
|
|
|
byte_match = 0 if unsigned else 0xFF
|
|
half_byte_match = 0 if unsigned else 0xF0
|
|
|
|
while current_byte > 0:
|
|
success, bit = self.read_bit()
|
|
if not success: # Cannot read the bit (end of stream)
|
|
return False, out
|
|
|
|
if bit:
|
|
out[current_byte] = byte_match
|
|
current_byte -= 1
|
|
else:
|
|
success, data = self.read_bits((current_byte + 1) << 3, True)
|
|
out[:len(data)] = data
|
|
if not success:
|
|
return False, out
|
|
return True, out
|
|
|
|
if self.offset + 1 > self.length:
|
|
return False, out
|
|
|
|
success, bit = self.read_bit()
|
|
if not success: # Cannot read the bit (end of stream)
|
|
return False, out
|
|
|
|
if bit:
|
|
success, data = self.read_bits(4, True)
|
|
out[current_byte:((current_byte + 4) + 7) >> 3] = data
|
|
if not success:
|
|
return False, out
|
|
|
|
out[current_byte] |= half_byte_match # Maybe recheck this in BitStream.cpp@L617; We have to set the high 4 bits since these are set to 0 by ReadBits
|
|
else:
|
|
success, data = self.read_bits(8, True)
|
|
out[current_byte:current_byte + 8] = data
|
|
if not success:
|
|
return False, out
|
|
|
|
return True, out
|
|
|
|
@debug_wrapper
|
|
def write_bits(self, data: bytes, bit_length: int, align_bits_to_right: bool = True) -> bool:
|
|
if bit_length <= 0:
|
|
return False
|
|
|
|
self.add_bits(bit_length) # TODO: Check if needed
|
|
|
|
bits_used_mod8 = self.offset & 7
|
|
|
|
offset = 0
|
|
while bit_length > 0:
|
|
byte = data[offset]
|
|
|
|
if bit_length < 8 and align_bits_to_right:
|
|
byte <<= 8 - bit_length
|
|
|
|
if bits_used_mod8 == 0:
|
|
self._buffer[self.offset >> 3] = byte
|
|
else:
|
|
self._buffer[self.offset >> 3] |= byte >> bits_used_mod8
|
|
if 8 - bits_used_mod8 < 8 and 8 - bits_used_mod8 < bit_length: # Eh.. why not bits_used_mod8 != 0?
|
|
self._buffer[(self.offset >> 3) + 1] = byte << (8 - bits_used_mod8)
|
|
|
|
if bit_length >= 8:
|
|
self._offset += 8
|
|
else:
|
|
self._offset += bit_length
|
|
|
|
bit_length -= 8
|
|
offset += 1
|
|
|
|
return True
|
|
|
|
@debug_wrapper
|
|
def read_bits(self, bit_length: int, align_bits_to_right: bool = True) -> Tuple[bool, bytes]:
|
|
if bit_length <= 0:
|
|
return False, b""
|
|
if not self._can_access_bits(bit_length):
|
|
return False, b""
|
|
|
|
out = bytearray(BITS_TO_BYTES(bit_length))
|
|
|
|
read_offset_mod8 = self._offset & 7
|
|
|
|
offset = 0
|
|
while bit_length > 0:
|
|
out[offset] |= (self._buffer[self._offset >> 3] << read_offset_mod8) & 255
|
|
if read_offset_mod8 > 0 and bit_length > 8 - read_offset_mod8:
|
|
out[offset] |= self._buffer[(self.offset >> 3) + 1] >> (8 - read_offset_mod8)
|
|
|
|
bit_length -= 8
|
|
|
|
if bit_length < 0:
|
|
if align_bits_to_right:
|
|
out[offset] >>= -bit_length
|
|
self._offset += 8 + bit_length
|
|
else:
|
|
self._offset += 8
|
|
|
|
offset += 1
|
|
|
|
return True, out
|
|
|
|
@debug_wrapper
|
|
def read_aligned_bytes(self, byte_length: int) -> Tuple[bool, bytes]:
|
|
if byte_length <= 0:
|
|
return False, b""
|
|
|
|
self.align_read_to_byte_boundary()
|
|
|
|
if self.offset + (byte_length << 3) > (len(self._buffer) << 3):
|
|
return False, b""
|
|
|
|
out = self._buffer[(self.offset >> 3):(self.offset >> 3) + byte_length]
|
|
|
|
self._offset += byte_length << 3
|
|
|
|
return True, out
|
|
|
|
def pretty(self) -> str:
|
|
b = bytearray(" " + " ".join(" ".join(format(c, "08b")) for c in self._buffer) + " ", "ascii")
|
|
m = self._offset * 2 + (self._offset >> 3) - (self._offset // (max(1, len(self._buffer)) * 8))
|
|
b[m] = 124
|
|
return b.decode()
|
|
|
|
def debug_edits(self, prev_state: Bitstream) -> str:
|
|
#logger.debug(self.pretty())
|
|
|
|
bits = [list(format(c, "08b")) for c in self._buffer]
|
|
prev_bits = [list(format(c, "08b")) for c in prev_state._buffer]
|
|
|
|
bits.append([])
|
|
|
|
start = prev_state._offset
|
|
end = self._offset
|
|
|
|
edits = {start: "P", end: "R"}
|
|
|
|
for s, e in ignored_bits:
|
|
t = [x for x in edits.keys() if x <= s]
|
|
if len(t):
|
|
edits[e] = edits[t[-1]]
|
|
else:
|
|
edits[e] = "R"
|
|
edits[s] = "S"
|
|
|
|
edits = sorted(edits.items())
|
|
edits.reverse()
|
|
for k, v in edits:
|
|
bits[k // 8][k % 8:k % 8] = v
|
|
|
|
#print(bits)
|
|
|
|
bits = " ".join("".join([y + " " if y in "01" else y for y in x]).strip() for x in bits)
|
|
|
|
bits = bits.replace("R", logging.color(logging.COLOR_RESET))
|
|
bits = bits.replace("P", logging.color(logging.COLOR_PINK))
|
|
bits = bits.replace("S", logging.color(logging.COLOR_RED))
|
|
|
|
|
|
return bits
|
|
|
|
def __repr__(self) -> str:
|
|
return "<Bitstream addr:0x%012x offset:%d len:%d>" % (id(self), self._offset, len(self._buffer) << 3)
|