Image to prompt with BLIP and CLIP
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

116 lines
3.7 KiB

#!/usr/bin/env python3
import argparse
import csv
import open_clip
import os
import requests
import torch
from PIL import Image
from clip_interrogator import Interrogator, Config
def inference(ci: Interrogator, image: Image, mode: str) -> str:
image = image.convert("RGB")
if mode == "best":
return ci.interrogate(image)
elif mode == "classic":
return ci.interrogate_classic(image)
else:
return ci.interrogate_fast(image)
def inference_batch(ci: Interrogator, images: list[Image], mode: str) -> list[str]:
if mode == "best":
return ci.interrogate_batch(images)
elif mode == "classic":
return ci.interrogate_classic_batch(images)
else:
return ci.interrogate_fast_batch(images)
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"-c", "--clip", default="ViT-L-14/openai", help="name of CLIP model to use"
)
parser.add_argument(
"-d", "--device", default="auto", help="device to use (auto, cuda or cpu)"
)
parser.add_argument("-f", "--folder", help="path to folder of images")
parser.add_argument("-i", "--image", help="image file or url")
parser.add_argument("-m", "--mode", default="best", help="best, classic, or fast")
args = parser.parse_args()
if not args.folder and not args.image:
parser.print_help()
exit(1)
if args.folder is not None and args.image is not None:
print("Specify a folder or batch processing or a single image, not both")
exit(1)
# validate clip model name
models = ["/".join(x) for x in open_clip.list_pretrained()]
if args.clip not in models:
print(f"Could not find CLIP model {args.clip}!")
print(f" available models: {models}")
exit(1)
# select device
if args.device == "auto":
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if not torch.cuda.is_available():
print("CUDA is not available, using CPU. Warning: this will be very slow!")
else:
device = torch.device(args.device)
# generate a nice prompt
config = Config(device=device, clip_model_name=args.clip)
ci = Interrogator(config)
# process single image
if args.image is not None:
image_path = args.image
if str(image_path).startswith("http://") or str(image_path).startswith(
"https://"
):
image = Image.open(requests.get(image_path, stream=True).raw).convert("RGB")
else:
image = Image.open(image_path).convert("RGB")
if not image:
print(f"Error opening image {image_path}")
exit(1)
print(inference(ci, image, args.mode))
# process folder of images
elif args.folder is not None:
if not os.path.exists(args.folder):
print(f"The folder {args.folder} does not exist!")
exit(1)
files = [
f
for f in os.listdir(args.folder)
if f.endswith(".jpg") or f.endswith(".png")
]
prompts = inference_batch(
ci,
[Image.open(os.path.join(args.folder, f)).convert("RGB") for f in files],
args.mode,
)
for prompt in prompts:
print(prompt)
if len(prompts):
csv_path = os.path.join(args.folder, "desc.csv")
with open(csv_path, "w", encoding="utf-8", newline="") as f:
w = csv.writer(f, quoting=csv.QUOTE_MINIMAL)
w.writerow(["image", "prompt"])
for file, prompt in zip(files, prompts):
w.writerow([file, prompt])
print(f"\n\n\n\nGenerated {len(prompts)} and saved to {csv_path}, enjoy!")
if __name__ == "__main__":
main()