Compare commits
12 Commits
Author | SHA1 | Date | |
---|---|---|---|
ee04ff6e6d | |||
e975a5c492 | |||
eca447fd49 | |||
2422bee5fa | |||
faf9ccffbc | |||
8e59db2ad6 | |||
f1152cfb25 | |||
75824c306f | |||
b3fedb8214 | |||
446ab11a61 | |||
404f71dddd | |||
73eaa0e89b |
6
.gitignore
vendored
6
.gitignore
vendored
|
@ -130,3 +130,9 @@ dmypy.json
|
||||||
|
|
||||||
# VSC
|
# VSC
|
||||||
/.vscode
|
/.vscode
|
||||||
|
|
||||||
|
# Ignore ini files
|
||||||
|
*.ini
|
||||||
|
|
||||||
|
# Ignore logs directory
|
||||||
|
logs/
|
||||||
|
|
113
sampy/__main__.py
Normal file
113
sampy/__main__.py
Normal file
|
@ -0,0 +1,113 @@
|
||||||
|
import argparse
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import textwrap
|
||||||
|
|
||||||
|
|
||||||
|
def main(args: argparse.Namespace) -> int:
|
||||||
|
from sampy.config import Config
|
||||||
|
|
||||||
|
config = Config(*args.config, logging_level=args.log)
|
||||||
|
|
||||||
|
from sampy.network.protocol import Protocol
|
||||||
|
from sampy.server import InteractiveServer
|
||||||
|
|
||||||
|
server = InteractiveServer(
|
||||||
|
Protocol,
|
||||||
|
config=config,
|
||||||
|
)
|
||||||
|
server.start()
|
||||||
|
|
||||||
|
asyncio.get_event_loop().run_forever()
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
class Formatter(
|
||||||
|
argparse.ArgumentDefaultsHelpFormatter, argparse.RawTextHelpFormatter
|
||||||
|
):
|
||||||
|
pass
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
prog="sampy",
|
||||||
|
description=textwrap.dedent(
|
||||||
|
"""
|
||||||
|
A SAMP server made in python
|
||||||
|
|
||||||
|
SAMP (or SA-MP) is a free multiplayer mod for the PC port of GTA: San Andreas.
|
||||||
|
GTA: San Andreas was developed by Rockstar North and released in 2005.
|
||||||
|
SA-MP is an unofficial multiplayer mod made by the `SA-MP.com` team released @[sa-mp.com](https://www.sa-mp.com/)
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
epilog=textwrap.dedent(
|
||||||
|
"""
|
||||||
|
example:
|
||||||
|
%(prog)s --config default.ini secret.ini race.ini
|
||||||
|
|
||||||
|
This will first load default.ini configuration,
|
||||||
|
secret.ini might override the rcon- and server-password,
|
||||||
|
race.ini might change the hostname and rules.
|
||||||
|
|
||||||
|
example default.ini:
|
||||||
|
[sampy]
|
||||||
|
hostname = My SAMP Server
|
||||||
|
password =
|
||||||
|
rcon_password = VerySecretPassword
|
||||||
|
|
||||||
|
[sampy.rules]
|
||||||
|
lagcomp = off
|
||||||
|
|
||||||
|
[logging.loggers]
|
||||||
|
keys = root
|
||||||
|
|
||||||
|
[logging.handlers]
|
||||||
|
keys = console, file
|
||||||
|
|
||||||
|
[logging.formatters]
|
||||||
|
keys = simple, color
|
||||||
|
|
||||||
|
[logging.logger_root]
|
||||||
|
level = DEBUG
|
||||||
|
handlers = console, file
|
||||||
|
|
||||||
|
[logging.handler_console]
|
||||||
|
class = StreamHandler
|
||||||
|
formatter = color
|
||||||
|
args = (sys.stdout,)
|
||||||
|
|
||||||
|
[logging.handler_file]
|
||||||
|
class = logging.handlers.TimedRotatingFileHandler
|
||||||
|
formatter = simple
|
||||||
|
args = ("logs/sampy.log", "d", 1, 7,) # Note that the logs folder has to exist
|
||||||
|
|
||||||
|
[logging.formatter_simple]
|
||||||
|
format = %%(asctime)s.%%(msecs)03d | %%(levelname)-8s | %%(message)s
|
||||||
|
datefmt = %%Y-%%m-%%d %%H:%%M:%%S
|
||||||
|
|
||||||
|
[logging.formatter_color]
|
||||||
|
class = sampy.config.ColorFormatter
|
||||||
|
format = §6%%(asctime)s.%%(msecs)03d §1| %%(levelname)-8s §1|§r %%(message)s
|
||||||
|
datefmt = %%Y-%%m-%%d %%H:%%M:%%S
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
formatter_class=Formatter,
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--config",
|
||||||
|
type=str,
|
||||||
|
nargs="+",
|
||||||
|
default=[],
|
||||||
|
help="Config filenames",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--log",
|
||||||
|
type=str,
|
||||||
|
nargs="?",
|
||||||
|
help="Global logging level (Overrides any config)",
|
||||||
|
choices=logging._nameToLevel.keys(),
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
raise SystemExit(main(args))
|
15
sampy/client/player.py
Normal file
15
sampy/client/player.py
Normal file
|
@ -0,0 +1,15 @@
|
||||||
|
from ctypes import c_ubyte
|
||||||
|
|
||||||
|
|
||||||
|
class Player:
|
||||||
|
id: c_ubyte
|
||||||
|
username: str
|
||||||
|
score: int
|
||||||
|
ping: int
|
||||||
|
health: float
|
||||||
|
armor: float
|
||||||
|
# position: Vector3 # TODO
|
||||||
|
rotation: float
|
||||||
|
|
||||||
|
def __init__(self, username: str):
|
||||||
|
self.username = username
|
182
sampy/config.py
Normal file
182
sampy/config.py
Normal file
|
@ -0,0 +1,182 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import logging.config
|
||||||
|
from configparser import ConfigParser
|
||||||
|
from typing import Any, Dict, Mapping, Optional, Union
|
||||||
|
|
||||||
|
|
||||||
|
class Config(ConfigParser):
|
||||||
|
DEFAULTS: Mapping[str, Mapping[str, Union[str, int]]] = {
|
||||||
|
"sampy": {
|
||||||
|
"host": "0.0.0.0",
|
||||||
|
"port": 7777,
|
||||||
|
"hostname": "sam.py",
|
||||||
|
"password": "",
|
||||||
|
"rcon_password": "changeme",
|
||||||
|
"max_players": 50,
|
||||||
|
"mode": "",
|
||||||
|
"language": "English",
|
||||||
|
},
|
||||||
|
"sampy.rules": {
|
||||||
|
"weburl": "https://git.osufx.com/Sunpy/sampy",
|
||||||
|
},
|
||||||
|
"logging.loggers": {
|
||||||
|
"keys": "root",
|
||||||
|
},
|
||||||
|
"logging.handlers": {
|
||||||
|
"keys": "console",
|
||||||
|
},
|
||||||
|
"logging.formatters": {
|
||||||
|
"keys": "simple",
|
||||||
|
},
|
||||||
|
"logging.logger_root": {
|
||||||
|
"level": "INFO",
|
||||||
|
"handlers": "console",
|
||||||
|
},
|
||||||
|
"logging.handler_console": {
|
||||||
|
"class": "StreamHandler",
|
||||||
|
"formatter": "simple",
|
||||||
|
"args": "(sys.stdout,)",
|
||||||
|
},
|
||||||
|
"logging.formatter_simple": {
|
||||||
|
"format": "%(levelname)s - %(message)s",
|
||||||
|
"datefmt": "%Y-%m-%d %H:%M:%S",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*filenames,
|
||||||
|
dictionary: Mapping[str, Mapping[str, Union[str, int]]] = {},
|
||||||
|
logging_level: Optional[int] = None,
|
||||||
|
):
|
||||||
|
super().__init__(interpolation=None)
|
||||||
|
if logging_level is not None:
|
||||||
|
logging.root.setLevel(logging_level)
|
||||||
|
|
||||||
|
self.read_dict(self.DEFAULTS)
|
||||||
|
self.read_dict(dictionary)
|
||||||
|
|
||||||
|
found = self.read(filenames, encoding="utf-8-sig")
|
||||||
|
missing = set(filenames) - set(found)
|
||||||
|
|
||||||
|
if len(missing):
|
||||||
|
logging.warn("Config files not found: %s" % missing)
|
||||||
|
|
||||||
|
logging_config = self.get_logging_config()
|
||||||
|
if logging_config:
|
||||||
|
logging.config.fileConfig(logging_config)
|
||||||
|
|
||||||
|
if logging_level is not None:
|
||||||
|
logging.root.setLevel(logging_level)
|
||||||
|
|
||||||
|
logging.debug("Logging module has been configured")
|
||||||
|
else:
|
||||||
|
logging.warn("Logging module was not configured")
|
||||||
|
|
||||||
|
def get_logging_config(self) -> ConfigParser:
|
||||||
|
config = ConfigParser(interpolation=None)
|
||||||
|
for section in self.sections():
|
||||||
|
if not section.startswith("logging."):
|
||||||
|
continue
|
||||||
|
config[section.replace("logging.", "")] = self[section]
|
||||||
|
return config
|
||||||
|
|
||||||
|
@property
|
||||||
|
def host(self) -> str:
|
||||||
|
return self.get("sampy", "host")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def port(self) -> int:
|
||||||
|
return self.getint("sampy", "port")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def hostname(self) -> str:
|
||||||
|
return self.get("sampy", "hostname")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def password(self) -> str:
|
||||||
|
return self.get("sampy", "password")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def rcon_password(self) -> str:
|
||||||
|
return self.get("sampy", "rcon_password")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def max_players(self) -> int:
|
||||||
|
return self.getint("sampy", "max_players")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def mode(self) -> str:
|
||||||
|
return self.get("sampy", "mode")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def language(self) -> str:
|
||||||
|
return self.get("sampy", "language")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def rules(self) -> Dict[str, str]:
|
||||||
|
return self["sampy.rules"]
|
||||||
|
|
||||||
|
|
||||||
|
class LogRecordProxy:
|
||||||
|
def __init__(self, record: logging.LogRecord):
|
||||||
|
self._record = record
|
||||||
|
|
||||||
|
def __getattribute__(self, name: str) -> Any:
|
||||||
|
attr = {
|
||||||
|
k: v
|
||||||
|
for k, v in object.__getattribute__(self, "__dict__").items()
|
||||||
|
if k != "_record"
|
||||||
|
}
|
||||||
|
if name in attr:
|
||||||
|
return attr[name]
|
||||||
|
elif name == "__dict__": # Combine dicts
|
||||||
|
return {**object.__getattribute__(self, "_record").__dict__, **attr}
|
||||||
|
return object.__getattribute__(self, "_record").__getattribute__(name)
|
||||||
|
|
||||||
|
|
||||||
|
class ColorFormatter(logging.Formatter):
|
||||||
|
COLORS: Dict[str, str] = {
|
||||||
|
"0": "30", # Black
|
||||||
|
"1": "34", # Blue
|
||||||
|
"2": "32", # Green
|
||||||
|
"3": "36", # Cyan
|
||||||
|
"4": "31", # Red,
|
||||||
|
"5": "35", # Purple/Magenta
|
||||||
|
"6": "33", # Yellow/Gold
|
||||||
|
"7": "37", # White/Light Gray
|
||||||
|
"8": "30;1", # Dark Gray
|
||||||
|
"9": "34;1", # Light Blue
|
||||||
|
"a": "32;1", # Light Green
|
||||||
|
"b": "36;1", # Light Cyan
|
||||||
|
"c": "31;1", # Light Red
|
||||||
|
"d": "35;1", # Light Purple/Magenta
|
||||||
|
"e": "33;1", # Yellow
|
||||||
|
"f": "37;1", # White
|
||||||
|
"r": "0", # Reset
|
||||||
|
"l": "1", # Bold
|
||||||
|
"n": "4", # Underline
|
||||||
|
}
|
||||||
|
|
||||||
|
LEVEL_COLOR = {
|
||||||
|
logging.CRITICAL: "31",
|
||||||
|
logging.ERROR: "31",
|
||||||
|
logging.WARNING: "33",
|
||||||
|
logging.INFO: "32",
|
||||||
|
logging.DEBUG: "35",
|
||||||
|
logging.NOTSET: "37",
|
||||||
|
}
|
||||||
|
|
||||||
|
def format(self, record: logging.LogRecord) -> str:
|
||||||
|
record = LogRecordProxy(record)
|
||||||
|
level_color = ColorFormatter.LEVEL_COLOR.get(record.levelno, None)
|
||||||
|
if level_color is not None:
|
||||||
|
record.levelname = "\x1b[%sm%s\x1b[0m" % (level_color, record.levelname)
|
||||||
|
|
||||||
|
message = super().format(record)
|
||||||
|
for k, v in ColorFormatter.COLORS.items():
|
||||||
|
message = message.replace("§%s" % k, "\x1b[%sm" % v)
|
||||||
|
|
||||||
|
return message
|
0
sampy/network/__init__.py
Normal file
0
sampy/network/__init__.py
Normal file
75
sampy/network/game.py
Normal file
75
sampy/network/game.py
Normal file
|
@ -0,0 +1,75 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import TYPE_CHECKING, Tuple
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from sampy.server import Server
|
||||||
|
|
||||||
|
|
||||||
|
class Game:
|
||||||
|
@staticmethod
|
||||||
|
async def on_packet(server: Server, packet: bytes, addr: Tuple[str, int]) -> bool:
|
||||||
|
packet = Obfuscation.deobfuscate(server, packet)
|
||||||
|
|
||||||
|
# TODO
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class Obfuscation:
|
||||||
|
# Found @ addr 0x004C88E0 in windows server executable
|
||||||
|
LOOKUP_TABLE = bytes.fromhex(
|
||||||
|
"""
|
||||||
|
b46207e59daf63dde3d0ccfedcdb6b2e6a40ab47c9d153d52091a50e4adf1889
|
||||||
|
fd6f2512b713770065366d49ec572aa9115ffa7895a4bd1ed97944cdde81eb09
|
||||||
|
3ef6eeda7fa31aa72da6adc14693d21b9caad74e4b4d4cf3b834c0ca88f494cb
|
||||||
|
04393082d673b0bf2201416e482ca875b10aae9f278010cef02928850d05f735
|
||||||
|
bbbc1506f56071031fea5a33928de7905be9cf9ed35ded311c0b5216510f86c5
|
||||||
|
689b210c8b4287ff4fbec8e8c7d47ae0552f8a8eba9837e4b238a1b632833a7b
|
||||||
|
843c61fb8c143d433b1dc3a296b3f8c4f2262bd87cfc232466ef6964505459f1
|
||||||
|
a074acc67db5e6e2c27e67175ee1b93f6c700899455676f99a9719725c028f58
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
# I think this used to be zlib compression, but has been swapped out with obfuscation instead
|
||||||
|
@staticmethod
|
||||||
|
def deobfuscate(server: Server, packet: bytes) -> bytes:
|
||||||
|
checksum, data = packet[0], bytearray(packet[1:])
|
||||||
|
|
||||||
|
data = Obfuscation.xor_every_other_byte(
|
||||||
|
Obfuscation.get_port_xor_key(server), data
|
||||||
|
)
|
||||||
|
data = bytes(Obfuscation.LOOKUP_TABLE[b] for b in packet)
|
||||||
|
|
||||||
|
if checksum != Obfuscation.calc_checksum(data):
|
||||||
|
logging.error("Checksum failed!")
|
||||||
|
raise Exception("Checksum fail")
|
||||||
|
|
||||||
|
return bytes(data)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def obfuscate(server: Server, packet: bytes) -> bytes:
|
||||||
|
data = bytes(Obfuscation.LOOKUP_TABLE.index(b) for b in packet)
|
||||||
|
data = Obfuscation.xor_every_other_byte(
|
||||||
|
Obfuscation.get_port_xor_key(server), data
|
||||||
|
)
|
||||||
|
|
||||||
|
checksum = Obfuscation.calc_checksum(data)
|
||||||
|
return bytes([checksum]) + data
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def xor_every_other_byte(xor: int, packet: bytearray) -> bytearray:
|
||||||
|
for i in range(1, len(packet), 2):
|
||||||
|
packet[i] ^= xor
|
||||||
|
return packet
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def calc_checksum(packet: bytearray) -> int:
|
||||||
|
checksum = 0
|
||||||
|
for byte in packet:
|
||||||
|
checksum ^= byte & 0xAA
|
||||||
|
return checksum
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_port_xor_key(server: Server) -> int:
|
||||||
|
return (server.config.port ^ 0xCCCC) & 0xFF
|
26
sampy/network/protocol.py
Normal file
26
sampy/network/protocol.py
Normal file
|
@ -0,0 +1,26 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import TYPE_CHECKING, Tuple
|
||||||
|
|
||||||
|
from sampy.network.game import Game
|
||||||
|
from sampy.network.query import Query
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from sampy.server import Server
|
||||||
|
|
||||||
|
|
||||||
|
class Protocol:
|
||||||
|
VERSION = "0.3.7"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def on_packet(server: Server, packet: bytes, addr: Tuple[str, int]) -> bool:
|
||||||
|
logging.debug("on_packet")
|
||||||
|
|
||||||
|
if await Query.on_packet(server, packet, addr):
|
||||||
|
return True
|
||||||
|
if await Game.on_packet(server, packet, addr):
|
||||||
|
return True
|
||||||
|
|
||||||
|
logging.debug("Unhandled: %r" % packet)
|
||||||
|
return False
|
185
sampy/network/query.py
Normal file
185
sampy/network/query.py
Normal file
|
@ -0,0 +1,185 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import functools
|
||||||
|
import logging
|
||||||
|
import struct
|
||||||
|
from typing import TYPE_CHECKING, Callable, Dict, Tuple
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from sampy.server import Server
|
||||||
|
|
||||||
|
|
||||||
|
# Holds a dict of handlers for different packet IDs. Since every query packet id is a single byte, we use an int here.
|
||||||
|
HANDLERS: Dict[int, Callable[[Server, bytes, Tuple[str, int]], bool]] = {}
|
||||||
|
|
||||||
|
|
||||||
|
# Decorator that adds the function to the HANDLERS dict.
|
||||||
|
def handler(packet_id: bytes):
|
||||||
|
if len(packet_id) > 1:
|
||||||
|
raise Exception("Query opcode length cannot be bigger then 1")
|
||||||
|
|
||||||
|
def outer(func):
|
||||||
|
@functools.wraps(func)
|
||||||
|
def inner(
|
||||||
|
server: Server, packet: bytes, addr: Tuple[str, int], *args, **kwargs
|
||||||
|
):
|
||||||
|
return func(server, packet, addr, *args, **kwargs)
|
||||||
|
|
||||||
|
HANDLERS[packet_id[0]] = func
|
||||||
|
logging.debug("Added Query handler: %s -> %s" % (packet_id, func))
|
||||||
|
return inner
|
||||||
|
|
||||||
|
return outer
|
||||||
|
|
||||||
|
|
||||||
|
class Query:
|
||||||
|
"""Query handler
|
||||||
|
|
||||||
|
Reference: https://team.sa-mp.com/wiki/Query_Mechanism.html (Unable to archive due to https://sa-mp.com/robots.txt disallowing ia_archiver)
|
||||||
|
|
||||||
|
The Query protocol is *mostly* used for the samp server browser and other systems that queries for server info.
|
||||||
|
The exception to this is would be the "player_details"(d) packet on version 0.2.x and below where an ingame player pressing TAB would use this packet.
|
||||||
|
|
||||||
|
Structure:
|
||||||
|
- "SAMP" header (4 bytes)
|
||||||
|
- Server's ipv4 (4 bytes)
|
||||||
|
- Server's port (2 bytes)
|
||||||
|
- Packet type (1 byte)
|
||||||
|
- Packet data
|
||||||
|
Note that the server and client will both use the same first 4 parts when communicating.
|
||||||
|
Not all packets have packet data, and most packets doesn't have data when the client sends it as its a "request".
|
||||||
|
The server will always have packet data attached.
|
||||||
|
|
||||||
|
List of packets:
|
||||||
|
- i: Information packet. This includes hostname, players (online/max), mode, language and whether password is required.
|
||||||
|
- r: Rules packet. This is a list of "rules". The name "rules" is subjective in this case, as most are general optional information.
|
||||||
|
- c: Client list packet. This is a list of players and scores. Players being just their username and score a number.
|
||||||
|
- d: Detailed player list packet. Extends the client list packet with player id and ping.
|
||||||
|
- p: Ping packet. A client uses this packet with 4 random bytes and measures how long it takes before it gets the same packet back.
|
||||||
|
- x: Remote console packet. A client can send and receive anything on this packet. Usually used for remote console.
|
||||||
|
|
||||||
|
Additional findings:
|
||||||
|
- On info packet, strings can not be of size 0. This will make parsing of the rest of the packet fail.
|
||||||
|
- This is strange due to how we are sending the length of the string first, which should allow this...
|
||||||
|
- Fix: If string size is 0, we replace the string with a single space (Considering using a NULL byte, but unsure if this could cause other issues)
|
||||||
|
"""
|
||||||
|
|
||||||
|
HEADER_LENGTH = 11
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def on_packet(server: Server, packet: bytes, addr: Tuple[str, int]) -> bool:
|
||||||
|
if len(packet) < 11: # Packet is too small
|
||||||
|
return False
|
||||||
|
|
||||||
|
magic, _ip, _port, packet_id = struct.unpack(
|
||||||
|
b"<4sIHB", packet[: Query.HEADER_LENGTH]
|
||||||
|
) # Unpack packet
|
||||||
|
|
||||||
|
if magic != b"SAMP": # Validate magic
|
||||||
|
return False
|
||||||
|
|
||||||
|
return HANDLERS.get(packet_id, lambda *_: False)(server, packet, addr)
|
||||||
|
|
||||||
|
@handler(b"i")
|
||||||
|
@staticmethod
|
||||||
|
def info(server: Server, packet: bytes, addr: Tuple[str, int]) -> bool:
|
||||||
|
packet = packet[: Query.HEADER_LENGTH] # Discard additional data if passed
|
||||||
|
|
||||||
|
hostname = server.config.hostname.encode()
|
||||||
|
mode = server.config.mode.encode()
|
||||||
|
language = server.config.language.encode()
|
||||||
|
|
||||||
|
if len(hostname) == 0:
|
||||||
|
hostname = b" "
|
||||||
|
if len(mode) == 0:
|
||||||
|
mode = b" "
|
||||||
|
if len(language) == 0:
|
||||||
|
language = b" "
|
||||||
|
|
||||||
|
packet += struct.pack(
|
||||||
|
b"<?HHI%dsI%dsI%ds" % (len(hostname), len(mode), len(language)),
|
||||||
|
len(server.config.password) != 0,
|
||||||
|
len(server.players),
|
||||||
|
server.config.max_players,
|
||||||
|
len(hostname),
|
||||||
|
hostname,
|
||||||
|
len(mode),
|
||||||
|
mode,
|
||||||
|
len(language),
|
||||||
|
language,
|
||||||
|
)
|
||||||
|
server.sendto(packet, addr)
|
||||||
|
return True
|
||||||
|
|
||||||
|
@handler(b"r")
|
||||||
|
@staticmethod
|
||||||
|
def rules(server: Server, packet: bytes, addr: Tuple[str, int]) -> bool:
|
||||||
|
packet = packet[: Query.HEADER_LENGTH] # Discard additional data if passed
|
||||||
|
|
||||||
|
rules = server.config.rules
|
||||||
|
rules["version"] = server.protocol.VERSION # Add game version (read-only)
|
||||||
|
rules["worldtime"] = ""
|
||||||
|
rules["weather"] = "10"
|
||||||
|
|
||||||
|
packet += struct.pack(b"<H", len(rules))
|
||||||
|
for item in (item.encode() for pair in rules.items() for item in pair):
|
||||||
|
packet += struct.pack(b"<B%ds" % len(item), len(item), item)
|
||||||
|
|
||||||
|
server.sendto(packet, addr)
|
||||||
|
return True
|
||||||
|
|
||||||
|
@handler(b"c")
|
||||||
|
@staticmethod
|
||||||
|
def client_list(server: Server, packet: bytes, addr: Tuple[str, int]) -> bool:
|
||||||
|
packet = packet[: Query.HEADER_LENGTH] # Discard additional data if passed
|
||||||
|
|
||||||
|
players = server.players
|
||||||
|
packet += struct.pack(b"<H", len(players))
|
||||||
|
|
||||||
|
for player in players:
|
||||||
|
username = player.username.encode()
|
||||||
|
packet += struct.pack(b"<B%dsI", len(username), username, player.score)
|
||||||
|
|
||||||
|
server.sendto(packet, addr)
|
||||||
|
return True
|
||||||
|
|
||||||
|
@handler(b"d")
|
||||||
|
@staticmethod
|
||||||
|
def player_details(server: Server, packet: bytes, addr: Tuple[str, int]) -> bool:
|
||||||
|
packet = packet[: Query.HEADER_LENGTH] # Discard additional data if passed
|
||||||
|
|
||||||
|
players = server.players
|
||||||
|
packet += struct.pack(b"<H", len(players))
|
||||||
|
|
||||||
|
for player in players:
|
||||||
|
username = player.username.encode()
|
||||||
|
packet += struct.pack(
|
||||||
|
b"<BB%dsII",
|
||||||
|
player.id,
|
||||||
|
len(username),
|
||||||
|
username,
|
||||||
|
player.score,
|
||||||
|
player.ping,
|
||||||
|
)
|
||||||
|
|
||||||
|
server.sendto(packet, addr)
|
||||||
|
return True
|
||||||
|
|
||||||
|
@handler(b"p")
|
||||||
|
@staticmethod
|
||||||
|
def ping(server: Server, packet: bytes, addr: Tuple[str, int]) -> bool:
|
||||||
|
if (
|
||||||
|
len(packet) < Query.HEADER_LENGTH + 4
|
||||||
|
): # Packet is too small (Missing random)
|
||||||
|
return False
|
||||||
|
|
||||||
|
packet = packet[
|
||||||
|
: Query.HEADER_LENGTH + 4
|
||||||
|
] # Discard additional data if passed (+4 to include random)
|
||||||
|
server.sendto(packet, addr)
|
||||||
|
return True
|
||||||
|
|
||||||
|
@handler(b"x")
|
||||||
|
@staticmethod
|
||||||
|
def rcon(server: Server, packet: bytes, addr: Tuple[str, int]) -> bool:
|
||||||
|
return False # TODO
|
0
sampy/raknet/__init__.py
Normal file
0
sampy/raknet/__init__.py
Normal file
299
sampy/raknet/bitstream.py
Normal file
299
sampy/raknet/bitstream.py
Normal file
|
@ -0,0 +1,299 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any, List, Literal, Optional, Type, Union
|
||||||
|
|
||||||
|
# for seek()
|
||||||
|
SEEK_SET = 0
|
||||||
|
SEEK_CUR = 1
|
||||||
|
SEEK_END = 2
|
||||||
|
|
||||||
|
|
||||||
|
def get_bit_length(value: Union[bool, List[bool], bytes, Bitstream]) -> int:
|
||||||
|
if type(value) is bool:
|
||||||
|
return 1
|
||||||
|
elif type(value) is list and all(type(v) is bool for v in value):
|
||||||
|
return len(value)
|
||||||
|
elif type(value) is bytes:
|
||||||
|
return len(value) << 3
|
||||||
|
elif type(value) is Bitstream:
|
||||||
|
return len(value)
|
||||||
|
else:
|
||||||
|
raise TypeError("Invalid data type")
|
||||||
|
|
||||||
|
|
||||||
|
def bits_to_int(bits: List[bool], little_endian: bool = False) -> int:
|
||||||
|
num = 0
|
||||||
|
if little_endian:
|
||||||
|
bits = bits[::-1]
|
||||||
|
for bit in bits:
|
||||||
|
num = (num << 1) | bit
|
||||||
|
return num
|
||||||
|
|
||||||
|
|
||||||
|
def int_to_bits(
|
||||||
|
number: int, bit_length: int, little_endian: bool = False
|
||||||
|
) -> List[bool]:
|
||||||
|
bits = [bool(number & (0b1 << i)) for i in range(bit_length - 1, -1, -1)]
|
||||||
|
return bits[::-1] if little_endian else bits
|
||||||
|
|
||||||
|
|
||||||
|
class Bitstream:
|
||||||
|
def __init__(self, *values: Union[bool, List[bool], bytes, Bitstream]):
|
||||||
|
self._bytearray = bytearray()
|
||||||
|
self._bit_length = 0
|
||||||
|
self._offset = 0
|
||||||
|
|
||||||
|
for item in values:
|
||||||
|
if type(item) is bool:
|
||||||
|
self.append(item)
|
||||||
|
else:
|
||||||
|
self.extend(item)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_bytes(value: bytes) -> Bitstream:
|
||||||
|
bitstream = Bitstream()
|
||||||
|
bitstream._bytearray += value
|
||||||
|
bitstream._bit_length = len(bitstream._bytearray) << 3
|
||||||
|
return bitstream
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_int(value: int, bit_length: int) -> Bitstream:
|
||||||
|
return Bitstream(
|
||||||
|
[bool(value & (0b1 << i)) for i in range(bit_length - 1, -1, -1)]
|
||||||
|
)
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return self._bit_length
|
||||||
|
|
||||||
|
def __bytes__(self) -> bytes:
|
||||||
|
while len(self._bytearray) > ((len(self) + 7) >> 3):
|
||||||
|
self._bytearray.pop()
|
||||||
|
return bytes(self._bytearray)
|
||||||
|
|
||||||
|
def __int__(self) -> int:
|
||||||
|
num = 0
|
||||||
|
for bit in self:
|
||||||
|
num = (num << 1) | bit
|
||||||
|
return num
|
||||||
|
|
||||||
|
def __getitem__(self, bit_index: int) -> bool:
|
||||||
|
if bit_index < 0:
|
||||||
|
bit_index = len(self) + bit_index
|
||||||
|
|
||||||
|
if bit_index >= self._bit_length or bit_index < 0:
|
||||||
|
raise IndexError("bit index out of range")
|
||||||
|
|
||||||
|
mask = 1 << (7 - (bit_index % 8))
|
||||||
|
return bool(self._bytearray[bit_index >> 3] & mask)
|
||||||
|
|
||||||
|
def __setitem__(
|
||||||
|
self, bit_index: int, value: Union[bool, List[bool], bytes, Bitstream]
|
||||||
|
):
|
||||||
|
if bit_index >= self._bit_length:
|
||||||
|
raise IndexError("bit index out of range")
|
||||||
|
|
||||||
|
if type(value) is bytes:
|
||||||
|
self[bit_index] = Bitstream.from_bytes(value)
|
||||||
|
return
|
||||||
|
|
||||||
|
value_bit_length = get_bit_length(value)
|
||||||
|
|
||||||
|
if (bit_index + value_bit_length) > self._bit_length:
|
||||||
|
raise IndexError("Cannot write bits that extends the size of bitstream")
|
||||||
|
|
||||||
|
if type(value) is bool:
|
||||||
|
mask = 1 << (7 - (bit_index % 8))
|
||||||
|
if value:
|
||||||
|
self._bytearray[bit_index >> 3] |= mask
|
||||||
|
else:
|
||||||
|
self._bytearray[bit_index >> 3] &= ~mask
|
||||||
|
elif isinstance(value, Bitstream) or (
|
||||||
|
type(value) is list and all(type(v) is bool for v in value)
|
||||||
|
):
|
||||||
|
for i in range(value_bit_length):
|
||||||
|
self[bit_index + i] = value[i]
|
||||||
|
else:
|
||||||
|
raise TypeError("Expected bool, list[bool], bytes or Bitstream")
|
||||||
|
|
||||||
|
def __eq__(self, other: Any) -> bool:
|
||||||
|
if type(other) is bool:
|
||||||
|
other = [other]
|
||||||
|
|
||||||
|
if isinstance(other, Bitstream):
|
||||||
|
pass
|
||||||
|
elif type(other) is bytes:
|
||||||
|
pass
|
||||||
|
elif type(other) is list and all(type(v) is bool for v in other):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
raise TypeError("Expected Bitstream, bytes or list[bool]")
|
||||||
|
|
||||||
|
if isinstance(other, Bitstream) or type(other) is list:
|
||||||
|
if len(self) != len(other):
|
||||||
|
return False
|
||||||
|
for left, right in zip(self, other):
|
||||||
|
if left != right:
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
if bytes(self) != other:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def __mul__(self, number: int) -> Bitstream:
|
||||||
|
copy = self.copy()
|
||||||
|
for _ in range(1, number):
|
||||||
|
copy.extend(self)
|
||||||
|
return copy
|
||||||
|
|
||||||
|
def __imul__(self, number: int) -> Bitstream:
|
||||||
|
copy = self.copy()
|
||||||
|
for _ in range(1, number):
|
||||||
|
self.extend(copy)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __add__(self, value: Union[bool, List[bool], bytes, Bitstream]) -> Bitstream:
|
||||||
|
copy = self.copy()
|
||||||
|
if type(value) is bool:
|
||||||
|
copy.append(value)
|
||||||
|
else:
|
||||||
|
copy.extend(value)
|
||||||
|
return copy
|
||||||
|
|
||||||
|
def __iadd__(self, value: Union[bool, List[bool], bytes, Bitstream]) -> Bitstream:
|
||||||
|
if type(value) is bool:
|
||||||
|
self.append(value)
|
||||||
|
else:
|
||||||
|
self.extend(value)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __iter__(self) -> Bitstream:
|
||||||
|
for i in range(len(self)):
|
||||||
|
yield self[i]
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return "<Bitstream addr:0x%012x offset:%d len:%d data:'%s'>" % (
|
||||||
|
id(self),
|
||||||
|
self._offset,
|
||||||
|
len(self),
|
||||||
|
bytes(self).hex(" "),
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def bits(self) -> List[bool]:
|
||||||
|
return [b for b in self]
|
||||||
|
|
||||||
|
def copy(self) -> Bitstream:
|
||||||
|
bitstream = Bitstream()
|
||||||
|
bitstream._bytearray = self._bytearray.copy()
|
||||||
|
bitstream._bit_length = self._bit_length
|
||||||
|
bitstream._offset = self._offset
|
||||||
|
return bitstream
|
||||||
|
|
||||||
|
def append(self, value: bool):
|
||||||
|
if type(value) is not bool:
|
||||||
|
raise TypeError("Expected bool")
|
||||||
|
|
||||||
|
while (len(self._bytearray) << 3) < (self._bit_length + 1):
|
||||||
|
self._bytearray.append(0)
|
||||||
|
prev_bit_length = self._bit_length
|
||||||
|
self._bit_length += 1
|
||||||
|
self[prev_bit_length] = value
|
||||||
|
|
||||||
|
def extend(self, value: Union[List[bool], bytes, Bitstream]):
|
||||||
|
if type(value) is bytes:
|
||||||
|
pass
|
||||||
|
elif isinstance(value, Bitstream):
|
||||||
|
pass
|
||||||
|
elif type(value) is list and all(type(v) is bool for v in value):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
raise TypeError("Expected list[bool], bytes or Bitstream")
|
||||||
|
|
||||||
|
value_bit_length = get_bit_length(value)
|
||||||
|
if value_bit_length == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
while (len(self._bytearray) << 3) < (self._bit_length + value_bit_length):
|
||||||
|
self._bytearray.append(0)
|
||||||
|
prev_bit_length = self._bit_length
|
||||||
|
self._bit_length += value_bit_length
|
||||||
|
self[prev_bit_length] = value
|
||||||
|
|
||||||
|
def clear(self):
|
||||||
|
self._bytearray.clear()
|
||||||
|
self._bit_length = 0
|
||||||
|
self._offset = 0
|
||||||
|
|
||||||
|
def pop(self) -> bool:
|
||||||
|
value = self[-1]
|
||||||
|
self[-1] = False
|
||||||
|
self._bit_length -= 1
|
||||||
|
return value
|
||||||
|
|
||||||
|
def read(
|
||||||
|
self,
|
||||||
|
type: Type[Union[bool, bytes, Bitstream, int]],
|
||||||
|
bit_length: int,
|
||||||
|
bit_index: Optional[int] = None,
|
||||||
|
) -> Union[List[bool], bytes, Bitstream, int]:
|
||||||
|
start = self._offset if bit_index is None else bit_index
|
||||||
|
|
||||||
|
if (start + bit_length) > self._bit_length or (start < 0):
|
||||||
|
raise IndexError("bit index out of range")
|
||||||
|
|
||||||
|
bits = [self[i] for i in range(start, start + bit_length)]
|
||||||
|
|
||||||
|
if bit_index is None:
|
||||||
|
self._offset += bit_length
|
||||||
|
|
||||||
|
if type is bool:
|
||||||
|
return bits
|
||||||
|
|
||||||
|
bitstream = Bitstream(bits)
|
||||||
|
if type is Bitstream:
|
||||||
|
return bitstream
|
||||||
|
elif type is bytes:
|
||||||
|
return bytes(bitstream)
|
||||||
|
elif type is int:
|
||||||
|
return int(bitstream)
|
||||||
|
|
||||||
|
raise TypeError("Invalid data type")
|
||||||
|
|
||||||
|
def write(
|
||||||
|
self,
|
||||||
|
value: Union[bool, List[bool], bytes, Bitstream],
|
||||||
|
bit_index: Optional[int] = None,
|
||||||
|
):
|
||||||
|
start = self._offset if bit_index is None else bit_index
|
||||||
|
|
||||||
|
if start < 0:
|
||||||
|
raise IndexError("bit index out of range")
|
||||||
|
|
||||||
|
value_bit_length = get_bit_length(value)
|
||||||
|
if value_bit_length == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
while (len(self._bytearray) << 3) < (start + value_bit_length):
|
||||||
|
self._bytearray.append(0)
|
||||||
|
|
||||||
|
if bit_index is None:
|
||||||
|
self._offset += value_bit_length
|
||||||
|
_bit_length = self._offset
|
||||||
|
else:
|
||||||
|
_bit_length = start + value_bit_length
|
||||||
|
|
||||||
|
if _bit_length > self._bit_length:
|
||||||
|
self._bit_length = _bit_length
|
||||||
|
self[start] = value
|
||||||
|
|
||||||
|
def seek(self, position: int, whence: Literal[0, 1, 2] = 0):
|
||||||
|
if whence == 0:
|
||||||
|
self._offset = position
|
||||||
|
elif whence == 1:
|
||||||
|
self._offset += position
|
||||||
|
elif whence == 2:
|
||||||
|
self._offset = len(self) + position
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid whence")
|
||||||
|
|
||||||
|
self._offset = min(max(self._offset, 0), len(self))
|
91
sampy/server.py
Normal file
91
sampy/server.py
Normal file
|
@ -0,0 +1,91 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from typing import TYPE_CHECKING, Optional, Tuple, Type
|
||||||
|
|
||||||
|
from sampy.client.player import Player
|
||||||
|
from sampy.config import Config
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from sampy.network.protocol import Protocol
|
||||||
|
|
||||||
|
|
||||||
|
class UDPProtocol:
|
||||||
|
transport: asyncio.transports.DatagramTransport
|
||||||
|
|
||||||
|
def __init__(self, protocol: Protocol, local_addr: Tuple[str, int]):
|
||||||
|
self.protocol = protocol
|
||||||
|
self.local_addr = local_addr
|
||||||
|
|
||||||
|
def start(self):
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
logging.debug("Creating datagram endpoint")
|
||||||
|
connect = loop.create_datagram_endpoint(
|
||||||
|
lambda: self,
|
||||||
|
local_addr=self.local_addr,
|
||||||
|
)
|
||||||
|
loop.run_until_complete(connect)
|
||||||
|
|
||||||
|
def stop(self): # TODO: Shutdown code
|
||||||
|
if self.transport is None:
|
||||||
|
raise Exception("Cannot stop a server that hasn't been started")
|
||||||
|
|
||||||
|
logging.debug("Shutting down")
|
||||||
|
self.transport.close()
|
||||||
|
|
||||||
|
def connection_made(self, transport: asyncio.transports.DatagramTransport):
|
||||||
|
logging.debug("UDP Protocol: connection_made")
|
||||||
|
self.transport = transport
|
||||||
|
|
||||||
|
def connection_lost(self, exc: Exception | None):
|
||||||
|
logging.debug("UDP Protocol: connection_lost")
|
||||||
|
|
||||||
|
def datagram_received(self, data: bytes, addr: Tuple[str, int]):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def sendto(self, data: bytes | bytearray | memoryview, addr: Tuple[str, int]):
|
||||||
|
self.transport.sendto(data, addr)
|
||||||
|
|
||||||
|
|
||||||
|
class Server(UDPProtocol):
|
||||||
|
config: Config
|
||||||
|
|
||||||
|
def __init__(self, protocol: Type[Protocol], config: Optional[Config] = None):
|
||||||
|
if config is None:
|
||||||
|
config = Config()
|
||||||
|
logging.warn("Server was initialized with default config")
|
||||||
|
super().__init__(
|
||||||
|
protocol=protocol(),
|
||||||
|
local_addr=(
|
||||||
|
config.get("sampy", "host"),
|
||||||
|
config.getint("sampy", "port"),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
def datagram_received(self, data: bytes, addr: Tuple[str, int]):
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
loop.create_task(self.protocol.on_packet(self, data, addr))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def players(self) -> list[Player]: # TODO
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
class InteractiveServer(Server):
|
||||||
|
def __init__(self, protocol: Type[Protocol], config: Optional[Config] = None):
|
||||||
|
super().__init__(protocol=protocol, config=config)
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
loop.create_task(self.run_input_loop())
|
||||||
|
|
||||||
|
async def run_input_loop(self):
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
while True:
|
||||||
|
command = await loop.run_in_executor(None, input)
|
||||||
|
|
||||||
|
if command in ("quit", "exit", "stop"):
|
||||||
|
self.stop()
|
||||||
|
loop.stop()
|
||||||
|
return
|
433
tests/test_bitstream.py
Normal file
433
tests/test_bitstream.py
Normal file
|
@ -0,0 +1,433 @@
|
||||||
|
from typing import List, Optional, Type, Union
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from sampy.raknet.bitstream import Bitstream, get_bit_length
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"value",
|
||||||
|
[
|
||||||
|
b"",
|
||||||
|
b"A",
|
||||||
|
b"AB",
|
||||||
|
b"\xFF\x00\xAA\x55",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_from_bytes(value: bytes):
|
||||||
|
bitstream = Bitstream.from_bytes(value)
|
||||||
|
assert len(bitstream) == (len(value) << 3)
|
||||||
|
assert bytes(bitstream) == value
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"value,length,expected",
|
||||||
|
[
|
||||||
|
(0b0, 0, None),
|
||||||
|
(0b1, 1, None),
|
||||||
|
(0b0, 1, None),
|
||||||
|
(0b1011, 4, None),
|
||||||
|
(0b10111011, 8, None),
|
||||||
|
(0b10111011, 7, 0b0111011), # -> 0b0111011
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_from_int(value: int, length: int, expected: int):
|
||||||
|
bitstream = Bitstream.from_int(value, length)
|
||||||
|
assert len(bitstream) == length
|
||||||
|
assert len(bytes(bitstream)) == ((length + 7) >> 3)
|
||||||
|
assert bitstream.bits == [bool(value & (1 << i)) for i in range(length - 1, -1, -1)]
|
||||||
|
if expected is not None:
|
||||||
|
expected_bitstream = Bitstream.from_int(expected, length)
|
||||||
|
assert bitstream.bits == expected_bitstream.bits
|
||||||
|
assert bytes(bitstream) == bytes(expected_bitstream)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"values",
|
||||||
|
[
|
||||||
|
[True, False, False, True, True],
|
||||||
|
[[True], [False, True], [False, False]],
|
||||||
|
[b"a", b"bc"],
|
||||||
|
[b"aa", b"b"],
|
||||||
|
[
|
||||||
|
Bitstream.from_bytes(b"A"),
|
||||||
|
Bitstream.from_int(0b1011, 4),
|
||||||
|
Bitstream.from_bytes(b"B"),
|
||||||
|
],
|
||||||
|
[True, b"A", b"B", [False, True], Bitstream.from_int(0b10110, 5), False],
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_init(values: List[Union[bool, List[bool], bytes, Bitstream]]):
|
||||||
|
bitstream = Bitstream(*values)
|
||||||
|
|
||||||
|
bitstream.seek(0)
|
||||||
|
for item in values:
|
||||||
|
bit_length = get_bit_length(item)
|
||||||
|
if type(item) is bool:
|
||||||
|
item = [item]
|
||||||
|
|
||||||
|
read_type = bool if type(item) is list else type(item)
|
||||||
|
value = bitstream.read(read_type, bit_length)
|
||||||
|
assert value == item
|
||||||
|
|
||||||
|
|
||||||
|
def test_getitem():
|
||||||
|
sample_values = [True, False, False, True, False, True, True, True, True, False]
|
||||||
|
|
||||||
|
for length in range(1, len(sample_values)):
|
||||||
|
values = sample_values[:length]
|
||||||
|
bitstream = Bitstream(values)
|
||||||
|
for i in range(length):
|
||||||
|
assert bitstream[i] == values[i]
|
||||||
|
assert bitstream[-i - 1] == values[-i - 1] # For inverse lookup
|
||||||
|
|
||||||
|
|
||||||
|
def test_getitem_index_error():
|
||||||
|
bitstream = Bitstream(True)
|
||||||
|
with pytest.raises(IndexError):
|
||||||
|
bitstream[1]
|
||||||
|
with pytest.raises(IndexError):
|
||||||
|
bitstream[-2] # Inverse IndexError
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"value",
|
||||||
|
[
|
||||||
|
True,
|
||||||
|
[True, True],
|
||||||
|
b"A",
|
||||||
|
Bitstream(True, b"C"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_setitem(value: Union[bool, List[bool], bytes, Bitstream]):
|
||||||
|
bit_length = get_bit_length(value)
|
||||||
|
read_type = bool if type(value) is list else type(value)
|
||||||
|
for length in range(10):
|
||||||
|
for i in range(length - bit_length + 1):
|
||||||
|
bitstream = Bitstream([False] * length)
|
||||||
|
bitstream[i] = value
|
||||||
|
assert bitstream.read(read_type, bit_length, i) == (
|
||||||
|
value if type(value) is not bool else [value]
|
||||||
|
)
|
||||||
|
assert len(bitstream) == length
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"value,equals,not_equals",
|
||||||
|
[
|
||||||
|
([], [[]], [[True], [False]]),
|
||||||
|
(True, [[True]], [[False], [True, False], b"A"]),
|
||||||
|
([True, False], [[True, False]], [[False], [False, False], b"A"]),
|
||||||
|
(
|
||||||
|
b"A",
|
||||||
|
[b"A", b"A", Bitstream.from_bytes(b"A"), Bitstream.from_bytes(b"A").bits],
|
||||||
|
[
|
||||||
|
[],
|
||||||
|
[False],
|
||||||
|
b"B",
|
||||||
|
b"B",
|
||||||
|
Bitstream.from_bytes(b"B"),
|
||||||
|
Bitstream.from_bytes(b"B").bits,
|
||||||
|
],
|
||||||
|
),
|
||||||
|
(b"ABC", [b"ABC"], [[], [False], b"A"]),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_eq(
|
||||||
|
value: Union[bool, List[bool], bytes, Bitstream],
|
||||||
|
equals: List[Union[bool, List[bool], bytes, Bitstream]],
|
||||||
|
not_equals: List[Union[bool, List[bool], bytes, Bitstream]],
|
||||||
|
):
|
||||||
|
bitstream = Bitstream(value)
|
||||||
|
for equal in equals:
|
||||||
|
assert bitstream == equal
|
||||||
|
for not_equal in not_equals:
|
||||||
|
assert bitstream != not_equal
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"values,number,expected",
|
||||||
|
[
|
||||||
|
([], 2, []),
|
||||||
|
([True, False], 2, [True, False, True, False]),
|
||||||
|
(b"A", 2, b"AA"),
|
||||||
|
([False], 3, [False, False, False]),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_mul(
|
||||||
|
values: Union[bool, List[bool], bytes, Bitstream],
|
||||||
|
number: int,
|
||||||
|
expected: Union[List[bool], bytes, Bitstream],
|
||||||
|
):
|
||||||
|
bitstream = Bitstream(values)
|
||||||
|
|
||||||
|
assert (bitstream * number) == expected
|
||||||
|
assert bitstream == values
|
||||||
|
|
||||||
|
bitstream *= number
|
||||||
|
assert bitstream == expected
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"init_values,add_value,expected",
|
||||||
|
[
|
||||||
|
([], True, [True]),
|
||||||
|
([True, False], False, [True, False, False]),
|
||||||
|
(b"A", b"B", b"AB"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_add(
|
||||||
|
init_values: Union[bool, List[bool], bytes, Bitstream],
|
||||||
|
add_value: Union[bool, List[bool], bytes, Bitstream],
|
||||||
|
expected: Union[List[bool], bytes, Bitstream],
|
||||||
|
):
|
||||||
|
bitstream = Bitstream(init_values)
|
||||||
|
|
||||||
|
assert (bitstream + add_value) == expected
|
||||||
|
assert len(bitstream) == get_bit_length(init_values)
|
||||||
|
|
||||||
|
bitstream += add_value
|
||||||
|
assert bitstream == expected
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"init_values",
|
||||||
|
[
|
||||||
|
[],
|
||||||
|
[True, False],
|
||||||
|
[b"A"],
|
||||||
|
[True, b"B"],
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_iter(init_values: List[Union[bool, List[bool], bytes, Bitstream]]):
|
||||||
|
bitstream = Bitstream(*init_values)
|
||||||
|
assert len(list(bitstream)) == len(bitstream)
|
||||||
|
for index, bit in enumerate(bitstream):
|
||||||
|
assert bit == bitstream[index]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"init_values,expected",
|
||||||
|
[
|
||||||
|
([], b""),
|
||||||
|
([b"A"], b"A"),
|
||||||
|
([False], b"\x00"),
|
||||||
|
([[False] * 6, True], b"\x02"),
|
||||||
|
([[False] * 7, True], b"\x01"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_bytes(
|
||||||
|
init_values: List[Union[bool, List[bool], bytes, Bitstream]], expected: bytes
|
||||||
|
):
|
||||||
|
assert bytes(Bitstream(*init_values)) == expected
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"init_values,expected",
|
||||||
|
[
|
||||||
|
([], []),
|
||||||
|
([b"\x01"], [False] * 7 + [True]),
|
||||||
|
([False], [False]),
|
||||||
|
([False] * 6 + [True], [False] * 6 + [True]),
|
||||||
|
([False] * 7 + [True], [False] * 7 + [True]),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_bits(
|
||||||
|
init_values: List[Union[bool, List[bool], bytes, Bitstream]], expected: List[bool]
|
||||||
|
):
|
||||||
|
assert Bitstream(*init_values).bits == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_copy():
|
||||||
|
bitstream = Bitstream(True, False)
|
||||||
|
bitstream_copy = bitstream.copy()
|
||||||
|
assert bitstream == bitstream_copy
|
||||||
|
assert id(bitstream) != id(bitstream.copy)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"lst",
|
||||||
|
[
|
||||||
|
[False] * 10,
|
||||||
|
[True] * 10,
|
||||||
|
[True, False] * 5,
|
||||||
|
[False, True] * 5,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_append(lst: List[bool]):
|
||||||
|
bitstream = Bitstream()
|
||||||
|
for i in range(len(lst)):
|
||||||
|
bitstream.append(lst[i])
|
||||||
|
assert bitstream[i] == lst[i]
|
||||||
|
assert len(bitstream) == (i + 1)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"extend_value",
|
||||||
|
[
|
||||||
|
[True, False, True],
|
||||||
|
b"A",
|
||||||
|
b"BB",
|
||||||
|
Bitstream(True, b"C"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_extend(extend_value: Union[List[bool], bytes, Bitstream]):
|
||||||
|
for i in range(10):
|
||||||
|
bitstream = Bitstream([False] * i)
|
||||||
|
bitstream.extend(extend_value)
|
||||||
|
value = bitstream.read(
|
||||||
|
bool if type(extend_value) is list else type(extend_value),
|
||||||
|
get_bit_length(extend_value),
|
||||||
|
i,
|
||||||
|
)
|
||||||
|
assert value == extend_value
|
||||||
|
|
||||||
|
|
||||||
|
def test_extend_error():
|
||||||
|
bitstream = Bitstream()
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
bitstream.extend(True)
|
||||||
|
|
||||||
|
|
||||||
|
def test_clear():
|
||||||
|
bitstream = Bitstream(b"A", [True, False, False, True])
|
||||||
|
bitstream._offset = 3
|
||||||
|
bitstream.clear()
|
||||||
|
assert len(bitstream) == 0
|
||||||
|
assert len(bytes(bitstream)) == 0
|
||||||
|
assert bitstream._offset == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"lst",
|
||||||
|
[
|
||||||
|
[False] * 10,
|
||||||
|
[True] * 10,
|
||||||
|
[True, False] * 5,
|
||||||
|
[False, True] * 5,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_pop(lst: List[bool]):
|
||||||
|
bitstream = Bitstream(lst)
|
||||||
|
for i in range(len(bitstream)):
|
||||||
|
assert bitstream.pop() == lst[-i - 1]
|
||||||
|
|
||||||
|
with pytest.raises(IndexError):
|
||||||
|
assert bitstream.pop()
|
||||||
|
|
||||||
|
|
||||||
|
def test_read():
|
||||||
|
bitstream = Bitstream(True, b"A", False, True, False)
|
||||||
|
|
||||||
|
assert bitstream.read(bool, 1) == [True]
|
||||||
|
assert bitstream.read(bytes, 8) == b"A"
|
||||||
|
assert bitstream.read(int, 3) == 2
|
||||||
|
|
||||||
|
bitstream.seek(0)
|
||||||
|
assert bitstream.read(Bitstream, len(bitstream) - 1) == Bitstream(
|
||||||
|
True, b"A", False, True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"bitstream,read_type,bit_length,bit_index,expected",
|
||||||
|
[
|
||||||
|
(Bitstream(True, b"A", False, True, False), bool, 1, 0, [True]),
|
||||||
|
(
|
||||||
|
Bitstream(True, b"A", False, True, False),
|
||||||
|
bool,
|
||||||
|
3,
|
||||||
|
8 + 1,
|
||||||
|
[False, True, False],
|
||||||
|
),
|
||||||
|
(Bitstream(True, b"A", False, True, False), bytes, 8, 1, b"A"),
|
||||||
|
(
|
||||||
|
Bitstream(True, b"A", False, True, False),
|
||||||
|
Bitstream,
|
||||||
|
1 + 8 + 3,
|
||||||
|
0,
|
||||||
|
Bitstream(True, b"A", False, True, False),
|
||||||
|
),
|
||||||
|
(Bitstream(True, b"A", False, True, False), int, 3, 1 + 8, 2),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_read_index(
|
||||||
|
bitstream: Bitstream,
|
||||||
|
read_type: Type[Union[bool, bytes, Bitstream, int]],
|
||||||
|
bit_length: int,
|
||||||
|
bit_index: Optional[int],
|
||||||
|
expected: Union[List[bool], bytes, Bitstream, int],
|
||||||
|
):
|
||||||
|
value = bitstream.read(read_type, bit_length, bit_index)
|
||||||
|
|
||||||
|
if read_type is bool:
|
||||||
|
assert type(value) == list
|
||||||
|
else:
|
||||||
|
assert type(value) == read_type
|
||||||
|
assert value == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_read_error():
|
||||||
|
bitstream = Bitstream(True, b"A", False, True, False)
|
||||||
|
|
||||||
|
with pytest.raises(IndexError):
|
||||||
|
bitstream.read(bool, 1, len(bitstream))
|
||||||
|
|
||||||
|
with pytest.raises(IndexError):
|
||||||
|
bitstream.read(bool, 1, -1)
|
||||||
|
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
bitstream.read(list, 1, 0)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"bitstream,value,bit_index,current_offset",
|
||||||
|
[
|
||||||
|
(Bitstream(), True, None, None),
|
||||||
|
(Bitstream(), b"A", None, None),
|
||||||
|
(Bitstream(), True, 0, None),
|
||||||
|
(Bitstream(), True, 3, None),
|
||||||
|
(Bitstream(), True, 7, None),
|
||||||
|
(Bitstream(), True, 8, None),
|
||||||
|
(Bitstream(False), True, None, None),
|
||||||
|
(Bitstream(False), True, 0, None),
|
||||||
|
(Bitstream(False), True, 0, 0),
|
||||||
|
(Bitstream(False), True, None, 0),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_write(
|
||||||
|
bitstream: Bitstream,
|
||||||
|
value: Union[bool, List[bool], bytes, Bitstream],
|
||||||
|
bit_index: Optional[int],
|
||||||
|
current_offset: Optional[int],
|
||||||
|
):
|
||||||
|
if current_offset is not None:
|
||||||
|
bitstream.seek(current_offset)
|
||||||
|
|
||||||
|
bit_length = get_bit_length(value)
|
||||||
|
bitstream.write(value, bit_index)
|
||||||
|
|
||||||
|
if bit_index is None:
|
||||||
|
bitstream.seek(-bit_length, 1)
|
||||||
|
read_value = bitstream.read(
|
||||||
|
bool if type(value) is list else type(value), bit_length
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
read_value = bitstream.read(
|
||||||
|
bool if type(value) is list else type(value), bit_length, bit_index
|
||||||
|
)
|
||||||
|
|
||||||
|
assert read_value == ([value] if type(value) is bool else value)
|
||||||
|
|
||||||
|
|
||||||
|
def test_seek():
|
||||||
|
bitstream = Bitstream([True] * 8)
|
||||||
|
|
||||||
|
bitstream.seek(4, 0)
|
||||||
|
assert bitstream._offset == 4
|
||||||
|
|
||||||
|
bitstream.seek(-1, 1)
|
||||||
|
assert bitstream._offset == 3
|
||||||
|
|
||||||
|
bitstream.seek(-1, 2)
|
||||||
|
assert bitstream._offset == 7
|
Loading…
Reference in New Issue
Block a user