|
|
|
@ -3,7 +3,7 @@ import sys
|
|
|
|
|
import copy |
|
|
|
|
import json |
|
|
|
|
import threading |
|
|
|
|
import queue |
|
|
|
|
import heapq |
|
|
|
|
import traceback |
|
|
|
|
|
|
|
|
|
if '--dont-upcast-attention' in sys.argv: |
|
|
|
@ -148,6 +148,7 @@ class PromptExecutor:
|
|
|
|
|
to_execute += [(0, x)] |
|
|
|
|
|
|
|
|
|
while len(to_execute) > 0: |
|
|
|
|
#always execute the output that depends on the least amount of unexecuted nodes first |
|
|
|
|
to_execute = sorted(list(map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1])), a[-1]), to_execute))) |
|
|
|
|
x = to_execute.pop(0)[-1] |
|
|
|
|
|
|
|
|
@ -266,10 +267,63 @@ def validate_prompt(prompt):
|
|
|
|
|
def prompt_worker(q): |
|
|
|
|
e = PromptExecutor() |
|
|
|
|
while True: |
|
|
|
|
item = q.get() |
|
|
|
|
item, item_id = q.get() |
|
|
|
|
e.execute(item[-2], item[-1]) |
|
|
|
|
q.task_done() |
|
|
|
|
q.task_done(item_id) |
|
|
|
|
|
|
|
|
|
class PromptQueue: |
|
|
|
|
def __init__(self): |
|
|
|
|
self.mutex = threading.RLock() |
|
|
|
|
self.not_empty = threading.Condition(self.mutex) |
|
|
|
|
self.task_counter = 0 |
|
|
|
|
self.queue = [] |
|
|
|
|
self.currently_running = {} |
|
|
|
|
|
|
|
|
|
def put(self, item): |
|
|
|
|
with self.mutex: |
|
|
|
|
heapq.heappush(self.queue, item) |
|
|
|
|
self.not_empty.notify() |
|
|
|
|
|
|
|
|
|
def get(self): |
|
|
|
|
with self.not_empty: |
|
|
|
|
while len(self.queue) == 0: |
|
|
|
|
self.not_empty.wait() |
|
|
|
|
item = heapq.heappop(self.queue) |
|
|
|
|
i = self.task_counter |
|
|
|
|
self.currently_running[i] = copy.deepcopy(item) |
|
|
|
|
self.task_counter += 1 |
|
|
|
|
return (item, i) |
|
|
|
|
|
|
|
|
|
def task_done(self, item_id): |
|
|
|
|
with self.mutex: |
|
|
|
|
self.currently_running.pop(item_id) |
|
|
|
|
|
|
|
|
|
def get_current_queue(self): |
|
|
|
|
with self.mutex: |
|
|
|
|
out = [] |
|
|
|
|
for x in self.currently_running.values(): |
|
|
|
|
out += [x] |
|
|
|
|
return (out, copy.deepcopy(self.queue)) |
|
|
|
|
|
|
|
|
|
def get_tasks_remaining(self): |
|
|
|
|
with self.mutex: |
|
|
|
|
return len(self.queue) + len(self.currently_running) |
|
|
|
|
|
|
|
|
|
def wipe_queue(self): |
|
|
|
|
with self.mutex: |
|
|
|
|
self.queue = [] |
|
|
|
|
|
|
|
|
|
def delete_queue_item(self, function): |
|
|
|
|
with self.mutex: |
|
|
|
|
for x in range(len(self.queue)): |
|
|
|
|
if function(self.queue[x]): |
|
|
|
|
if len(self.queue) == 1: |
|
|
|
|
self.wipe_queue() |
|
|
|
|
else: |
|
|
|
|
self.queue.pop(x) |
|
|
|
|
heapq.heapify(self.queue) |
|
|
|
|
return True |
|
|
|
|
return False |
|
|
|
|
|
|
|
|
|
from http.server import BaseHTTPRequestHandler, HTTPServer |
|
|
|
|
|
|
|
|
@ -285,9 +339,16 @@ class PromptServer(BaseHTTPRequestHandler):
|
|
|
|
|
self._set_headers(ct='application/json') |
|
|
|
|
prompt_info = {} |
|
|
|
|
exec_info = {} |
|
|
|
|
exec_info['queue_remaining'] = self.server.prompt_queue.unfinished_tasks |
|
|
|
|
exec_info['queue_remaining'] = self.server.prompt_queue.get_tasks_remaining() |
|
|
|
|
prompt_info['exec_info'] = exec_info |
|
|
|
|
self.wfile.write(json.dumps(prompt_info).encode('utf-8')) |
|
|
|
|
elif self.path == "/queue": |
|
|
|
|
self._set_headers(ct='application/json') |
|
|
|
|
queue_info = {} |
|
|
|
|
current_queue = self.server.prompt_queue.get_current_queue() |
|
|
|
|
queue_info['queue_running'] = current_queue[0] |
|
|
|
|
queue_info['queue_pending'] = current_queue[1] |
|
|
|
|
self.wfile.write(json.dumps(queue_info).encode('utf-8')) |
|
|
|
|
elif self.path == "/object_info": |
|
|
|
|
self._set_headers(ct='application/json') |
|
|
|
|
out = {} |
|
|
|
@ -325,12 +386,16 @@ class PromptServer(BaseHTTPRequestHandler):
|
|
|
|
|
out_string = "" |
|
|
|
|
if self.path == "/prompt": |
|
|
|
|
print("got prompt") |
|
|
|
|
self.data_string = self.rfile.read(int(self.headers['Content-Length'])) |
|
|
|
|
json_data = json.loads(self.data_string) |
|
|
|
|
data_string = self.rfile.read(int(self.headers['Content-Length'])) |
|
|
|
|
json_data = json.loads(data_string) |
|
|
|
|
if "number" in json_data: |
|
|
|
|
number = float(json_data['number']) |
|
|
|
|
else: |
|
|
|
|
number = self.server.number |
|
|
|
|
if "front" in json_data: |
|
|
|
|
if json_data['front']: |
|
|
|
|
number = -number |
|
|
|
|
|
|
|
|
|
self.server.number += 1 |
|
|
|
|
if "prompt" in json_data: |
|
|
|
|
prompt = json_data["prompt"] |
|
|
|
@ -344,6 +409,18 @@ class PromptServer(BaseHTTPRequestHandler):
|
|
|
|
|
resp_code = 400 |
|
|
|
|
out_string = valid[1] |
|
|
|
|
print("invalid prompt:", valid[1]) |
|
|
|
|
elif self.path == "/queue": |
|
|
|
|
data_string = self.rfile.read(int(self.headers['Content-Length'])) |
|
|
|
|
json_data = json.loads(data_string) |
|
|
|
|
if "clear" in json_data: |
|
|
|
|
if json_data["clear"]: |
|
|
|
|
self.server.prompt_queue.wipe_queue() |
|
|
|
|
if "delete" in json_data: |
|
|
|
|
to_delete = json_data['delete'] |
|
|
|
|
for id_to_delete in to_delete: |
|
|
|
|
delete_func = lambda a: a[1] == int(id_to_delete) |
|
|
|
|
self.server.prompt_queue.delete_queue_item(delete_func) |
|
|
|
|
|
|
|
|
|
self._set_headers(code=resp_code) |
|
|
|
|
self.end_headers() |
|
|
|
|
self.wfile.write(out_string.encode('utf8')) |
|
|
|
@ -366,7 +443,7 @@ def run(prompt_queue, address='', port=8188):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
q = queue.PriorityQueue() |
|
|
|
|
q = PromptQueue() |
|
|
|
|
threading.Thread(target=prompt_worker, daemon=True, args=(q,)).start() |
|
|
|
|
run(q, address='127.0.0.1', port=8188) |
|
|
|
|
|
|
|
|
|