diff --git a/.github/workflows/test-build.yml b/.github/workflows/test-build.yml new file mode 100644 index 00000000..421dd5ee --- /dev/null +++ b/.github/workflows/test-build.yml @@ -0,0 +1,31 @@ +name: Build package + +# +# This workflow is a test of the python package build. +# Install Python dependencies across different Python versions. +# + +on: + push: + paths: + - "requirements.txt" + - ".github/workflows/test-build.yml" + +jobs: + build: + name: Build Test + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python-version: ["3.8", "3.9", "3.10", "3.11"] + steps: + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements.txt \ No newline at end of file diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 00000000..b5a68e0f --- /dev/null +++ b/pytest.ini @@ -0,0 +1,5 @@ +[pytest] +markers = + inference: mark as inference test (deselect with '-m "not inference"') +testpaths = tests +addopts = -s \ No newline at end of file diff --git a/tests/README.md b/tests/README.md new file mode 100644 index 00000000..2005fd45 --- /dev/null +++ b/tests/README.md @@ -0,0 +1,29 @@ +# Automated Testing + +## Running tests locally + +Additional requirements for running tests: +``` +pip install pytest +pip install websocket-client==1.6.1 +opencv-python==4.6.0.66 +scikit-image==0.21.0 +``` +Run inference tests: +``` +pytest tests/inference +``` + +## Quality regression test +Compares images in 2 directories to ensure they are the same + +1) Run an inference test to save a directory of "ground truth" images +``` + pytest tests/inference --output_dir tests/inference/baseline +``` +2) Make code edits + +3) Run inference and quality comparison tests +``` +pytest +``` \ No newline at end of file diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/compare/conftest.py b/tests/compare/conftest.py new file mode 100644 index 00000000..dd5078c9 --- /dev/null +++ b/tests/compare/conftest.py @@ -0,0 +1,41 @@ +import os +import pytest + +# Command line arguments for pytest +def pytest_addoption(parser): + parser.addoption('--baseline_dir', action="store", default='tests/inference/baseline', help='Directory for ground-truth images') + parser.addoption('--test_dir', action="store", default='tests/inference/samples', help='Directory for images to test') + parser.addoption('--metrics_file', action="store", default='tests/metrics.md', help='Output file for metrics') + parser.addoption('--img_output_dir', action="store", default='tests/compare/samples', help='Output directory for diff metric images') + +# This initializes args at the beginning of the test session +@pytest.fixture(scope="session", autouse=True) +def args_pytest(pytestconfig): + args = {} + args['baseline_dir'] = pytestconfig.getoption('baseline_dir') + args['test_dir'] = pytestconfig.getoption('test_dir') + args['metrics_file'] = pytestconfig.getoption('metrics_file') + args['img_output_dir'] = pytestconfig.getoption('img_output_dir') + + # Initialize metrics file + with open(args['metrics_file'], 'a') as f: + # if file is empty, write header + if os.stat(args['metrics_file']).st_size == 0: + f.write("| date | run | file | status | value | \n") + f.write("| --- | --- | --- | --- | --- | \n") + + return args + + +def gather_file_basenames(directory: str): + files = [] + for file in os.listdir(directory): + if file.endswith(".png"): + files.append(file) + return files + +# Creates the list of baseline file names to use as a fixture +def pytest_generate_tests(metafunc): + if "baseline_fname" in metafunc.fixturenames: + baseline_fnames = gather_file_basenames(metafunc.config.getoption("baseline_dir")) + metafunc.parametrize("baseline_fname", baseline_fnames) diff --git a/tests/compare/test_quality.py b/tests/compare/test_quality.py new file mode 100644 index 00000000..92a2d5a8 --- /dev/null +++ b/tests/compare/test_quality.py @@ -0,0 +1,195 @@ +import datetime +import numpy as np +import os +from PIL import Image +import pytest +from pytest import fixture +from typing import Tuple, List + +from cv2 import imread, cvtColor, COLOR_BGR2RGB +from skimage.metrics import structural_similarity as ssim + + +""" +This test suite compares images in 2 directories by file name +The directories are specified by the command line arguments --baseline_dir and --test_dir + +""" +# ssim: Structural Similarity Index +# Returns a tuple of (ssim, diff_image) +def ssim_score(img0: np.ndarray, img1: np.ndarray) -> Tuple[float, np.ndarray]: + score, diff = ssim(img0, img1, channel_axis=-1, full=True) + # rescale the difference image to 0-255 range + diff = (diff * 255).astype("uint8") + return score, diff + +# Metrics must return a tuple of (score, diff_image) +METRICS = {"ssim": ssim_score} +METRICS_PASS_THRESHOLD = {"ssim": 0.95} + + +class TestCompareImageMetrics: + @fixture(scope="class") + def test_file_names(self, args_pytest): + test_dir = args_pytest['test_dir'] + fnames = self.gather_file_basenames(test_dir) + yield fnames + del fnames + + @fixture(scope="class", autouse=True) + def teardown(self, args_pytest): + yield + # Runs after all tests are complete + # Aggregate output files into a grid of images + baseline_dir = args_pytest['baseline_dir'] + test_dir = args_pytest['test_dir'] + img_output_dir = args_pytest['img_output_dir'] + metrics_file = args_pytest['metrics_file'] + + grid_dir = os.path.join(img_output_dir, "grid") + os.makedirs(grid_dir, exist_ok=True) + + for metric_dir in METRICS.keys(): + metric_path = os.path.join(img_output_dir, metric_dir) + for file in os.listdir(metric_path): + if file.endswith(".png"): + score = self.lookup_score_from_fname(file, metrics_file) + image_file_list = [] + image_file_list.append([ + os.path.join(baseline_dir, file), + os.path.join(test_dir, file), + os.path.join(metric_path, file) + ]) + # Create grid + image_list = [[Image.open(file) for file in files] for files in image_file_list] + grid = self.image_grid(image_list) + grid.save(os.path.join(grid_dir, f"{metric_dir}_{score:.3f}_{file}")) + + # Tests run for each baseline file name + @fixture() + def fname(self, baseline_fname): + yield baseline_fname + del baseline_fname + + def test_directories_not_empty(self, args_pytest): + baseline_dir = args_pytest['baseline_dir'] + test_dir = args_pytest['test_dir'] + assert len(os.listdir(baseline_dir)) != 0, f"Baseline directory {baseline_dir} is empty" + assert len(os.listdir(test_dir)) != 0, f"Test directory {test_dir} is empty" + + def test_dir_has_all_matching_metadata(self, fname, test_file_names, args_pytest): + # Check that all files in baseline_dir have a file in test_dir with matching metadata + baseline_file_path = os.path.join(args_pytest['baseline_dir'], fname) + file_paths = [os.path.join(args_pytest['test_dir'], f) for f in test_file_names] + file_match = self.find_file_match(baseline_file_path, file_paths) + assert file_match is not None, f"Could not find a file in {args_pytest['test_dir']} with matching metadata to {baseline_file_path}" + + # For a baseline image file, finds the corresponding file name in test_dir and + # compares the images using the metrics in METRICS + @pytest.mark.parametrize("metric", METRICS.keys()) + def test_pipeline_compare( + self, + args_pytest, + fname, + test_file_names, + metric, + ): + baseline_dir = args_pytest['baseline_dir'] + test_dir = args_pytest['test_dir'] + metrics_output_file = args_pytest['metrics_file'] + img_output_dir = args_pytest['img_output_dir'] + + baseline_file_path = os.path.join(baseline_dir, fname) + + # Find file match + file_paths = [os.path.join(test_dir, f) for f in test_file_names] + test_file = self.find_file_match(baseline_file_path, file_paths) + + # Run metrics + sample_baseline = self.read_img(baseline_file_path) + sample_secondary = self.read_img(test_file) + + score, metric_img = METRICS[metric](sample_baseline, sample_secondary) + metric_status = score > METRICS_PASS_THRESHOLD[metric] + + # Save metric values + with open(metrics_output_file, 'a') as f: + run_info = os.path.splitext(fname)[0] + metric_status_str = "PASS ✅" if metric_status else "FAIL ❌" + date_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") + f.write(f"| {date_str} | {run_info} | {metric} | {metric_status_str} | {score} | \n") + + # Save metric image + metric_img_dir = os.path.join(img_output_dir, metric) + os.makedirs(metric_img_dir, exist_ok=True) + output_filename = f'{fname}' + Image.fromarray(metric_img).save(os.path.join(metric_img_dir, output_filename)) + + assert score > METRICS_PASS_THRESHOLD[metric] + + def read_img(self, filename: str) -> np.ndarray: + cvImg = imread(filename) + cvImg = cvtColor(cvImg, COLOR_BGR2RGB) + return cvImg + + def image_grid(self, img_list: list[list[Image.Image]]): + # imgs is a 2D list of images + # Assumes the input images are a rectangular grid of equal sized images + rows = len(img_list) + cols = len(img_list[0]) + + w, h = img_list[0][0].size + grid = Image.new('RGB', size=(cols*w, rows*h)) + + for i, row in enumerate(img_list): + for j, img in enumerate(row): + grid.paste(img, box=(j*w, i*h)) + return grid + + def lookup_score_from_fname(self, + fname: str, + metrics_output_file: str + ) -> float: + fname_basestr = os.path.splitext(fname)[0] + with open(metrics_output_file, 'r') as f: + for line in f: + if fname_basestr in line: + score = float(line.split('|')[5]) + return score + raise ValueError(f"Could not find score for {fname} in {metrics_output_file}") + + def gather_file_basenames(self, directory: str): + files = [] + for file in os.listdir(directory): + if file.endswith(".png"): + files.append(file) + return files + + def read_file_prompt(self, fname:str) -> str: + # Read prompt from image file metadata + img = Image.open(fname) + img.load() + return img.info['prompt'] + + def find_file_match(self, baseline_file: str, file_paths: List[str]): + # Find a file in file_paths with matching metadata to baseline_file + baseline_prompt = self.read_file_prompt(baseline_file) + + # Do not match empty prompts + if baseline_prompt is None or baseline_prompt == "": + return None + + # Find file match + # Reorder test_file_names so that the file with matching name is first + # This is an optimization because matching file names are more likely + # to have matching metadata if they were generated with the same script + basename = os.path.basename(baseline_file) + file_path_basenames = [os.path.basename(f) for f in file_paths] + if basename in file_path_basenames: + match_index = file_path_basenames.index(basename) + file_paths.insert(0, file_paths.pop(match_index)) + + for f in file_paths: + test_file_prompt = self.read_file_prompt(f) + if baseline_prompt == test_file_prompt: + return f \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..1a35880a --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,36 @@ +import os +import pytest + +# Command line arguments for pytest +def pytest_addoption(parser): + parser.addoption('--output_dir', action="store", default='tests/inference/samples', help='Output directory for generated images') + parser.addoption("--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0", help="Specify the IP address to listen on (default: 127.0.0.1). If --listen is provided without an argument, it defaults to 0.0.0.0. (listens on all)") + parser.addoption("--port", type=int, default=8188, help="Set the listen port.") + +# This initializes args at the beginning of the test session +@pytest.fixture(scope="session", autouse=True) +def args_pytest(pytestconfig): + args = {} + args['output_dir'] = pytestconfig.getoption('output_dir') + args['listen'] = pytestconfig.getoption('listen') + args['port'] = pytestconfig.getoption('port') + + os.makedirs(args['output_dir'], exist_ok=True) + + return args + +def pytest_collection_modifyitems(items): + # Modifies items so tests run in the correct order + + LAST_TESTS = ['test_quality'] + + # Move the last items to the end + last_items = [] + for test_name in LAST_TESTS: + for item in items.copy(): + print(item.module.__name__, item) + if item.module.__name__ == test_name: + last_items.append(item) + items.remove(item) + + items.extend(last_items) diff --git a/tests/inference/__init__.py b/tests/inference/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/inference/graphs/default_graph_sdxl1_0.json b/tests/inference/graphs/default_graph_sdxl1_0.json new file mode 100644 index 00000000..c06c6829 --- /dev/null +++ b/tests/inference/graphs/default_graph_sdxl1_0.json @@ -0,0 +1,144 @@ +{ + "4": { + "inputs": { + "ckpt_name": "sd_xl_base_1.0.safetensors" + }, + "class_type": "CheckpointLoaderSimple" + }, + "5": { + "inputs": { + "width": 1024, + "height": 1024, + "batch_size": 1 + }, + "class_type": "EmptyLatentImage" + }, + "6": { + "inputs": { + "text": "a photo of a cat", + "clip": [ + "4", + 1 + ] + }, + "class_type": "CLIPTextEncode" + }, + "10": { + "inputs": { + "add_noise": "enable", + "noise_seed": 42, + "steps": 20, + "cfg": 7.5, + "sampler_name": "euler", + "scheduler": "normal", + "start_at_step": 0, + "end_at_step": 32, + "return_with_leftover_noise": "enable", + "model": [ + "4", + 0 + ], + "positive": [ + "6", + 0 + ], + "negative": [ + "15", + 0 + ], + "latent_image": [ + "5", + 0 + ] + }, + "class_type": "KSamplerAdvanced" + }, + "12": { + "inputs": { + "samples": [ + "14", + 0 + ], + "vae": [ + "4", + 2 + ] + }, + "class_type": "VAEDecode" + }, + "13": { + "inputs": { + "filename_prefix": "test_inference", + "images": [ + "12", + 0 + ] + }, + "class_type": "SaveImage" + }, + "14": { + "inputs": { + "add_noise": "disable", + "noise_seed": 42, + "steps": 20, + "cfg": 7.5, + "sampler_name": "euler", + "scheduler": "normal", + "start_at_step": 32, + "end_at_step": 10000, + "return_with_leftover_noise": "disable", + "model": [ + "16", + 0 + ], + "positive": [ + "17", + 0 + ], + "negative": [ + "20", + 0 + ], + "latent_image": [ + "10", + 0 + ] + }, + "class_type": "KSamplerAdvanced" + }, + "15": { + "inputs": { + "conditioning": [ + "6", + 0 + ] + }, + "class_type": "ConditioningZeroOut" + }, + "16": { + "inputs": { + "ckpt_name": "sd_xl_refiner_1.0.safetensors" + }, + "class_type": "CheckpointLoaderSimple" + }, + "17": { + "inputs": { + "text": "a photo of a cat", + "clip": [ + "16", + 1 + ] + }, + "class_type": "CLIPTextEncode" + }, + "20": { + "inputs": { + "text": "", + "clip": [ + "16", + 1 + ] + }, + "class_type": "CLIPTextEncode" + } + } \ No newline at end of file diff --git a/tests/inference/test_inference.py b/tests/inference/test_inference.py new file mode 100644 index 00000000..a96f9455 --- /dev/null +++ b/tests/inference/test_inference.py @@ -0,0 +1,247 @@ +from copy import deepcopy +from io import BytesIO +from urllib import request +import numpy +import os +from PIL import Image +import pytest +from pytest import fixture +import time +import torch +from typing import Union +import json +import subprocess +import websocket #NOTE: websocket-client (https://github.com/websocket-client/websocket-client) +import uuid +import urllib.request +import urllib.parse + +# Currently causes an error when running pytest with built-in pytest args +# TODO: modify cli_args.py to not parse args on import +# We will hard-code sampler and scheduler lists for now +# from comfy.samplers import KSampler + +""" +These tests generate and save images through a range of parameters +""" + +class ComfyGraph: + def __init__(self, + graph: dict, + sampler_nodes: list[str], + ): + self.graph = graph + self.sampler_nodes = sampler_nodes + + def set_prompt(self, prompt, negative_prompt=None): + # Sets the prompt for the sampler nodes (eg. base and refiner) + for node in self.sampler_nodes: + prompt_node = self.graph[node]['inputs']['positive'][0] + self.graph[prompt_node]['inputs']['text'] = prompt + if negative_prompt: + negative_prompt_node = self.graph[node]['inputs']['negative'][0] + self.graph[negative_prompt_node]['inputs']['text'] = negative_prompt + + def set_sampler_name(self, sampler_name:str, ): + # sets the sampler name for the sampler nodes (eg. base and refiner) + for node in self.sampler_nodes: + self.graph[node]['inputs']['sampler_name'] = sampler_name + + def set_scheduler(self, scheduler:str): + # sets the sampler name for the sampler nodes (eg. base and refiner) + for node in self.sampler_nodes: + self.graph[node]['inputs']['scheduler'] = scheduler + + def set_filename_prefix(self, prefix:str): + # sets the filename prefix for the save nodes + for node in self.graph: + if self.graph[node]['class_type'] == 'SaveImage': + self.graph[node]['inputs']['filename_prefix'] = prefix + + +class ComfyClient: + # From examples/websockets_api_example.py + + def connect(self, + listen:str = '127.0.0.1', + port:Union[str,int] = 8188, + client_id: str = str(uuid.uuid4()) + ): + self.client_id = client_id + self.server_address = f"{listen}:{port}" + ws = websocket.WebSocket() + ws.connect("ws://{}/ws?clientId={}".format(self.server_address, self.client_id)) + self.ws = ws + + def queue_prompt(self, prompt): + p = {"prompt": prompt, "client_id": self.client_id} + data = json.dumps(p).encode('utf-8') + req = urllib.request.Request("http://{}/prompt".format(self.server_address), data=data) + return json.loads(urllib.request.urlopen(req).read()) + + def get_image(self, filename, subfolder, folder_type): + data = {"filename": filename, "subfolder": subfolder, "type": folder_type} + url_values = urllib.parse.urlencode(data) + with urllib.request.urlopen("http://{}/view?{}".format(self.server_address, url_values)) as response: + return response.read() + + def get_history(self, prompt_id): + with urllib.request.urlopen("http://{}/history/{}".format(self.server_address, prompt_id)) as response: + return json.loads(response.read()) + + def get_images(self, graph, save=True): + prompt = graph + if not save: + # Replace save nodes with preview nodes + prompt_str = json.dumps(prompt) + prompt_str = prompt_str.replace('SaveImage', 'PreviewImage') + prompt = json.loads(prompt_str) + + prompt_id = self.queue_prompt(prompt)['prompt_id'] + output_images = {} + while True: + out = self.ws.recv() + if isinstance(out, str): + message = json.loads(out) + if message['type'] == 'executing': + data = message['data'] + if data['node'] is None and data['prompt_id'] == prompt_id: + break #Execution is done + else: + continue #previews are binary data + + history = self.get_history(prompt_id)[prompt_id] + for o in history['outputs']: + for node_id in history['outputs']: + node_output = history['outputs'][node_id] + if 'images' in node_output: + images_output = [] + for image in node_output['images']: + image_data = self.get_image(image['filename'], image['subfolder'], image['type']) + images_output.append(image_data) + output_images[node_id] = images_output + + return output_images + +# +# Initialize graphs +# +default_graph_file = 'tests/inference/graphs/default_graph_sdxl1_0.json' +with open(default_graph_file, 'r') as file: + default_graph = json.loads(file.read()) +DEFAULT_COMFY_GRAPH = ComfyGraph(graph=default_graph, sampler_nodes=['10','14']) +DEFAULT_COMFY_GRAPH_ID = os.path.splitext(os.path.basename(default_graph_file))[0] + +# +# Loop through these variables +# +comfy_graph_list = [DEFAULT_COMFY_GRAPH] +comfy_graph_ids = [DEFAULT_COMFY_GRAPH_ID] +prompt_list = [ + 'a painting of a cat', +] +#TODO use sampler and scheduler list from comfy.samplers.KSampler +# sampler_list = KSampler.SAMPLERS +# scheduler_list = KSampler.SCHEDULERS +# Hard coded sampler and scheduler lists for now +SCHEDULERS = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform"] +SAMPLERS = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral", + "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu", + "dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddim", "uni_pc", "uni_pc_bh2"] +sampler_list = SAMPLERS +scheduler_list = SCHEDULERS +@pytest.mark.inference +@pytest.mark.parametrize("sampler", sampler_list) +@pytest.mark.parametrize("scheduler", scheduler_list) +@pytest.mark.parametrize("prompt", prompt_list) +class TestInference: + # + # Initialize server and client + # + @fixture(scope="class", autouse=True) + def _server(self, args_pytest): + # Start server + p = subprocess.Popen([ + 'python','main.py', + '--output-directory', args_pytest["output_dir"], + '--listen', args_pytest["listen"], + '--port', str(args_pytest["port"]), + ]) + yield + p.kill() + torch.cuda.empty_cache() + + def start_client(self, listen:str, port:int): + # Start client + comfy_client = ComfyClient() + # Connect to server (with retries) + n_tries = 5 + for i in range(n_tries): + time.sleep(4) + try: + comfy_client.connect(listen=listen, port=port) + except ConnectionRefusedError as e: + print(e) + print(f"({i+1}/{n_tries}) Retrying...") + else: + break + return comfy_client + + # + # Client and graph fixtures with server warmup + # + # Returns a "_client_graph", which is client-graph pair corresponding to an initialized server + # The "graph" is the default graph + @fixture(scope="class", params=comfy_graph_list, ids=comfy_graph_ids, autouse=True) + def _client_graph(self, request, args_pytest, _server) -> (ComfyClient, ComfyGraph): + comfy_graph = request.param + + # Start client + comfy_client = self.start_client(args_pytest["listen"], args_pytest["port"]) + + # Warm up pipeline + comfy_client.get_images(graph=comfy_graph.graph, save=False) + + yield comfy_client, comfy_graph + del comfy_client + del comfy_graph + torch.cuda.empty_cache() + + @fixture + def client(self, _client_graph): + client = _client_graph[0] + yield client + + @fixture + def comfy_graph(self, _client_graph): + # avoid mutating the graph + graph = deepcopy(_client_graph[1]) + yield graph + + def test_comfy( + self, + client, + comfy_graph, + sampler, + scheduler, + prompt, + request + ): + test_info = request.node.name + comfy_graph.set_filename_prefix(test_info) + # Settings for comfy graph + comfy_graph.set_sampler_name(sampler) + comfy_graph.set_scheduler(scheduler) + comfy_graph.set_prompt(prompt) + + # Generate + images = client.get_images(comfy_graph.graph) + + assert len(images) != 0, "No images generated" + # assert all images are not blank + for images_output in images.values(): + for image_data in images_output: + pil_image = Image.open(BytesIO(image_data)) + assert numpy.array(pil_image).any() != 0, "Image is blank" + +