Browse Source
* Add inference tests * Clean up * Rename test graph file * Add readme for tests * Separate server fixture * test file name change * Assert images are generated * Clean up comments * Add __init__.py so tests can run with command line `pytest` * Fix command line args for pytest * Loop all samplers/schedulers in test_inference.py * Ci quality workflows compare (#1) * Add image comparison tests * Comparison tests do not pass with empty metadata * Ensure tests are run in correct order * Save image files with test name * Update tests readme * Reduce step counts in tests to ~halve runtime * Ci quality workflows build (#2) * Add build test github workflowpull/1555/head
enzymezoo-code
1 year ago
committed by
GitHub
10 changed files with 728 additions and 0 deletions
@ -0,0 +1,31 @@ |
|||||||
|
name: Build package |
||||||
|
|
||||||
|
# |
||||||
|
# This workflow is a test of the python package build. |
||||||
|
# Install Python dependencies across different Python versions. |
||||||
|
# |
||||||
|
|
||||||
|
on: |
||||||
|
push: |
||||||
|
paths: |
||||||
|
- "requirements.txt" |
||||||
|
- ".github/workflows/test-build.yml" |
||||||
|
|
||||||
|
jobs: |
||||||
|
build: |
||||||
|
name: Build Test |
||||||
|
runs-on: ubuntu-latest |
||||||
|
strategy: |
||||||
|
fail-fast: false |
||||||
|
matrix: |
||||||
|
python-version: ["3.8", "3.9", "3.10", "3.11"] |
||||||
|
steps: |
||||||
|
- uses: actions/checkout@v2 |
||||||
|
- name: Set up Python ${{ matrix.python-version }} |
||||||
|
uses: actions/setup-python@v2 |
||||||
|
with: |
||||||
|
python-version: ${{ matrix.python-version }} |
||||||
|
- name: Install dependencies |
||||||
|
run: | |
||||||
|
python -m pip install --upgrade pip |
||||||
|
pip install -r requirements.txt |
@ -0,0 +1,5 @@ |
|||||||
|
[pytest] |
||||||
|
markers = |
||||||
|
inference: mark as inference test (deselect with '-m "not inference"') |
||||||
|
testpaths = tests |
||||||
|
addopts = -s |
@ -0,0 +1,29 @@ |
|||||||
|
# Automated Testing |
||||||
|
|
||||||
|
## Running tests locally |
||||||
|
|
||||||
|
Additional requirements for running tests: |
||||||
|
``` |
||||||
|
pip install pytest |
||||||
|
pip install websocket-client==1.6.1 |
||||||
|
opencv-python==4.6.0.66 |
||||||
|
scikit-image==0.21.0 |
||||||
|
``` |
||||||
|
Run inference tests: |
||||||
|
``` |
||||||
|
pytest tests/inference |
||||||
|
``` |
||||||
|
|
||||||
|
## Quality regression test |
||||||
|
Compares images in 2 directories to ensure they are the same |
||||||
|
|
||||||
|
1) Run an inference test to save a directory of "ground truth" images |
||||||
|
``` |
||||||
|
pytest tests/inference --output_dir tests/inference/baseline |
||||||
|
``` |
||||||
|
2) Make code edits |
||||||
|
|
||||||
|
3) Run inference and quality comparison tests |
||||||
|
``` |
||||||
|
pytest |
||||||
|
``` |
@ -0,0 +1,41 @@ |
|||||||
|
import os |
||||||
|
import pytest |
||||||
|
|
||||||
|
# Command line arguments for pytest |
||||||
|
def pytest_addoption(parser): |
||||||
|
parser.addoption('--baseline_dir', action="store", default='tests/inference/baseline', help='Directory for ground-truth images') |
||||||
|
parser.addoption('--test_dir', action="store", default='tests/inference/samples', help='Directory for images to test') |
||||||
|
parser.addoption('--metrics_file', action="store", default='tests/metrics.md', help='Output file for metrics') |
||||||
|
parser.addoption('--img_output_dir', action="store", default='tests/compare/samples', help='Output directory for diff metric images') |
||||||
|
|
||||||
|
# This initializes args at the beginning of the test session |
||||||
|
@pytest.fixture(scope="session", autouse=True) |
||||||
|
def args_pytest(pytestconfig): |
||||||
|
args = {} |
||||||
|
args['baseline_dir'] = pytestconfig.getoption('baseline_dir') |
||||||
|
args['test_dir'] = pytestconfig.getoption('test_dir') |
||||||
|
args['metrics_file'] = pytestconfig.getoption('metrics_file') |
||||||
|
args['img_output_dir'] = pytestconfig.getoption('img_output_dir') |
||||||
|
|
||||||
|
# Initialize metrics file |
||||||
|
with open(args['metrics_file'], 'a') as f: |
||||||
|
# if file is empty, write header |
||||||
|
if os.stat(args['metrics_file']).st_size == 0: |
||||||
|
f.write("| date | run | file | status | value | \n") |
||||||
|
f.write("| --- | --- | --- | --- | --- | \n") |
||||||
|
|
||||||
|
return args |
||||||
|
|
||||||
|
|
||||||
|
def gather_file_basenames(directory: str): |
||||||
|
files = [] |
||||||
|
for file in os.listdir(directory): |
||||||
|
if file.endswith(".png"): |
||||||
|
files.append(file) |
||||||
|
return files |
||||||
|
|
||||||
|
# Creates the list of baseline file names to use as a fixture |
||||||
|
def pytest_generate_tests(metafunc): |
||||||
|
if "baseline_fname" in metafunc.fixturenames: |
||||||
|
baseline_fnames = gather_file_basenames(metafunc.config.getoption("baseline_dir")) |
||||||
|
metafunc.parametrize("baseline_fname", baseline_fnames) |
@ -0,0 +1,195 @@ |
|||||||
|
import datetime |
||||||
|
import numpy as np |
||||||
|
import os |
||||||
|
from PIL import Image |
||||||
|
import pytest |
||||||
|
from pytest import fixture |
||||||
|
from typing import Tuple, List |
||||||
|
|
||||||
|
from cv2 import imread, cvtColor, COLOR_BGR2RGB |
||||||
|
from skimage.metrics import structural_similarity as ssim |
||||||
|
|
||||||
|
|
||||||
|
""" |
||||||
|
This test suite compares images in 2 directories by file name |
||||||
|
The directories are specified by the command line arguments --baseline_dir and --test_dir |
||||||
|
|
||||||
|
""" |
||||||
|
# ssim: Structural Similarity Index |
||||||
|
# Returns a tuple of (ssim, diff_image) |
||||||
|
def ssim_score(img0: np.ndarray, img1: np.ndarray) -> Tuple[float, np.ndarray]: |
||||||
|
score, diff = ssim(img0, img1, channel_axis=-1, full=True) |
||||||
|
# rescale the difference image to 0-255 range |
||||||
|
diff = (diff * 255).astype("uint8") |
||||||
|
return score, diff |
||||||
|
|
||||||
|
# Metrics must return a tuple of (score, diff_image) |
||||||
|
METRICS = {"ssim": ssim_score} |
||||||
|
METRICS_PASS_THRESHOLD = {"ssim": 0.95} |
||||||
|
|
||||||
|
|
||||||
|
class TestCompareImageMetrics: |
||||||
|
@fixture(scope="class") |
||||||
|
def test_file_names(self, args_pytest): |
||||||
|
test_dir = args_pytest['test_dir'] |
||||||
|
fnames = self.gather_file_basenames(test_dir) |
||||||
|
yield fnames |
||||||
|
del fnames |
||||||
|
|
||||||
|
@fixture(scope="class", autouse=True) |
||||||
|
def teardown(self, args_pytest): |
||||||
|
yield |
||||||
|
# Runs after all tests are complete |
||||||
|
# Aggregate output files into a grid of images |
||||||
|
baseline_dir = args_pytest['baseline_dir'] |
||||||
|
test_dir = args_pytest['test_dir'] |
||||||
|
img_output_dir = args_pytest['img_output_dir'] |
||||||
|
metrics_file = args_pytest['metrics_file'] |
||||||
|
|
||||||
|
grid_dir = os.path.join(img_output_dir, "grid") |
||||||
|
os.makedirs(grid_dir, exist_ok=True) |
||||||
|
|
||||||
|
for metric_dir in METRICS.keys(): |
||||||
|
metric_path = os.path.join(img_output_dir, metric_dir) |
||||||
|
for file in os.listdir(metric_path): |
||||||
|
if file.endswith(".png"): |
||||||
|
score = self.lookup_score_from_fname(file, metrics_file) |
||||||
|
image_file_list = [] |
||||||
|
image_file_list.append([ |
||||||
|
os.path.join(baseline_dir, file), |
||||||
|
os.path.join(test_dir, file), |
||||||
|
os.path.join(metric_path, file) |
||||||
|
]) |
||||||
|
# Create grid |
||||||
|
image_list = [[Image.open(file) for file in files] for files in image_file_list] |
||||||
|
grid = self.image_grid(image_list) |
||||||
|
grid.save(os.path.join(grid_dir, f"{metric_dir}_{score:.3f}_{file}")) |
||||||
|
|
||||||
|
# Tests run for each baseline file name |
||||||
|
@fixture() |
||||||
|
def fname(self, baseline_fname): |
||||||
|
yield baseline_fname |
||||||
|
del baseline_fname |
||||||
|
|
||||||
|
def test_directories_not_empty(self, args_pytest): |
||||||
|
baseline_dir = args_pytest['baseline_dir'] |
||||||
|
test_dir = args_pytest['test_dir'] |
||||||
|
assert len(os.listdir(baseline_dir)) != 0, f"Baseline directory {baseline_dir} is empty" |
||||||
|
assert len(os.listdir(test_dir)) != 0, f"Test directory {test_dir} is empty" |
||||||
|
|
||||||
|
def test_dir_has_all_matching_metadata(self, fname, test_file_names, args_pytest): |
||||||
|
# Check that all files in baseline_dir have a file in test_dir with matching metadata |
||||||
|
baseline_file_path = os.path.join(args_pytest['baseline_dir'], fname) |
||||||
|
file_paths = [os.path.join(args_pytest['test_dir'], f) for f in test_file_names] |
||||||
|
file_match = self.find_file_match(baseline_file_path, file_paths) |
||||||
|
assert file_match is not None, f"Could not find a file in {args_pytest['test_dir']} with matching metadata to {baseline_file_path}" |
||||||
|
|
||||||
|
# For a baseline image file, finds the corresponding file name in test_dir and |
||||||
|
# compares the images using the metrics in METRICS |
||||||
|
@pytest.mark.parametrize("metric", METRICS.keys()) |
||||||
|
def test_pipeline_compare( |
||||||
|
self, |
||||||
|
args_pytest, |
||||||
|
fname, |
||||||
|
test_file_names, |
||||||
|
metric, |
||||||
|
): |
||||||
|
baseline_dir = args_pytest['baseline_dir'] |
||||||
|
test_dir = args_pytest['test_dir'] |
||||||
|
metrics_output_file = args_pytest['metrics_file'] |
||||||
|
img_output_dir = args_pytest['img_output_dir'] |
||||||
|
|
||||||
|
baseline_file_path = os.path.join(baseline_dir, fname) |
||||||
|
|
||||||
|
# Find file match |
||||||
|
file_paths = [os.path.join(test_dir, f) for f in test_file_names] |
||||||
|
test_file = self.find_file_match(baseline_file_path, file_paths) |
||||||
|
|
||||||
|
# Run metrics |
||||||
|
sample_baseline = self.read_img(baseline_file_path) |
||||||
|
sample_secondary = self.read_img(test_file) |
||||||
|
|
||||||
|
score, metric_img = METRICS[metric](sample_baseline, sample_secondary) |
||||||
|
metric_status = score > METRICS_PASS_THRESHOLD[metric] |
||||||
|
|
||||||
|
# Save metric values |
||||||
|
with open(metrics_output_file, 'a') as f: |
||||||
|
run_info = os.path.splitext(fname)[0] |
||||||
|
metric_status_str = "PASS ✅" if metric_status else "FAIL ❌" |
||||||
|
date_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") |
||||||
|
f.write(f"| {date_str} | {run_info} | {metric} | {metric_status_str} | {score} | \n") |
||||||
|
|
||||||
|
# Save metric image |
||||||
|
metric_img_dir = os.path.join(img_output_dir, metric) |
||||||
|
os.makedirs(metric_img_dir, exist_ok=True) |
||||||
|
output_filename = f'{fname}' |
||||||
|
Image.fromarray(metric_img).save(os.path.join(metric_img_dir, output_filename)) |
||||||
|
|
||||||
|
assert score > METRICS_PASS_THRESHOLD[metric] |
||||||
|
|
||||||
|
def read_img(self, filename: str) -> np.ndarray: |
||||||
|
cvImg = imread(filename) |
||||||
|
cvImg = cvtColor(cvImg, COLOR_BGR2RGB) |
||||||
|
return cvImg |
||||||
|
|
||||||
|
def image_grid(self, img_list: list[list[Image.Image]]): |
||||||
|
# imgs is a 2D list of images |
||||||
|
# Assumes the input images are a rectangular grid of equal sized images |
||||||
|
rows = len(img_list) |
||||||
|
cols = len(img_list[0]) |
||||||
|
|
||||||
|
w, h = img_list[0][0].size |
||||||
|
grid = Image.new('RGB', size=(cols*w, rows*h)) |
||||||
|
|
||||||
|
for i, row in enumerate(img_list): |
||||||
|
for j, img in enumerate(row): |
||||||
|
grid.paste(img, box=(j*w, i*h)) |
||||||
|
return grid |
||||||
|
|
||||||
|
def lookup_score_from_fname(self, |
||||||
|
fname: str, |
||||||
|
metrics_output_file: str |
||||||
|
) -> float: |
||||||
|
fname_basestr = os.path.splitext(fname)[0] |
||||||
|
with open(metrics_output_file, 'r') as f: |
||||||
|
for line in f: |
||||||
|
if fname_basestr in line: |
||||||
|
score = float(line.split('|')[5]) |
||||||
|
return score |
||||||
|
raise ValueError(f"Could not find score for {fname} in {metrics_output_file}") |
||||||
|
|
||||||
|
def gather_file_basenames(self, directory: str): |
||||||
|
files = [] |
||||||
|
for file in os.listdir(directory): |
||||||
|
if file.endswith(".png"): |
||||||
|
files.append(file) |
||||||
|
return files |
||||||
|
|
||||||
|
def read_file_prompt(self, fname:str) -> str: |
||||||
|
# Read prompt from image file metadata |
||||||
|
img = Image.open(fname) |
||||||
|
img.load() |
||||||
|
return img.info['prompt'] |
||||||
|
|
||||||
|
def find_file_match(self, baseline_file: str, file_paths: List[str]): |
||||||
|
# Find a file in file_paths with matching metadata to baseline_file |
||||||
|
baseline_prompt = self.read_file_prompt(baseline_file) |
||||||
|
|
||||||
|
# Do not match empty prompts |
||||||
|
if baseline_prompt is None or baseline_prompt == "": |
||||||
|
return None |
||||||
|
|
||||||
|
# Find file match |
||||||
|
# Reorder test_file_names so that the file with matching name is first |
||||||
|
# This is an optimization because matching file names are more likely |
||||||
|
# to have matching metadata if they were generated with the same script |
||||||
|
basename = os.path.basename(baseline_file) |
||||||
|
file_path_basenames = [os.path.basename(f) for f in file_paths] |
||||||
|
if basename in file_path_basenames: |
||||||
|
match_index = file_path_basenames.index(basename) |
||||||
|
file_paths.insert(0, file_paths.pop(match_index)) |
||||||
|
|
||||||
|
for f in file_paths: |
||||||
|
test_file_prompt = self.read_file_prompt(f) |
||||||
|
if baseline_prompt == test_file_prompt: |
||||||
|
return f |
@ -0,0 +1,36 @@ |
|||||||
|
import os |
||||||
|
import pytest |
||||||
|
|
||||||
|
# Command line arguments for pytest |
||||||
|
def pytest_addoption(parser): |
||||||
|
parser.addoption('--output_dir', action="store", default='tests/inference/samples', help='Output directory for generated images') |
||||||
|
parser.addoption("--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0", help="Specify the IP address to listen on (default: 127.0.0.1). If --listen is provided without an argument, it defaults to 0.0.0.0. (listens on all)") |
||||||
|
parser.addoption("--port", type=int, default=8188, help="Set the listen port.") |
||||||
|
|
||||||
|
# This initializes args at the beginning of the test session |
||||||
|
@pytest.fixture(scope="session", autouse=True) |
||||||
|
def args_pytest(pytestconfig): |
||||||
|
args = {} |
||||||
|
args['output_dir'] = pytestconfig.getoption('output_dir') |
||||||
|
args['listen'] = pytestconfig.getoption('listen') |
||||||
|
args['port'] = pytestconfig.getoption('port') |
||||||
|
|
||||||
|
os.makedirs(args['output_dir'], exist_ok=True) |
||||||
|
|
||||||
|
return args |
||||||
|
|
||||||
|
def pytest_collection_modifyitems(items): |
||||||
|
# Modifies items so tests run in the correct order |
||||||
|
|
||||||
|
LAST_TESTS = ['test_quality'] |
||||||
|
|
||||||
|
# Move the last items to the end |
||||||
|
last_items = [] |
||||||
|
for test_name in LAST_TESTS: |
||||||
|
for item in items.copy(): |
||||||
|
print(item.module.__name__, item) |
||||||
|
if item.module.__name__ == test_name: |
||||||
|
last_items.append(item) |
||||||
|
items.remove(item) |
||||||
|
|
||||||
|
items.extend(last_items) |
@ -0,0 +1,144 @@ |
|||||||
|
{ |
||||||
|
"4": { |
||||||
|
"inputs": { |
||||||
|
"ckpt_name": "sd_xl_base_1.0.safetensors" |
||||||
|
}, |
||||||
|
"class_type": "CheckpointLoaderSimple" |
||||||
|
}, |
||||||
|
"5": { |
||||||
|
"inputs": { |
||||||
|
"width": 1024, |
||||||
|
"height": 1024, |
||||||
|
"batch_size": 1 |
||||||
|
}, |
||||||
|
"class_type": "EmptyLatentImage" |
||||||
|
}, |
||||||
|
"6": { |
||||||
|
"inputs": { |
||||||
|
"text": "a photo of a cat", |
||||||
|
"clip": [ |
||||||
|
"4", |
||||||
|
1 |
||||||
|
] |
||||||
|
}, |
||||||
|
"class_type": "CLIPTextEncode" |
||||||
|
}, |
||||||
|
"10": { |
||||||
|
"inputs": { |
||||||
|
"add_noise": "enable", |
||||||
|
"noise_seed": 42, |
||||||
|
"steps": 20, |
||||||
|
"cfg": 7.5, |
||||||
|
"sampler_name": "euler", |
||||||
|
"scheduler": "normal", |
||||||
|
"start_at_step": 0, |
||||||
|
"end_at_step": 32, |
||||||
|
"return_with_leftover_noise": "enable", |
||||||
|
"model": [ |
||||||
|
"4", |
||||||
|
0 |
||||||
|
], |
||||||
|
"positive": [ |
||||||
|
"6", |
||||||
|
0 |
||||||
|
], |
||||||
|
"negative": [ |
||||||
|
"15", |
||||||
|
0 |
||||||
|
], |
||||||
|
"latent_image": [ |
||||||
|
"5", |
||||||
|
0 |
||||||
|
] |
||||||
|
}, |
||||||
|
"class_type": "KSamplerAdvanced" |
||||||
|
}, |
||||||
|
"12": { |
||||||
|
"inputs": { |
||||||
|
"samples": [ |
||||||
|
"14", |
||||||
|
0 |
||||||
|
], |
||||||
|
"vae": [ |
||||||
|
"4", |
||||||
|
2 |
||||||
|
] |
||||||
|
}, |
||||||
|
"class_type": "VAEDecode" |
||||||
|
}, |
||||||
|
"13": { |
||||||
|
"inputs": { |
||||||
|
"filename_prefix": "test_inference", |
||||||
|
"images": [ |
||||||
|
"12", |
||||||
|
0 |
||||||
|
] |
||||||
|
}, |
||||||
|
"class_type": "SaveImage" |
||||||
|
}, |
||||||
|
"14": { |
||||||
|
"inputs": { |
||||||
|
"add_noise": "disable", |
||||||
|
"noise_seed": 42, |
||||||
|
"steps": 20, |
||||||
|
"cfg": 7.5, |
||||||
|
"sampler_name": "euler", |
||||||
|
"scheduler": "normal", |
||||||
|
"start_at_step": 32, |
||||||
|
"end_at_step": 10000, |
||||||
|
"return_with_leftover_noise": "disable", |
||||||
|
"model": [ |
||||||
|
"16", |
||||||
|
0 |
||||||
|
], |
||||||
|
"positive": [ |
||||||
|
"17", |
||||||
|
0 |
||||||
|
], |
||||||
|
"negative": [ |
||||||
|
"20", |
||||||
|
0 |
||||||
|
], |
||||||
|
"latent_image": [ |
||||||
|
"10", |
||||||
|
0 |
||||||
|
] |
||||||
|
}, |
||||||
|
"class_type": "KSamplerAdvanced" |
||||||
|
}, |
||||||
|
"15": { |
||||||
|
"inputs": { |
||||||
|
"conditioning": [ |
||||||
|
"6", |
||||||
|
0 |
||||||
|
] |
||||||
|
}, |
||||||
|
"class_type": "ConditioningZeroOut" |
||||||
|
}, |
||||||
|
"16": { |
||||||
|
"inputs": { |
||||||
|
"ckpt_name": "sd_xl_refiner_1.0.safetensors" |
||||||
|
}, |
||||||
|
"class_type": "CheckpointLoaderSimple" |
||||||
|
}, |
||||||
|
"17": { |
||||||
|
"inputs": { |
||||||
|
"text": "a photo of a cat", |
||||||
|
"clip": [ |
||||||
|
"16", |
||||||
|
1 |
||||||
|
] |
||||||
|
}, |
||||||
|
"class_type": "CLIPTextEncode" |
||||||
|
}, |
||||||
|
"20": { |
||||||
|
"inputs": { |
||||||
|
"text": "", |
||||||
|
"clip": [ |
||||||
|
"16", |
||||||
|
1 |
||||||
|
] |
||||||
|
}, |
||||||
|
"class_type": "CLIPTextEncode" |
||||||
|
} |
||||||
|
} |
@ -0,0 +1,247 @@ |
|||||||
|
from copy import deepcopy |
||||||
|
from io import BytesIO |
||||||
|
from urllib import request |
||||||
|
import numpy |
||||||
|
import os |
||||||
|
from PIL import Image |
||||||
|
import pytest |
||||||
|
from pytest import fixture |
||||||
|
import time |
||||||
|
import torch |
||||||
|
from typing import Union |
||||||
|
import json |
||||||
|
import subprocess |
||||||
|
import websocket #NOTE: websocket-client (https://github.com/websocket-client/websocket-client) |
||||||
|
import uuid |
||||||
|
import urllib.request |
||||||
|
import urllib.parse |
||||||
|
|
||||||
|
# Currently causes an error when running pytest with built-in pytest args |
||||||
|
# TODO: modify cli_args.py to not parse args on import |
||||||
|
# We will hard-code sampler and scheduler lists for now |
||||||
|
# from comfy.samplers import KSampler |
||||||
|
|
||||||
|
""" |
||||||
|
These tests generate and save images through a range of parameters |
||||||
|
""" |
||||||
|
|
||||||
|
class ComfyGraph: |
||||||
|
def __init__(self, |
||||||
|
graph: dict, |
||||||
|
sampler_nodes: list[str], |
||||||
|
): |
||||||
|
self.graph = graph |
||||||
|
self.sampler_nodes = sampler_nodes |
||||||
|
|
||||||
|
def set_prompt(self, prompt, negative_prompt=None): |
||||||
|
# Sets the prompt for the sampler nodes (eg. base and refiner) |
||||||
|
for node in self.sampler_nodes: |
||||||
|
prompt_node = self.graph[node]['inputs']['positive'][0] |
||||||
|
self.graph[prompt_node]['inputs']['text'] = prompt |
||||||
|
if negative_prompt: |
||||||
|
negative_prompt_node = self.graph[node]['inputs']['negative'][0] |
||||||
|
self.graph[negative_prompt_node]['inputs']['text'] = negative_prompt |
||||||
|
|
||||||
|
def set_sampler_name(self, sampler_name:str, ): |
||||||
|
# sets the sampler name for the sampler nodes (eg. base and refiner) |
||||||
|
for node in self.sampler_nodes: |
||||||
|
self.graph[node]['inputs']['sampler_name'] = sampler_name |
||||||
|
|
||||||
|
def set_scheduler(self, scheduler:str): |
||||||
|
# sets the sampler name for the sampler nodes (eg. base and refiner) |
||||||
|
for node in self.sampler_nodes: |
||||||
|
self.graph[node]['inputs']['scheduler'] = scheduler |
||||||
|
|
||||||
|
def set_filename_prefix(self, prefix:str): |
||||||
|
# sets the filename prefix for the save nodes |
||||||
|
for node in self.graph: |
||||||
|
if self.graph[node]['class_type'] == 'SaveImage': |
||||||
|
self.graph[node]['inputs']['filename_prefix'] = prefix |
||||||
|
|
||||||
|
|
||||||
|
class ComfyClient: |
||||||
|
# From examples/websockets_api_example.py |
||||||
|
|
||||||
|
def connect(self, |
||||||
|
listen:str = '127.0.0.1', |
||||||
|
port:Union[str,int] = 8188, |
||||||
|
client_id: str = str(uuid.uuid4()) |
||||||
|
): |
||||||
|
self.client_id = client_id |
||||||
|
self.server_address = f"{listen}:{port}" |
||||||
|
ws = websocket.WebSocket() |
||||||
|
ws.connect("ws://{}/ws?clientId={}".format(self.server_address, self.client_id)) |
||||||
|
self.ws = ws |
||||||
|
|
||||||
|
def queue_prompt(self, prompt): |
||||||
|
p = {"prompt": prompt, "client_id": self.client_id} |
||||||
|
data = json.dumps(p).encode('utf-8') |
||||||
|
req = urllib.request.Request("http://{}/prompt".format(self.server_address), data=data) |
||||||
|
return json.loads(urllib.request.urlopen(req).read()) |
||||||
|
|
||||||
|
def get_image(self, filename, subfolder, folder_type): |
||||||
|
data = {"filename": filename, "subfolder": subfolder, "type": folder_type} |
||||||
|
url_values = urllib.parse.urlencode(data) |
||||||
|
with urllib.request.urlopen("http://{}/view?{}".format(self.server_address, url_values)) as response: |
||||||
|
return response.read() |
||||||
|
|
||||||
|
def get_history(self, prompt_id): |
||||||
|
with urllib.request.urlopen("http://{}/history/{}".format(self.server_address, prompt_id)) as response: |
||||||
|
return json.loads(response.read()) |
||||||
|
|
||||||
|
def get_images(self, graph, save=True): |
||||||
|
prompt = graph |
||||||
|
if not save: |
||||||
|
# Replace save nodes with preview nodes |
||||||
|
prompt_str = json.dumps(prompt) |
||||||
|
prompt_str = prompt_str.replace('SaveImage', 'PreviewImage') |
||||||
|
prompt = json.loads(prompt_str) |
||||||
|
|
||||||
|
prompt_id = self.queue_prompt(prompt)['prompt_id'] |
||||||
|
output_images = {} |
||||||
|
while True: |
||||||
|
out = self.ws.recv() |
||||||
|
if isinstance(out, str): |
||||||
|
message = json.loads(out) |
||||||
|
if message['type'] == 'executing': |
||||||
|
data = message['data'] |
||||||
|
if data['node'] is None and data['prompt_id'] == prompt_id: |
||||||
|
break #Execution is done |
||||||
|
else: |
||||||
|
continue #previews are binary data |
||||||
|
|
||||||
|
history = self.get_history(prompt_id)[prompt_id] |
||||||
|
for o in history['outputs']: |
||||||
|
for node_id in history['outputs']: |
||||||
|
node_output = history['outputs'][node_id] |
||||||
|
if 'images' in node_output: |
||||||
|
images_output = [] |
||||||
|
for image in node_output['images']: |
||||||
|
image_data = self.get_image(image['filename'], image['subfolder'], image['type']) |
||||||
|
images_output.append(image_data) |
||||||
|
output_images[node_id] = images_output |
||||||
|
|
||||||
|
return output_images |
||||||
|
|
||||||
|
# |
||||||
|
# Initialize graphs |
||||||
|
# |
||||||
|
default_graph_file = 'tests/inference/graphs/default_graph_sdxl1_0.json' |
||||||
|
with open(default_graph_file, 'r') as file: |
||||||
|
default_graph = json.loads(file.read()) |
||||||
|
DEFAULT_COMFY_GRAPH = ComfyGraph(graph=default_graph, sampler_nodes=['10','14']) |
||||||
|
DEFAULT_COMFY_GRAPH_ID = os.path.splitext(os.path.basename(default_graph_file))[0] |
||||||
|
|
||||||
|
# |
||||||
|
# Loop through these variables |
||||||
|
# |
||||||
|
comfy_graph_list = [DEFAULT_COMFY_GRAPH] |
||||||
|
comfy_graph_ids = [DEFAULT_COMFY_GRAPH_ID] |
||||||
|
prompt_list = [ |
||||||
|
'a painting of a cat', |
||||||
|
] |
||||||
|
#TODO use sampler and scheduler list from comfy.samplers.KSampler |
||||||
|
# sampler_list = KSampler.SAMPLERS |
||||||
|
# scheduler_list = KSampler.SCHEDULERS |
||||||
|
# Hard coded sampler and scheduler lists for now |
||||||
|
SCHEDULERS = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform"] |
||||||
|
SAMPLERS = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral", |
||||||
|
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu", |
||||||
|
"dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddim", "uni_pc", "uni_pc_bh2"] |
||||||
|
sampler_list = SAMPLERS |
||||||
|
scheduler_list = SCHEDULERS |
||||||
|
@pytest.mark.inference |
||||||
|
@pytest.mark.parametrize("sampler", sampler_list) |
||||||
|
@pytest.mark.parametrize("scheduler", scheduler_list) |
||||||
|
@pytest.mark.parametrize("prompt", prompt_list) |
||||||
|
class TestInference: |
||||||
|
# |
||||||
|
# Initialize server and client |
||||||
|
# |
||||||
|
@fixture(scope="class", autouse=True) |
||||||
|
def _server(self, args_pytest): |
||||||
|
# Start server |
||||||
|
p = subprocess.Popen([ |
||||||
|
'python','main.py', |
||||||
|
'--output-directory', args_pytest["output_dir"], |
||||||
|
'--listen', args_pytest["listen"], |
||||||
|
'--port', str(args_pytest["port"]), |
||||||
|
]) |
||||||
|
yield |
||||||
|
p.kill() |
||||||
|
torch.cuda.empty_cache() |
||||||
|
|
||||||
|
def start_client(self, listen:str, port:int): |
||||||
|
# Start client |
||||||
|
comfy_client = ComfyClient() |
||||||
|
# Connect to server (with retries) |
||||||
|
n_tries = 5 |
||||||
|
for i in range(n_tries): |
||||||
|
time.sleep(4) |
||||||
|
try: |
||||||
|
comfy_client.connect(listen=listen, port=port) |
||||||
|
except ConnectionRefusedError as e: |
||||||
|
print(e) |
||||||
|
print(f"({i+1}/{n_tries}) Retrying...") |
||||||
|
else: |
||||||
|
break |
||||||
|
return comfy_client |
||||||
|
|
||||||
|
# |
||||||
|
# Client and graph fixtures with server warmup |
||||||
|
# |
||||||
|
# Returns a "_client_graph", which is client-graph pair corresponding to an initialized server |
||||||
|
# The "graph" is the default graph |
||||||
|
@fixture(scope="class", params=comfy_graph_list, ids=comfy_graph_ids, autouse=True) |
||||||
|
def _client_graph(self, request, args_pytest, _server) -> (ComfyClient, ComfyGraph): |
||||||
|
comfy_graph = request.param |
||||||
|
|
||||||
|
# Start client |
||||||
|
comfy_client = self.start_client(args_pytest["listen"], args_pytest["port"]) |
||||||
|
|
||||||
|
# Warm up pipeline |
||||||
|
comfy_client.get_images(graph=comfy_graph.graph, save=False) |
||||||
|
|
||||||
|
yield comfy_client, comfy_graph |
||||||
|
del comfy_client |
||||||
|
del comfy_graph |
||||||
|
torch.cuda.empty_cache() |
||||||
|
|
||||||
|
@fixture |
||||||
|
def client(self, _client_graph): |
||||||
|
client = _client_graph[0] |
||||||
|
yield client |
||||||
|
|
||||||
|
@fixture |
||||||
|
def comfy_graph(self, _client_graph): |
||||||
|
# avoid mutating the graph |
||||||
|
graph = deepcopy(_client_graph[1]) |
||||||
|
yield graph |
||||||
|
|
||||||
|
def test_comfy( |
||||||
|
self, |
||||||
|
client, |
||||||
|
comfy_graph, |
||||||
|
sampler, |
||||||
|
scheduler, |
||||||
|
prompt, |
||||||
|
request |
||||||
|
): |
||||||
|
test_info = request.node.name |
||||||
|
comfy_graph.set_filename_prefix(test_info) |
||||||
|
# Settings for comfy graph |
||||||
|
comfy_graph.set_sampler_name(sampler) |
||||||
|
comfy_graph.set_scheduler(scheduler) |
||||||
|
comfy_graph.set_prompt(prompt) |
||||||
|
|
||||||
|
# Generate |
||||||
|
images = client.get_images(comfy_graph.graph) |
||||||
|
|
||||||
|
assert len(images) != 0, "No images generated" |
||||||
|
# assert all images are not blank |
||||||
|
for images_output in images.values(): |
||||||
|
for image_data in images_output: |
||||||
|
pil_image = Image.open(BytesIO(image_data)) |
||||||
|
assert numpy.array(pil_image).any() != 0, "Image is blank" |
||||||
|
|
||||||
|
|
Loading…
Reference in new issue