|
|
|
@ -317,9 +317,7 @@ class VAE:
|
|
|
|
|
pixel_samples = pixel_samples.cpu().movedim(1,-1) |
|
|
|
|
return pixel_samples |
|
|
|
|
|
|
|
|
|
def decode_tiled(self, samples): |
|
|
|
|
tile_x = tile_y = 64 |
|
|
|
|
overlap = 8 |
|
|
|
|
def decode_tiled(self, samples, tile_x=64, tile_y=64, overlap = 8): |
|
|
|
|
model_management.unload_model() |
|
|
|
|
output = torch.empty((samples.shape[0], 3, samples.shape[2] * 8, samples.shape[3] * 8), device="cpu") |
|
|
|
|
self.first_stage_model = self.first_stage_model.to(self.device) |
|
|
|
@ -656,3 +654,103 @@ def load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, e
|
|
|
|
|
sd = load_torch_file(ckpt_path) |
|
|
|
|
model = load_model_weights(model, sd, verbose=False, load_state_dict_to=load_state_dict_to) |
|
|
|
|
return (ModelPatcher(model), clip, vae) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_checkpoint_guess_config(ckpt_path, fp16=False, output_vae=True, output_clip=True, embedding_directory=None): |
|
|
|
|
sd = load_torch_file(ckpt_path) |
|
|
|
|
sd_keys = sd.keys() |
|
|
|
|
clip = None |
|
|
|
|
vae = None |
|
|
|
|
|
|
|
|
|
class WeightsLoader(torch.nn.Module): |
|
|
|
|
pass |
|
|
|
|
|
|
|
|
|
w = WeightsLoader() |
|
|
|
|
load_state_dict_to = [] |
|
|
|
|
if output_vae: |
|
|
|
|
vae = VAE() |
|
|
|
|
w.first_stage_model = vae.first_stage_model |
|
|
|
|
load_state_dict_to = [w] |
|
|
|
|
|
|
|
|
|
if output_clip: |
|
|
|
|
clip_config = {} |
|
|
|
|
if "cond_stage_model.model.transformer.resblocks.22.attn.out_proj.weight" in sd_keys: |
|
|
|
|
clip_config['target'] = 'ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder' |
|
|
|
|
else: |
|
|
|
|
clip_config['target'] = 'ldm.modules.encoders.modules.FrozenCLIPEmbedder' |
|
|
|
|
clip = CLIP(config=clip_config, embedding_directory=embedding_directory) |
|
|
|
|
w.cond_stage_model = clip.cond_stage_model |
|
|
|
|
load_state_dict_to = [w] |
|
|
|
|
|
|
|
|
|
sd_config = { |
|
|
|
|
"linear_start": 0.00085, |
|
|
|
|
"linear_end": 0.012, |
|
|
|
|
"num_timesteps_cond": 1, |
|
|
|
|
"log_every_t": 200, |
|
|
|
|
"timesteps": 1000, |
|
|
|
|
"first_stage_key": "jpg", |
|
|
|
|
"cond_stage_key": "txt", |
|
|
|
|
"image_size": 64, |
|
|
|
|
"channels": 4, |
|
|
|
|
"cond_stage_trainable": False, |
|
|
|
|
"monitor": "val/loss_simple_ema", |
|
|
|
|
"scale_factor": 0.18215, |
|
|
|
|
"use_ema": False, |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
unet_config = { |
|
|
|
|
"use_checkpoint": True, |
|
|
|
|
"image_size": 32, |
|
|
|
|
"out_channels": 4, |
|
|
|
|
"attention_resolutions": [ |
|
|
|
|
4, |
|
|
|
|
2, |
|
|
|
|
1 |
|
|
|
|
], |
|
|
|
|
"num_res_blocks": 2, |
|
|
|
|
"channel_mult": [ |
|
|
|
|
1, |
|
|
|
|
2, |
|
|
|
|
4, |
|
|
|
|
4 |
|
|
|
|
], |
|
|
|
|
"use_spatial_transformer": True, |
|
|
|
|
"transformer_depth": 1, |
|
|
|
|
"legacy": False |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
if len(sd['model.diffusion_model.input_blocks.1.1.proj_in.weight'].shape) == 2: |
|
|
|
|
unet_config['use_linear_in_transformer'] = True |
|
|
|
|
|
|
|
|
|
unet_config["use_fp16"] = fp16 |
|
|
|
|
unet_config["model_channels"] = sd['model.diffusion_model.input_blocks.0.0.weight'].shape[0] |
|
|
|
|
unet_config["in_channels"] = sd['model.diffusion_model.input_blocks.0.0.weight'].shape[1] |
|
|
|
|
unet_config["context_dim"] = sd['model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight'].shape[1] |
|
|
|
|
|
|
|
|
|
sd_config["unet_config"] = {"target": "ldm.modules.diffusionmodules.openaimodel.UNetModel", "params": unet_config} |
|
|
|
|
model_config = {"target": "ldm.models.diffusion.ddpm.LatentDiffusion", "params": sd_config} |
|
|
|
|
|
|
|
|
|
if unet_config["in_channels"] > 4: #inpainting model |
|
|
|
|
sd_config["conditioning_key"] = "hybrid" |
|
|
|
|
sd_config["finetune_keys"] = None |
|
|
|
|
model_config["target"] = "ldm.models.diffusion.ddpm.LatentInpaintDiffusion" |
|
|
|
|
else: |
|
|
|
|
sd_config["conditioning_key"] = "crossattn" |
|
|
|
|
|
|
|
|
|
if unet_config["context_dim"] == 1024: |
|
|
|
|
unet_config["num_head_channels"] = 64 #SD2.x |
|
|
|
|
else: |
|
|
|
|
unet_config["num_heads"] = 8 #SD1.x |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model = instantiate_from_config(model_config) |
|
|
|
|
model = load_model_weights(model, sd, verbose=False, load_state_dict_to=load_state_dict_to) |
|
|
|
|
|
|
|
|
|
if unet_config["context_dim"] == 1024 and unet_config["in_channels"] == 4: #only SD2.x non inpainting models are v prediction |
|
|
|
|
cond = torch.zeros((1, 2, unet_config["context_dim"]), device="cpu") |
|
|
|
|
x_in = torch.rand((1, unet_config["in_channels"], 8, 8), device="cpu", generator=torch.manual_seed(1)) |
|
|
|
|
out = model.apply_model(x_in, torch.tensor([999], device="cpu"), cond) |
|
|
|
|
if out.mean() < -0.6: #mean of eps should be ~0 and mean of v prediction should be ~-1 |
|
|
|
|
model.parameterization = 'v' |
|
|
|
|
|
|
|
|
|
return (ModelPatcher(model), clip, vae) |
|
|
|
|