Browse Source

Make VAE code closer to sgm.

pull/1789/head
comfyanonymous 1 year ago
parent
commit
d44a2de49f
  1. 3
      comfy/diffusers_load.py
  2. 370
      comfy/ldm/models/autoencoder.py
  3. 31
      comfy/ldm/modules/diffusionmodules/model.py
  4. 39
      comfy/sd.py
  5. 11
      comfy/utils.py
  6. 3
      nodes.py

3
comfy/diffusers_load.py

@ -31,6 +31,7 @@ def load_diffusers(model_path, output_vae=True, output_clip=True, embedding_dire
vae = None vae = None
if output_vae: if output_vae:
vae = comfy.sd.VAE(ckpt_path=vae_path) sd = comfy.utils.load_torch_file(vae_path)
vae = comfy.sd.VAE(sd=sd)
return (unet, clip, vae) return (unet, clip, vae)

370
comfy/ldm/models/autoencoder.py

@ -2,67 +2,66 @@ import torch
# import pytorch_lightning as pl # import pytorch_lightning as pl
import torch.nn.functional as F import torch.nn.functional as F
from contextlib import contextmanager from contextlib import contextmanager
from typing import Any, Dict, List, Optional, Tuple, Union
from comfy.ldm.modules.diffusionmodules.model import Encoder, Decoder
from comfy.ldm.modules.distributions.distributions import DiagonalGaussianDistribution from comfy.ldm.modules.distributions.distributions import DiagonalGaussianDistribution
from comfy.ldm.util import instantiate_from_config from comfy.ldm.util import instantiate_from_config
from comfy.ldm.modules.ema import LitEma from comfy.ldm.modules.ema import LitEma
# class AutoencoderKL(pl.LightningModule): class DiagonalGaussianRegularizer(torch.nn.Module):
class AutoencoderKL(torch.nn.Module): def __init__(self, sample: bool = True):
def __init__(self,
ddconfig,
lossconfig,
embed_dim,
ckpt_path=None,
ignore_keys=[],
image_key="image",
colorize_nlabels=None,
monitor=None,
ema_decay=None,
learn_logvar=False
):
super().__init__() super().__init__()
self.learn_logvar = learn_logvar self.sample = sample
self.image_key = image_key
self.encoder = Encoder(**ddconfig) def get_trainable_parameters(self) -> Any:
self.decoder = Decoder(**ddconfig) yield from ()
self.loss = instantiate_from_config(lossconfig)
assert ddconfig["double_z"] def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1) log = dict()
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) posterior = DiagonalGaussianDistribution(z)
self.embed_dim = embed_dim if self.sample:
if colorize_nlabels is not None: z = posterior.sample()
assert type(colorize_nlabels)==int else:
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) z = posterior.mode()
kl_loss = posterior.kl()
kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
log["kl_loss"] = kl_loss
return z, log
class AbstractAutoencoder(torch.nn.Module):
"""
This is the base class for all autoencoders, including image autoencoders, image autoencoders with discriminators,
unCLIP models, etc. Hence, it is fairly general, and specific features
(e.g. discriminator training, encoding, decoding) must be implemented in subclasses.
"""
def __init__(
self,
ema_decay: Union[None, float] = None,
monitor: Union[None, str] = None,
input_key: str = "jpg",
**kwargs,
):
super().__init__()
self.input_key = input_key
self.use_ema = ema_decay is not None
if monitor is not None: if monitor is not None:
self.monitor = monitor self.monitor = monitor
self.use_ema = ema_decay is not None
if self.use_ema: if self.use_ema:
self.ema_decay = ema_decay
assert 0. < ema_decay < 1.
self.model_ema = LitEma(self, decay=ema_decay) self.model_ema = LitEma(self, decay=ema_decay)
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") logpy.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
if ckpt_path is not None: def get_input(self, batch) -> Any:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) raise NotImplementedError()
def init_from_ckpt(self, path, ignore_keys=list()): def on_train_batch_end(self, *args, **kwargs):
if path.lower().endswith(".safetensors"): # for EMA computation
import safetensors.torch if self.use_ema:
sd = safetensors.torch.load_file(path, device="cpu") self.model_ema(self)
else:
sd = torch.load(path, map_location="cpu")["state_dict"]
keys = list(sd.keys())
for k in keys:
for ik in ignore_keys:
if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k))
del sd[k]
self.load_state_dict(sd, strict=False)
print(f"Restored from {path}")
@contextmanager @contextmanager
def ema_scope(self, context=None): def ema_scope(self, context=None):
@ -70,154 +69,159 @@ class AutoencoderKL(torch.nn.Module):
self.model_ema.store(self.parameters()) self.model_ema.store(self.parameters())
self.model_ema.copy_to(self) self.model_ema.copy_to(self)
if context is not None: if context is not None:
print(f"{context}: Switched to EMA weights") logpy.info(f"{context}: Switched to EMA weights")
try: try:
yield None yield None
finally: finally:
if self.use_ema: if self.use_ema:
self.model_ema.restore(self.parameters()) self.model_ema.restore(self.parameters())
if context is not None: if context is not None:
print(f"{context}: Restored training weights") logpy.info(f"{context}: Restored training weights")
def on_train_batch_end(self, *args, **kwargs): def encode(self, *args, **kwargs) -> torch.Tensor:
if self.use_ema: raise NotImplementedError("encode()-method of abstract base class called")
self.model_ema(self)
def decode(self, *args, **kwargs) -> torch.Tensor:
def encode(self, x): raise NotImplementedError("decode()-method of abstract base class called")
h = self.encoder(x)
moments = self.quant_conv(h) def instantiate_optimizer_from_config(self, params, lr, cfg):
posterior = DiagonalGaussianDistribution(moments) logpy.info(f"loading >>> {cfg['target']} <<< optimizer from config")
return posterior return get_obj_from_str(cfg["target"])(
params, lr=lr, **cfg.get("params", dict())
def decode(self, z): )
z = self.post_quant_conv(z)
dec = self.decoder(z) def configure_optimizers(self) -> Any:
return dec raise NotImplementedError()
def forward(self, input, sample_posterior=True):
posterior = self.encode(input) class AutoencodingEngine(AbstractAutoencoder):
if sample_posterior: """
z = posterior.sample() Base class for all image autoencoders that we train, like VQGAN or AutoencoderKL
else: (we also restore them explicitly as special cases for legacy reasons).
z = posterior.mode() Regularizations such as KL or VQ are moved to the regularizer class.
dec = self.decode(z) """
return dec, posterior
def __init__(
def get_input(self, batch, k): self,
x = batch[k] *args,
if len(x.shape) == 3: encoder_config: Dict,
x = x[..., None] decoder_config: Dict,
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() regularizer_config: Dict,
return x **kwargs,
):
def training_step(self, batch, batch_idx, optimizer_idx): super().__init__(*args, **kwargs)
inputs = self.get_input(batch, self.image_key)
reconstructions, posterior = self(inputs) self.encoder: torch.nn.Module = instantiate_from_config(encoder_config)
self.decoder: torch.nn.Module = instantiate_from_config(decoder_config)
if optimizer_idx == 0: self.regularization: AbstractRegularizer = instantiate_from_config(
# train encoder+decoder+logvar regularizer_config
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, )
last_layer=self.get_last_layer(), split="train")
self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
return aeloss
if optimizer_idx == 1:
# train the discriminator
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
last_layer=self.get_last_layer(), split="train")
self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
return discloss
def validation_step(self, batch, batch_idx):
log_dict = self._validation_step(batch, batch_idx)
with self.ema_scope():
log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema")
return log_dict
def _validation_step(self, batch, batch_idx, postfix=""):
inputs = self.get_input(batch, self.image_key)
reconstructions, posterior = self(inputs)
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
last_layer=self.get_last_layer(), split="val"+postfix)
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
last_layer=self.get_last_layer(), split="val"+postfix)
self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"])
self.log_dict(log_dict_ae)
self.log_dict(log_dict_disc)
return self.log_dict
def configure_optimizers(self):
lr = self.learning_rate
ae_params_list = list(self.encoder.parameters()) + list(self.decoder.parameters()) + list(
self.quant_conv.parameters()) + list(self.post_quant_conv.parameters())
if self.learn_logvar:
print(f"{self.__class__.__name__}: Learning logvar")
ae_params_list.append(self.loss.logvar)
opt_ae = torch.optim.Adam(ae_params_list,
lr=lr, betas=(0.5, 0.9))
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
lr=lr, betas=(0.5, 0.9))
return [opt_ae, opt_disc], []
def get_last_layer(self): def get_last_layer(self):
return self.decoder.conv_out.weight return self.decoder.get_last_layer()
@torch.no_grad() def encode(
def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs): self,
log = dict() x: torch.Tensor,
x = self.get_input(batch, self.image_key) return_reg_log: bool = False,
x = x.to(self.device) unregularized: bool = False,
if not only_inputs: ) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
xrec, posterior = self(x) z = self.encoder(x)
if x.shape[1] > 3: if unregularized:
# colorize with random projection return z, dict()
assert xrec.shape[1] > 3 z, reg_log = self.regularization(z)
x = self.to_rgb(x) if return_reg_log:
xrec = self.to_rgb(xrec) return z, reg_log
log["samples"] = self.decode(torch.randn_like(posterior.sample())) return z
log["reconstructions"] = xrec
if log_ema or self.use_ema: def decode(self, z: torch.Tensor, **kwargs) -> torch.Tensor:
with self.ema_scope(): x = self.decoder(z, **kwargs)
xrec_ema, posterior_ema = self(x)
if x.shape[1] > 3:
# colorize with random projection
assert xrec_ema.shape[1] > 3
xrec_ema = self.to_rgb(xrec_ema)
log["samples_ema"] = self.decode(torch.randn_like(posterior_ema.sample()))
log["reconstructions_ema"] = xrec_ema
log["inputs"] = x
return log
def to_rgb(self, x):
assert self.image_key == "segmentation"
if not hasattr(self, "colorize"):
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
x = F.conv2d(x, weight=self.colorize)
x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
return x return x
def forward(
self, x: torch.Tensor, **additional_decode_kwargs
) -> Tuple[torch.Tensor, torch.Tensor, dict]:
z, reg_log = self.encode(x, return_reg_log=True)
dec = self.decode(z, **additional_decode_kwargs)
return z, dec, reg_log
class AutoencodingEngineLegacy(AutoencodingEngine):
def __init__(self, embed_dim: int, **kwargs):
self.max_batch_size = kwargs.pop("max_batch_size", None)
ddconfig = kwargs.pop("ddconfig")
super().__init__(
encoder_config={
"target": "comfy.ldm.modules.diffusionmodules.model.Encoder",
"params": ddconfig,
},
decoder_config={
"target": "comfy.ldm.modules.diffusionmodules.model.Decoder",
"params": ddconfig,
},
**kwargs,
)
self.quant_conv = torch.nn.Conv2d(
(1 + ddconfig["double_z"]) * ddconfig["z_channels"],
(1 + ddconfig["double_z"]) * embed_dim,
1,
)
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
self.embed_dim = embed_dim
class IdentityFirstStage(torch.nn.Module): def get_autoencoder_params(self) -> list:
def __init__(self, *args, vq_interface=False, **kwargs): params = super().get_autoencoder_params()
self.vq_interface = vq_interface return params
super().__init__()
def encode(self, x, *args, **kwargs):
return x
def decode(self, x, *args, **kwargs): def encode(
return x self, x: torch.Tensor, return_reg_log: bool = False
) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
if self.max_batch_size is None:
z = self.encoder(x)
z = self.quant_conv(z)
else:
N = x.shape[0]
bs = self.max_batch_size
n_batches = int(math.ceil(N / bs))
z = list()
for i_batch in range(n_batches):
z_batch = self.encoder(x[i_batch * bs : (i_batch + 1) * bs])
z_batch = self.quant_conv(z_batch)
z.append(z_batch)
z = torch.cat(z, 0)
z, reg_log = self.regularization(z)
if return_reg_log:
return z, reg_log
return z
def decode(self, z: torch.Tensor, **decoder_kwargs) -> torch.Tensor:
if self.max_batch_size is None:
dec = self.post_quant_conv(z)
dec = self.decoder(dec, **decoder_kwargs)
else:
N = z.shape[0]
bs = self.max_batch_size
n_batches = int(math.ceil(N / bs))
dec = list()
for i_batch in range(n_batches):
dec_batch = self.post_quant_conv(z[i_batch * bs : (i_batch + 1) * bs])
dec_batch = self.decoder(dec_batch, **decoder_kwargs)
dec.append(dec_batch)
dec = torch.cat(dec, 0)
def quantize(self, x, *args, **kwargs): return dec
if self.vq_interface:
return x, None, [None, None, None]
return x
def forward(self, x, *args, **kwargs):
return x
class AutoencoderKL(AutoencodingEngineLegacy):
def __init__(self, **kwargs):
if "lossconfig" in kwargs:
kwargs["loss_config"] = kwargs.pop("lossconfig")
super().__init__(
regularizer_config={
"target": (
"comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"
)
},
**kwargs,
)

31
comfy/ldm/modules/diffusionmodules/model.py

@ -541,7 +541,10 @@ class Decoder(nn.Module):
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False, resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
attn_type="vanilla", **ignorekwargs): conv_out_op=comfy.ops.Conv2d,
resnet_op=ResnetBlock,
attn_op=AttnBlock,
**ignorekwargs):
super().__init__() super().__init__()
if use_linear_attn: attn_type = "linear" if use_linear_attn: attn_type = "linear"
self.ch = ch self.ch = ch
@ -570,12 +573,12 @@ class Decoder(nn.Module):
# middle # middle
self.mid = nn.Module() self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=block_in, self.mid.block_1 = resnet_op(in_channels=block_in,
out_channels=block_in, out_channels=block_in,
temb_channels=self.temb_ch, temb_channels=self.temb_ch,
dropout=dropout) dropout=dropout)
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) self.mid.attn_1 = attn_op(block_in)
self.mid.block_2 = ResnetBlock(in_channels=block_in, self.mid.block_2 = resnet_op(in_channels=block_in,
out_channels=block_in, out_channels=block_in,
temb_channels=self.temb_ch, temb_channels=self.temb_ch,
dropout=dropout) dropout=dropout)
@ -587,13 +590,13 @@ class Decoder(nn.Module):
attn = nn.ModuleList() attn = nn.ModuleList()
block_out = ch*ch_mult[i_level] block_out = ch*ch_mult[i_level]
for i_block in range(self.num_res_blocks+1): for i_block in range(self.num_res_blocks+1):
block.append(ResnetBlock(in_channels=block_in, block.append(resnet_op(in_channels=block_in,
out_channels=block_out, out_channels=block_out,
temb_channels=self.temb_ch, temb_channels=self.temb_ch,
dropout=dropout)) dropout=dropout))
block_in = block_out block_in = block_out
if curr_res in attn_resolutions: if curr_res in attn_resolutions:
attn.append(make_attn(block_in, attn_type=attn_type)) attn.append(attn_op(block_in))
up = nn.Module() up = nn.Module()
up.block = block up.block = block
up.attn = attn up.attn = attn
@ -604,13 +607,13 @@ class Decoder(nn.Module):
# end # end
self.norm_out = Normalize(block_in) self.norm_out = Normalize(block_in)
self.conv_out = comfy.ops.Conv2d(block_in, self.conv_out = conv_out_op(block_in,
out_ch, out_ch,
kernel_size=3, kernel_size=3,
stride=1, stride=1,
padding=1) padding=1)
def forward(self, z): def forward(self, z, **kwargs):
#assert z.shape[1:] == self.z_shape[1:] #assert z.shape[1:] == self.z_shape[1:]
self.last_z_shape = z.shape self.last_z_shape = z.shape
@ -621,16 +624,16 @@ class Decoder(nn.Module):
h = self.conv_in(z) h = self.conv_in(z)
# middle # middle
h = self.mid.block_1(h, temb) h = self.mid.block_1(h, temb, **kwargs)
h = self.mid.attn_1(h) h = self.mid.attn_1(h, **kwargs)
h = self.mid.block_2(h, temb) h = self.mid.block_2(h, temb, **kwargs)
# upsampling # upsampling
for i_level in reversed(range(self.num_resolutions)): for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks+1): for i_block in range(self.num_res_blocks+1):
h = self.up[i_level].block[i_block](h, temb) h = self.up[i_level].block[i_block](h, temb, **kwargs)
if len(self.up[i_level].attn) > 0: if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h) h = self.up[i_level].attn[i_block](h, **kwargs)
if i_level != 0: if i_level != 0:
h = self.up[i_level].upsample(h) h = self.up[i_level].upsample(h)
@ -640,7 +643,7 @@ class Decoder(nn.Module):
h = self.norm_out(h) h = self.norm_out(h)
h = nonlinearity(h) h = nonlinearity(h)
h = self.conv_out(h) h = self.conv_out(h, **kwargs)
if self.tanh_out: if self.tanh_out:
h = torch.tanh(h) h = torch.tanh(h)
return h return h

39
comfy/sd.py

@ -4,7 +4,7 @@ import math
from comfy import model_management from comfy import model_management
from .ldm.util import instantiate_from_config from .ldm.util import instantiate_from_config
from .ldm.models.autoencoder import AutoencoderKL from .ldm.models.autoencoder import AutoencoderKL, AutoencodingEngine
import yaml import yaml
import comfy.utils import comfy.utils
@ -140,21 +140,24 @@ class CLIP:
return self.patcher.get_key_patches() return self.patcher.get_key_patches()
class VAE: class VAE:
def __init__(self, ckpt_path=None, device=None, config=None): def __init__(self, sd=None, device=None, config=None):
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
sd = diffusers_convert.convert_vae_state_dict(sd)
if config is None: if config is None:
#default SD1.x/SD2.x VAE parameters #default SD1.x/SD2.x VAE parameters
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0} ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
self.first_stage_model = AutoencoderKL(ddconfig, {'target': 'torch.nn.Identity'}, 4, monitor="val/rec_loss") self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=4)
else: else:
self.first_stage_model = AutoencoderKL(**(config['params'])) self.first_stage_model = AutoencoderKL(**(config['params']))
self.first_stage_model = self.first_stage_model.eval() self.first_stage_model = self.first_stage_model.eval()
if ckpt_path is not None:
sd = comfy.utils.load_torch_file(ckpt_path) m, u = self.first_stage_model.load_state_dict(sd, strict=False)
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format if len(m) > 0:
sd = diffusers_convert.convert_vae_state_dict(sd) print("Missing VAE keys", m)
m, u = self.first_stage_model.load_state_dict(sd, strict=False)
if len(m) > 0: if len(u) > 0:
print("Missing VAE keys", m) print("Leftover VAE keys", u)
if device is None: if device is None:
device = model_management.vae_device() device = model_management.vae_device()
@ -183,7 +186,7 @@ class VAE:
steps += pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x * 2, tile_y // 2, overlap) steps += pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x * 2, tile_y // 2, overlap)
pbar = comfy.utils.ProgressBar(steps) pbar = comfy.utils.ProgressBar(steps)
encode_fn = lambda a: self.first_stage_model.encode((2. * a - 1.).to(self.vae_dtype).to(self.device)).sample().float() encode_fn = lambda a: self.first_stage_model.encode((2. * a - 1.).to(self.vae_dtype).to(self.device)).float()
samples = comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) samples = comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar)
samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar)
samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar)
@ -229,7 +232,7 @@ class VAE:
samples = torch.empty((pixel_samples.shape[0], 4, round(pixel_samples.shape[2] // 8), round(pixel_samples.shape[3] // 8)), device="cpu") samples = torch.empty((pixel_samples.shape[0], 4, round(pixel_samples.shape[2] // 8), round(pixel_samples.shape[3] // 8)), device="cpu")
for x in range(0, pixel_samples.shape[0], batch_number): for x in range(0, pixel_samples.shape[0], batch_number):
pixels_in = (2. * pixel_samples[x:x+batch_number] - 1.).to(self.vae_dtype).to(self.device) pixels_in = (2. * pixel_samples[x:x+batch_number] - 1.).to(self.vae_dtype).to(self.device)
samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).sample().cpu().float() samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).cpu().float()
except model_management.OOM_EXCEPTION as e: except model_management.OOM_EXCEPTION as e:
print("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.") print("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
@ -375,10 +378,8 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
model.load_model_weights(state_dict, "model.diffusion_model.") model.load_model_weights(state_dict, "model.diffusion_model.")
if output_vae: if output_vae:
w = WeightsLoader() vae_sd = comfy.utils.state_dict_prefix_replace(state_dict, {"first_stage_model.": ""}, filter_keys=True)
vae = VAE(config=vae_config) vae = VAE(sd=vae_sd, config=vae_config)
w.first_stage_model = vae.first_stage_model
load_model_weights(w, state_dict)
if output_clip: if output_clip:
w = WeightsLoader() w = WeightsLoader()
@ -427,10 +428,8 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
model.load_model_weights(sd, "model.diffusion_model.") model.load_model_weights(sd, "model.diffusion_model.")
if output_vae: if output_vae:
vae = VAE() vae_sd = comfy.utils.state_dict_prefix_replace(sd, {"first_stage_model.": ""}, filter_keys=True)
w = WeightsLoader() vae = VAE(sd=vae_sd)
w.first_stage_model = vae.first_stage_model
load_model_weights(w, sd)
if output_clip: if output_clip:
w = WeightsLoader() w = WeightsLoader()

11
comfy/utils.py

@ -47,12 +47,17 @@ def state_dict_key_replace(state_dict, keys_to_replace):
state_dict[keys_to_replace[x]] = state_dict.pop(x) state_dict[keys_to_replace[x]] = state_dict.pop(x)
return state_dict return state_dict
def state_dict_prefix_replace(state_dict, replace_prefix): def state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=False):
if filter_keys:
out = {}
else:
out = state_dict
for rp in replace_prefix: for rp in replace_prefix:
replace = list(map(lambda a: (a, "{}{}".format(replace_prefix[rp], a[len(rp):])), filter(lambda a: a.startswith(rp), state_dict.keys()))) replace = list(map(lambda a: (a, "{}{}".format(replace_prefix[rp], a[len(rp):])), filter(lambda a: a.startswith(rp), state_dict.keys())))
for x in replace: for x in replace:
state_dict[x[1]] = state_dict.pop(x[0]) w = state_dict.pop(x[0])
return state_dict out[x[1]] = w
return out
def transformers_convert(sd, prefix_from, prefix_to, number): def transformers_convert(sd, prefix_from, prefix_to, number):

3
nodes.py

@ -584,7 +584,8 @@ class VAELoader:
#TODO: scale factor? #TODO: scale factor?
def load_vae(self, vae_name): def load_vae(self, vae_name):
vae_path = folder_paths.get_full_path("vae", vae_name) vae_path = folder_paths.get_full_path("vae", vae_name)
vae = comfy.sd.VAE(ckpt_path=vae_path) sd = comfy.utils.load_torch_file(vae_path)
vae = comfy.sd.VAE(sd=sd)
return (vae,) return (vae,)
class ControlNetLoader: class ControlNetLoader:

Loading…
Cancel
Save