Compare commits

...

9 Commits

Author SHA1 Message Date
ee04ff6e6d Added 0.3.7 obfuscation layer 2023-03-19 06:02:55 +01:00
e975a5c492 Indent error 2023-03-19 06:02:36 +01:00
eca447fd49 Query info packet 0 size strings fix 2023-03-19 05:17:44 +01:00
2422bee5fa Change default config 2023-03-19 05:14:30 +01:00
faf9ccffbc Config and logging 2023-03-19 01:43:14 +01:00
8e59db2ad6 Interactive server console 2023-03-15 07:27:33 +01:00
f1152cfb25 Server with query protocol 2023-03-15 07:12:47 +01:00
75824c306f Changed formatting 2023-03-15 06:06:58 +01:00
b3fedb8214 __main__.py with argparse 2023-03-15 06:06:13 +01:00
12 changed files with 947 additions and 184 deletions

6
.gitignore vendored
View File

@ -130,3 +130,9 @@ dmypy.json
# VSC # VSC
/.vscode /.vscode
# Ignore ini files
*.ini
# Ignore logs directory
logs/

113
sampy/__main__.py Normal file
View File

@ -0,0 +1,113 @@
import argparse
import asyncio
import logging
import textwrap
def main(args: argparse.Namespace) -> int:
from sampy.config import Config
config = Config(*args.config, logging_level=args.log)
from sampy.network.protocol import Protocol
from sampy.server import InteractiveServer
server = InteractiveServer(
Protocol,
config=config,
)
server.start()
asyncio.get_event_loop().run_forever()
return 0
if __name__ == "__main__":
class Formatter(
argparse.ArgumentDefaultsHelpFormatter, argparse.RawTextHelpFormatter
):
pass
parser = argparse.ArgumentParser(
prog="sampy",
description=textwrap.dedent(
"""
A SAMP server made in python
SAMP (or SA-MP) is a free multiplayer mod for the PC port of GTA: San Andreas.
GTA: San Andreas was developed by Rockstar North and released in 2005.
SA-MP is an unofficial multiplayer mod made by the `SA-MP.com` team released @[sa-mp.com](https://www.sa-mp.com/)
"""
),
epilog=textwrap.dedent(
"""
example:
%(prog)s --config default.ini secret.ini race.ini
This will first load default.ini configuration,
secret.ini might override the rcon- and server-password,
race.ini might change the hostname and rules.
example default.ini:
[sampy]
hostname = My SAMP Server
password =
rcon_password = VerySecretPassword
[sampy.rules]
lagcomp = off
[logging.loggers]
keys = root
[logging.handlers]
keys = console, file
[logging.formatters]
keys = simple, color
[logging.logger_root]
level = DEBUG
handlers = console, file
[logging.handler_console]
class = StreamHandler
formatter = color
args = (sys.stdout,)
[logging.handler_file]
class = logging.handlers.TimedRotatingFileHandler
formatter = simple
args = ("logs/sampy.log", "d", 1, 7,) # Note that the logs folder has to exist
[logging.formatter_simple]
format = %%(asctime)s.%%(msecs)03d | %%(levelname)-8s | %%(message)s
datefmt = %%Y-%%m-%%d %%H:%%M:%%S
[logging.formatter_color]
class = sampy.config.ColorFormatter
format = §6%%(asctime)s.%%(msecs)03d §1| %%(levelname)-8s §1|§r %%(message)s
datefmt = %%Y-%%m-%%d %%H:%%M:%%S
"""
),
formatter_class=Formatter,
)
parser.add_argument(
"--config",
type=str,
nargs="+",
default=[],
help="Config filenames",
)
parser.add_argument(
"--log",
type=str,
nargs="?",
help="Global logging level (Overrides any config)",
choices=logging._nameToLevel.keys(),
)
args = parser.parse_args()
raise SystemExit(main(args))

15
sampy/client/player.py Normal file
View File

@ -0,0 +1,15 @@
from ctypes import c_ubyte
class Player:
id: c_ubyte
username: str
score: int
ping: int
health: float
armor: float
# position: Vector3 # TODO
rotation: float
def __init__(self, username: str):
self.username = username

182
sampy/config.py Normal file
View File

@ -0,0 +1,182 @@
from __future__ import annotations
import logging
import logging.config
from configparser import ConfigParser
from typing import Any, Dict, Mapping, Optional, Union
class Config(ConfigParser):
DEFAULTS: Mapping[str, Mapping[str, Union[str, int]]] = {
"sampy": {
"host": "0.0.0.0",
"port": 7777,
"hostname": "sam.py",
"password": "",
"rcon_password": "changeme",
"max_players": 50,
"mode": "",
"language": "English",
},
"sampy.rules": {
"weburl": "https://git.osufx.com/Sunpy/sampy",
},
"logging.loggers": {
"keys": "root",
},
"logging.handlers": {
"keys": "console",
},
"logging.formatters": {
"keys": "simple",
},
"logging.logger_root": {
"level": "INFO",
"handlers": "console",
},
"logging.handler_console": {
"class": "StreamHandler",
"formatter": "simple",
"args": "(sys.stdout,)",
},
"logging.formatter_simple": {
"format": "%(levelname)s - %(message)s",
"datefmt": "%Y-%m-%d %H:%M:%S",
},
}
def __init__(
self,
*filenames,
dictionary: Mapping[str, Mapping[str, Union[str, int]]] = {},
logging_level: Optional[int] = None,
):
super().__init__(interpolation=None)
if logging_level is not None:
logging.root.setLevel(logging_level)
self.read_dict(self.DEFAULTS)
self.read_dict(dictionary)
found = self.read(filenames, encoding="utf-8-sig")
missing = set(filenames) - set(found)
if len(missing):
logging.warn("Config files not found: %s" % missing)
logging_config = self.get_logging_config()
if logging_config:
logging.config.fileConfig(logging_config)
if logging_level is not None:
logging.root.setLevel(logging_level)
logging.debug("Logging module has been configured")
else:
logging.warn("Logging module was not configured")
def get_logging_config(self) -> ConfigParser:
config = ConfigParser(interpolation=None)
for section in self.sections():
if not section.startswith("logging."):
continue
config[section.replace("logging.", "")] = self[section]
return config
@property
def host(self) -> str:
return self.get("sampy", "host")
@property
def port(self) -> int:
return self.getint("sampy", "port")
@property
def hostname(self) -> str:
return self.get("sampy", "hostname")
@property
def password(self) -> str:
return self.get("sampy", "password")
@property
def rcon_password(self) -> str:
return self.get("sampy", "rcon_password")
@property
def max_players(self) -> int:
return self.getint("sampy", "max_players")
@property
def mode(self) -> str:
return self.get("sampy", "mode")
@property
def language(self) -> str:
return self.get("sampy", "language")
@property
def rules(self) -> Dict[str, str]:
return self["sampy.rules"]
class LogRecordProxy:
def __init__(self, record: logging.LogRecord):
self._record = record
def __getattribute__(self, name: str) -> Any:
attr = {
k: v
for k, v in object.__getattribute__(self, "__dict__").items()
if k != "_record"
}
if name in attr:
return attr[name]
elif name == "__dict__": # Combine dicts
return {**object.__getattribute__(self, "_record").__dict__, **attr}
return object.__getattribute__(self, "_record").__getattribute__(name)
class ColorFormatter(logging.Formatter):
COLORS: Dict[str, str] = {
"0": "30", # Black
"1": "34", # Blue
"2": "32", # Green
"3": "36", # Cyan
"4": "31", # Red,
"5": "35", # Purple/Magenta
"6": "33", # Yellow/Gold
"7": "37", # White/Light Gray
"8": "30;1", # Dark Gray
"9": "34;1", # Light Blue
"a": "32;1", # Light Green
"b": "36;1", # Light Cyan
"c": "31;1", # Light Red
"d": "35;1", # Light Purple/Magenta
"e": "33;1", # Yellow
"f": "37;1", # White
"r": "0", # Reset
"l": "1", # Bold
"n": "4", # Underline
}
LEVEL_COLOR = {
logging.CRITICAL: "31",
logging.ERROR: "31",
logging.WARNING: "33",
logging.INFO: "32",
logging.DEBUG: "35",
logging.NOTSET: "37",
}
def format(self, record: logging.LogRecord) -> str:
record = LogRecordProxy(record)
level_color = ColorFormatter.LEVEL_COLOR.get(record.levelno, None)
if level_color is not None:
record.levelname = "\x1b[%sm%s\x1b[0m" % (level_color, record.levelname)
message = super().format(record)
for k, v in ColorFormatter.COLORS.items():
message = message.replace("§%s" % k, "\x1b[%sm" % v)
return message

View File

75
sampy/network/game.py Normal file
View File

@ -0,0 +1,75 @@
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Tuple
if TYPE_CHECKING:
from sampy.server import Server
class Game:
@staticmethod
async def on_packet(server: Server, packet: bytes, addr: Tuple[str, int]) -> bool:
packet = Obfuscation.deobfuscate(server, packet)
# TODO
return False
class Obfuscation:
# Found @ addr 0x004C88E0 in windows server executable
LOOKUP_TABLE = bytes.fromhex(
"""
b46207e59daf63dde3d0ccfedcdb6b2e6a40ab47c9d153d52091a50e4adf1889
fd6f2512b713770065366d49ec572aa9115ffa7895a4bd1ed97944cdde81eb09
3ef6eeda7fa31aa72da6adc14693d21b9caad74e4b4d4cf3b834c0ca88f494cb
04393082d673b0bf2201416e482ca875b10aae9f278010cef02928850d05f735
bbbc1506f56071031fea5a33928de7905be9cf9ed35ded311c0b5216510f86c5
689b210c8b4287ff4fbec8e8c7d47ae0552f8a8eba9837e4b238a1b632833a7b
843c61fb8c143d433b1dc3a296b3f8c4f2262bd87cfc232466ef6964505459f1
a074acc67db5e6e2c27e67175ee1b93f6c700899455676f99a9719725c028f58
"""
)
# I think this used to be zlib compression, but has been swapped out with obfuscation instead
@staticmethod
def deobfuscate(server: Server, packet: bytes) -> bytes:
checksum, data = packet[0], bytearray(packet[1:])
data = Obfuscation.xor_every_other_byte(
Obfuscation.get_port_xor_key(server), data
)
data = bytes(Obfuscation.LOOKUP_TABLE[b] for b in packet)
if checksum != Obfuscation.calc_checksum(data):
logging.error("Checksum failed!")
raise Exception("Checksum fail")
return bytes(data)
@staticmethod
def obfuscate(server: Server, packet: bytes) -> bytes:
data = bytes(Obfuscation.LOOKUP_TABLE.index(b) for b in packet)
data = Obfuscation.xor_every_other_byte(
Obfuscation.get_port_xor_key(server), data
)
checksum = Obfuscation.calc_checksum(data)
return bytes([checksum]) + data
@staticmethod
def xor_every_other_byte(xor: int, packet: bytearray) -> bytearray:
for i in range(1, len(packet), 2):
packet[i] ^= xor
return packet
@staticmethod
def calc_checksum(packet: bytearray) -> int:
checksum = 0
for byte in packet:
checksum ^= byte & 0xAA
return checksum
@staticmethod
def get_port_xor_key(server: Server) -> int:
return (server.config.port ^ 0xCCCC) & 0xFF

26
sampy/network/protocol.py Normal file
View File

@ -0,0 +1,26 @@
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Tuple
from sampy.network.game import Game
from sampy.network.query import Query
if TYPE_CHECKING:
from sampy.server import Server
class Protocol:
VERSION = "0.3.7"
@staticmethod
async def on_packet(server: Server, packet: bytes, addr: Tuple[str, int]) -> bool:
logging.debug("on_packet")
if await Query.on_packet(server, packet, addr):
return True
if await Game.on_packet(server, packet, addr):
return True
logging.debug("Unhandled: %r" % packet)
return False

185
sampy/network/query.py Normal file
View File

@ -0,0 +1,185 @@
from __future__ import annotations
import functools
import logging
import struct
from typing import TYPE_CHECKING, Callable, Dict, Tuple
if TYPE_CHECKING:
from sampy.server import Server
# Holds a dict of handlers for different packet IDs. Since every query packet id is a single byte, we use an int here.
HANDLERS: Dict[int, Callable[[Server, bytes, Tuple[str, int]], bool]] = {}
# Decorator that adds the function to the HANDLERS dict.
def handler(packet_id: bytes):
if len(packet_id) > 1:
raise Exception("Query opcode length cannot be bigger then 1")
def outer(func):
@functools.wraps(func)
def inner(
server: Server, packet: bytes, addr: Tuple[str, int], *args, **kwargs
):
return func(server, packet, addr, *args, **kwargs)
HANDLERS[packet_id[0]] = func
logging.debug("Added Query handler: %s -> %s" % (packet_id, func))
return inner
return outer
class Query:
"""Query handler
Reference: https://team.sa-mp.com/wiki/Query_Mechanism.html (Unable to archive due to https://sa-mp.com/robots.txt disallowing ia_archiver)
The Query protocol is *mostly* used for the samp server browser and other systems that queries for server info.
The exception to this is would be the "player_details"(d) packet on version 0.2.x and below where an ingame player pressing TAB would use this packet.
Structure:
- "SAMP" header (4 bytes)
- Server's ipv4 (4 bytes)
- Server's port (2 bytes)
- Packet type (1 byte)
- Packet data
Note that the server and client will both use the same first 4 parts when communicating.
Not all packets have packet data, and most packets doesn't have data when the client sends it as its a "request".
The server will always have packet data attached.
List of packets:
- i: Information packet. This includes hostname, players (online/max), mode, language and whether password is required.
- r: Rules packet. This is a list of "rules". The name "rules" is subjective in this case, as most are general optional information.
- c: Client list packet. This is a list of players and scores. Players being just their username and score a number.
- d: Detailed player list packet. Extends the client list packet with player id and ping.
- p: Ping packet. A client uses this packet with 4 random bytes and measures how long it takes before it gets the same packet back.
- x: Remote console packet. A client can send and receive anything on this packet. Usually used for remote console.
Additional findings:
- On info packet, strings can not be of size 0. This will make parsing of the rest of the packet fail.
- This is strange due to how we are sending the length of the string first, which should allow this...
- Fix: If string size is 0, we replace the string with a single space (Considering using a NULL byte, but unsure if this could cause other issues)
"""
HEADER_LENGTH = 11
@staticmethod
async def on_packet(server: Server, packet: bytes, addr: Tuple[str, int]) -> bool:
if len(packet) < 11: # Packet is too small
return False
magic, _ip, _port, packet_id = struct.unpack(
b"<4sIHB", packet[: Query.HEADER_LENGTH]
) # Unpack packet
if magic != b"SAMP": # Validate magic
return False
return HANDLERS.get(packet_id, lambda *_: False)(server, packet, addr)
@handler(b"i")
@staticmethod
def info(server: Server, packet: bytes, addr: Tuple[str, int]) -> bool:
packet = packet[: Query.HEADER_LENGTH] # Discard additional data if passed
hostname = server.config.hostname.encode()
mode = server.config.mode.encode()
language = server.config.language.encode()
if len(hostname) == 0:
hostname = b" "
if len(mode) == 0:
mode = b" "
if len(language) == 0:
language = b" "
packet += struct.pack(
b"<?HHI%dsI%dsI%ds" % (len(hostname), len(mode), len(language)),
len(server.config.password) != 0,
len(server.players),
server.config.max_players,
len(hostname),
hostname,
len(mode),
mode,
len(language),
language,
)
server.sendto(packet, addr)
return True
@handler(b"r")
@staticmethod
def rules(server: Server, packet: bytes, addr: Tuple[str, int]) -> bool:
packet = packet[: Query.HEADER_LENGTH] # Discard additional data if passed
rules = server.config.rules
rules["version"] = server.protocol.VERSION # Add game version (read-only)
rules["worldtime"] = ""
rules["weather"] = "10"
packet += struct.pack(b"<H", len(rules))
for item in (item.encode() for pair in rules.items() for item in pair):
packet += struct.pack(b"<B%ds" % len(item), len(item), item)
server.sendto(packet, addr)
return True
@handler(b"c")
@staticmethod
def client_list(server: Server, packet: bytes, addr: Tuple[str, int]) -> bool:
packet = packet[: Query.HEADER_LENGTH] # Discard additional data if passed
players = server.players
packet += struct.pack(b"<H", len(players))
for player in players:
username = player.username.encode()
packet += struct.pack(b"<B%dsI", len(username), username, player.score)
server.sendto(packet, addr)
return True
@handler(b"d")
@staticmethod
def player_details(server: Server, packet: bytes, addr: Tuple[str, int]) -> bool:
packet = packet[: Query.HEADER_LENGTH] # Discard additional data if passed
players = server.players
packet += struct.pack(b"<H", len(players))
for player in players:
username = player.username.encode()
packet += struct.pack(
b"<BB%dsII",
player.id,
len(username),
username,
player.score,
player.ping,
)
server.sendto(packet, addr)
return True
@handler(b"p")
@staticmethod
def ping(server: Server, packet: bytes, addr: Tuple[str, int]) -> bool:
if (
len(packet) < Query.HEADER_LENGTH + 4
): # Packet is too small (Missing random)
return False
packet = packet[
: Query.HEADER_LENGTH + 4
] # Discard additional data if passed (+4 to include random)
server.sendto(packet, addr)
return True
@handler(b"x")
@staticmethod
def rcon(server: Server, packet: bytes, addr: Tuple[str, int]) -> bool:
return False # TODO

View File

@ -1,6 +1,6 @@
from __future__ import annotations from __future__ import annotations
from typing import Any, List, Union, Optional, Literal, Type
from typing import Any, List, Literal, Optional, Type, Union
# for seek() # for seek()
SEEK_SET = 0 SEEK_SET = 0
@ -30,8 +30,10 @@ def bits_to_int(bits: List[bool], little_endian: bool = False) -> int:
return num return num
def int_to_bits(number: int, bit_length: int, little_endian: bool = False) -> List[bool]: def int_to_bits(
bits = [bool(number & (0b1 << i)) for i in range(bit_length-1, -1, -1)] number: int, bit_length: int, little_endian: bool = False
) -> List[bool]:
bits = [bool(number & (0b1 << i)) for i in range(bit_length - 1, -1, -1)]
return bits[::-1] if little_endian else bits return bits[::-1] if little_endian else bits
@ -56,7 +58,9 @@ class Bitstream:
@staticmethod @staticmethod
def from_int(value: int, bit_length: int) -> Bitstream: def from_int(value: int, bit_length: int) -> Bitstream:
return Bitstream([bool(value & (0b1 << i)) for i in range(bit_length-1, -1, -1)]) return Bitstream(
[bool(value & (0b1 << i)) for i in range(bit_length - 1, -1, -1)]
)
def __len__(self) -> int: def __len__(self) -> int:
return self._bit_length return self._bit_length
@ -82,7 +86,9 @@ class Bitstream:
mask = 1 << (7 - (bit_index % 8)) mask = 1 << (7 - (bit_index % 8))
return bool(self._bytearray[bit_index >> 3] & mask) return bool(self._bytearray[bit_index >> 3] & mask)
def __setitem__(self, bit_index: int, value: Union[bool, List[bool], bytes, Bitstream]): def __setitem__(
self, bit_index: int, value: Union[bool, List[bool], bytes, Bitstream]
):
if bit_index >= self._bit_length: if bit_index >= self._bit_length:
raise IndexError("bit index out of range") raise IndexError("bit index out of range")
@ -101,7 +107,9 @@ class Bitstream:
self._bytearray[bit_index >> 3] |= mask self._bytearray[bit_index >> 3] |= mask
else: else:
self._bytearray[bit_index >> 3] &= ~mask self._bytearray[bit_index >> 3] &= ~mask
elif isinstance(value, Bitstream) or (type(value) is list and all(type(v) is bool for v in value)): elif isinstance(value, Bitstream) or (
type(value) is list and all(type(v) is bool for v in value)
):
for i in range(value_bit_length): for i in range(value_bit_length):
self[bit_index + i] = value[i] self[bit_index + i] = value[i]
else: else:
@ -164,7 +172,10 @@ class Bitstream:
def __repr__(self) -> str: def __repr__(self) -> str:
return "<Bitstream addr:0x%012x offset:%d len:%d data:'%s'>" % ( return "<Bitstream addr:0x%012x offset:%d len:%d data:'%s'>" % (
id(self), self._offset, len(self), bytes(self).hex(" ") id(self),
self._offset,
len(self),
bytes(self).hex(" "),
) )
@property @property
@ -223,7 +234,7 @@ class Bitstream:
self, self,
type: Type[Union[bool, bytes, Bitstream, int]], type: Type[Union[bool, bytes, Bitstream, int]],
bit_length: int, bit_length: int,
bit_index: Optional[int] = None bit_index: Optional[int] = None,
) -> Union[List[bool], bytes, Bitstream, int]: ) -> Union[List[bool], bytes, Bitstream, int]:
start = self._offset if bit_index is None else bit_index start = self._offset if bit_index is None else bit_index
@ -248,7 +259,11 @@ class Bitstream:
raise TypeError("Invalid data type") raise TypeError("Invalid data type")
def write(self, value: Union[bool, List[bool], bytes, Bitstream], bit_index: Optional[int] = None): def write(
self,
value: Union[bool, List[bool], bytes, Bitstream],
bit_index: Optional[int] = None,
):
start = self._offset if bit_index is None else bit_index start = self._offset if bit_index is None else bit_index
if start < 0: if start < 0:

91
sampy/server.py Normal file
View File

@ -0,0 +1,91 @@
from __future__ import annotations
import asyncio
import logging
from typing import TYPE_CHECKING, Optional, Tuple, Type
from sampy.client.player import Player
from sampy.config import Config
if TYPE_CHECKING:
from sampy.network.protocol import Protocol
class UDPProtocol:
transport: asyncio.transports.DatagramTransport
def __init__(self, protocol: Protocol, local_addr: Tuple[str, int]):
self.protocol = protocol
self.local_addr = local_addr
def start(self):
loop = asyncio.get_event_loop()
logging.debug("Creating datagram endpoint")
connect = loop.create_datagram_endpoint(
lambda: self,
local_addr=self.local_addr,
)
loop.run_until_complete(connect)
def stop(self): # TODO: Shutdown code
if self.transport is None:
raise Exception("Cannot stop a server that hasn't been started")
logging.debug("Shutting down")
self.transport.close()
def connection_made(self, transport: asyncio.transports.DatagramTransport):
logging.debug("UDP Protocol: connection_made")
self.transport = transport
def connection_lost(self, exc: Exception | None):
logging.debug("UDP Protocol: connection_lost")
def datagram_received(self, data: bytes, addr: Tuple[str, int]):
raise NotImplementedError
def sendto(self, data: bytes | bytearray | memoryview, addr: Tuple[str, int]):
self.transport.sendto(data, addr)
class Server(UDPProtocol):
config: Config
def __init__(self, protocol: Type[Protocol], config: Optional[Config] = None):
if config is None:
config = Config()
logging.warn("Server was initialized with default config")
super().__init__(
protocol=protocol(),
local_addr=(
config.get("sampy", "host"),
config.getint("sampy", "port"),
),
)
self.config = config
def datagram_received(self, data: bytes, addr: Tuple[str, int]):
loop = asyncio.get_event_loop()
loop.create_task(self.protocol.on_packet(self, data, addr))
@property
def players(self) -> list[Player]: # TODO
return []
class InteractiveServer(Server):
def __init__(self, protocol: Type[Protocol], config: Optional[Config] = None):
super().__init__(protocol=protocol, config=config)
loop = asyncio.get_event_loop()
loop.create_task(self.run_input_loop())
async def run_input_loop(self):
loop = asyncio.get_event_loop()
while True:
command = await loop.run_in_executor(None, input)
if command in ("quit", "exit", "stop"):
self.stop()
loop.stop()
return

View File

@ -1,47 +1,62 @@
from typing import List, Union, Type, Optional from typing import List, Optional, Type, Union
import pytest import pytest
from sampy.raknet.bitstream import Bitstream, get_bit_length from sampy.raknet.bitstream import Bitstream, get_bit_length
@pytest.mark.parametrize("value", [ @pytest.mark.parametrize(
"value",
[
b"", b"",
b"A", b"A",
b"AB", b"AB",
b"\xFF\x00\xAA\x55", b"\xFF\x00\xAA\x55",
]) ],
)
def test_from_bytes(value: bytes): def test_from_bytes(value: bytes):
bitstream = Bitstream.from_bytes(value) bitstream = Bitstream.from_bytes(value)
assert len(bitstream) == (len(value) << 3) assert len(bitstream) == (len(value) << 3)
assert bytes(bitstream) == value assert bytes(bitstream) == value
@pytest.mark.parametrize("value,length,expected", [ @pytest.mark.parametrize(
"value,length,expected",
[
(0b0, 0, None), (0b0, 0, None),
(0b1, 1, None), (0b1, 1, None),
(0b0, 1, None), (0b0, 1, None),
(0b1011, 4, None), (0b1011, 4, None),
(0b10111011, 8, None), (0b10111011, 8, None),
(0b10111011, 7, 0b0111011), # -> 0b0111011 (0b10111011, 7, 0b0111011), # -> 0b0111011
]) ],
)
def test_from_int(value: int, length: int, expected: int): def test_from_int(value: int, length: int, expected: int):
bitstream = Bitstream.from_int(value, length) bitstream = Bitstream.from_int(value, length)
assert len(bitstream) == length assert len(bitstream) == length
assert len(bytes(bitstream)) == ((length + 7) >> 3) assert len(bytes(bitstream)) == ((length + 7) >> 3)
assert bitstream.bits == [bool(value & (1 << i)) for i in range(length-1, -1, -1)] assert bitstream.bits == [bool(value & (1 << i)) for i in range(length - 1, -1, -1)]
if expected is not None: if expected is not None:
expected_bitstream = Bitstream.from_int(expected, length) expected_bitstream = Bitstream.from_int(expected, length)
assert bitstream.bits == expected_bitstream.bits assert bitstream.bits == expected_bitstream.bits
assert bytes(bitstream) == bytes(expected_bitstream) assert bytes(bitstream) == bytes(expected_bitstream)
@pytest.mark.parametrize("values", [ @pytest.mark.parametrize(
"values",
[
[True, False, False, True, True], [True, False, False, True, True],
[[True], [False, True], [False, False]], [[True], [False, True], [False, False]],
[b"a", b"bc"], [b"a", b"bc"],
[b"aa", b"b"], [b"aa", b"b"],
[Bitstream.from_bytes(b"A"), Bitstream.from_int(0b1011, 4), Bitstream.from_bytes(b"B")], [
Bitstream.from_bytes(b"A"),
Bitstream.from_int(0b1011, 4),
Bitstream.from_bytes(b"B"),
],
[True, b"A", b"B", [False, True], Bitstream.from_int(0b10110, 5), False], [True, b"A", b"B", [False, True], Bitstream.from_int(0b10110, 5), False],
]) ],
)
def test_init(values: List[Union[bool, List[bool], bytes, Bitstream]]): def test_init(values: List[Union[bool, List[bool], bytes, Bitstream]]):
bitstream = Bitstream(*values) bitstream = Bitstream(*values)
@ -75,12 +90,15 @@ def test_getitem_index_error():
bitstream[-2] # Inverse IndexError bitstream[-2] # Inverse IndexError
@pytest.mark.parametrize("value", [ @pytest.mark.parametrize(
"value",
[
True, True,
[True, True], [True, True],
b"A", b"A",
Bitstream(True, b"C"), Bitstream(True, b"C"),
]) ],
)
def test_setitem(value: Union[bool, List[bool], bytes, Bitstream]): def test_setitem(value: Union[bool, List[bool], bytes, Bitstream]):
bit_length = get_bit_length(value) bit_length = get_bit_length(value)
read_type = bool if type(value) is list else type(value) read_type = bool if type(value) is list else type(value)
@ -88,41 +106,38 @@ def test_setitem(value: Union[bool, List[bool], bytes, Bitstream]):
for i in range(length - bit_length + 1): for i in range(length - bit_length + 1):
bitstream = Bitstream([False] * length) bitstream = Bitstream([False] * length)
bitstream[i] = value bitstream[i] = value
assert bitstream.read(read_type, bit_length, i) == (value if type(value) is not bool else [value]) assert bitstream.read(read_type, bit_length, i) == (
value if type(value) is not bool else [value]
)
assert len(bitstream) == length assert len(bitstream) == length
@pytest.mark.parametrize("value,equals,not_equals", [ @pytest.mark.parametrize(
( "value,equals,not_equals",
[], [
[[]], ([], [[]], [[True], [False]]),
[[True], [False]] (True, [[True]], [[False], [True, False], b"A"]),
), ([True, False], [[True, False]], [[False], [False, False], b"A"]),
(
True,
[[True]],
[[False], [True, False], b"A"]
),
(
[True, False],
[[True, False]],
[[False], [False, False], b"A"]
),
( (
b"A", b"A",
[b"A", b"A", Bitstream.from_bytes(b"A"), Bitstream.from_bytes(b"A").bits], [b"A", b"A", Bitstream.from_bytes(b"A"), Bitstream.from_bytes(b"A").bits],
[[], [False], b"B", b"B", Bitstream.from_bytes(b"B"), Bitstream.from_bytes(b"B").bits] [
[],
[False],
b"B",
b"B",
Bitstream.from_bytes(b"B"),
Bitstream.from_bytes(b"B").bits,
],
), ),
( (b"ABC", [b"ABC"], [[], [False], b"A"]),
b"ABC", ],
[b"ABC"], )
[[], [False], b"A"]
),
])
def test_eq( def test_eq(
value: Union[bool, List[bool], bytes, Bitstream], value: Union[bool, List[bool], bytes, Bitstream],
equals: List[Union[bool, List[bool], bytes, Bitstream]], equals: List[Union[bool, List[bool], bytes, Bitstream]],
not_equals: List[Union[bool, List[bool], bytes, Bitstream]]): not_equals: List[Union[bool, List[bool], bytes, Bitstream]],
):
bitstream = Bitstream(value) bitstream = Bitstream(value)
for equal in equals: for equal in equals:
assert bitstream == equal assert bitstream == equal
@ -130,16 +145,20 @@ def test_eq(
assert bitstream != not_equal assert bitstream != not_equal
@pytest.mark.parametrize("values,number,expected", [ @pytest.mark.parametrize(
"values,number,expected",
[
([], 2, []), ([], 2, []),
([True, False], 2, [True, False, True, False]), ([True, False], 2, [True, False, True, False]),
(b"A", 2, b"AA"), (b"A", 2, b"AA"),
([False], 3, [False, False, False]), ([False], 3, [False, False, False]),
]) ],
)
def test_mul( def test_mul(
values: Union[bool, List[bool], bytes, Bitstream], values: Union[bool, List[bool], bytes, Bitstream],
number: int, number: int,
expected: Union[List[bool], bytes, Bitstream]): expected: Union[List[bool], bytes, Bitstream],
):
bitstream = Bitstream(values) bitstream = Bitstream(values)
assert (bitstream * number) == expected assert (bitstream * number) == expected
@ -149,15 +168,19 @@ def test_mul(
assert bitstream == expected assert bitstream == expected
@pytest.mark.parametrize("init_values,add_value,expected", [ @pytest.mark.parametrize(
"init_values,add_value,expected",
[
([], True, [True]), ([], True, [True]),
([True, False], False, [True, False, False]), ([True, False], False, [True, False, False]),
(b"A", b"B", b"AB"), (b"A", b"B", b"AB"),
]) ],
)
def test_add( def test_add(
init_values: Union[bool, List[bool], bytes, Bitstream], init_values: Union[bool, List[bool], bytes, Bitstream],
add_value: Union[bool, List[bool], bytes, Bitstream], add_value: Union[bool, List[bool], bytes, Bitstream],
expected: Union[List[bool], bytes, Bitstream]): expected: Union[List[bool], bytes, Bitstream],
):
bitstream = Bitstream(init_values) bitstream = Bitstream(init_values)
assert (bitstream + add_value) == expected assert (bitstream + add_value) == expected
@ -167,12 +190,15 @@ def test_add(
assert bitstream == expected assert bitstream == expected
@pytest.mark.parametrize("init_values", [ @pytest.mark.parametrize(
"init_values",
[
[], [],
[True, False], [True, False],
[b"A"], [b"A"],
[True, b"B"], [True, b"B"],
]) ],
)
def test_iter(init_values: List[Union[bool, List[bool], bytes, Bitstream]]): def test_iter(init_values: List[Union[bool, List[bool], bytes, Bitstream]]):
bitstream = Bitstream(*init_values) bitstream = Bitstream(*init_values)
assert len(list(bitstream)) == len(bitstream) assert len(list(bitstream)) == len(bitstream)
@ -180,29 +206,35 @@ def test_iter(init_values: List[Union[bool, List[bool], bytes, Bitstream]]):
assert bit == bitstream[index] assert bit == bitstream[index]
@pytest.mark.parametrize("init_values,expected", [ @pytest.mark.parametrize(
"init_values,expected",
[
([], b""), ([], b""),
([b"A"], b"A"), ([b"A"], b"A"),
([False], b"\x00"), ([False], b"\x00"),
([[False] * 6, True], b"\x02"), ([[False] * 6, True], b"\x02"),
([[False] * 7, True], b"\x01"), ([[False] * 7, True], b"\x01"),
]) ],
)
def test_bytes( def test_bytes(
init_values: List[Union[bool, List[bool], bytes, Bitstream]], init_values: List[Union[bool, List[bool], bytes, Bitstream]], expected: bytes
expected: bytes): ):
assert bytes(Bitstream(*init_values)) == expected assert bytes(Bitstream(*init_values)) == expected
@pytest.mark.parametrize("init_values,expected", [ @pytest.mark.parametrize(
"init_values,expected",
[
([], []), ([], []),
([b"\x01"], [False] * 7 + [True]), ([b"\x01"], [False] * 7 + [True]),
([False], [False]), ([False], [False]),
([False] * 6 + [True], [False] * 6 + [True]), ([False] * 6 + [True], [False] * 6 + [True]),
([False] * 7 + [True], [False] * 7 + [True]), ([False] * 7 + [True], [False] * 7 + [True]),
]) ],
)
def test_bits( def test_bits(
init_values: List[Union[bool, List[bool], bytes, Bitstream]], init_values: List[Union[bool, List[bool], bytes, Bitstream]], expected: List[bool]
expected: List[bool]): ):
assert Bitstream(*init_values).bits == expected assert Bitstream(*init_values).bits == expected
@ -213,12 +245,15 @@ def test_copy():
assert id(bitstream) != id(bitstream.copy) assert id(bitstream) != id(bitstream.copy)
@pytest.mark.parametrize("lst", [ @pytest.mark.parametrize(
"lst",
[
[False] * 10, [False] * 10,
[True] * 10, [True] * 10,
[True, False] * 5, [True, False] * 5,
[False, True] * 5, [False, True] * 5,
]) ],
)
def test_append(lst: List[bool]): def test_append(lst: List[bool]):
bitstream = Bitstream() bitstream = Bitstream()
for i in range(len(lst)): for i in range(len(lst)):
@ -227,17 +262,24 @@ def test_append(lst: List[bool]):
assert len(bitstream) == (i + 1) assert len(bitstream) == (i + 1)
@pytest.mark.parametrize("extend_value", [ @pytest.mark.parametrize(
"extend_value",
[
[True, False, True], [True, False, True],
b"A", b"A",
b"BB", b"BB",
Bitstream(True, b"C"), Bitstream(True, b"C"),
]) ],
)
def test_extend(extend_value: Union[List[bool], bytes, Bitstream]): def test_extend(extend_value: Union[List[bool], bytes, Bitstream]):
for i in range(10): for i in range(10):
bitstream = Bitstream([False] * i) bitstream = Bitstream([False] * i)
bitstream.extend(extend_value) bitstream.extend(extend_value)
value = bitstream.read(bool if type(extend_value) is list else type(extend_value), get_bit_length(extend_value), i) value = bitstream.read(
bool if type(extend_value) is list else type(extend_value),
get_bit_length(extend_value),
i,
)
assert value == extend_value assert value == extend_value
@ -256,12 +298,15 @@ def test_clear():
assert bitstream._offset == 0 assert bitstream._offset == 0
@pytest.mark.parametrize("lst", [ @pytest.mark.parametrize(
"lst",
[
[False] * 10, [False] * 10,
[True] * 10, [True] * 10,
[True, False] * 5, [True, False] * 5,
[False, True] * 5, [False, True] * 5,
]) ],
)
def test_pop(lst: List[bool]): def test_pop(lst: List[bool]):
bitstream = Bitstream(lst) bitstream = Bitstream(lst)
for i in range(len(bitstream)): for i in range(len(bitstream)):
@ -279,42 +324,40 @@ def test_read():
assert bitstream.read(int, 3) == 2 assert bitstream.read(int, 3) == 2
bitstream.seek(0) bitstream.seek(0)
assert bitstream.read(Bitstream, len(bitstream) - 1) == Bitstream(True, b"A", False, True) assert bitstream.read(Bitstream, len(bitstream) - 1) == Bitstream(
True, b"A", False, True
)
@pytest.mark.parametrize("bitstream,read_type,bit_length,bit_index,expected", [ @pytest.mark.parametrize(
"bitstream,read_type,bit_length,bit_index,expected",
[
(Bitstream(True, b"A", False, True, False), bool, 1, 0, [True]),
( (
Bitstream(True, b"A", False, True, False), Bitstream(True, b"A", False, True, False),
bool, 1, 0, bool,
[True] 3,
8 + 1,
[False, True, False],
), ),
(Bitstream(True, b"A", False, True, False), bytes, 8, 1, b"A"),
( (
Bitstream(True, b"A", False, True, False), Bitstream(True, b"A", False, True, False),
bool, 3, 8 + 1, Bitstream,
[False, True, False] 1 + 8 + 3,
), 0,
(
Bitstream(True, b"A", False, True, False), Bitstream(True, b"A", False, True, False),
bytes, 8, 1,
b"A"
), ),
( (Bitstream(True, b"A", False, True, False), int, 3, 1 + 8, 2),
Bitstream(True, b"A", False, True, False), ],
Bitstream, 1 + 8 + 3, 0, )
Bitstream(True, b"A", False, True, False)
),
(
Bitstream(True, b"A", False, True, False),
int, 3, 1 + 8,
2
),
])
def test_read_index( def test_read_index(
bitstream: Bitstream, bitstream: Bitstream,
read_type: Type[Union[bool, bytes, Bitstream, int]], read_type: Type[Union[bool, bytes, Bitstream, int]],
bit_length: int, bit_length: int,
bit_index: Optional[int], bit_index: Optional[int],
expected: Union[List[bool], bytes, Bitstream, int]): expected: Union[List[bool], bytes, Bitstream, int],
):
value = bitstream.read(read_type, bit_length, bit_index) value = bitstream.read(read_type, bit_length, bit_index)
if read_type is bool: if read_type is bool:
@ -337,7 +380,9 @@ def test_read_error():
bitstream.read(list, 1, 0) bitstream.read(list, 1, 0)
@pytest.mark.parametrize("bitstream,value,bit_index,current_offset", [ @pytest.mark.parametrize(
"bitstream,value,bit_index,current_offset",
[
(Bitstream(), True, None, None), (Bitstream(), True, None, None),
(Bitstream(), b"A", None, None), (Bitstream(), b"A", None, None),
(Bitstream(), True, 0, None), (Bitstream(), True, 0, None),
@ -348,8 +393,14 @@ def test_read_error():
(Bitstream(False), True, 0, None), (Bitstream(False), True, 0, None),
(Bitstream(False), True, 0, 0), (Bitstream(False), True, 0, 0),
(Bitstream(False), True, None, 0), (Bitstream(False), True, None, 0),
]) ],
def test_write(bitstream: Bitstream, value: Union[bool, List[bool], bytes, Bitstream], bit_index: Optional[int], current_offset: Optional[int]): )
def test_write(
bitstream: Bitstream,
value: Union[bool, List[bool], bytes, Bitstream],
bit_index: Optional[int],
current_offset: Optional[int],
):
if current_offset is not None: if current_offset is not None:
bitstream.seek(current_offset) bitstream.seek(current_offset)
@ -358,9 +409,13 @@ def test_write(bitstream: Bitstream, value: Union[bool, List[bool], bytes, Bitst
if bit_index is None: if bit_index is None:
bitstream.seek(-bit_length, 1) bitstream.seek(-bit_length, 1)
read_value = bitstream.read(bool if type(value) is list else type(value), bit_length) read_value = bitstream.read(
bool if type(value) is list else type(value), bit_length
)
else: else:
read_value = bitstream.read(bool if type(value) is list else type(value), bit_length, bit_index) read_value = bitstream.read(
bool if type(value) is list else type(value), bit_length, bit_index
)
assert read_value == ([value] if type(value) is bool else value) assert read_value == ([value] if type(value) is bool else value)