|
|
|
@ -55,14 +55,18 @@ class ModelPatcher:
|
|
|
|
|
def memory_required(self, input_shape): |
|
|
|
|
return self.model.memory_required(input_shape=input_shape) |
|
|
|
|
|
|
|
|
|
def set_model_sampler_cfg_function(self, sampler_cfg_function): |
|
|
|
|
def set_model_sampler_cfg_function(self, sampler_cfg_function, disable_cfg1_optimization=False): |
|
|
|
|
if len(inspect.signature(sampler_cfg_function).parameters) == 3: |
|
|
|
|
self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way |
|
|
|
|
else: |
|
|
|
|
self.model_options["sampler_cfg_function"] = sampler_cfg_function |
|
|
|
|
if disable_cfg1_optimization: |
|
|
|
|
self.model_options["disable_cfg1_optimization"] = True |
|
|
|
|
|
|
|
|
|
def set_model_sampler_post_cfg_function(self, post_cfg_function): |
|
|
|
|
def set_model_sampler_post_cfg_function(self, post_cfg_function, disable_cfg1_optimization=False): |
|
|
|
|
self.model_options["sampler_post_cfg_function"] = self.model_options.get("sampler_post_cfg_function", []) + [post_cfg_function] |
|
|
|
|
if disable_cfg1_optimization: |
|
|
|
|
self.model_options["disable_cfg1_optimization"] = True |
|
|
|
|
|
|
|
|
|
def set_model_unet_function_wrapper(self, unet_wrapper_function): |
|
|
|
|
self.model_options["model_function_wrapper"] = unet_wrapper_function |
|
|
|
|