Browse Source

Add more features to the backend queue code.

The queue can now be queried, entries can be deleted and prompts easily
queued to the front of the queue.

Just need to expose it in the UI next.
pull/3/head
comfyanonymous 2 years ago
parent
commit
4b08314257
  1. 91
      main.py

91
main.py

@ -3,7 +3,7 @@ import sys
import copy
import json
import threading
import queue
import heapq
import traceback
if '--dont-upcast-attention' in sys.argv:
@ -148,6 +148,7 @@ class PromptExecutor:
to_execute += [(0, x)]
while len(to_execute) > 0:
#always execute the output that depends on the least amount of unexecuted nodes first
to_execute = sorted(list(map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1])), a[-1]), to_execute)))
x = to_execute.pop(0)[-1]
@ -266,10 +267,63 @@ def validate_prompt(prompt):
def prompt_worker(q):
e = PromptExecutor()
while True:
item = q.get()
item, item_id = q.get()
e.execute(item[-2], item[-1])
q.task_done()
q.task_done(item_id)
class PromptQueue:
def __init__(self):
self.mutex = threading.RLock()
self.not_empty = threading.Condition(self.mutex)
self.task_counter = 0
self.queue = []
self.currently_running = {}
def put(self, item):
with self.mutex:
heapq.heappush(self.queue, item)
self.not_empty.notify()
def get(self):
with self.not_empty:
while len(self.queue) == 0:
self.not_empty.wait()
item = heapq.heappop(self.queue)
i = self.task_counter
self.currently_running[i] = copy.deepcopy(item)
self.task_counter += 1
return (item, i)
def task_done(self, item_id):
with self.mutex:
self.currently_running.pop(item_id)
def get_current_queue(self):
with self.mutex:
out = []
for x in self.currently_running.values():
out += [x]
return (out, copy.deepcopy(self.queue))
def get_tasks_remaining(self):
with self.mutex:
return len(self.queue) + len(self.currently_running)
def wipe_queue(self):
with self.mutex:
self.queue = []
def delete_queue_item(self, function):
with self.mutex:
for x in range(len(self.queue)):
if function(self.queue[x]):
if len(self.queue) == 1:
self.wipe_queue()
else:
self.queue.pop(x)
heapq.heapify(self.queue)
return True
return False
from http.server import BaseHTTPRequestHandler, HTTPServer
@ -285,9 +339,16 @@ class PromptServer(BaseHTTPRequestHandler):
self._set_headers(ct='application/json')
prompt_info = {}
exec_info = {}
exec_info['queue_remaining'] = self.server.prompt_queue.unfinished_tasks
exec_info['queue_remaining'] = self.server.prompt_queue.get_tasks_remaining()
prompt_info['exec_info'] = exec_info
self.wfile.write(json.dumps(prompt_info).encode('utf-8'))
elif self.path == "/queue":
self._set_headers(ct='application/json')
queue_info = {}
current_queue = self.server.prompt_queue.get_current_queue()
queue_info['queue_running'] = current_queue[0]
queue_info['queue_pending'] = current_queue[1]
self.wfile.write(json.dumps(queue_info).encode('utf-8'))
elif self.path == "/object_info":
self._set_headers(ct='application/json')
out = {}
@ -325,12 +386,16 @@ class PromptServer(BaseHTTPRequestHandler):
out_string = ""
if self.path == "/prompt":
print("got prompt")
self.data_string = self.rfile.read(int(self.headers['Content-Length']))
json_data = json.loads(self.data_string)
data_string = self.rfile.read(int(self.headers['Content-Length']))
json_data = json.loads(data_string)
if "number" in json_data:
number = float(json_data['number'])
else:
number = self.server.number
if "front" in json_data:
if json_data['front']:
number = -number
self.server.number += 1
if "prompt" in json_data:
prompt = json_data["prompt"]
@ -344,6 +409,18 @@ class PromptServer(BaseHTTPRequestHandler):
resp_code = 400
out_string = valid[1]
print("invalid prompt:", valid[1])
elif self.path == "/queue":
data_string = self.rfile.read(int(self.headers['Content-Length']))
json_data = json.loads(data_string)
if "clear" in json_data:
if json_data["clear"]:
self.server.prompt_queue.wipe_queue()
if "delete" in json_data:
to_delete = json_data['delete']
for id_to_delete in to_delete:
delete_func = lambda a: a[1] == int(id_to_delete)
self.server.prompt_queue.delete_queue_item(delete_func)
self._set_headers(code=resp_code)
self.end_headers()
self.wfile.write(out_string.encode('utf8'))
@ -366,7 +443,7 @@ def run(prompt_queue, address='', port=8188):
if __name__ == "__main__":
q = queue.PriorityQueue()
q = PromptQueue()
threading.Thread(target=prompt_worker, daemon=True, args=(q,)).start()
run(q, address='127.0.0.1', port=8188)

Loading…
Cancel
Save