From 390078904c791c7c66c08478ed3d657b42ba7888 Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Wed, 13 Dec 2023 05:56:39 +0000 Subject: [PATCH] Group node fixes (#2259) * Prevent cleaning graph state on undo/redo * Remove pause rendering due to LG bug * Fix crash on disconnected internal reroutes * Fix widget inputs being incorrect order and value * Fix initial primitive values on connect * basic support for basic rerouted converted inputs * Populate primitive to reroute input * dont crash on bad primitive links * Fix convert to group changing control value * reduce restrictions * fix random crash in tests --- tests-ui/tests/groupNode.test.js | 134 ++++++++++++++++++++++++++-- tests-ui/utils/ezgraph.js | 8 ++ tests-ui/utils/index.js | 9 ++ web/extensions/core/groupNode.js | 127 +++++++++++++++++++++----- web/extensions/core/rerouteNode.js | 1 + web/extensions/core/widgetInputs.js | 19 +++- web/scripts/app.js | 4 +- 7 files changed, 275 insertions(+), 27 deletions(-) diff --git a/tests-ui/tests/groupNode.test.js b/tests-ui/tests/groupNode.test.js index dc9d4bd4..625890a0 100644 --- a/tests-ui/tests/groupNode.test.js +++ b/tests-ui/tests/groupNode.test.js @@ -1,7 +1,7 @@ // @ts-check /// -const { start, createDefaultWorkflow } = require("../utils"); +const { start, createDefaultWorkflow, getNodeDef, checkBeforeAndAfterReload } = require("../utils"); const lg = require("../utils/litegraph"); describe("group node", () => { @@ -273,7 +273,7 @@ describe("group node", () => { let reroutes = []; let prevNode = nodes.ckpt; - for(let i = 0; i < 5; i++) { + for (let i = 0; i < 5; i++) { const reroute = ez.Reroute(); prevNode.outputs[0].connectTo(reroute.inputs[0]); prevNode = reroute; @@ -283,7 +283,7 @@ describe("group node", () => { const group = await convertToGroup(app, graph, "test", [...reroutes, ...Object.values(nodes)]); expect((await graph.toPrompt()).output).toEqual(getOutput()); - + group.menu["Convert to nodes"].call(); expect((await graph.toPrompt()).output).toEqual(getOutput()); }); @@ -407,12 +407,18 @@ describe("group node", () => { const decode = ez.VAEDecode(group2.outputs.LATENT, group2.outputs.VAE); const preview = ez.PreviewImage(decode.outputs[0]); - expect((await graph.toPrompt()).output).toEqual({ + const output = { [latent.id]: { inputs: { width: 512, height: 512, batch_size: 1 }, class_type: "EmptyLatentImage" }, [vae.id]: { inputs: { vae_name: "vae1.safetensors" }, class_type: "VAELoader" }, [decode.id]: { inputs: { samples: [latent.id + "", 0], vae: [vae.id + "", 0] }, class_type: "VAEDecode" }, [preview.id]: { inputs: { images: [decode.id + "", 0] }, class_type: "PreviewImage" }, - }); + }; + expect((await graph.toPrompt()).output).toEqual(output); + + // Ensure missing connections dont cause errors + group2.inputs.VAE.disconnect(); + delete output[decode.id].inputs.vae; + expect((await graph.toPrompt()).output).toEqual(output); }); test("displays generated image on group node", async () => { const { ez, graph, app } = await start(); @@ -673,6 +679,55 @@ describe("group node", () => { 2: { inputs: { text: "positive" }, class_type: "CLIPTextEncode" }, }); }); + test("correctly handles widget inputs", async () => { + const { ez, graph, app } = await start(); + const upscaleMethods = (await getNodeDef("ImageScaleBy")).input.required["upscale_method"][0]; + + const image = ez.LoadImage(); + const scale1 = ez.ImageScaleBy(image.outputs[0]); + const scale2 = ez.ImageScaleBy(image.outputs[0]); + const preview1 = ez.PreviewImage(scale1.outputs[0]); + const preview2 = ez.PreviewImage(scale2.outputs[0]); + scale1.widgets.upscale_method.value = upscaleMethods[1]; + scale1.widgets.upscale_method.convertToInput(); + + const group = await convertToGroup(app, graph, "test", [scale1, scale2]); + expect(group.inputs.length).toBe(3); + expect(group.inputs[0].input.type).toBe("IMAGE"); + expect(group.inputs[1].input.type).toBe("IMAGE"); + expect(group.inputs[2].input.type).toBe("COMBO"); + + // Ensure links are maintained + expect(group.inputs[0].connection?.originNode?.id).toBe(image.id); + expect(group.inputs[1].connection?.originNode?.id).toBe(image.id); + expect(group.inputs[2].connection).toBeFalsy(); + + // Ensure primitive gets correct type + const primitive = ez.PrimitiveNode(); + primitive.outputs[0].connectTo(group.inputs[2]); + expect(primitive.widgets.value.widget.options.values).toBe(upscaleMethods); + expect(primitive.widgets.value.value).toBe(upscaleMethods[1]); // Ensure value is copied + primitive.widgets.value.value = upscaleMethods[1]; + + await checkBeforeAndAfterReload(graph, async (r) => { + const scale1id = r ? `${group.id}:0` : scale1.id; + const scale2id = r ? `${group.id}:1` : scale2.id; + // Ensure widget value is applied to prompt + expect((await graph.toPrompt()).output).toStrictEqual({ + [image.id]: { inputs: { image: "example.png", upload: "image" }, class_type: "LoadImage" }, + [scale1id]: { + inputs: { upscale_method: upscaleMethods[1], scale_by: 1, image: [`${image.id}`, 0] }, + class_type: "ImageScaleBy", + }, + [scale2id]: { + inputs: { upscale_method: "nearest-exact", scale_by: 1, image: [`${image.id}`, 0] }, + class_type: "ImageScaleBy", + }, + [preview1.id]: { inputs: { images: [`${scale1id}`, 0] }, class_type: "PreviewImage" }, + [preview2.id]: { inputs: { images: [`${scale2id}`, 0] }, class_type: "PreviewImage" }, + }); + }); + }); test("adds widgets in node execution order", async () => { const { ez, graph, app } = await start(); const scale = ez.LatentUpscale(); @@ -846,4 +901,73 @@ describe("group node", () => { expect(p2.widgets.control_after_generate.value).toBe("randomize"); expect(p2.widgets.control_filter_list.value).toBe("/.+/"); }); + test("internal reroutes work with converted inputs and merge options", async () => { + const { ez, graph, app } = await start(); + const vae = ez.VAELoader(); + const latent = ez.EmptyLatentImage(); + const decode = ez.VAEDecode(latent.outputs.LATENT, vae.outputs.VAE); + const scale = ez.ImageScale(decode.outputs.IMAGE); + ez.PreviewImage(scale.outputs.IMAGE); + + const r1 = ez.Reroute(); + const r2 = ez.Reroute(); + + latent.widgets.width.value = 64; + latent.widgets.height.value = 128; + + latent.widgets.width.convertToInput(); + latent.widgets.height.convertToInput(); + latent.widgets.batch_size.convertToInput(); + + scale.widgets.width.convertToInput(); + scale.widgets.height.convertToInput(); + + r1.inputs[0].input.label = "hbw"; + r1.outputs[0].connectTo(latent.inputs.height); + r1.outputs[0].connectTo(latent.inputs.batch_size); + r1.outputs[0].connectTo(scale.inputs.width); + + r2.inputs[0].input.label = "wh"; + r2.outputs[0].connectTo(latent.inputs.width); + r2.outputs[0].connectTo(scale.inputs.height); + + const group = await convertToGroup(app, graph, "test", [r1, r2, latent, decode, scale]); + + expect(group.inputs[0].input.type).toBe("VAE"); + expect(group.inputs[1].input.type).toBe("INT"); + expect(group.inputs[2].input.type).toBe("INT"); + + const p1 = ez.PrimitiveNode(); + const p2 = ez.PrimitiveNode(); + p1.outputs[0].connectTo(group.inputs[1]); + p2.outputs[0].connectTo(group.inputs[2]); + + expect(p1.widgets.value.widget.options?.min).toBe(16); // width/height min + expect(p1.widgets.value.widget.options?.max).toBe(4096); // batch max + expect(p1.widgets.value.widget.options?.step).toBe(80); // width/height step * 10 + + expect(p2.widgets.value.widget.options?.min).toBe(16); // width/height min + expect(p2.widgets.value.widget.options?.max).toBe(8192); // width/height max + expect(p2.widgets.value.widget.options?.step).toBe(80); // width/height step * 10 + + expect(p1.widgets.value.value).toBe(128); + expect(p2.widgets.value.value).toBe(64); + + p1.widgets.value.value = 16; + p2.widgets.value.value = 32; + + await checkBeforeAndAfterReload(graph, async (r) => { + const id = (v) => (r ? `${group.id}:` : "") + v; + expect((await graph.toPrompt()).output).toStrictEqual({ + 1: { inputs: { vae_name: "vae1.safetensors" }, class_type: "VAELoader" }, + [id(2)]: { inputs: { width: 32, height: 16, batch_size: 16 }, class_type: "EmptyLatentImage" }, + [id(3)]: { inputs: { samples: [id(2), 0], vae: ["1", 0] }, class_type: "VAEDecode" }, + [id(4)]: { + inputs: { upscale_method: "nearest-exact", width: 16, height: 32, crop: "disabled", image: [id(3), 0] }, + class_type: "ImageScale", + }, + 5: { inputs: { images: [id(4), 0] }, class_type: "PreviewImage" }, + }); + }); + }); }); diff --git a/tests-ui/utils/ezgraph.js b/tests-ui/utils/ezgraph.js index 3101aa29..8a55246e 100644 --- a/tests-ui/utils/ezgraph.js +++ b/tests-ui/utils/ezgraph.js @@ -78,6 +78,14 @@ export class EzInput extends EzSlot { this.input = input; } + get connection() { + const link = this.node.node.inputs?.[this.index]?.link; + if (link == null) { + return null; + } + return new EzConnection(this.node.app, this.node.app.graph.links[link]); + } + disconnect() { this.node.node.disconnectInput(this.index); } diff --git a/tests-ui/utils/index.js b/tests-ui/utils/index.js index 3a018f56..6a08e859 100644 --- a/tests-ui/utils/index.js +++ b/tests-ui/utils/index.js @@ -104,3 +104,12 @@ export function createDefaultWorkflow(ez, graph) { return { ckpt, pos, neg, empty, sampler, decode, save }; } + +export async function getNodeDefs() { + const { api } = require("../../web/scripts/api"); + return api.getNodeDefs(); +} + +export async function getNodeDef(nodeId) { + return (await getNodeDefs())[nodeId]; +} \ No newline at end of file diff --git a/web/extensions/core/groupNode.js b/web/extensions/core/groupNode.js index 9a1d9b20..dc962ac2 100644 --- a/web/extensions/core/groupNode.js +++ b/web/extensions/core/groupNode.js @@ -174,6 +174,11 @@ export class GroupNodeConfig { node.index = i; this.processNode(node, seenInputs, seenOutputs); } + + for (const p of this.#convertedToProcess) { + p(); + } + this.#convertedToProcess = null; await app.registerNodeDef("workflow/" + this.name, this.nodeDef); } @@ -192,7 +197,10 @@ export class GroupNodeConfig { if (!this.linksFrom[sourceNodeId]) { this.linksFrom[sourceNodeId] = {}; } - this.linksFrom[sourceNodeId][sourceNodeSlot] = l; + if (!this.linksFrom[sourceNodeId][sourceNodeSlot]) { + this.linksFrom[sourceNodeId][sourceNodeSlot] = []; + } + this.linksFrom[sourceNodeId][sourceNodeSlot].push(l); if (!this.linksTo[targetNodeId]) { this.linksTo[targetNodeId] = {}; @@ -230,11 +238,11 @@ export class GroupNodeConfig { // Skip as its not linked if (!linksFrom) return; - let type = linksFrom["0"][5]; + let type = linksFrom["0"][0][5]; if (type === "COMBO") { // Use the array items const source = node.outputs[0].widget.name; - const fromTypeName = this.nodeData.nodes[linksFrom["0"][2]].type; + const fromTypeName = this.nodeData.nodes[linksFrom["0"][0][2]].type; const fromType = globalDefs[fromTypeName]; const input = fromType.input.required[source] ?? fromType.input.optional[source]; type = input[0]; @@ -258,10 +266,33 @@ export class GroupNodeConfig { return null; } + let config = {}; let rerouteType = "*"; if (linksFrom) { - const [, , id, slot] = linksFrom["0"]; - rerouteType = this.nodeData.nodes[id].inputs[slot].type; + for (const [, , id, slot] of linksFrom["0"]) { + const node = this.nodeData.nodes[id]; + const input = node.inputs[slot]; + if (rerouteType === "*") { + rerouteType = input.type; + } + if (input.widget) { + const targetDef = globalDefs[node.type]; + const targetWidget = + targetDef.input.required[input.widget.name] ?? targetDef.input.optional[input.widget.name]; + + const widget = [targetWidget[0], config]; + const res = mergeIfValid( + { + widget, + }, + targetWidget, + false, + null, + widget + ); + config = res?.customConfig ?? config; + } + } } else if (linksTo) { const [id, slot] = linksTo["0"]; rerouteType = this.nodeData.nodes[id].outputs[slot].type; @@ -282,10 +313,11 @@ export class GroupNodeConfig { } } + config.forceInput = true; return { input: { required: { - [rerouteType]: [rerouteType, {}], + [rerouteType]: [rerouteType, config], }, }, output: [rerouteType], @@ -420,10 +452,18 @@ export class GroupNodeConfig { defaultInput: true, }); this.nodeDef.input.required[name] = config; + this.newToOldWidgetMap[name] = { node, inputName }; + + if (!this.oldToNewWidgetMap[node.index]) { + this.oldToNewWidgetMap[node.index] = {}; + } + this.oldToNewWidgetMap[node.index][inputName] = name; + inputMap[slots.length + i] = this.inputCount++; } } + #convertedToProcess = []; processNodeInputs(node, seenInputs, inputs) { const inputMapping = []; @@ -434,7 +474,11 @@ export class GroupNodeConfig { const linksTo = this.linksTo[node.index] ?? {}; const inputMap = (this.oldToNewInputMap[node.index] = {}); this.processInputSlots(inputs, node, slots, linksTo, inputMap, seenInputs); - this.processConvertedWidgets(inputs, node, slots, converted, linksTo, inputMap, seenInputs); + + // Converted inputs have to be processed after all other nodes as they'll be at the end of the list + this.#convertedToProcess.push(() => + this.processConvertedWidgets(inputs, node, slots, converted, linksTo, inputMap, seenInputs) + ); return inputMapping; } @@ -597,11 +641,15 @@ export class GroupNodeHandler { const output = this.groupData.newToOldOutputMap[link.origin_slot]; let innerNode = this.innerNodes[output.node.index]; let l; - while (innerNode.type === "Reroute") { + while (innerNode?.type === "Reroute") { l = innerNode.getInputLink(0); innerNode = innerNode.getInputNode(0); } + if (!innerNode) { + return null; + } + if (l && GroupNodeHandler.isGroupNode(innerNode)) { return innerNode.updateLink(l); } @@ -669,6 +717,8 @@ export class GroupNodeHandler { top = newNode.pos[1]; } + if (!newNode.widgets) continue; + const map = this.groupData.oldToNewWidgetMap[innerNode.index]; if (map) { const widgets = Object.keys(map); @@ -725,7 +775,7 @@ export class GroupNodeHandler { } }; - const reconnectOutputs = () => { + const reconnectOutputs = (selectedIds) => { for (let groupOutputId = 0; groupOutputId < node.outputs?.length; groupOutputId++) { const output = node.outputs[groupOutputId]; if (!output.links) continue; @@ -865,7 +915,7 @@ export class GroupNodeHandler { if (innerNode.type === "PrimitiveNode") { innerNode.primitiveValue = newValue; const primitiveLinked = this.groupData.primitiveToWidget[old.node.index]; - for (const linked of primitiveLinked) { + for (const linked of primitiveLinked ?? []) { const node = this.innerNodes[linked.nodeId]; const widget = node.widgets.find((w) => w.name === linked.inputName); @@ -874,6 +924,18 @@ export class GroupNodeHandler { } } continue; + } else if (innerNode.type === "Reroute") { + const rerouteLinks = this.groupData.linksFrom[old.node.index]; + for (const [_, , targetNodeId, targetSlot] of rerouteLinks["0"]) { + const node = this.innerNodes[targetNodeId]; + const input = node.inputs[targetSlot]; + if (input.widget) { + const widget = node.widgets?.find((w) => w.name === input.widget.name); + if (widget) { + widget.value = newValue; + } + } + } } const widget = innerNode.widgets?.find((w) => w.name === old.inputName); @@ -901,33 +963,58 @@ export class GroupNodeHandler { this.node.widgets[targetWidgetIndex + i].value = primitiveNode.widgets[i].value; } } + return true; + } + + populateReroute(node, nodeId, map) { + if (node.type !== "Reroute") return; + + const link = this.groupData.linksFrom[nodeId]?.[0]?.[0]; + if (!link) return; + const [, , targetNodeId, targetNodeSlot] = link; + const targetNode = this.groupData.nodeData.nodes[targetNodeId]; + const inputs = targetNode.inputs; + const targetWidget = inputs?.[targetNodeSlot].widget; + if (!targetWidget) return; + + const offset = inputs.length - (targetNode.widgets_values?.length ?? 0); + const v = targetNode.widgets_values?.[targetNodeSlot - offset]; + if (v == null) return; + + const widgetName = Object.values(map)[0]; + const widget = this.node.widgets.find(w => w.name === widgetName); + if(widget) { + widget.value = v; + } } + populateWidgets() { + if (!this.node.widgets) return; + for (let nodeId = 0; nodeId < this.groupData.nodeData.nodes.length; nodeId++) { const node = this.groupData.nodeData.nodes[nodeId]; - - if (!node.widgets_values?.length) continue; - - const map = this.groupData.oldToNewWidgetMap[nodeId]; + const map = this.groupData.oldToNewWidgetMap[nodeId] ?? {}; const widgets = Object.keys(map); + if (!node.widgets_values?.length) { + // special handling for populating values into reroutes + // this allows primitives connect to them to pick up the correct value + this.populateReroute(node, nodeId, map); + continue; + } + let linkedShift = 0; for (let i = 0; i < widgets.length; i++) { const oldName = widgets[i]; const newName = map[oldName]; const widgetIndex = this.node.widgets.findIndex((w) => w.name === newName); const mainWidget = this.node.widgets[widgetIndex]; - if (!newName) { - // New name will be null if its a converted widget - this.populatePrimitive(node, nodeId, oldName, i, linkedShift); - + if (this.populatePrimitive(node, nodeId, oldName, i, linkedShift)) { // Find the inner widget and shift by the number of linked widgets as they will have been removed too const innerWidget = this.innerNodes[nodeId].widgets?.find((w) => w.name === oldName); linkedShift += innerWidget.linkedWidgets?.length ?? 0; - continue; } - if (widgetIndex === -1) { continue; } diff --git a/web/extensions/core/rerouteNode.js b/web/extensions/core/rerouteNode.js index cfa952f3..4feff91e 100644 --- a/web/extensions/core/rerouteNode.js +++ b/web/extensions/core/rerouteNode.js @@ -54,6 +54,7 @@ app.registerExtension({ const linkId = currentNode.inputs[0].link; if (linkId !== null) { const link = app.graph.links[linkId]; + if (!link) return; const node = app.graph.getNodeById(link.origin_id); const type = node.constructor.type; if (type === "Reroute") { diff --git a/web/extensions/core/widgetInputs.js b/web/extensions/core/widgetInputs.js index 865db492..3f1c1f8c 100644 --- a/web/extensions/core/widgetInputs.js +++ b/web/extensions/core/widgetInputs.js @@ -180,7 +180,7 @@ export function mergeIfValid(output, config2, forceUpdate, recreateWidget, confi const isNumber = config1[0] === "INT" || config1[0] === "FLOAT"; for (const k of keys.values()) { - if (k !== "default" && k !== "forceInput" && k !== "defaultInput") { + if (k !== "default" && k !== "forceInput" && k !== "defaultInput" && k !== "control_after_generate" && k !== "multiline") { let v1 = config1[1][k]; let v2 = config2[1]?.[k]; @@ -633,6 +633,14 @@ app.registerExtension({ } } + // Restore any saved control values + const controlValues = this.controlValues; + if(this.lastType === this.widgets[0].type && controlValues?.length === this.widgets.length - 1) { + for(let i = 0; i < controlValues.length; i++) { + this.widgets[i + 1].value = controlValues[i]; + } + } + // When our value changes, update other widgets to reflect our changes // e.g. so LoadImage shows correct image const callback = widget.callback; @@ -721,6 +729,15 @@ app.registerExtension({ w.onRemove(); } } + + // Temporarily store the current values in case the node is being recreated + // e.g. by group node conversion + this.controlValues = []; + this.lastType = this.widgets[0]?.type; + for(let i = 1; i < this.widgets.length; i++) { + this.controlValues.push(this.widgets[i].value); + } + setTimeout(() => { delete this.lastType; delete this.controlValues }, 15); this.widgets.length = 0; } } diff --git a/web/scripts/app.js b/web/scripts/app.js index d2a6f4de..62169abf 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -1774,7 +1774,9 @@ export class ComfyApp { if (parent?.updateLink) { link = parent.updateLink(link); } - inputs[node.inputs[i].name] = [String(link.origin_id), parseInt(link.origin_slot)]; + if (link) { + inputs[node.inputs[i].name] = [String(link.origin_id), parseInt(link.origin_slot)]; + } } } }