From 3bc9d80cdf298675fa736b3edfcd157a5b863c46 Mon Sep 17 00:00:00 2001 From: Tobias Haugeland Date: Mon, 2 Dec 2024 11:01:59 +0100 Subject: [PATCH] fix docker networking, and use app internal state instead of globals --- backend/psn_server.py | 83 +++++++++++++++++++++++++++---------------- compose.yaml | 1 + 2 files changed, 54 insertions(+), 30 deletions(-) diff --git a/backend/psn_server.py b/backend/psn_server.py index 3e91006..996822a 100644 --- a/backend/psn_server.py +++ b/backend/psn_server.py @@ -4,7 +4,8 @@ import socket import json from aiohttp import web import logging -from dataclasses import dataclass, make_dataclass +import asyncio +from dataclasses import dataclass PSN_DEFAULT_UDP_PORT = 56565 PSN_DEFAULT_UDP_MCAST_ADDRESS = "236.10.10.10" @@ -12,6 +13,7 @@ PORT = 8000 IP = "0.0.0.0" NUM_TRACKERS = 3 +# Internal state is a list of TrackerData objects @dataclass class TrackerData: id: int @@ -19,22 +21,21 @@ class TrackerData: y: float -trackers = {} -ws_clients = set() + def to_tracker(self): + tracker = psn.Tracker(self.id, f"Tracker {self.id}") + x, y = pic_to_scene_coords(self.x, self.y) + tracker.set_pos(psn.Float3(x, y, 0)) + return tracker -for i in range(NUM_TRACKERS): - trackers[i] = psn.Tracker(i, f"Tracker {i}") - trackers[i].set_pos(psn.Float3(0, 0, 0)) -def update_tracker(tracker_data_json: str): - global trackers +def update_tracker(tracker_data_json: str, app: web.Application): tracker = TrackerData(**json.loads(tracker_data_json)) - trackers[tracker.id].set_pos(psn.Float3(tracker.x, tracker.y, 0)) + app["trackers"][tracker.id] = tracker -def trackers_to_json(): - return json.dumps([{"id": tracker.get_id(), "x": tracker.get_pos().x, "y": tracker.get_pos().y} for tracker in trackers.values()]) -sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) +def trackers_to_json(app: web.Application): + return json.dumps([tracker.__dict__ for tracker in app["trackers"].values()]) + def get_time_ms(): return int(time.time() * 1000) @@ -50,16 +51,10 @@ def get_elapsed_time_ms(): def pic_to_scene_coords(x, y): return x / 200, y / 200 -async def update_all_clients(): - for ws in ws_clients: - await ws.send_str(trackers_to_json()) -def send_psn_positions(): - encoder = psn.Encoder("Server 1") - - packets = encoder.encode_data(trackers, get_elapsed_time_ms()) - for packet in packets: - sock.sendto(packet, (PSN_DEFAULT_UDP_MCAST_ADDRESS, PSN_DEFAULT_UDP_PORT)) +async def update_all_clients(app: web.Application): + for ws in app["ws_clients"]: + await ws.send_str(trackers_to_json(app)) async def handle_websocket(request): @@ -68,17 +63,16 @@ async def handle_websocket(request): await ws.prepare(request) logging.debug("Websocket connection ready") - ws_clients.add(ws) + request.app["ws_clients"].add(ws) - await ws.send_str(trackers_to_json()) + await ws.send_str(trackers_to_json(request.app)) try: async for msg in ws: - logging.debug(f"Websocket data: {msg}") if msg.type == web.WSMsgType.TEXT: # Each message is a single tracker object - update_tracker(msg.data) - await update_all_clients() + update_tracker(msg.data, request.app) + await update_all_clients(request.app) elif msg.type == web.WSMsgType.ERROR: logging.error("ws connection closed with exception %s" % ws.exception()) @@ -90,25 +84,54 @@ async def handle_websocket(request): logging.debug("Websocket connection closing") await ws.close() - ws_clients.remove(ws) + request.app["ws_clients"].remove(ws) return ws - - async def handle_root(request): return web.FileResponse("./static/index.html") +async def broadcast_psn_data(app): + encoder = psn.Encoder("Server 1") + while True: + trackers = {} + for tracker_data in app["trackers"].values(): + trackers[tracker_data.id] = tracker_data.to_tracker() + packets = encoder.encode_data(trackers, get_elapsed_time_ms()) + for packet in packets: + app["sock"].sendto(packet, (PSN_DEFAULT_UDP_MCAST_ADDRESS, PSN_DEFAULT_UDP_PORT)) + await asyncio.sleep(0.033) # ~30fps + + +async def background_tasks(app: web.Application): + app["broadcast_psn_data"] = asyncio.create_task(broadcast_psn_data(app)) + yield + app["broadcast_psn_data"].cancel() + await app["broadcast_psn_data"] + + def create_app(): app = web.Application() app.router.add_get("/", handle_root) app.router.add_get("/ws", handle_websocket) app.router.add_static("/", "./static") + + # Setup app state + app["ws_clients"] = set() + app["trackers"] = {} + app["sock"] = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + + for i in range(NUM_TRACKERS): + app["trackers"][i] = TrackerData(i, 0, 0) + + app.cleanup_ctx.append(background_tasks) + return app if __name__ == "__main__": logging.basicConfig(level=logging.DEBUG) - web.run_app(create_app(), host=IP, port=PORT) + app = create_app() + web.run_app(app, host=IP, port=PORT) diff --git a/compose.yaml b/compose.yaml index 3e3537f..cee8e62 100644 --- a/compose.yaml +++ b/compose.yaml @@ -3,3 +3,4 @@ services: build: . ports: - "8000:8000" + network_mode: host