import os
from comfy_extras . chainner_models import model_loading
from comfy import model_management
import torch
import comfy . utils
import folder_paths
class UpscaleModelLoader :
@classmethod
def INPUT_TYPES ( s ) :
return { " required " : { " model_name " : ( folder_paths . get_filename_list ( " upscale_models " ) , ) ,
} }
RETURN_TYPES = ( " UPSCALE_MODEL " , )
FUNCTION = " load_model "
CATEGORY = " loaders "
def load_model ( self , model_name ) :
model_path = folder_paths . get_full_path ( " upscale_models " , model_name )
sd = comfy . utils . load_torch_file ( model_path , safe_load = True )
if " module.layers.0.residual_group.blocks.0.norm1.weight " in sd :
sd = comfy . utils . state_dict_prefix_replace ( sd , { " module. " : " " } )
out = model_loading . load_state_dict ( sd ) . eval ( )
return ( out , )
class ImageUpscaleWithModel :
@classmethod
def INPUT_TYPES ( s ) :
return { " required " : { " upscale_model " : ( " UPSCALE_MODEL " , ) ,
" image " : ( " IMAGE " , ) ,
} }
RETURN_TYPES = ( " IMAGE " , )
FUNCTION = " upscale "
CATEGORY = " image/upscaling "
def upscale ( self , upscale_model , image ) :
device = model_management . get_torch_device ( )
memory_required = model_management . module_size ( upscale_model )
memory_required + = ( 512 * 512 * 3 ) * image . element_size ( ) * max ( upscale_model . scale , 1.0 ) * 384.0 #The 384.0 is an estimate of how much some of these models take, TODO: make it more accurate
memory_required + = image . nelement ( ) * image . element_size ( )
model_management . free_memory ( memory_required , device )
upscale_model . to ( device )
in_img = image . movedim ( - 1 , - 3 ) . to ( device )
tile = 512
overlap = 32
oom = True
while oom :
try :
steps = in_img . shape [ 0 ] * comfy . utils . get_tiled_scale_steps ( in_img . shape [ 3 ] , in_img . shape [ 2 ] , tile_x = tile , tile_y = tile , overlap = overlap )
pbar = comfy . utils . ProgressBar ( steps )
s = comfy . utils . tiled_scale ( in_img , lambda a : upscale_model ( a ) , tile_x = tile , tile_y = tile , overlap = overlap , upscale_amount = upscale_model . scale , pbar = pbar )
oom = False
except model_management . OOM_EXCEPTION as e :
tile / / = 2
if tile < 128 :
raise e
upscale_model . cpu ( )
s = torch . clamp ( s . movedim ( - 3 , - 1 ) , min = 0 , max = 1.0 )
return ( s , )
NODE_CLASS_MAPPINGS = {
" UpscaleModelLoader " : UpscaleModelLoader ,
" ImageUpscaleWithModel " : ImageUpscaleWithModel
}