fly-socket/handlers/mainHandler.py
2019-08-01 01:09:26 +02:00

273 lines
9.5 KiB
Python

import asyncio
import logging
import struct
from time import time
from objects import glob
flights = {}
clients = set()
STRUCTS = {
"start": b"fHB%ds",
"end": b"",
"flight_data": b"2BHhH3f",
"server_open": b"",
"server_close": b""
}
class Client:
def __init__(self, ws):
self.ws = ws
self.subscriptions = set()
self.get_flight_changes = False
async def broadcast_flight_change(uuid, active):
logging.info("[%d] Broadcasting flight running: %r" % (uuid, active))
for cli in [cli for cli in clients if cli.get_flight_changes or uuid in cli.subscriptions]:
await cli.ws.send(struct.pack(b"<BI?",
4, # packet_id
uuid,
active
))
async def broadcast_flight_data(uuid, data):
logging.info("[%d] Broadcasting flight data on %d bytes" % (uuid, len(data)))
for cli in [cli for cli in clients if uuid in cli.subscriptions]:
await cli.ws.send(struct.pack(b"<BI%ds" % len(data),
5, # packet_id
uuid,
data
))
async def handle(ws, path):
if len(path):
path = path[1:]
# Register new client
cli = Client(ws)
clients.add(cli)
logging.info("<ws://%s:%d/%s> Connected." % (*ws.remote_address, path))
async def send_error(msg):
logging.debug("<ws://%s:%d/%s> -> ERROR: %s" % (*ws.remote_address, path, msg.decode()))
await cli.ws.send(struct.pack(b"<BH%ds" % len(msg),
0xff, # packet_id: Error
len(msg),
msg
))
async def send_all_flights(_):
logging.info("<ws://%s:%d/%s> -> All active flights" % (*ws.remote_address, path))
await cli.ws.send(struct.pack(b"<2B%dI" % len(flights),
0, # packet_id
len(flights),
*flights
))
async def subscribe_flight_changes(data):
if not len(data):
return await send_error(b"Invalid data")
(cli.get_flight_changes,) = struct.unpack(b"<?", data[0:1])
logging.info("<ws://%s:%d/%s> -> %subscribed to all flight changes" % (*ws.remote_address, path, cli.get_flight_changes and "S" or "Uns"))
await cli.ws.send(struct.pack(b"<B?",
1, # packet_id
cli.get_flight_changes
))
async def subscribe_flight(data):
if len(data) < 2:
return await send_error(b"Invalid data")
(uuid, subscribe) = struct.unpack(b"<I?", data[0:5])
if subscribe:
if uuid not in flights.keys():
await send_error(b"%d not found" % uuid)
else:
cli.subscriptions.add(uuid)
elif uuid in cli.subscriptions:
cli.subscriptions.remove(uuid)
logging.info("<ws://%s:%d/%s> -> %subscribed to %d" % (*ws.remote_address, path, uuid in cli.subscriptions and "S" or "Uns", uuid))
await cli.ws.send(struct.pack(b"<BI?",
2, # packet_id
uuid,
uuid in cli.subscriptions
))
async def fetch_data(data):
if not len(data):
return await send_error(b"Invalid data")
(uuid,) = struct.unpack(b"<I", data[0:4])
if uuid not in flights.keys():
await send_error(b"Flight not found")
else:
data = flights[uuid].get_all()
await cli.ws.send(struct.pack(b"<BI%ds" % len(data),
3, # packet_id
len(data),
data
))
async def default(data):
await send_error(b"Unimplemented function")
logging.error("Invalid packet received: %s" % data)
async def query(data):
if not len(data):
return await send_error(b"Invalid data")
((packet_id,), data) = (struct.unpack(b"<B", data[0:1]), data[1:])
logging.debug("Received packetID: %d" % packet_id)
switch = {
0: send_all_flights, # fetch all active flights
1: subscribe_flight_changes, # next byte is a boolean to subscribe to flight changes (start & end)
2: subscribe_flight, # next uint is uuid to subscribe to (toggle)
3: fetch_data # fetch all stored timeline data
}
await switch.get(packet_id, default)(data)
try:
if len(path):
await subscribe_flight(struct.pack(b"<I?", int(path), True))
while not cli.ws.closed:
async for data in cli.ws:
data = data.encode() # For some reason or another... websockets module returns a string instead of bytes?
await query(data)
await asyncio.sleep(.1)
finally:
# Unregister current client
clients.remove(cli)
logging.info("<ws://%s:%d/%s> Disconnected." % (*ws.remote_address, path))
class Flight:
def __init__(self, uuid, data):
self.uuid = uuid
(
self.max_fuel,
self.model_id,
self.playername_len,
self.playername
) = struct.unpack(b"<" + STRUCTS["start"] % 24, data[:31])
self.active = True
self.last_timeline_values = [None] * 8
self.timeline = b""
# Fix playername padding
self.playername = self.playername[:self.playername_len]
logging.info("[%d] Flight started" % self.uuid)
asyncio.ensure_future( broadcast_flight_change(self.uuid, self.active) )
def get_head(self):
return struct.pack(b"<I" + STRUCTS["start"],
self.uuid,
self.max_fuel,
self.model_id,
self.playername_len,
self.playername)
def get_all(self):
return self.get_head() + self.timeline
def add(self, data):
data = [*struct.unpack(b"<" + STRUCTS["flight_data"], data[:20])]
for i in range(len(data)):
if data[i] == self.last_timeline_values[i]:
data[i] = None
else:
self.last_timeline_values[i] = data[i]
frame = self.format_flight_data(data)
if not frame:
logging.debug("[%d] Empty frame (skipping)" % self.uuid) # Should not be possible?
return
logging.debug("[%d] New frame: %s" % (self.uuid, frame))
asyncio.ensure_future( broadcast_flight_data(self.uuid, frame) )
self.timeline += frame
@staticmethod
def format_flight_data(data): # Compresses data by setting a flag of what data has changed
ret = b""
flag = 0
for i in range(len(data)):
if data[i] != None:
flag |= 1<<i
ret += struct.pack(b"<" + bytes([STRUCTS["flight_data"][i]]), data[i])
if flag == 0: # Empty frame; Nothing changed so we dont even add a timestamp
return None
return struct.pack(b"<IB", int(time()), flag) + ret
def end(self):
self.active = False
logging.info("[%d] Flight ended" % self.uuid)
asyncio.ensure_future( broadcast_flight_change(self.uuid, self.active) )
with open("%s/%s.flt" % (glob.config["save_path"], self.uuid), "wb") as f:
f.write(self.get_all())
class DiscoveryProtocol(asyncio.DatagramProtocol):
def __init__(self):
super().__init__()
def connection_made(self, transport):
self.transport = transport
def datagram_received(self, data, addr):
logging.info("Received data from %s:%d" % addr)
if len(data) < 4: # Not even a header... smh
logging.debug("Noise received: %b" % data)
return
(magic, packet_id) = struct.unpack(b"<3sB", data[:4])
if magic != b"FLY" or packet_id > 5:
logging.error("Invalid packet header: %b" % data[:4])
return
switch = {
0: self.unimplemented, # ping?
1: self.handle_start,
2: self.handle_flight_data,
3: self.handle_end,
4: self.handle_server_open,
5: self.handle_server_close
}
switch.get(packet_id, self.unimplemented)(data[4:])
@staticmethod
def unimplemented(data):
logging.warning("Unimplemented: %b" % data)
@staticmethod
def handle_start(data):
(uuid, data) = (*struct.unpack(b"<I", data[:4]), data[4:])
flights[uuid] = flight = Flight(uuid, data)
logging.debug("Added %d into flights dict" % flight.uuid)
@staticmethod
def handle_end(data):
(uuid, data) = (*struct.unpack(b"<I", data[:4]), data[4:])
flights[uuid].end()
@staticmethod
def handle_flight_data(data):
(uuid, data) = (*struct.unpack(b"<I", data[:4]), data[4:])
flights[uuid].add(data)
@staticmethod
def handle_server_open(_):
pass
@staticmethod
def handle_server_close(_):
pass