diff --git a/execution.py b/execution.py index 43cab207..30eeb630 100644 --- a/execution.py +++ b/execution.py @@ -55,6 +55,7 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data={}): input_data_all = get_input_data(inputs, class_def, outputs, prompt, extra_data) if server.client_id is not None: + server.last_node_id = unique_id server.send_sync("executing", { "node": unique_id }, server.client_id) obj = class_def() @@ -188,6 +189,7 @@ class PromptExecutor: for x in executed: self.old_prompt[x] = copy.deepcopy(prompt[x]) finally: + self.server.last_node_id = None if self.server.client_id is not None: self.server.send_sync("executing", { "node": None }, self.server.client_id) diff --git a/server.py b/server.py index 94ac12b5..84b0941f 100644 --- a/server.py +++ b/server.py @@ -32,21 +32,34 @@ class PromptServer(): self.web_root = os.path.join(os.path.dirname( os.path.realpath(__file__)), "web") routes = web.RouteTableDef() + self.last_node_id = None + self.client_id = None @routes.get('/ws') async def websocket_handler(request): ws = web.WebSocketResponse() await ws.prepare(request) - sid = uuid.uuid4().hex + sid = request.rel_url.query.get('clientId', '') + if sid: + # Reusing existing session, remove old + self.sockets.pop(sid, None) + else: + sid = uuid.uuid4().hex + self.sockets[sid] = ws + try: # Send initial state to the new client await self.send("status", { "status": self.get_queue_info(), 'sid': sid }, sid) + # On reconnect if we are the currently executing client send the current node + if self.client_id == sid and self.last_node_id is not None: + await self.send("executing", { "node": self.last_node_id }, sid) + async for msg in ws: if msg.type == aiohttp.WSMsgType.ERROR: print('ws connection closed with exception %s' % ws.exception()) finally: - self.sockets.pop(sid) + self.sockets.pop(sid, None) return ws @routes.get("/") diff --git a/web/scripts/api.js b/web/scripts/api.js index 61b786f6..39f48d4a 100644 --- a/web/scripts/api.js +++ b/web/scripts/api.js @@ -28,7 +28,13 @@ class ComfyApi extends EventTarget { } let opened = false; - this.socket = new WebSocket(`ws${window.location.protocol === "https:" ? "s" : ""}://${location.host}/ws`); + let existingSession = sessionStorage["Comfy.SessionId"] || ""; + if (existingSession) { + existingSession = "?clientId=" + existingSession; + } + this.socket = new WebSocket( + `ws${window.location.protocol === "https:" ? "s" : ""}://${location.host}/ws${existingSession}` + ); this.socket.addEventListener("open", () => { opened = true; @@ -62,6 +68,7 @@ class ComfyApi extends EventTarget { case "status": if (msg.data.sid) { this.clientId = msg.data.sid; + sessionStorage["Comfy.SessionId"] = this.clientId; } this.dispatchEvent(new CustomEvent("status", { detail: msg.data.status })); break;