You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
116 lines
3.7 KiB
116 lines
3.7 KiB
#!/usr/bin/env python3 |
|
import argparse |
|
import csv |
|
import open_clip |
|
import os |
|
import requests |
|
import torch |
|
from PIL import Image |
|
from clip_interrogator import Interrogator, Config |
|
|
|
|
|
def inference(ci: Interrogator, image: Image, mode: str) -> str: |
|
image = image.convert("RGB") |
|
if mode == "best": |
|
return ci.interrogate(image) |
|
elif mode == "classic": |
|
return ci.interrogate_classic(image) |
|
else: |
|
return ci.interrogate_fast(image) |
|
|
|
|
|
def inference_batch(ci: Interrogator, images: list[Image], mode: str) -> list[str]: |
|
if mode == "best": |
|
return ci.interrogate_batch(images) |
|
elif mode == "classic": |
|
return ci.interrogate_classic_batch(images) |
|
else: |
|
return ci.interrogate_fast_batch(images) |
|
|
|
|
|
def main(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
"-c", "--clip", default="ViT-L-14/openai", help="name of CLIP model to use" |
|
) |
|
parser.add_argument( |
|
"-d", "--device", default="auto", help="device to use (auto, cuda or cpu)" |
|
) |
|
parser.add_argument("-f", "--folder", help="path to folder of images") |
|
parser.add_argument("-i", "--image", help="image file or url") |
|
parser.add_argument("-m", "--mode", default="best", help="best, classic, or fast") |
|
|
|
args = parser.parse_args() |
|
if not args.folder and not args.image: |
|
parser.print_help() |
|
exit(1) |
|
|
|
if args.folder is not None and args.image is not None: |
|
print("Specify a folder or batch processing or a single image, not both") |
|
exit(1) |
|
|
|
# validate clip model name |
|
models = ["/".join(x) for x in open_clip.list_pretrained()] |
|
if args.clip not in models: |
|
print(f"Could not find CLIP model {args.clip}!") |
|
print(f" available models: {models}") |
|
exit(1) |
|
|
|
# select device |
|
if args.device == "auto": |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
if not torch.cuda.is_available(): |
|
print("CUDA is not available, using CPU. Warning: this will be very slow!") |
|
else: |
|
device = torch.device(args.device) |
|
|
|
# generate a nice prompt |
|
config = Config(device=device, clip_model_name=args.clip) |
|
ci = Interrogator(config) |
|
|
|
# process single image |
|
if args.image is not None: |
|
image_path = args.image |
|
if str(image_path).startswith("http://") or str(image_path).startswith( |
|
"https://" |
|
): |
|
image = Image.open(requests.get(image_path, stream=True).raw).convert("RGB") |
|
else: |
|
image = Image.open(image_path).convert("RGB") |
|
if not image: |
|
print(f"Error opening image {image_path}") |
|
exit(1) |
|
print(inference(ci, image, args.mode)) |
|
|
|
# process folder of images |
|
elif args.folder is not None: |
|
if not os.path.exists(args.folder): |
|
print(f"The folder {args.folder} does not exist!") |
|
exit(1) |
|
|
|
files = [ |
|
f |
|
for f in os.listdir(args.folder) |
|
if f.endswith(".jpg") or f.endswith(".png") |
|
] |
|
prompts = inference_batch( |
|
ci, |
|
[Image.open(os.path.join(args.folder, f)).convert("RGB") for f in files], |
|
args.mode, |
|
) |
|
for prompt in prompts: |
|
print(prompt) |
|
|
|
if len(prompts): |
|
csv_path = os.path.join(args.folder, "desc.csv") |
|
with open(csv_path, "w", encoding="utf-8", newline="") as f: |
|
w = csv.writer(f, quoting=csv.QUOTE_MINIMAL) |
|
w.writerow(["image", "prompt"]) |
|
for file, prompt in zip(files, prompts): |
|
w.writerow([file, prompt]) |
|
|
|
print(f"\n\n\n\nGenerated {len(prompts)} and saved to {csv_path}, enjoy!") |
|
|
|
|
|
if __name__ == "__main__": |
|
main()
|
|
|