Browse Source
* Add test for extension hooks Add afterConfigureGraph callback * fix commentpull/2144/head
pythongosssss
12 months ago
committed by
GitHub
3 changed files with 201 additions and 3 deletions
@ -0,0 +1,196 @@
|
||||
// @ts-check
|
||||
/// <reference path="../node_modules/@types/jest/index.d.ts" />
|
||||
const { start } = require("../utils"); |
||||
const lg = require("../utils/litegraph"); |
||||
|
||||
describe("extensions", () => { |
||||
beforeEach(() => { |
||||
lg.setup(global); |
||||
}); |
||||
|
||||
afterEach(() => { |
||||
lg.teardown(global); |
||||
}); |
||||
|
||||
it("calls each extension hook", async () => { |
||||
const mockExtension = { |
||||
name: "TestExtension", |
||||
init: jest.fn(), |
||||
setup: jest.fn(), |
||||
addCustomNodeDefs: jest.fn(), |
||||
getCustomWidgets: jest.fn(), |
||||
beforeRegisterNodeDef: jest.fn(), |
||||
registerCustomNodes: jest.fn(), |
||||
loadedGraphNode: jest.fn(), |
||||
nodeCreated: jest.fn(), |
||||
beforeConfigureGraph: jest.fn(), |
||||
afterConfigureGraph: jest.fn(), |
||||
}; |
||||
|
||||
const { app, ez, graph } = await start({ |
||||
async preSetup(app) { |
||||
app.registerExtension(mockExtension); |
||||
}, |
||||
}); |
||||
|
||||
// Basic initialisation hooks should be called once, with app
|
||||
expect(mockExtension.init).toHaveBeenCalledTimes(1); |
||||
expect(mockExtension.init).toHaveBeenCalledWith(app); |
||||
|
||||
// Adding custom node defs should be passed the full list of nodes
|
||||
expect(mockExtension.addCustomNodeDefs).toHaveBeenCalledTimes(1); |
||||
expect(mockExtension.addCustomNodeDefs.mock.calls[0][1]).toStrictEqual(app); |
||||
const defs = mockExtension.addCustomNodeDefs.mock.calls[0][0]; |
||||
expect(defs).toHaveProperty("KSampler"); |
||||
expect(defs).toHaveProperty("LoadImage"); |
||||
|
||||
// Get custom widgets is called once and should return new widget types
|
||||
expect(mockExtension.getCustomWidgets).toHaveBeenCalledTimes(1); |
||||
expect(mockExtension.getCustomWidgets).toHaveBeenCalledWith(app); |
||||
|
||||
// Before register node def will be called once per node type
|
||||
const nodeNames = Object.keys(defs); |
||||
const nodeCount = nodeNames.length; |
||||
expect(mockExtension.beforeRegisterNodeDef).toHaveBeenCalledTimes(nodeCount); |
||||
for (let i = 0; i < nodeCount; i++) { |
||||
// It should be send the JS class and the original JSON definition
|
||||
const nodeClass = mockExtension.beforeRegisterNodeDef.mock.calls[i][0]; |
||||
const nodeDef = mockExtension.beforeRegisterNodeDef.mock.calls[i][1]; |
||||
|
||||
expect(nodeClass.name).toBe("ComfyNode"); |
||||
expect(nodeClass.comfyClass).toBe(nodeNames[i]); |
||||
expect(nodeDef.name).toBe(nodeNames[i]); |
||||
expect(nodeDef).toHaveProperty("input"); |
||||
expect(nodeDef).toHaveProperty("output"); |
||||
} |
||||
|
||||
// Register custom nodes is called once after registerNode defs to allow adding other frontend nodes
|
||||
expect(mockExtension.registerCustomNodes).toHaveBeenCalledTimes(1); |
||||
|
||||
// Before configure graph will be called here as the default graph is being loaded
|
||||
expect(mockExtension.beforeConfigureGraph).toHaveBeenCalledTimes(1); |
||||
// it gets sent the graph data that is going to be loaded
|
||||
const graphData = mockExtension.beforeConfigureGraph.mock.calls[0][0]; |
||||
|
||||
// A node created is fired for each node constructor that is called
|
||||
expect(mockExtension.nodeCreated).toHaveBeenCalledTimes(graphData.nodes.length); |
||||
for (let i = 0; i < graphData.nodes.length; i++) { |
||||
expect(mockExtension.nodeCreated.mock.calls[i][0].type).toBe(graphData.nodes[i].type); |
||||
} |
||||
|
||||
// Each node then calls loadedGraphNode to allow them to be updated
|
||||
expect(mockExtension.loadedGraphNode).toHaveBeenCalledTimes(graphData.nodes.length); |
||||
for (let i = 0; i < graphData.nodes.length; i++) { |
||||
expect(mockExtension.loadedGraphNode.mock.calls[i][0].type).toBe(graphData.nodes[i].type); |
||||
} |
||||
|
||||
// After configure is then called once all the setup is done
|
||||
expect(mockExtension.afterConfigureGraph).toHaveBeenCalledTimes(1); |
||||
|
||||
expect(mockExtension.setup).toHaveBeenCalledTimes(1); |
||||
expect(mockExtension.setup).toHaveBeenCalledWith(app); |
||||
|
||||
// Ensure hooks are called in the correct order
|
||||
const callOrder = [ |
||||
"init", |
||||
"addCustomNodeDefs", |
||||
"getCustomWidgets", |
||||
"beforeRegisterNodeDef", |
||||
"registerCustomNodes", |
||||
"beforeConfigureGraph", |
||||
"nodeCreated", |
||||
"loadedGraphNode", |
||||
"afterConfigureGraph", |
||||
"setup", |
||||
]; |
||||
for (let i = 1; i < callOrder.length; i++) { |
||||
const fn1 = mockExtension[callOrder[i - 1]]; |
||||
const fn2 = mockExtension[callOrder[i]]; |
||||
expect(fn1.mock.invocationCallOrder[0]).toBeLessThan(fn2.mock.invocationCallOrder[0]); |
||||
} |
||||
|
||||
graph.clear(); |
||||
|
||||
// Ensure adding a new node calls the correct callback
|
||||
ez.LoadImage(); |
||||
expect(mockExtension.loadedGraphNode).toHaveBeenCalledTimes(graphData.nodes.length); |
||||
expect(mockExtension.nodeCreated).toHaveBeenCalledTimes(graphData.nodes.length + 1); |
||||
expect(mockExtension.nodeCreated.mock.lastCall[0].type).toBe("LoadImage"); |
||||
|
||||
// Reload the graph to ensure correct hooks are fired
|
||||
await graph.reload(); |
||||
|
||||
// These hooks should not be fired again
|
||||
expect(mockExtension.init).toHaveBeenCalledTimes(1); |
||||
expect(mockExtension.addCustomNodeDefs).toHaveBeenCalledTimes(1); |
||||
expect(mockExtension.getCustomWidgets).toHaveBeenCalledTimes(1); |
||||
expect(mockExtension.registerCustomNodes).toHaveBeenCalledTimes(1); |
||||
expect(mockExtension.beforeRegisterNodeDef).toHaveBeenCalledTimes(nodeCount); |
||||
expect(mockExtension.setup).toHaveBeenCalledTimes(1); |
||||
|
||||
// These should be called again
|
||||
expect(mockExtension.beforeConfigureGraph).toHaveBeenCalledTimes(2); |
||||
expect(mockExtension.nodeCreated).toHaveBeenCalledTimes(graphData.nodes.length + 2); |
||||
expect(mockExtension.loadedGraphNode).toHaveBeenCalledTimes(graphData.nodes.length + 1); |
||||
expect(mockExtension.afterConfigureGraph).toHaveBeenCalledTimes(2); |
||||
}); |
||||
|
||||
it("allows custom nodeDefs and widgets to be registered", async () => { |
||||
const widgetMock = jest.fn((node, inputName, inputData, app) => { |
||||
expect(node.constructor.comfyClass).toBe("TestNode"); |
||||
expect(inputName).toBe("test_input"); |
||||
expect(inputData[0]).toBe("CUSTOMWIDGET"); |
||||
expect(inputData[1]?.hello).toBe("world"); |
||||
expect(app).toStrictEqual(app); |
||||
|
||||
return { |
||||
widget: node.addWidget("button", inputName, "hello", () => {}), |
||||
}; |
||||
}); |
||||
|
||||
// Register our extension that adds a custom node + widget type
|
||||
const mockExtension = { |
||||
name: "TestExtension", |
||||
addCustomNodeDefs: (nodeDefs) => { |
||||
nodeDefs["TestNode"] = { |
||||
output: [], |
||||
output_name: [], |
||||
output_is_list: [], |
||||
name: "TestNode", |
||||
display_name: "TestNode", |
||||
category: "Test", |
||||
input: { |
||||
required: { |
||||
test_input: ["CUSTOMWIDGET", { hello: "world" }], |
||||
}, |
||||
}, |
||||
}; |
||||
}, |
||||
getCustomWidgets: jest.fn(() => { |
||||
return { |
||||
CUSTOMWIDGET: widgetMock, |
||||
}; |
||||
}), |
||||
}; |
||||
|
||||
const { graph, ez } = await start({ |
||||
async preSetup(app) { |
||||
app.registerExtension(mockExtension); |
||||
}, |
||||
}); |
||||
|
||||
expect(mockExtension.getCustomWidgets).toBeCalledTimes(1); |
||||
|
||||
graph.clear(); |
||||
expect(widgetMock).toBeCalledTimes(0); |
||||
const node = ez.TestNode(); |
||||
expect(widgetMock).toBeCalledTimes(1); |
||||
|
||||
// Ensure our custom widget is created
|
||||
expect(node.inputs.length).toBe(0); |
||||
expect(node.widgets.length).toBe(1); |
||||
const w = node.widgets[0].widget; |
||||
expect(w.name).toBe("test_input"); |
||||
expect(w.type).toBe("button"); |
||||
}); |
||||
}); |
Loading…
Reference in new issue