fix docker networking, and use app internal state instead of globals

This commit is contained in:
2024-12-02 11:01:59 +01:00
parent a0adb644df
commit 3bc9d80cdf
2 changed files with 54 additions and 30 deletions

View File

@@ -4,7 +4,8 @@ import socket
import json import json
from aiohttp import web from aiohttp import web
import logging import logging
from dataclasses import dataclass, make_dataclass import asyncio
from dataclasses import dataclass
PSN_DEFAULT_UDP_PORT = 56565 PSN_DEFAULT_UDP_PORT = 56565
PSN_DEFAULT_UDP_MCAST_ADDRESS = "236.10.10.10" PSN_DEFAULT_UDP_MCAST_ADDRESS = "236.10.10.10"
@@ -12,6 +13,7 @@ PORT = 8000
IP = "0.0.0.0" IP = "0.0.0.0"
NUM_TRACKERS = 3 NUM_TRACKERS = 3
# Internal state is a list of TrackerData objects
@dataclass @dataclass
class TrackerData: class TrackerData:
id: int id: int
@@ -19,22 +21,21 @@ class TrackerData:
y: float y: float
trackers = {} def to_tracker(self):
ws_clients = set() tracker = psn.Tracker(self.id, f"Tracker {self.id}")
x, y = pic_to_scene_coords(self.x, self.y)
tracker.set_pos(psn.Float3(x, y, 0))
return tracker
for i in range(NUM_TRACKERS):
trackers[i] = psn.Tracker(i, f"Tracker {i}")
trackers[i].set_pos(psn.Float3(0, 0, 0))
def update_tracker(tracker_data_json: str): def update_tracker(tracker_data_json: str, app: web.Application):
global trackers
tracker = TrackerData(**json.loads(tracker_data_json)) tracker = TrackerData(**json.loads(tracker_data_json))
trackers[tracker.id].set_pos(psn.Float3(tracker.x, tracker.y, 0)) app["trackers"][tracker.id] = tracker
def trackers_to_json():
return json.dumps([{"id": tracker.get_id(), "x": tracker.get_pos().x, "y": tracker.get_pos().y} for tracker in trackers.values()])
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) def trackers_to_json(app: web.Application):
return json.dumps([tracker.__dict__ for tracker in app["trackers"].values()])
def get_time_ms(): def get_time_ms():
return int(time.time() * 1000) return int(time.time() * 1000)
@@ -50,16 +51,10 @@ def get_elapsed_time_ms():
def pic_to_scene_coords(x, y): def pic_to_scene_coords(x, y):
return x / 200, y / 200 return x / 200, y / 200
async def update_all_clients():
for ws in ws_clients:
await ws.send_str(trackers_to_json())
def send_psn_positions(): async def update_all_clients(app: web.Application):
encoder = psn.Encoder("Server 1") for ws in app["ws_clients"]:
await ws.send_str(trackers_to_json(app))
packets = encoder.encode_data(trackers, get_elapsed_time_ms())
for packet in packets:
sock.sendto(packet, (PSN_DEFAULT_UDP_MCAST_ADDRESS, PSN_DEFAULT_UDP_PORT))
async def handle_websocket(request): async def handle_websocket(request):
@@ -68,17 +63,16 @@ async def handle_websocket(request):
await ws.prepare(request) await ws.prepare(request)
logging.debug("Websocket connection ready") logging.debug("Websocket connection ready")
ws_clients.add(ws) request.app["ws_clients"].add(ws)
await ws.send_str(trackers_to_json()) await ws.send_str(trackers_to_json(request.app))
try: try:
async for msg in ws: async for msg in ws:
logging.debug(f"Websocket data: {msg}")
if msg.type == web.WSMsgType.TEXT: if msg.type == web.WSMsgType.TEXT:
# Each message is a single tracker object # Each message is a single tracker object
update_tracker(msg.data) update_tracker(msg.data, request.app)
await update_all_clients() await update_all_clients(request.app)
elif msg.type == web.WSMsgType.ERROR: elif msg.type == web.WSMsgType.ERROR:
logging.error("ws connection closed with exception %s" % ws.exception()) logging.error("ws connection closed with exception %s" % ws.exception())
@@ -90,25 +84,54 @@ async def handle_websocket(request):
logging.debug("Websocket connection closing") logging.debug("Websocket connection closing")
await ws.close() await ws.close()
ws_clients.remove(ws) request.app["ws_clients"].remove(ws)
return ws return ws
async def handle_root(request): async def handle_root(request):
return web.FileResponse("./static/index.html") return web.FileResponse("./static/index.html")
async def broadcast_psn_data(app):
encoder = psn.Encoder("Server 1")
while True:
trackers = {}
for tracker_data in app["trackers"].values():
trackers[tracker_data.id] = tracker_data.to_tracker()
packets = encoder.encode_data(trackers, get_elapsed_time_ms())
for packet in packets:
app["sock"].sendto(packet, (PSN_DEFAULT_UDP_MCAST_ADDRESS, PSN_DEFAULT_UDP_PORT))
await asyncio.sleep(0.033) # ~30fps
async def background_tasks(app: web.Application):
app["broadcast_psn_data"] = asyncio.create_task(broadcast_psn_data(app))
yield
app["broadcast_psn_data"].cancel()
await app["broadcast_psn_data"]
def create_app(): def create_app():
app = web.Application() app = web.Application()
app.router.add_get("/", handle_root) app.router.add_get("/", handle_root)
app.router.add_get("/ws", handle_websocket) app.router.add_get("/ws", handle_websocket)
app.router.add_static("/", "./static") app.router.add_static("/", "./static")
# Setup app state
app["ws_clients"] = set()
app["trackers"] = {}
app["sock"] = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
for i in range(NUM_TRACKERS):
app["trackers"][i] = TrackerData(i, 0, 0)
app.cleanup_ctx.append(background_tasks)
return app return app
if __name__ == "__main__": if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG) logging.basicConfig(level=logging.DEBUG)
web.run_app(create_app(), host=IP, port=PORT) app = create_app()
web.run_app(app, host=IP, port=PORT)

View File

@@ -3,3 +3,4 @@ services:
build: . build: .
ports: ports:
- "8000:8000" - "8000:8000"
network_mode: host