Files
sampy/sampy/raknet/bitstream.py

389 lines
12 KiB
Python

from __future__ import annotations
import ctypes as c
from array import array
from typing import Tuple, Union
from ..helpers import logging
logger = logging.getLogger(__name__)
DEBUG = (logger.root.level & logging.DEBUG) == 10
if DEBUG:
import inspect
BITS_TO_BYTES = lambda x: (((x)+7)>>3)
BYTES_TO_BITS = lambda x: ((x)<<3)
debug_wrapper_depth = 0
ignored_bits = []
class Bitstream:
def __init__(self, data: bytes = b""):
self._offset = 0
self._buffer = bytearray(data) # array("B", data)
def debug_wrapper(func):
def wrapper(self, *args, **kwargs):
if not DEBUG:
return func(self, *args, **kwargs)
global debug_wrapper_depth
#global ignored_bits
ignored_bits.clear()
curframe = inspect.currentframe()
calframe = inspect.getouterframes(curframe, 2)
#print(calframe)
copy = self.clone()
header_func = "%s%s" % (func.__name__, (*args, *kwargs))
header = " {BLUE}{header_func}{SPACE_FUNC}{CYAN}{caller}{SPACE_CALLER}{GREEN}{file}{YELLOW}@L{lineno}{RESET}".format(
header_func = header_func,
caller = calframe[1].function,
file = calframe[1].filename.split("\\")[-1],
lineno = calframe[1].lineno,
CYAN = logging.color(logging.COLOR_CYAN),
BLUE = logging.color(logging.COLOR_BLUE),
GREEN = logging.color(logging.COLOR_GREEN),
YELLOW = logging.color(logging.COLOR_YELLOW),
RESET = logging.color(logging.COLOR_RESET),
SPACE_FUNC = " " * (28 - len(header_func)),
SPACE_CALLER = " " * (32 - len(calframe[1].function)),
)
prefix = "%s%s%s" % (
logging.color(logging.COLOR_RED),
">" * debug_wrapper_depth,
logging.color(logging.COLOR_RESET)
)
logger.debug(prefix + header)
debug_wrapper_depth += 1
result = func(self, *args, **kwargs)
edits = self.debug_edits(copy)
debug_wrapper_depth -= 1
logger.debug(edits)
return result
return wrapper
@property
def offset(self) -> int:
return self._offset
@property
def length(self) -> int:
return len(self._buffer) << 3
def clone(self) -> Bitstream:
copy = Bitstream(self._buffer)
copy._offset = self._offset
return copy
def _can_access_bits(self, bits_to_access: int) -> bool:
return self._offset + bits_to_access <= len(self._buffer) << 3
def add_bits(self, bit_length: int):
if bit_length <= 0:
return
self._buffer.extend(b"\x00" * BITS_TO_BYTES(bit_length))
def align_read_to_byte_boundary(self):
if DEBUG:
ignored_from = self._offset
if self._offset:
self._offset += 8 - (((self._offset - 1) & 7) + 1)
if DEBUG:
ignored_bits.append((ignored_from, self._offset))
@debug_wrapper
def write_bit(self, value: bool) -> bool:
if not self._can_access_bits(1):
return False
mask = (1 << (7 - (self._offset % 8)))
if value:
self._buffer[self._offset >> 3] |= mask
else:
self._buffer[self._offset >> 3] &= ~mask
self._offset += 1
return True
@debug_wrapper
def read_bit(self) -> Tuple[bool, bool]:
if not self._can_access_bits(1):
return False, False
mask = (1 << (7 - (self._offset % 8)))
value = self._buffer[self._offset >> 3] & mask
self._offset += 1
return True, value > 0
@debug_wrapper
def write(self, value: bytes, bit_length: int) -> bool:
if not self._can_access_bits(bit_length):
return False
if len(value) << 3 < bit_length:
return False
byte_from = self._offset >> 3
byte_to = (self._offset + bit_length + 7) >> 3
bits_written = 0
for byte_index in range(byte_from, byte_to, 1):
byte = self._buffer[byte_index]
for bit_index in range(self._offset % 8, 8, 1):
mask = 1 << (7 - bit_index)
bit = value[bits_written >> 3] & (1 << ((7 - bits_written) % 8))
if bit:
byte |= mask
else:
byte &= ~mask
self._offset += 1
bits_written += 1
if bits_written >= bit_length:
break
self._buffer[byte_index] = byte
return True
@debug_wrapper
def read_int(self, bit_length: int) -> Tuple[bool, int]:
if not self._can_access_bits(bit_length):
return False, 0
byte_from = self._offset >> 3
byte_to = (self._offset + bit_length + 7) >> 3
# Since python doesnt have a max number size,
# we can read the bits we need and store them into a number,
# then convert the number to bytes (as they are all in the same for python)
value = 0
for byte_index in range(byte_from, byte_to, 1):
for bit_index in range(self._offset % 8, 8, 1):
mask = 1 << (7 - bit_index)
bit = self._buffer[self._offset >> 3] & (1 << ((7 - self._offset) % 8))
value = (value << 1) + bit
self._offset += 1
if self._offset >= bit_length:
break
#if self._offset % 8 == 0:
#value +=
return True, value
def read(self, bit_length: int) -> Tuple[bool, bytes]:
success, value = self.read_int(bit_length)
return success, value.to_bytes((bit_length + 7) >> 3, "big") # bytes(value)
@debug_wrapper
def write_compressed(self, data: bytes, bit_length: int, unsigned: bool = True) -> bool:
current_byte = (bit_length >> 3) - 1
byte_match = 0 if unsigned else 0xFF
while current_byte > 0:
if data[current_byte] == byte_match:
self.write_bit(True)
else:
self.write_bit(False)
self.write_bits(data, (current_byte + 1) << 3, True)
return True
if (unsigned and data[current_byte] & 0xF0 == 0x00) or (not unsigned and data[current_byte] & 0xF0 == 0xF0):
self.write_bit(True)
self.write_bits(data[current_byte:], 4, True)
else:
self.write_bit(False)
self.write_bits(data[current_byte:], 8, True)
return True
@debug_wrapper
def read_compressed(self, bit_length: int, unsigned: bool = True) -> Tuple[bool, bytes]:
current_byte = (bit_length >> 3) - 1
out = bytearray(current_byte + 1)
byte_match = 0 if unsigned else 0xFF
half_byte_match = 0 if unsigned else 0xF0
while current_byte > 0:
success, bit = self.read_bit()
if not success: # Cannot read the bit (end of stream)
return False, out
if bit:
out[current_byte] = byte_match
current_byte -= 1
else:
success, data = self.read_bits((current_byte + 1) << 3, True)
out[:len(data)] = data
if not success:
return False, out
return True, out
if self.offset + 1 > self.length:
return False, out
success, bit = self.read_bit()
if not success: # Cannot read the bit (end of stream)
return False, out
if bit:
success, data = self.read_bits(4, True)
out[current_byte:((current_byte + 4) + 7) >> 3] = data
if not success:
return False, out
out[current_byte] |= half_byte_match # Maybe recheck this in BitStream.cpp@L617; We have to set the high 4 bits since these are set to 0 by ReadBits
else:
success, data = self.read_bits(8, True)
out[current_byte:current_byte + 8] = data
if not success:
return False, out
return True, out
@debug_wrapper
def write_bits(self, data: bytes, bit_length: int, align_bits_to_right: bool = True) -> bool:
if bit_length <= 0:
return False
self.add_bits(bit_length) # TODO: Check if needed
bits_used_mod8 = self.offset & 7
offset = 0
while bit_length > 0:
byte = data[offset]
if bit_length < 8 and align_bits_to_right:
byte <<= 8 - bit_length
if bits_used_mod8 == 0:
self._buffer[self.offset >> 3] = byte
else:
self._buffer[self.offset >> 3] |= byte >> bits_used_mod8
if 8 - bits_used_mod8 < 8 and 8 - bits_used_mod8 < bit_length: # Eh.. why not bits_used_mod8 != 0?
self._buffer[(self.offset >> 3) + 1] = byte << (8 - bits_used_mod8)
if bit_length >= 8:
self._offset += 8
else:
self._offset += bit_length
bit_length -= 8
offset += 1
return True
@debug_wrapper
def read_bits(self, bit_length: int, align_bits_to_right: bool = True) -> Tuple[bool, bytes]:
if bit_length <= 0:
return False, b""
if not self._can_access_bits(bit_length):
return False, b""
out = bytearray(BITS_TO_BYTES(bit_length))
read_offset_mod8 = self._offset & 7
offset = 0
while bit_length > 0:
out[offset] |= (self._buffer[self._offset >> 3] << read_offset_mod8) & 255
if read_offset_mod8 > 0 and bit_length > 8 - read_offset_mod8:
out[offset] |= self._buffer[(self.offset >> 3) + 1] >> (8 - read_offset_mod8)
bit_length -= 8
if bit_length < 0:
if align_bits_to_right:
out[offset] >>= -bit_length
self._offset += 8 + bit_length
else:
self._offset += 8
offset += 1
return True, out
@debug_wrapper
def read_aligned_bytes(self, byte_length: int) -> Tuple[bool, bytes]:
if byte_length <= 0:
return False, b""
self.align_read_to_byte_boundary()
if self.offset + (byte_length << 3) > (len(self._buffer) << 3):
return False, b""
out = self._buffer[(self.offset >> 3):(self.offset >> 3) + byte_length]
self._offset += byte_length << 3
return True, out
def pretty(self) -> str:
b = bytearray(" " + " ".join(" ".join(format(c, "08b")) for c in self._buffer) + " ", "ascii")
m = self._offset * 2 + (self._offset >> 3) - (self._offset // (max(1, len(self._buffer)) * 8))
b[m] = 124
return b.decode()
def debug_edits(self, prev_state: Bitstream) -> str:
#logger.debug(self.pretty())
bits = [list(format(c, "08b")) for c in self._buffer]
prev_bits = [list(format(c, "08b")) for c in prev_state._buffer]
bits.append([])
start = prev_state._offset
end = self._offset
edits = {start: "P", end: "R"}
for s, e in ignored_bits:
t = [x for x in edits.keys() if x <= s]
if len(t):
edits[e] = edits[t[-1]]
else:
edits[e] = "R"
edits[s] = "S"
edits = sorted(edits.items())
edits.reverse()
for k, v in edits:
bits[k // 8][k % 8:k % 8] = v
#print(bits)
bits = " ".join("".join([y + " " if y in "01" else y for y in x]).strip() for x in bits)
bits = bits.replace("R", logging.color(logging.COLOR_RESET))
bits = bits.replace("P", logging.color(logging.COLOR_PINK))
bits = bits.replace("S", logging.color(logging.COLOR_RED))
return bits
def __repr__(self) -> str:
return "<Bitstream addr:0x%012x offset:%d len:%d>" % (id(self), self._offset, len(self._buffer) << 3)