diff --git a/server.py b/server.py index 5aba5761..5dc8e150 100644 --- a/server.py +++ b/server.py @@ -73,6 +73,14 @@ class PromptServer(): async def get_root(request): return web.FileResponse(os.path.join(self.web_root, "index.html")) + @routes.get("/embeddings") + def get_embeddings(self): + models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models") + embed_dir = os.path.join(models_dir, "embeddings") + embeddings = nodes.filter_files_extensions(nodes.recursive_search(embed_dir), nodes.supported_pt_extensions) + + return web.json_response(list(map(lambda a: os.path.splitext(a)[0].lower(), embeddings))) + @routes.get("/extensions") async def get_extensions(request): files = glob.glob(os.path.join(self.web_root, 'extensions/**/*.js'), recursive=True) diff --git a/web/scripts/api.js b/web/scripts/api.js index 39f48d4a..b90b1c65 100644 --- a/web/scripts/api.js +++ b/web/scripts/api.js @@ -106,6 +106,15 @@ class ComfyApi extends EventTarget { return await resp.json(); } + /** + * Gets a list of embedding names + * @returns An array of script urls to import + */ + async getEmbeddings() { + const resp = await fetch("/embeddings", { cache: "no-store" }); + return await resp.json(); + } + /** * Loads node object definitions for the graph * @returns The node definitions diff --git a/web/scripts/app.js b/web/scripts/app.js index e70e1c15..00e1baa3 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -2,7 +2,7 @@ import { ComfyWidgets } from "./widgets.js"; import { ComfyUI } from "./ui.js"; import { api } from "./api.js"; import { defaultGraph } from "./defaultGraph.js"; -import { getPngMetadata } from "./pnginfo.js"; +import { getPngMetadata, importA1111 } from "./pnginfo.js"; class ComfyApp { constructor() { @@ -675,8 +675,12 @@ class ComfyApp { async handleFile(file) { if (file.type === "image/png") { const pngInfo = await getPngMetadata(file); - if (pngInfo && pngInfo.workflow) { - this.loadGraphData(JSON.parse(pngInfo.workflow)); + if (pngInfo) { + if (pngInfo.workflow) { + this.loadGraphData(JSON.parse(pngInfo.workflow)); + } else if (pngInfo.parameters) { + importA1111(this.graph, pngInfo.parameters); + } } } else if (file.type === "application/json" || file.name.endsWith(".json")) { const reader = new FileReader(); diff --git a/web/scripts/pnginfo.js b/web/scripts/pnginfo.js index 923f8745..580030d8 100644 --- a/web/scripts/pnginfo.js +++ b/web/scripts/pnginfo.js @@ -1,3 +1,5 @@ +import { api } from "./api.js"; + export function getPngMetadata(file) { return new Promise((r) => { const reader = new FileReader(); @@ -43,3 +45,262 @@ export function getPngMetadata(file) { reader.readAsArrayBuffer(file); }); } + +export async function importA1111(graph, parameters) { + const p = parameters.lastIndexOf("\nSteps:"); + if (p > -1) { + const embeddings = await api.getEmbeddings(); + const opts = parameters + .substr(p) + .split(",") + .reduce((p, n) => { + const s = n.split(":"); + p[s[0].trim().toLowerCase()] = s[1].trim(); + return p; + }, {}); + const p2 = parameters.lastIndexOf("\nNegative prompt:", p); + if (p2 > -1) { + let positive = parameters.substr(0, p2).trim(); + let negative = parameters.substring(p2 + 18, p).trim(); + + const ckptNode = LiteGraph.createNode("CheckpointLoaderSimple"); + const clipSkipNode = LiteGraph.createNode("CLIPSetLastLayer"); + const positiveNode = LiteGraph.createNode("CLIPTextEncode"); + const negativeNode = LiteGraph.createNode("CLIPTextEncode"); + const samplerNode = LiteGraph.createNode("KSampler"); + const imageNode = LiteGraph.createNode("EmptyLatentImage"); + const vaeNode = LiteGraph.createNode("VAEDecode"); + const vaeLoaderNode = LiteGraph.createNode("VAELoader"); + const saveNode = LiteGraph.createNode("SaveImage"); + let hrSamplerNode = null; + + const ceil64 = (v) => Math.ceil(v / 64) * 64; + + function getWidget(node, name) { + return node.widgets.find((w) => w.name === name); + } + + function setWidgetValue(node, name, value, isOptionPrefix) { + const w = getWidget(node, name); + if (isOptionPrefix) { + const o = w.options.values.find((w) => w.startsWith(value)); + if (o) { + w.value = o; + } else { + console.warn(`Unknown value '${value}' for widget '${name}'`, node); + w.value = value; + } + } else { + w.value = value; + } + } + + function createLoraNodes(clipNode, text, prevClip, prevModel) { + const loras = []; + text = text.replace(/]+)>/g, function (m, c) { + const s = c.split(":"); + const weight = parseFloat(s[1]); + if (isNaN(weight)) { + console.warn("Invalid LORA", m); + } else { + loras.push({ name: s[0], weight }); + } + return ""; + }); + + for (const l of loras) { + const loraNode = LiteGraph.createNode("LoraLoader"); + graph.add(loraNode); + setWidgetValue(loraNode, "lora_name", l.name, true); + setWidgetValue(loraNode, "strength_model", l.weight); + setWidgetValue(loraNode, "strength_clip", l.weight); + prevModel.node.connect(prevModel.index, loraNode, 0); + prevClip.node.connect(prevClip.index, loraNode, 1); + prevModel = { node: loraNode, index: 0 }; + prevClip = { node: loraNode, index: 1 }; + } + + prevClip.node.connect(1, clipNode, 0); + prevModel.node.connect(0, samplerNode, 0); + if (hrSamplerNode) { + prevModel.node.connect(0, hrSamplerNode, 0); + } + + return { text, prevModel, prevClip }; + } + + function replaceEmbeddings(text) { + return text.replaceAll( + new RegExp( + "\\b(" + embeddings.map((e) => e.replace(/[.*+?^${}()|[\]\\]/g, "\\$&")).join("\\b|\\b") + ")\\b", + "ig" + ), + "embedding:$1" + ); + } + + function popOpt(name) { + const v = opts[name]; + delete opts[name]; + return v; + } + + graph.clear(); + graph.add(ckptNode); + graph.add(clipSkipNode); + graph.add(positiveNode); + graph.add(negativeNode); + graph.add(samplerNode); + graph.add(imageNode); + graph.add(vaeNode); + graph.add(vaeLoaderNode); + graph.add(saveNode); + + ckptNode.connect(1, clipSkipNode, 0); + clipSkipNode.connect(0, positiveNode, 0); + clipSkipNode.connect(0, negativeNode, 0); + ckptNode.connect(0, samplerNode, 0); + positiveNode.connect(0, samplerNode, 1); + negativeNode.connect(0, samplerNode, 2); + imageNode.connect(0, samplerNode, 3); + vaeNode.connect(0, saveNode, 0); + samplerNode.connect(0, vaeNode, 0); + vaeLoaderNode.connect(0, vaeNode, 1); + + const handlers = { + model(v) { + setWidgetValue(ckptNode, "ckpt_name", v, true); + }, + "cfg scale"(v) { + setWidgetValue(samplerNode, "cfg", +v); + }, + "clip skip"(v) { + setWidgetValue(clipSkipNode, "stop_at_clip_layer", -v); + }, + sampler(v) { + let name = v.toLowerCase().replace("++", "pp").replaceAll(" ", "_"); + if (name.includes("karras")) { + name = name.replace("karras", "").replace(/_+$/, ""); + setWidgetValue(samplerNode, "scheduler", "karras"); + } else { + setWidgetValue(samplerNode, "scheduler", "normal"); + } + const w = getWidget(samplerNode, "sampler_name"); + const o = w.options.values.find((w) => w === name || w === "sample_" + name); + if (o) { + setWidgetValue(samplerNode, "sampler_name", o); + } + }, + size(v) { + const wxh = v.split("x"); + const w = ceil64(+wxh[0]); + const h = ceil64(+wxh[1]); + const hrUp = popOpt("hires upscale"); + const hrSz = popOpt("hires resize"); + let hrMethod = popOpt("hires upscaler"); + + setWidgetValue(imageNode, "width", w); + setWidgetValue(imageNode, "height", h); + + if (hrUp || hrSz) { + let uw, uh; + if (hrUp) { + uw = w * hrUp; + uh = h * hrUp; + } else { + const s = hrSz.split("x"); + uw = +s[0]; + uh = +s[1]; + } + + let upscaleNode; + let latentNode; + + if (hrMethod.startsWith("Latent")) { + latentNode = upscaleNode = LiteGraph.createNode("LatentUpscale"); + graph.add(upscaleNode); + samplerNode.connect(0, upscaleNode, 0); + + switch (hrMethod) { + case "Latent (nearest-exact)": + hrMethod = "nearest-exact"; + break; + } + setWidgetValue(upscaleNode, "upscale_method", hrMethod, true); + } else { + const decode = LiteGraph.createNode("VAEDecodeTiled"); + graph.add(decode); + samplerNode.connect(0, decode, 0); + vaeLoaderNode.connect(0, decode, 1); + + const upscaleLoaderNode = LiteGraph.createNode("UpscaleModelLoader"); + graph.add(upscaleLoaderNode); + setWidgetValue(upscaleLoaderNode, "model_name", hrMethod, true); + + const modelUpscaleNode = LiteGraph.createNode("ImageUpscaleWithModel"); + graph.add(modelUpscaleNode); + decode.connect(0, modelUpscaleNode, 1); + upscaleLoaderNode.connect(0, modelUpscaleNode, 0); + + upscaleNode = LiteGraph.createNode("ImageScale"); + graph.add(upscaleNode); + modelUpscaleNode.connect(0, upscaleNode, 0); + + const vaeEncodeNode = (latentNode = LiteGraph.createNode("VAEEncodeTiled")); + graph.add(vaeEncodeNode); + upscaleNode.connect(0, vaeEncodeNode, 0); + vaeLoaderNode.connect(0, vaeEncodeNode, 1); + } + + setWidgetValue(upscaleNode, "width", ceil64(uw)); + setWidgetValue(upscaleNode, "height", ceil64(uh)); + + hrSamplerNode = LiteGraph.createNode("KSampler"); + graph.add(hrSamplerNode); + ckptNode.connect(0, hrSamplerNode, 0); + positiveNode.connect(0, hrSamplerNode, 1); + negativeNode.connect(0, hrSamplerNode, 2); + latentNode.connect(0, hrSamplerNode, 3); + hrSamplerNode.connect(0, vaeNode, 0); + } + }, + steps(v) { + setWidgetValue(samplerNode, "steps", +v); + }, + seed(v) { + setWidgetValue(samplerNode, "seed", +v); + }, + }; + + for (const opt in opts) { + if (opt in handlers) { + handlers[opt](popOpt(opt)); + } + } + + if (hrSamplerNode) { + setWidgetValue(hrSamplerNode, "steps", getWidget(samplerNode, "steps").value); + setWidgetValue(hrSamplerNode, "cfg", getWidget(samplerNode, "cfg").value); + setWidgetValue(hrSamplerNode, "scheduler", getWidget(samplerNode, "scheduler").value); + setWidgetValue(hrSamplerNode, "sampler_name", getWidget(samplerNode, "sampler_name").value); + setWidgetValue(hrSamplerNode, "denoise", +(popOpt("denoising strength") || "1")); + } + + let n = createLoraNodes(positiveNode, positive, { node: clipSkipNode, index: 0 }, { node: ckptNode, index: 0 }); + positive = n.text; + n = createLoraNodes(negativeNode, negative, n.prevClip, n.prevModel); + negative = n.text; + + setWidgetValue(positiveNode, "text", replaceEmbeddings(positive)); + setWidgetValue(negativeNode, "text", replaceEmbeddings(negative)); + + graph.arrange(); + + for (const opt of ["model hash", "ensd"]) { + delete opts[opt]; + } + + console.warn("Unhandled parameters:", opts); + } + } +}