@ -1,211 +1,148 @@
import hashlib
import inspect
import math
import numpy as np
import open_clip
import os
import requests
import pickle
import time
import torch
from dataclasses import dataclass
from models . blip import blip_decoder
from PIL import Image
from transformers import AutoProcessor , AutoModelForCausalLM , BlipForConditionalGeneration , Blip2ForConditionalGeneration
from torchvision import transforms
from torchvision . transforms . functional import InterpolationMode
from tqdm import tqdm
from typing import List , Optional
from safetensors . numpy import load_file , save_file
CAPTION_MODELS = {
' blip-base ' : ' Salesforce/blip-image-captioning-base ' , # 990MB
' blip-large ' : ' Salesforce/blip-image-captioning-large ' , # 1.9GB
' blip2-2.7b ' : ' Salesforce/blip2-opt-2.7b ' , # 15.5GB
' blip2-flan-t5-xl ' : ' Salesforce/blip2-flan-t5-xl ' , # 15.77GB
' git-large-coco ' : ' microsoft/git-large-coco ' , # 1.58GB
}
CACHE_URL_BASE = ' https://huggingface.co/pharmapsychotic/ci-preprocess/resolve/main/ '
from typing import List
@dataclass
class Config :
# models can optionally be passed in directly
caption_model = None
caption_processor = None
blip_model = None
clip_model = None
clip_preprocess = None
# blip settings
caption_max_length : int = 32
caption_model_name : Optional [ str ] = ' blip-large ' # use a key from CAPTION_MODELS or None
caption_offload : bool = False
blip_image_eval_size : int = 384
blip_max_length : int = 32
blip_model_url : str = ' https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth '
blip_num_beams : int = 8
blip_offload : bool = False
# clip settings
clip_model_name : str = ' ViT-L-14/openai '
clip_model_path : Optional [ str ] = None
clip_offload : bool = False
clip_model_name : str = ' ViT-H-14/laion2b_s32b_b79k '
clip_model_path : str = None
# interrogator settings
cache_path : str = ' cache ' # path to store cached text embeddings
download_cache : bool = True # when true, cached embeds are downloaded from huggingface
chunk_size : int = 2048 # batch size for CLIP, use smaller for lower VRAM
cache_path : str = ' cache '
chunk_size : int = 2048
data_path : str = os . path . join ( os . path . dirname ( __file__ ) , ' data ' )
device : str = ( " mps " if torch . backends . mps . is_available ( ) else " cuda " if torch . cuda . is_available ( ) else " cpu " )
device : str = ' cuda ' if torch . cuda . is_available ( ) else ' cpu '
flavor_intermediate_count : int = 2048
quiet : bool = False # when quiet progress bars are not shown
def apply_low_vram_defaults ( self ) :
self . caption_model_name = ' blip-base '
self . caption_offload = True
self . clip_offload = True
self . chunk_size = 1024
self . flavor_intermediate_count = 1024
class Interrogator ( ) :
def __init__ ( self , config : Config ) :
self . config = config
self . device = config . device
self . dtype = torch . float16 if self . device == ' cuda ' else torch . float32
self . caption_offloaded = True
self . clip_offloaded = True
self . load_caption_model ( )
self . load_clip_model ( )
def load_caption_model ( self ) :
if self . config . caption_model is None and self . config . caption_model_name :
if not self . config . quiet :
print ( f " Loading caption model { self . config . caption_model_name } ... " )
model_path = CAPTION_MODELS [ self . config . caption_model_name ]
if self . config . caption_model_name . startswith ( ' git- ' ) :
caption_model = AutoModelForCausalLM . from_pretrained ( model_path , torch_dtype = torch . float32 )
elif self . config . caption_model_name . startswith ( ' blip2- ' ) :
caption_model = Blip2ForConditionalGeneration . from_pretrained ( model_path , torch_dtype = self . dtype )
if config . blip_model is None :
if not config . quiet :
print ( " Loading BLIP model... " )
blip_path = os . path . dirname ( inspect . getfile ( blip_decoder ) )
configs_path = os . path . join ( os . path . dirname ( blip_path ) , ' configs ' )
med_config = os . path . join ( configs_path , ' med_config.json ' )
blip_model = blip_decoder (
pretrained = config . blip_model_url ,
image_size = config . blip_image_eval_size ,
vit = ' large ' ,
med_config = med_config
)
blip_model . eval ( )
blip_model = blip_model . to ( config . device )
self . blip_model = blip_model
else :
caption_model = BlipForConditionalGeneration . from_pretrained ( model_path , torch_dtype = self . dtype )
self . caption_processor = AutoProcessor . from_pretrained ( model_path )
self . blip_model = config . blip_model
caption_model . eval ( )
if not self . config . caption_offload :
caption_model = caption_model . to ( self . config . device )
self . caption_model = caption_model
else :
self . caption_model = self . config . caption_model
self . caption_processor = self . config . caption_processor
self . load_clip_model ( )
def load_clip_model ( self ) :
start_time = time . time ( )
config = self . config
clip_model_name , clip_model_pretrained_name = config . clip_model_name . split ( ' / ' , 2 )
if config . clip_model is None :
if not config . quiet :
print ( f " Loading CLIP model { config . clip_model_name } ... " )
print ( " Loading CLIP model... " )
clip_model_name , clip_model_pretrained_name = config . clip_model_name . split ( ' / ' , 2 )
self . clip_model , _ , self . clip_preprocess = open_clip . create_model_and_transforms (
clip_model_name ,
pretrained = clip_model_pretrained_name ,
precision = ' fp16 ' if config . device == ' cuda ' else ' fp32 ' ,
precision = ' fp16 ' ,
device = config . device ,
jit = False ,
cache_dir = config . clip_model_path
)
self . clip_model . eval ( )
self . clip_model . half ( ) . to ( config . device ) . eval ( )
else :
self . clip_model = config . clip_model
self . clip_preprocess = config . clip_preprocess
self . tokenize = open_clip . get_tokenizer ( clip_model_name )
sites = [ ' Artstation ' , ' behance ' , ' cg society ' , ' cgsociety ' , ' deviantart ' , ' dribbble ' ,
' flickr ' , ' instagram ' , ' pexels ' , ' pinterest ' , ' pixabay ' , ' pixiv ' , ' polycount ' ,
' reddit ' , ' shutterstock ' , ' tumblr ' , ' unsplash ' , ' zbrush central ' ]
sites = [ ' Artstation ' , ' behance ' , ' cg society ' , ' cgsociety ' , ' deviantart ' , ' dribble ' , ' flickr ' , ' instagram ' , ' pexels ' , ' pinterest ' , ' pixabay ' , ' pixiv ' , ' polycount ' , ' reddit ' , ' shutterstock ' , ' tumblr ' , ' unsplash ' , ' zbrush central ' ]
trending_list = [ site for site in sites ]
trending_list . extend ( [ " trending on " + site for site in sites ] )
trending_list . extend ( [ " featured on " + site for site in sites ] )
trending_list . extend ( [ site + " contest winner " for site in sites ] )
raw_artists = load_list ( config . data_path , ' artists.txt ' )
raw_artists = _ load_list( config . data_path , ' artists.txt ' )
artists = [ f " by { a } " for a in raw_artists ]
artists . extend ( [ f " inspired by { a } " for a in raw_artists ] )
self . _prepare_clip ( )
self . artists = LabelTable ( artists , " artists " , self )
self . flavors = LabelTable ( load_list ( config . data_path , ' flavors.txt ' ) , " flavors " , self )
self . mediums = LabelTable ( load_list ( config . data_path , ' mediums.txt ' ) , " mediums " , self )
self . movements = LabelTable ( load_list ( config . data_path , ' movements.txt ' ) , " movements " , self )
self . trendings = LabelTable ( trending_list , " trendings " , self )
self . negative = LabelTable ( load_list ( config . data_path , ' negative.txt ' ) , " negative " , self )
self . artists = LabelTable ( artists , " artists " , self . clip_model , self . tokenize , config )
self . flavors = LabelTable ( _load_list ( config . data_path , ' flavors.txt ' ) , " flavors " , self . clip_model , self . tokenize , config )
self . mediums = LabelTable ( _load_list ( config . data_path , ' mediums.txt ' ) , " mediums " , self . clip_model , self . tokenize , config )
self . movements = LabelTable ( _load_list ( config . data_path , ' movements.txt ' ) , " movements " , self . clip_model , self . tokenize , config )
self . trendings = LabelTable ( trending_list , " trendings " , self . clip_model , self . tokenize , config )
end_time = time . time ( )
if not config . quiet :
print ( f " Loaded CLIP model and data in { end_time - start_time : .2f } seconds. " )
def chain (
self ,
image_features : torch . Tensor ,
phrases : List [ str ] ,
best_prompt : str = " " ,
best_sim : float = 0 ,
min_count : int = 8 ,
max_count : int = 32 ,
desc = " Chaining " ,
reverse : bool = False
) - > str :
self . _prepare_clip ( )
phrases = set ( phrases )
if not best_prompt :
best_prompt = self . rank_top ( image_features , [ f for f in phrases ] , reverse = reverse )
best_sim = self . similarity ( image_features , best_prompt )
phrases . remove ( best_prompt )
curr_prompt , curr_sim = best_prompt , best_sim
def check ( addition : str , idx : int ) - > bool :
nonlocal best_prompt , best_sim , curr_prompt , curr_sim
prompt = curr_prompt + " , " + addition
sim = self . similarity ( image_features , prompt )
if reverse :
sim = - sim
if sim > best_sim :
best_prompt , best_sim = prompt , sim
if sim > curr_sim or idx < min_count :
curr_prompt , curr_sim = prompt , sim
return True
return False
for idx in tqdm ( range ( max_count ) , desc = desc , disable = self . config . quiet ) :
best = self . rank_top ( image_features , [ f " { curr_prompt } , { f } " for f in phrases ] , reverse = reverse )
flave = best [ len ( curr_prompt ) + 2 : ]
if not check ( flave , idx ) :
break
if _prompt_at_max_len ( curr_prompt , self . tokenize ) :
break
phrases . remove ( flave )
return best_prompt
def generate_caption ( self , pil_image : Image ) - > str :
assert self . caption_model is not None , " No caption model loaded. "
self . _prepare_caption ( )
inputs = self . caption_processor ( images = pil_image , return_tensors = " pt " ) . to ( self . device )
if not self . config . caption_model_name . startswith ( ' git- ' ) :
inputs = inputs . to ( self . dtype )
tokens = self . caption_model . generate ( * * inputs , max_new_tokens = self . config . caption_max_length )
return self . caption_processor . batch_decode ( tokens , skip_special_tokens = True ) [ 0 ] . strip ( )
if self . config . blip_offload :
self . blip_model = self . blip_model . to ( self . device )
size = self . config . blip_image_eval_size
gpu_image = transforms . Compose ( [
transforms . Resize ( ( size , size ) , interpolation = InterpolationMode . BICUBIC ) ,
transforms . ToTensor ( ) ,
transforms . Normalize ( ( 0.48145466 , 0.4578275 , 0.40821073 ) , ( 0.26862954 , 0.26130258 , 0.27577711 ) )
] ) ( pil_image ) . unsqueeze ( 0 ) . to ( self . device )
with torch . no_grad ( ) :
caption = self . blip_model . generate (
gpu_image ,
sample = False ,
num_beams = self . config . blip_num_beams ,
max_length = self . config . blip_max_length ,
min_length = 5
)
if self . config . blip_offload :
self . blip_model = self . blip_model . to ( " cpu " )
return caption [ 0 ]
def image_to_features ( self , image : Image ) - > torch . Tensor :
self . _prepare_clip ( )
images = self . clip_preprocess ( image ) . unsqueeze ( 0 ) . to ( self . device )
with torch . no_grad ( ) , torch . cuda . amp . autocast ( ) :
image_features = self . clip_model . encode_image ( images )
image_features / = image_features . norm ( dim = - 1 , keepdim = True )
return image_features
def interrogate_classic ( self , image : Image , max_flavors : int = 3 , caption : Optional [ str ] = None ) - > str :
""" Classic mode creates a prompt in a standard format first describing the image,
then listing the artist , trending , movement , and flavor text modifiers . """
caption = caption or self . generate_caption ( image )
def interrogate_classic ( self , image : Image , max_flavors : int = 3 ) - > str :
caption = self . generate_caption ( image )
image_features = self . image_to_features ( image )
medium = self . mediums . rank ( image_features , 1 ) [ 0 ]
@ -221,52 +158,73 @@ class Interrogator():
return _truncate_to_fit ( prompt , self . tokenize )
def interrogate_fast ( self , image : Image , max_flavors : int = 32 , caption : Optional [ str ] = None ) - > str :
""" Fast mode simply adds the top ranked terms after a caption. It generally results in
better similarity between generated prompt and image than classic mode , but the prompts
are less readable . """
caption = caption or self . generate_caption ( image )
def interrogate_fast ( self , image : Image , max_flavors : int = 32 ) - > str :
caption = self . generate_caption ( image )
image_features = self . image_to_features ( image )
merged = _merge_tables ( [ self . artists , self . flavors , self . mediums , self . movements , self . trendings ] , self )
merged = _merge_tables ( [ self . artists , self . flavors , self . mediums , self . movements , self . trendings ] , self . config )
tops = merged . rank ( image_features , max_flavors )
return _truncate_to_fit ( caption + " , " + " , " . join ( tops ) , self . tokenize )
def interrogate_negative ( self , image : Image , max_flavors : int = 32 ) - > str :
""" Negative mode chains together the most dissimilar terms to the image. It can be used
to help build a negative prompt to pair with the regular positive prompt and often
improve the results of generated images particularly with Stable Diffusion 2. """
def interrogate ( self , image : Image , max_flavors : int = 32 ) - > str :
caption = self . generate_caption ( image )
image_features = self . image_to_features ( image )
flaves = self . flavors . rank ( image_features , self . config . flavor_intermediate_count , reverse = True )
flaves = flaves + self . negative . labels
return self . chain ( image_features , flaves , max_count = max_flavors , reverse = True , desc = " Negative chain " )
def interrogate ( self , image : Image , min_flavors : int = 8 , max_flavors : int = 32 , caption : Optional [ str ] = None ) - > str :
caption = caption or self . generate_caption ( image )
image_features = self . image_to_features ( image )
flaves = self . flavors . rank ( image_features , self . config . flavor_intermediate_count )
best_medium = self . mediums . rank ( image_features , 1 ) [ 0 ]
best_artist = self . artists . rank ( image_features , 1 ) [ 0 ]
best_trending = self . trendings . rank ( image_features , 1 ) [ 0 ]
best_movement = self . movements . rank ( image_features , 1 ) [ 0 ]
best_prompt = caption
best_sim = self . similarity ( image_features , best_prompt )
def check ( addition : str ) - > bool :
nonlocal best_prompt , best_sim
prompt = best_prompt + " , " + addition
sim = self . similarity ( image_features , prompt )
if sim > best_sim :
best_sim = sim
best_prompt = prompt
return True
return False
merged = _merge_tables ( [ self . artists , self . flavors , self . mediums , self . movements , self . trendings ] , self )
flaves = merged . rank ( image_features , self . config . flavor_intermediate_count )
best_prompt , best_sim = caption , self . similarity ( image_features , caption )
best_prompt = self . chain ( image_features , flaves , best_prompt , best_sim , min_count = min_flavors , max_count = max_flavors , desc = " Flavor chain " )
def check_multi_batch ( opts : List [ str ] ) :
nonlocal best_prompt , best_sim
prompts = [ ]
for i in range ( 2 * * len ( opts ) ) :
prompt = best_prompt
for bit in range ( len ( opts ) ) :
if i & ( 1 << bit ) :
prompt + = " , " + opts [ bit ]
prompts . append ( prompt )
t = LabelTable ( prompts , None , self . clip_model , self . tokenize , self . config )
best_prompt = t . rank ( image_features , 1 ) [ 0 ]
best_sim = self . similarity ( image_features , best_prompt )
fast_prompt = self . interrogate_fast ( image , max_flavors , caption = caption )
classic_prompt = self . interrogate_classic ( image , max_flavors , caption = caption )
candidates = [ caption , classic_prompt , fast_prompt , best_prompt ]
return candidates [ np . argmax ( self . similarities ( image_features , candidates ) ) ]
check_multi_batch ( [ best_medium , best_artist , best_trending , best_movement ] )
def rank_top ( self , image_features : torch . Tensor , text_array : List [ str ] , reverse : bool = False ) - > str :
self . _prepare_clip ( )
extended_flavors = set ( flaves )
for _ in tqdm ( range ( max_flavors ) , desc = " Flavor chain " , disable = self . config . quiet ) :
best = self . rank_top ( image_features , [ f " { best_prompt } , { f } " for f in extended_flavors ] )
flave = best [ len ( best_prompt ) + 2 : ]
if not check ( flave ) :
break
if _prompt_at_max_len ( best_prompt , self . tokenize ) :
break
extended_flavors . remove ( flave )
return best_prompt
def rank_top ( self , image_features : torch . Tensor , text_array : List [ str ] ) - > str :
text_tokens = self . tokenize ( [ text for text in text_array ] ) . to ( self . device )
with torch . no_grad ( ) , torch . cuda . amp . autocast ( ) :
text_features = self . clip_model . encode_text ( text_tokens )
text_features / = text_features . norm ( dim = - 1 , keepdim = True )
similarity = text_features @ image_features . T
if reverse :
similarity = - similarity
return text_array [ similarity . argmax ( ) . item ( ) ]
def similarity ( self , image_features : torch . Tensor , text : str ) - > float :
self . _prepare_clip ( )
text_tokens = self . tokenize ( [ text ] ) . to ( self . device )
with torch . no_grad ( ) , torch . cuda . amp . autocast ( ) :
text_features = self . clip_model . encode_text ( text_tokens )
@ -274,45 +232,29 @@ class Interrogator():
similarity = text_features @ image_features . T
return similarity [ 0 ] [ 0 ] . item ( )
def similarities ( self , image_features : torch . Tensor , text_array : List [ str ] ) - > List [ float ] :
self . _prepare_clip ( )
text_tokens = self . tokenize ( [ text for text in text_array ] ) . to ( self . device )
with torch . no_grad ( ) , torch . cuda . amp . autocast ( ) :
text_features = self . clip_model . encode_text ( text_tokens )
text_features / = text_features . norm ( dim = - 1 , keepdim = True )
similarity = text_features @ image_features . T
return similarity . T [ 0 ] . tolist ( )
def _prepare_caption ( self ) :
if self . config . clip_offload and not self . clip_offloaded :
self . clip_model = self . clip_model . to ( ' cpu ' )
self . clip_offloaded = True
if self . caption_offloaded :
self . caption_model = self . caption_model . to ( self . device )
self . caption_offloaded = False
def _prepare_clip ( self ) :
if self . config . caption_offload and not self . caption_offloaded :
self . caption_model = self . caption_model . to ( ' cpu ' )
self . caption_offloaded = True
if self . clip_offloaded :
self . clip_model = self . clip_model . to ( self . device )
self . clip_offloaded = False
class LabelTable ( ) :
def __init__ ( self , labels : List [ str ] , desc : str , ci : Interrogator ) :
clip_model , config = ci . clip_model , ci . config
def __init__ ( self , labels : List [ str ] , desc : str , clip_model , tokenize , config : Config ) :
self . chunk_size = config . chunk_size
self . config = config
self . device = config . device
self . embeds = [ ]
self . labels = labels
self . tokenize = ci . tokenize
self . tokenize = tokenize
hash = hashlib . sha256 ( " , " . join ( labels ) . encode ( ) ) . hexdigest ( )
sanitized_name = self . config . clip_model_name . replace ( ' / ' , ' _ ' ) . replace ( ' @ ' , ' _ ' )
self . _load_cached ( desc , hash , sanitized_name )
cache_filepath = None
if config . cache_path is not None and desc is not None :
os . makedirs ( config . cache_path , exist_ok = True )
sanitized_name = config . clip_model_name . replace ( ' / ' , ' _ ' ) . replace ( ' @ ' , ' _ ' )
cache_filepath = os . path . join ( config . cache_path , f " { sanitized_name } _ { desc } .pkl " )
if desc is not None and os . path . exists ( cache_filepath ) :
with open ( cache_filepath , ' rb ' ) as f :
data = pickle . load ( f )
if data . get ( ' hash ' ) == hash :
self . labels = data [ ' labels ' ]
self . embeds = data [ ' embeds ' ]
if len ( self . labels ) != len ( self . embeds ) :
self . embeds = [ ]
@ -326,63 +268,26 @@ class LabelTable():
for i in range ( text_features . shape [ 0 ] ) :
self . embeds . append ( text_features [ i ] )
if desc and self . config . cache_path :
os . makedirs ( self . config . cache_path , exist_ok = True )
cache_filepath = os . path . join ( self . config . cache_path , f " { sanitized_name } _ { desc } .safetensors " )
tensors = {
" embeds " : np . stack ( self . embeds ) ,
" hash " : np . array ( [ ord ( c ) for c in hash ] , dtype = np . int8 )
}
save_file ( tensors , cache_filepath )
if self . device == ' cpu ' or self . device == torch . device ( ' cpu ' ) :
self . embeds = [ e . astype ( np . float32 ) for e in self . embeds ]
def _load_cached ( self , desc : str , hash : str , sanitized_name : str ) - > bool :
if self . config . cache_path is None or desc is None :
return False
cached_safetensors = os . path . join ( self . config . cache_path , f " { sanitized_name } _ { desc } .safetensors " )
if self . config . download_cache and not os . path . exists ( cached_safetensors ) :
download_url = CACHE_URL_BASE + f " { sanitized_name } _ { desc } .safetensors "
try :
os . makedirs ( self . config . cache_path , exist_ok = True )
_download_file ( download_url , cached_safetensors , quiet = self . config . quiet )
except Exception as e :
print ( f " Failed to download { download_url } " )
print ( e )
return False
if os . path . exists ( cached_safetensors ) :
try :
tensors = load_file ( cached_safetensors )
except Exception as e :
print ( f " Failed to load { cached_safetensors } " )
print ( e )
return False
if ' hash ' in tensors and ' embeds ' in tensors :
if np . array_equal ( tensors [ ' hash ' ] , np . array ( [ ord ( c ) for c in hash ] , dtype = np . int8 ) ) :
self . embeds = tensors [ ' embeds ' ]
if len ( self . embeds . shape ) == 2 :
self . embeds = [ self . embeds [ i ] for i in range ( self . embeds . shape [ 0 ] ) ]
return True
return False
if cache_filepath is not None :
with open ( cache_filepath , ' wb ' ) as f :
pickle . dump ( {
" labels " : self . labels ,
" embeds " : self . embeds ,
" hash " : hash ,
" model " : config . clip_model_name
} , f )
def _rank ( self , image_features : torch . Tensor , text_embeds : torch . Tensor , top_count : int = 1 , reverse : bool = False ) - > str :
def _rank ( self , image_features : torch . Tensor , text_embeds : torch . Tensor , top_count : int = 1 ) - > str :
top_count = min ( top_count , len ( text_embeds ) )
text_embeds = torch . stack ( [ torch . from_numpy ( t ) for t in text_embeds ] ) . to ( self . device )
with torch . cuda . amp . autocast ( ) :
similarity = image_features @ text_embeds . T
if reverse :
similarity = - similarity
_ , top_labels = similarity . float ( ) . cpu ( ) . topk ( top_count , dim = - 1 )
return [ top_labels [ 0 ] [ i ] . numpy ( ) for i in range ( top_count ) ]
def rank ( self , image_features : torch . Tensor , top_count : int = 1 , reverse : bool = False ) - > List [ str ] :
def rank ( self , image_features : torch . Tensor , top_count : int = 1 ) - > List [ str ] :
if len ( self . labels ) < = self . chunk_size :
tops = self . _rank ( image_features , self . embeds , top_count = top_count , reverse = reverse )
tops = self . _rank ( image_features , self . embeds , top_count = top_count )
return [ self . labels [ i ] for i in tops ]
num_chunks = int ( math . ceil ( len ( self . labels ) / self . chunk_size ) )
@ -392,7 +297,7 @@ class LabelTable():
for chunk_idx in tqdm ( range ( num_chunks ) , disable = self . config . quiet ) :
start = chunk_idx * self . chunk_size
stop = min ( start + self . chunk_size , len ( self . embeds ) )
tops = self . _rank ( image_features , self . embeds [ start : stop ] , top_count = keep_per_chunk , reverse = reverse )
tops = self . _rank ( image_features , self . embeds [ start : stop ] , top_count = keep_per_chunk )
top_labels . extend ( [ self . labels [ start + i ] for i in tops ] )
top_embeds . extend ( [ self . embeds [ start + i ] for i in tops ] )
@ -400,23 +305,13 @@ class LabelTable():
return [ top_labels [ i ] for i in tops ]
def _download_file ( url : str , filepath : str , chunk_size : int = 4 * 1024 * 1024 , quiet : bool = False ) :
r = requests . get ( url , stream = True )
if r . status_code != 200 :
return
file_size = int ( r . headers . get ( " Content-Length " , 0 ) )
filename = url . split ( " / " ) [ - 1 ]
progress = tqdm ( total = file_size , unit = " B " , unit_scale = True , desc = filename , disable = quiet )
with open ( filepath , " wb " ) as f :
for chunk in r . iter_content ( chunk_size = chunk_size ) :
if chunk :
f . write ( chunk )
progress . update ( len ( chunk ) )
progress . close ( )
def _load_list ( data_path : str , filename : str ) - > List [ str ] :
with open ( os . path . join ( data_path , filename ) , ' r ' , encoding = ' utf-8 ' , errors = ' replace ' ) as f :
items = [ line . strip ( ) for line in f . readlines ( ) ]
return items
def _merge_tables ( tables : List [ LabelTable ] , ci : Interrogator ) - > LabelTable :
m = LabelTable ( [ ] , None , c i)
def _merge_tables ( tables : List [ LabelTable ] , config : Config ) - > LabelTable :
m = LabelTable ( [ ] , None , None , None , config )
for table in tables :
m . labels . extend ( table . labels )
m . embeds . extend ( table . embeds )
@ -434,17 +329,3 @@ def _truncate_to_fit(text: str, tokenize) -> str:
break
new_text + = ' , ' + part
return new_text
def list_caption_models ( ) - > List [ str ] :
return list ( CAPTION_MODELS . keys ( ) )
def list_clip_models ( ) - > List [ str ] :
return [ ' / ' . join ( x ) for x in open_clip . list_pretrained ( ) ]
def load_list ( data_path : str , filename : Optional [ str ] = None ) - > List [ str ] :
""" Load a list of strings from a file. """
if filename is not None :
data_path = os . path . join ( data_path , filename )
with open ( data_path , ' r ' , encoding = ' utf-8 ' , errors = ' replace ' ) as f :
items = [ line . strip ( ) for line in f . readlines ( ) ]
return items