You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
276 lines
9.7 KiB
276 lines
9.7 KiB
import os |
|
import sys |
|
import asyncio |
|
import nodes |
|
import execution |
|
import uuid |
|
import json |
|
import glob |
|
try: |
|
import aiohttp |
|
from aiohttp import web |
|
except ImportError: |
|
print("Module 'aiohttp' not installed. Please install it via:") |
|
print("pip install aiohttp") |
|
print("or") |
|
print("pip install -r requirements.txt") |
|
sys.exit() |
|
|
|
import mimetypes |
|
|
|
|
|
@web.middleware |
|
async def cache_control(request: web.Request, handler): |
|
response: web.Response = await handler(request) |
|
if request.path.endswith('.js') or request.path.endswith('.css'): |
|
response.headers.setdefault('Cache-Control', 'no-cache') |
|
return response |
|
|
|
class PromptServer(): |
|
def __init__(self, loop): |
|
mimetypes.init(); |
|
mimetypes.types_map['.js'] = 'application/javascript; charset=utf-8' |
|
self.prompt_queue = None |
|
self.loop = loop |
|
self.messages = asyncio.Queue() |
|
self.number = 0 |
|
self.app = web.Application(client_max_size=20971520, middlewares=[cache_control]) |
|
self.sockets = dict() |
|
self.web_root = os.path.join(os.path.dirname( |
|
os.path.realpath(__file__)), "web") |
|
routes = web.RouteTableDef() |
|
self.last_node_id = None |
|
self.client_id = None |
|
|
|
@routes.get('/ws') |
|
async def websocket_handler(request): |
|
ws = web.WebSocketResponse() |
|
await ws.prepare(request) |
|
sid = request.rel_url.query.get('clientId', '') |
|
if sid: |
|
# Reusing existing session, remove old |
|
self.sockets.pop(sid, None) |
|
else: |
|
sid = uuid.uuid4().hex |
|
|
|
self.sockets[sid] = ws |
|
|
|
try: |
|
# Send initial state to the new client |
|
await self.send("status", { "status": self.get_queue_info(), 'sid': sid }, sid) |
|
# On reconnect if we are the currently executing client send the current node |
|
if self.client_id == sid and self.last_node_id is not None: |
|
await self.send("executing", { "node": self.last_node_id }, sid) |
|
|
|
async for msg in ws: |
|
if msg.type == aiohttp.WSMsgType.ERROR: |
|
print('ws connection closed with exception %s' % ws.exception()) |
|
finally: |
|
self.sockets.pop(sid, None) |
|
return ws |
|
|
|
@routes.get("/") |
|
async def get_root(request): |
|
return web.FileResponse(os.path.join(self.web_root, "index.html")) |
|
|
|
@routes.get("/extensions") |
|
async def get_extensions(request): |
|
files = glob.glob(os.path.join(self.web_root, 'extensions/**/*.js'), recursive=True) |
|
return web.json_response(list(map(lambda f: "/" + os.path.relpath(f, self.web_root).replace("\\", "/"), files))) |
|
|
|
@routes.post("/upload/image") |
|
async def upload_image(request): |
|
upload_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input") |
|
|
|
if not os.path.exists(upload_dir): |
|
os.makedirs(upload_dir) |
|
|
|
post = await request.post() |
|
image = post.get("image") |
|
|
|
if image and image.file: |
|
filename = image.filename |
|
if not filename: |
|
return web.Response(status=400) |
|
|
|
split = os.path.splitext(filename) |
|
i = 1 |
|
while os.path.exists(os.path.join(upload_dir, filename)): |
|
filename = f"{split[0]} ({i}){split[1]}" |
|
i += 1 |
|
|
|
filepath = os.path.join(upload_dir, filename) |
|
|
|
with open(filepath, "wb") as f: |
|
f.write(image.file.read()) |
|
|
|
return web.json_response({"name" : filename}) |
|
else: |
|
return web.Response(status=400) |
|
|
|
|
|
@routes.get("/view/{file}") |
|
async def view_image(request): |
|
if "file" in request.match_info: |
|
type = request.rel_url.query.get("type", "output") |
|
if type != "output" and type != "input": |
|
return web.Response(status=400) |
|
|
|
output_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), type) |
|
file = request.match_info["file"] |
|
file = os.path.basename(file) |
|
file = os.path.join(output_dir, file) |
|
if os.path.isfile(file): |
|
return web.FileResponse(file) |
|
|
|
return web.Response(status=404) |
|
|
|
@routes.get("/prompt") |
|
async def get_prompt(request): |
|
return web.json_response(self.get_queue_info()) |
|
|
|
@routes.get("/object_info") |
|
async def get_object_info(request): |
|
out = {} |
|
for x in nodes.NODE_CLASS_MAPPINGS: |
|
obj_class = nodes.NODE_CLASS_MAPPINGS[x] |
|
info = {} |
|
info['input'] = obj_class.INPUT_TYPES() |
|
info['output'] = obj_class.RETURN_TYPES |
|
info['name'] = x #TODO |
|
info['description'] = '' |
|
info['category'] = 'sd' |
|
if hasattr(obj_class, 'CATEGORY'): |
|
info['category'] = obj_class.CATEGORY |
|
out[x] = info |
|
return web.json_response(out) |
|
|
|
@routes.get("/history") |
|
async def get_history(request): |
|
return web.json_response(self.prompt_queue.get_history()) |
|
|
|
@routes.get("/queue") |
|
async def get_queue(request): |
|
queue_info = {} |
|
current_queue = self.prompt_queue.get_current_queue() |
|
queue_info['queue_running'] = current_queue[0] |
|
queue_info['queue_pending'] = current_queue[1] |
|
return web.json_response(queue_info) |
|
|
|
@routes.post("/prompt") |
|
async def post_prompt(request): |
|
print("got prompt") |
|
resp_code = 200 |
|
out_string = "" |
|
json_data = await request.json() |
|
|
|
if "number" in json_data: |
|
number = float(json_data['number']) |
|
else: |
|
number = self.number |
|
if "front" in json_data: |
|
if json_data['front']: |
|
number = -number |
|
|
|
self.number += 1 |
|
|
|
if "prompt" in json_data: |
|
prompt = json_data["prompt"] |
|
valid = execution.validate_prompt(prompt) |
|
extra_data = {} |
|
if "extra_data" in json_data: |
|
extra_data = json_data["extra_data"] |
|
|
|
if "client_id" in json_data: |
|
extra_data["client_id"] = json_data["client_id"] |
|
if valid[0]: |
|
self.prompt_queue.put((number, id(prompt), prompt, extra_data)) |
|
else: |
|
resp_code = 400 |
|
out_string = valid[1] |
|
print("invalid prompt:", valid[1]) |
|
|
|
return web.Response(body=out_string, status=resp_code) |
|
|
|
@routes.post("/queue") |
|
async def post_queue(request): |
|
json_data = await request.json() |
|
if "clear" in json_data: |
|
if json_data["clear"]: |
|
self.prompt_queue.wipe_queue() |
|
if "delete" in json_data: |
|
to_delete = json_data['delete'] |
|
for id_to_delete in to_delete: |
|
delete_func = lambda a: a[1] == int(id_to_delete) |
|
self.prompt_queue.delete_queue_item(delete_func) |
|
|
|
return web.Response(status=200) |
|
|
|
@routes.post("/interrupt") |
|
async def post_interrupt(request): |
|
nodes.interrupt_processing() |
|
return web.Response(status=200) |
|
|
|
@routes.post("/history") |
|
async def post_history(request): |
|
json_data = await request.json() |
|
if "clear" in json_data: |
|
if json_data["clear"]: |
|
self.prompt_queue.wipe_history() |
|
if "delete" in json_data: |
|
to_delete = json_data['delete'] |
|
for id_to_delete in to_delete: |
|
self.prompt_queue.delete_history_item(id_to_delete) |
|
|
|
return web.Response(status=200) |
|
|
|
self.app.add_routes(routes) |
|
self.app.add_routes([ |
|
web.static('/', self.web_root), |
|
]) |
|
|
|
def get_queue_info(self): |
|
prompt_info = {} |
|
exec_info = {} |
|
exec_info['queue_remaining'] = self.prompt_queue.get_tasks_remaining() |
|
prompt_info['exec_info'] = exec_info |
|
return prompt_info |
|
|
|
async def send(self, event, data, sid=None): |
|
message = {"type": event, "data": data} |
|
|
|
if isinstance(message, str) == False: |
|
message = json.dumps(message) |
|
|
|
if sid is None: |
|
for ws in self.sockets.values(): |
|
await ws.send_str(message) |
|
elif sid in self.sockets: |
|
await self.sockets[sid].send_str(message) |
|
|
|
def send_sync(self, event, data, sid=None): |
|
self.loop.call_soon_threadsafe( |
|
self.messages.put_nowait, (event, data, sid)) |
|
|
|
def queue_updated(self): |
|
self.send_sync("status", { "status": self.get_queue_info() }) |
|
|
|
async def publish_loop(self): |
|
while True: |
|
msg = await self.messages.get() |
|
await self.send(*msg) |
|
|
|
async def start(self, address, port, verbose=True, call_on_start=None): |
|
runner = web.AppRunner(self.app) |
|
await runner.setup() |
|
site = web.TCPSite(runner, address, port) |
|
await site.start() |
|
|
|
if address == '': |
|
address = '0.0.0.0' |
|
if verbose: |
|
print("Starting server\n") |
|
print("To see the GUI go to: http://{}:{}".format(address, port)) |
|
if call_on_start is not None: |
|
call_on_start(address, port) |
|
|
|
|