comfyanonymous
1 year ago
8 changed files with 0 additions and 870 deletions
@ -1,105 +0,0 @@
|
||||
from functools import reduce |
||||
import math |
||||
import operator |
||||
|
||||
import numpy as np |
||||
from skimage import transform |
||||
import torch |
||||
from torch import nn |
||||
|
||||
|
||||
def translate2d(tx, ty): |
||||
mat = [[1, 0, tx], |
||||
[0, 1, ty], |
||||
[0, 0, 1]] |
||||
return torch.tensor(mat, dtype=torch.float32) |
||||
|
||||
|
||||
def scale2d(sx, sy): |
||||
mat = [[sx, 0, 0], |
||||
[ 0, sy, 0], |
||||
[ 0, 0, 1]] |
||||
return torch.tensor(mat, dtype=torch.float32) |
||||
|
||||
|
||||
def rotate2d(theta): |
||||
mat = [[torch.cos(theta), torch.sin(-theta), 0], |
||||
[torch.sin(theta), torch.cos(theta), 0], |
||||
[ 0, 0, 1]] |
||||
return torch.tensor(mat, dtype=torch.float32) |
||||
|
||||
|
||||
class KarrasAugmentationPipeline: |
||||
def __init__(self, a_prob=0.12, a_scale=2**0.2, a_aniso=2**0.2, a_trans=1/8): |
||||
self.a_prob = a_prob |
||||
self.a_scale = a_scale |
||||
self.a_aniso = a_aniso |
||||
self.a_trans = a_trans |
||||
|
||||
def __call__(self, image): |
||||
h, w = image.size |
||||
mats = [translate2d(h / 2 - 0.5, w / 2 - 0.5)] |
||||
|
||||
# x-flip |
||||
a0 = torch.randint(2, []).float() |
||||
mats.append(scale2d(1 - 2 * a0, 1)) |
||||
# y-flip |
||||
do = (torch.rand([]) < self.a_prob).float() |
||||
a1 = torch.randint(2, []).float() * do |
||||
mats.append(scale2d(1, 1 - 2 * a1)) |
||||
# scaling |
||||
do = (torch.rand([]) < self.a_prob).float() |
||||
a2 = torch.randn([]) * do |
||||
mats.append(scale2d(self.a_scale ** a2, self.a_scale ** a2)) |
||||
# rotation |
||||
do = (torch.rand([]) < self.a_prob).float() |
||||
a3 = (torch.rand([]) * 2 * math.pi - math.pi) * do |
||||
mats.append(rotate2d(-a3)) |
||||
# anisotropy |
||||
do = (torch.rand([]) < self.a_prob).float() |
||||
a4 = (torch.rand([]) * 2 * math.pi - math.pi) * do |
||||
a5 = torch.randn([]) * do |
||||
mats.append(rotate2d(a4)) |
||||
mats.append(scale2d(self.a_aniso ** a5, self.a_aniso ** -a5)) |
||||
mats.append(rotate2d(-a4)) |
||||
# translation |
||||
do = (torch.rand([]) < self.a_prob).float() |
||||
a6 = torch.randn([]) * do |
||||
a7 = torch.randn([]) * do |
||||
mats.append(translate2d(self.a_trans * w * a6, self.a_trans * h * a7)) |
||||
|
||||
# form the transformation matrix and conditioning vector |
||||
mats.append(translate2d(-h / 2 + 0.5, -w / 2 + 0.5)) |
||||
mat = reduce(operator.matmul, mats) |
||||
cond = torch.stack([a0, a1, a2, a3.cos() - 1, a3.sin(), a5 * a4.cos(), a5 * a4.sin(), a6, a7]) |
||||
|
||||
# apply the transformation |
||||
image_orig = np.array(image, dtype=np.float32) / 255 |
||||
if image_orig.ndim == 2: |
||||
image_orig = image_orig[..., None] |
||||
tf = transform.AffineTransform(mat.numpy()) |
||||
image = transform.warp(image_orig, tf.inverse, order=3, mode='reflect', cval=0.5, clip=False, preserve_range=True) |
||||
image_orig = torch.as_tensor(image_orig).movedim(2, 0) * 2 - 1 |
||||
image = torch.as_tensor(image).movedim(2, 0) * 2 - 1 |
||||
return image, image_orig, cond |
||||
|
||||
|
||||
class KarrasAugmentWrapper(nn.Module): |
||||
def __init__(self, model): |
||||
super().__init__() |
||||
self.inner_model = model |
||||
|
||||
def forward(self, input, sigma, aug_cond=None, mapping_cond=None, **kwargs): |
||||
if aug_cond is None: |
||||
aug_cond = input.new_zeros([input.shape[0], 9]) |
||||
if mapping_cond is None: |
||||
mapping_cond = aug_cond |
||||
else: |
||||
mapping_cond = torch.cat([aug_cond, mapping_cond], dim=1) |
||||
return self.inner_model(input, sigma, mapping_cond=mapping_cond, **kwargs) |
||||
|
||||
def set_skip_stages(self, skip_stages): |
||||
return self.inner_model.set_skip_stages(skip_stages) |
||||
|
||||
def set_patch_size(self, patch_size): |
||||
return self.inner_model.set_patch_size(patch_size) |
@ -1,110 +0,0 @@
|
||||
from functools import partial |
||||
import json |
||||
import math |
||||
import warnings |
||||
|
||||
from jsonmerge import merge |
||||
|
||||
from . import augmentation, layers, models, utils |
||||
|
||||
|
||||
def load_config(file): |
||||
defaults = { |
||||
'model': { |
||||
'sigma_data': 1., |
||||
'patch_size': 1, |
||||
'dropout_rate': 0., |
||||
'augment_wrapper': True, |
||||
'augment_prob': 0., |
||||
'mapping_cond_dim': 0, |
||||
'unet_cond_dim': 0, |
||||
'cross_cond_dim': 0, |
||||
'cross_attn_depths': None, |
||||
'skip_stages': 0, |
||||
'has_variance': False, |
||||
}, |
||||
'dataset': { |
||||
'type': 'imagefolder', |
||||
}, |
||||
'optimizer': { |
||||
'type': 'adamw', |
||||
'lr': 1e-4, |
||||
'betas': [0.95, 0.999], |
||||
'eps': 1e-6, |
||||
'weight_decay': 1e-3, |
||||
}, |
||||
'lr_sched': { |
||||
'type': 'inverse', |
||||
'inv_gamma': 20000., |
||||
'power': 1., |
||||
'warmup': 0.99, |
||||
}, |
||||
'ema_sched': { |
||||
'type': 'inverse', |
||||
'power': 0.6667, |
||||
'max_value': 0.9999 |
||||
}, |
||||
} |
||||
config = json.load(file) |
||||
return merge(defaults, config) |
||||
|
||||
|
||||
def make_model(config): |
||||
config = config['model'] |
||||
assert config['type'] == 'image_v1' |
||||
model = models.ImageDenoiserModelV1( |
||||
config['input_channels'], |
||||
config['mapping_out'], |
||||
config['depths'], |
||||
config['channels'], |
||||
config['self_attn_depths'], |
||||
config['cross_attn_depths'], |
||||
patch_size=config['patch_size'], |
||||
dropout_rate=config['dropout_rate'], |
||||
mapping_cond_dim=config['mapping_cond_dim'] + (9 if config['augment_wrapper'] else 0), |
||||
unet_cond_dim=config['unet_cond_dim'], |
||||
cross_cond_dim=config['cross_cond_dim'], |
||||
skip_stages=config['skip_stages'], |
||||
has_variance=config['has_variance'], |
||||
) |
||||
if config['augment_wrapper']: |
||||
model = augmentation.KarrasAugmentWrapper(model) |
||||
return model |
||||
|
||||
|
||||
def make_denoiser_wrapper(config): |
||||
config = config['model'] |
||||
sigma_data = config.get('sigma_data', 1.) |
||||
has_variance = config.get('has_variance', False) |
||||
if not has_variance: |
||||
return partial(layers.Denoiser, sigma_data=sigma_data) |
||||
return partial(layers.DenoiserWithVariance, sigma_data=sigma_data) |
||||
|
||||
|
||||
def make_sample_density(config): |
||||
sd_config = config['sigma_sample_density'] |
||||
sigma_data = config['sigma_data'] |
||||
if sd_config['type'] == 'lognormal': |
||||
loc = sd_config['mean'] if 'mean' in sd_config else sd_config['loc'] |
||||
scale = sd_config['std'] if 'std' in sd_config else sd_config['scale'] |
||||
return partial(utils.rand_log_normal, loc=loc, scale=scale) |
||||
if sd_config['type'] == 'loglogistic': |
||||
loc = sd_config['loc'] if 'loc' in sd_config else math.log(sigma_data) |
||||
scale = sd_config['scale'] if 'scale' in sd_config else 0.5 |
||||
min_value = sd_config['min_value'] if 'min_value' in sd_config else 0. |
||||
max_value = sd_config['max_value'] if 'max_value' in sd_config else float('inf') |
||||
return partial(utils.rand_log_logistic, loc=loc, scale=scale, min_value=min_value, max_value=max_value) |
||||
if sd_config['type'] == 'loguniform': |
||||
min_value = sd_config['min_value'] if 'min_value' in sd_config else config['sigma_min'] |
||||
max_value = sd_config['max_value'] if 'max_value' in sd_config else config['sigma_max'] |
||||
return partial(utils.rand_log_uniform, min_value=min_value, max_value=max_value) |
||||
if sd_config['type'] == 'v-diffusion': |
||||
min_value = sd_config['min_value'] if 'min_value' in sd_config else 0. |
||||
max_value = sd_config['max_value'] if 'max_value' in sd_config else float('inf') |
||||
return partial(utils.rand_v_diffusion, sigma_data=sigma_data, min_value=min_value, max_value=max_value) |
||||
if sd_config['type'] == 'split-lognormal': |
||||
loc = sd_config['mean'] if 'mean' in sd_config else sd_config['loc'] |
||||
scale_1 = sd_config['std_1'] if 'std_1' in sd_config else sd_config['scale_1'] |
||||
scale_2 = sd_config['std_2'] if 'std_2' in sd_config else sd_config['scale_2'] |
||||
return partial(utils.rand_split_log_normal, loc=loc, scale_1=scale_1, scale_2=scale_2) |
||||
raise ValueError('Unknown sample density type') |
@ -1,134 +0,0 @@
|
||||
import math |
||||
import os |
||||
from pathlib import Path |
||||
|
||||
from cleanfid.inception_torchscript import InceptionV3W |
||||
import clip |
||||
from resize_right import resize |
||||
import torch |
||||
from torch import nn |
||||
from torch.nn import functional as F |
||||
from torchvision import transforms |
||||
from tqdm.auto import trange |
||||
|
||||
from . import utils |
||||
|
||||
|
||||
class InceptionV3FeatureExtractor(nn.Module): |
||||
def __init__(self, device='cpu'): |
||||
super().__init__() |
||||
path = Path(os.environ.get('XDG_CACHE_HOME', Path.home() / '.cache')) / 'k-diffusion' |
||||
url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt' |
||||
digest = 'f58cb9b6ec323ed63459aa4fb441fe750cfe39fafad6da5cb504a16f19e958f4' |
||||
utils.download_file(path / 'inception-2015-12-05.pt', url, digest) |
||||
self.model = InceptionV3W(str(path), resize_inside=False).to(device) |
||||
self.size = (299, 299) |
||||
|
||||
def forward(self, x): |
||||
if x.shape[2:4] != self.size: |
||||
x = resize(x, out_shape=self.size, pad_mode='reflect') |
||||
if x.shape[1] == 1: |
||||
x = torch.cat([x] * 3, dim=1) |
||||
x = (x * 127.5 + 127.5).clamp(0, 255) |
||||
return self.model(x) |
||||
|
||||
|
||||
class CLIPFeatureExtractor(nn.Module): |
||||
def __init__(self, name='ViT-L/14@336px', device='cpu'): |
||||
super().__init__() |
||||
self.model = clip.load(name, device=device)[0].eval().requires_grad_(False) |
||||
self.normalize = transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), |
||||
std=(0.26862954, 0.26130258, 0.27577711)) |
||||
self.size = (self.model.visual.input_resolution, self.model.visual.input_resolution) |
||||
|
||||
def forward(self, x): |
||||
if x.shape[2:4] != self.size: |
||||
x = resize(x.add(1).div(2), out_shape=self.size, pad_mode='reflect').clamp(0, 1) |
||||
x = self.normalize(x) |
||||
x = self.model.encode_image(x).float() |
||||
x = F.normalize(x) * x.shape[1] ** 0.5 |
||||
return x |
||||
|
||||
|
||||
def compute_features(accelerator, sample_fn, extractor_fn, n, batch_size): |
||||
n_per_proc = math.ceil(n / accelerator.num_processes) |
||||
feats_all = [] |
||||
try: |
||||
for i in trange(0, n_per_proc, batch_size, disable=not accelerator.is_main_process): |
||||
cur_batch_size = min(n - i, batch_size) |
||||
samples = sample_fn(cur_batch_size)[:cur_batch_size] |
||||
feats_all.append(accelerator.gather(extractor_fn(samples))) |
||||
except StopIteration: |
||||
pass |
||||
return torch.cat(feats_all)[:n] |
||||
|
||||
|
||||
def polynomial_kernel(x, y): |
||||
d = x.shape[-1] |
||||
dot = x @ y.transpose(-2, -1) |
||||
return (dot / d + 1) ** 3 |
||||
|
||||
|
||||
def squared_mmd(x, y, kernel=polynomial_kernel): |
||||
m = x.shape[-2] |
||||
n = y.shape[-2] |
||||
kxx = kernel(x, x) |
||||
kyy = kernel(y, y) |
||||
kxy = kernel(x, y) |
||||
kxx_sum = kxx.sum([-1, -2]) - kxx.diagonal(dim1=-1, dim2=-2).sum(-1) |
||||
kyy_sum = kyy.sum([-1, -2]) - kyy.diagonal(dim1=-1, dim2=-2).sum(-1) |
||||
kxy_sum = kxy.sum([-1, -2]) |
||||
term_1 = kxx_sum / m / (m - 1) |
||||
term_2 = kyy_sum / n / (n - 1) |
||||
term_3 = kxy_sum * 2 / m / n |
||||
return term_1 + term_2 - term_3 |
||||
|
||||
|
||||
@utils.tf32_mode(matmul=False) |
||||
def kid(x, y, max_size=5000): |
||||
x_size, y_size = x.shape[0], y.shape[0] |
||||
n_partitions = math.ceil(max(x_size / max_size, y_size / max_size)) |
||||
total_mmd = x.new_zeros([]) |
||||
for i in range(n_partitions): |
||||
cur_x = x[round(i * x_size / n_partitions):round((i + 1) * x_size / n_partitions)] |
||||
cur_y = y[round(i * y_size / n_partitions):round((i + 1) * y_size / n_partitions)] |
||||
total_mmd = total_mmd + squared_mmd(cur_x, cur_y) |
||||
return total_mmd / n_partitions |
||||
|
||||
|
||||
class _MatrixSquareRootEig(torch.autograd.Function): |
||||
@staticmethod |
||||
def forward(ctx, a): |
||||
vals, vecs = torch.linalg.eigh(a) |
||||
ctx.save_for_backward(vals, vecs) |
||||
return vecs @ vals.abs().sqrt().diag_embed() @ vecs.transpose(-2, -1) |
||||
|
||||
@staticmethod |
||||
def backward(ctx, grad_output): |
||||
vals, vecs = ctx.saved_tensors |
||||
d = vals.abs().sqrt().unsqueeze(-1).repeat_interleave(vals.shape[-1], -1) |
||||
vecs_t = vecs.transpose(-2, -1) |
||||
return vecs @ (vecs_t @ grad_output @ vecs / (d + d.transpose(-2, -1))) @ vecs_t |
||||
|
||||
|
||||
def sqrtm_eig(a): |
||||
if a.ndim < 2: |
||||
raise RuntimeError('tensor of matrices must have at least 2 dimensions') |
||||
if a.shape[-2] != a.shape[-1]: |
||||
raise RuntimeError('tensor must be batches of square matrices') |
||||
return _MatrixSquareRootEig.apply(a) |
||||
|
||||
|
||||
@utils.tf32_mode(matmul=False) |
||||
def fid(x, y, eps=1e-8): |
||||
x_mean = x.mean(dim=0) |
||||
y_mean = y.mean(dim=0) |
||||
mean_term = (x_mean - y_mean).pow(2).sum() |
||||
x_cov = torch.cov(x.T) |
||||
y_cov = torch.cov(y.T) |
||||
eps_eye = torch.eye(x_cov.shape[0], device=x_cov.device, dtype=x_cov.dtype) * eps |
||||
x_cov = x_cov + eps_eye |
||||
y_cov = y_cov + eps_eye |
||||
x_cov_sqrt = sqrtm_eig(x_cov) |
||||
cov_term = torch.trace(x_cov + y_cov - 2 * sqrtm_eig(x_cov_sqrt @ y_cov @ x_cov_sqrt)) |
||||
return mean_term + cov_term |
@ -1,99 +0,0 @@
|
||||
import torch |
||||
from torch import nn |
||||
|
||||
|
||||
class DDPGradientStatsHook: |
||||
def __init__(self, ddp_module): |
||||
try: |
||||
ddp_module.register_comm_hook(self, self._hook_fn) |
||||
except AttributeError: |
||||
raise ValueError('DDPGradientStatsHook does not support non-DDP wrapped modules') |
||||
self._clear_state() |
||||
|
||||
def _clear_state(self): |
||||
self.bucket_sq_norms_small_batch = [] |
||||
self.bucket_sq_norms_large_batch = [] |
||||
|
||||
@staticmethod |
||||
def _hook_fn(self, bucket): |
||||
buf = bucket.buffer() |
||||
self.bucket_sq_norms_small_batch.append(buf.pow(2).sum()) |
||||
fut = torch.distributed.all_reduce(buf, op=torch.distributed.ReduceOp.AVG, async_op=True).get_future() |
||||
def callback(fut): |
||||
buf = fut.value()[0] |
||||
self.bucket_sq_norms_large_batch.append(buf.pow(2).sum()) |
||||
return buf |
||||
return fut.then(callback) |
||||
|
||||
def get_stats(self): |
||||
sq_norm_small_batch = sum(self.bucket_sq_norms_small_batch) |
||||
sq_norm_large_batch = sum(self.bucket_sq_norms_large_batch) |
||||
self._clear_state() |
||||
stats = torch.stack([sq_norm_small_batch, sq_norm_large_batch]) |
||||
torch.distributed.all_reduce(stats, op=torch.distributed.ReduceOp.AVG) |
||||
return stats[0].item(), stats[1].item() |
||||
|
||||
|
||||
class GradientNoiseScale: |
||||
"""Calculates the gradient noise scale (1 / SNR), or critical batch size, |
||||
from _An Empirical Model of Large-Batch Training_, |
||||
https://arxiv.org/abs/1812.06162). |
||||
|
||||
Args: |
||||
beta (float): The decay factor for the exponential moving averages used to |
||||
calculate the gradient noise scale. |
||||
Default: 0.9998 |
||||
eps (float): Added for numerical stability. |
||||
Default: 1e-8 |
||||
""" |
||||
|
||||
def __init__(self, beta=0.9998, eps=1e-8): |
||||
self.beta = beta |
||||
self.eps = eps |
||||
self.ema_sq_norm = 0. |
||||
self.ema_var = 0. |
||||
self.beta_cumprod = 1. |
||||
self.gradient_noise_scale = float('nan') |
||||
|
||||
def state_dict(self): |
||||
"""Returns the state of the object as a :class:`dict`.""" |
||||
return dict(self.__dict__.items()) |
||||
|
||||
def load_state_dict(self, state_dict): |
||||
"""Loads the object's state. |
||||
Args: |
||||
state_dict (dict): object state. Should be an object returned |
||||
from a call to :meth:`state_dict`. |
||||
""" |
||||
self.__dict__.update(state_dict) |
||||
|
||||
def update(self, sq_norm_small_batch, sq_norm_large_batch, n_small_batch, n_large_batch): |
||||
"""Updates the state with a new batch's gradient statistics, and returns the |
||||
current gradient noise scale. |
||||
|
||||
Args: |
||||
sq_norm_small_batch (float): The mean of the squared 2-norms of microbatch or |
||||
per sample gradients. |
||||
sq_norm_large_batch (float): The squared 2-norm of the mean of the microbatch or |
||||
per sample gradients. |
||||
n_small_batch (int): The batch size of the individual microbatch or per sample |
||||
gradients (1 if per sample). |
||||
n_large_batch (int): The total batch size of the mean of the microbatch or |
||||
per sample gradients. |
||||
""" |
||||
est_sq_norm = (n_large_batch * sq_norm_large_batch - n_small_batch * sq_norm_small_batch) / (n_large_batch - n_small_batch) |
||||
est_var = (sq_norm_small_batch - sq_norm_large_batch) / (1 / n_small_batch - 1 / n_large_batch) |
||||
self.ema_sq_norm = self.beta * self.ema_sq_norm + (1 - self.beta) * est_sq_norm |
||||
self.ema_var = self.beta * self.ema_var + (1 - self.beta) * est_var |
||||
self.beta_cumprod *= self.beta |
||||
self.gradient_noise_scale = max(self.ema_var, self.eps) / max(self.ema_sq_norm, self.eps) |
||||
return self.gradient_noise_scale |
||||
|
||||
def get_gns(self): |
||||
"""Returns the current gradient noise scale.""" |
||||
return self.gradient_noise_scale |
||||
|
||||
def get_stats(self): |
||||
"""Returns the current (debiased) estimates of the squared mean gradient |
||||
and gradient variance.""" |
||||
return self.ema_sq_norm / (1 - self.beta_cumprod), self.ema_var / (1 - self.beta_cumprod) |
@ -1,246 +0,0 @@
|
||||
import math |
||||
|
||||
from einops import rearrange, repeat |
||||
import torch |
||||
from torch import nn |
||||
from torch.nn import functional as F |
||||
|
||||
from . import utils |
||||
|
||||
# Karras et al. preconditioned denoiser |
||||
|
||||
class Denoiser(nn.Module): |
||||
"""A Karras et al. preconditioner for denoising diffusion models.""" |
||||
|
||||
def __init__(self, inner_model, sigma_data=1.): |
||||
super().__init__() |
||||
self.inner_model = inner_model |
||||
self.sigma_data = sigma_data |
||||
|
||||
def get_scalings(self, sigma): |
||||
c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) |
||||
c_out = sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 |
||||
c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 |
||||
return c_skip, c_out, c_in |
||||
|
||||
def loss(self, input, noise, sigma, **kwargs): |
||||
c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)] |
||||
noised_input = input + noise * utils.append_dims(sigma, input.ndim) |
||||
model_output = self.inner_model(noised_input * c_in, sigma, **kwargs) |
||||
target = (input - c_skip * noised_input) / c_out |
||||
return (model_output - target).pow(2).flatten(1).mean(1) |
||||
|
||||
def forward(self, input, sigma, **kwargs): |
||||
c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)] |
||||
return self.inner_model(input * c_in, sigma, **kwargs) * c_out + input * c_skip |
||||
|
||||
|
||||
class DenoiserWithVariance(Denoiser): |
||||
def loss(self, input, noise, sigma, **kwargs): |
||||
c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)] |
||||
noised_input = input + noise * utils.append_dims(sigma, input.ndim) |
||||
model_output, logvar = self.inner_model(noised_input * c_in, sigma, return_variance=True, **kwargs) |
||||
logvar = utils.append_dims(logvar, model_output.ndim) |
||||
target = (input - c_skip * noised_input) / c_out |
||||
losses = ((model_output - target) ** 2 / logvar.exp() + logvar) / 2 |
||||
return losses.flatten(1).mean(1) |
||||
|
||||
|
||||
# Residual blocks |
||||
|
||||
class ResidualBlock(nn.Module): |
||||
def __init__(self, *main, skip=None): |
||||
super().__init__() |
||||
self.main = nn.Sequential(*main) |
||||
self.skip = skip if skip else nn.Identity() |
||||
|
||||
def forward(self, input): |
||||
return self.main(input) + self.skip(input) |
||||
|
||||
|
||||
# Noise level (and other) conditioning |
||||
|
||||
class ConditionedModule(nn.Module): |
||||
pass |
||||
|
||||
|
||||
class UnconditionedModule(ConditionedModule): |
||||
def __init__(self, module): |
||||
super().__init__() |
||||
self.module = module |
||||
|
||||
def forward(self, input, cond=None): |
||||
return self.module(input) |
||||
|
||||
|
||||
class ConditionedSequential(nn.Sequential, ConditionedModule): |
||||
def forward(self, input, cond): |
||||
for module in self: |
||||
if isinstance(module, ConditionedModule): |
||||
input = module(input, cond) |
||||
else: |
||||
input = module(input) |
||||
return input |
||||
|
||||
|
||||
class ConditionedResidualBlock(ConditionedModule): |
||||
def __init__(self, *main, skip=None): |
||||
super().__init__() |
||||
self.main = ConditionedSequential(*main) |
||||
self.skip = skip if skip else nn.Identity() |
||||
|
||||
def forward(self, input, cond): |
||||
skip = self.skip(input, cond) if isinstance(self.skip, ConditionedModule) else self.skip(input) |
||||
return self.main(input, cond) + skip |
||||
|
||||
|
||||
class AdaGN(ConditionedModule): |
||||
def __init__(self, feats_in, c_out, num_groups, eps=1e-5, cond_key='cond'): |
||||
super().__init__() |
||||
self.num_groups = num_groups |
||||
self.eps = eps |
||||
self.cond_key = cond_key |
||||
self.mapper = nn.Linear(feats_in, c_out * 2) |
||||
|
||||
def forward(self, input, cond): |
||||
weight, bias = self.mapper(cond[self.cond_key]).chunk(2, dim=-1) |
||||
input = F.group_norm(input, self.num_groups, eps=self.eps) |
||||
return torch.addcmul(utils.append_dims(bias, input.ndim), input, utils.append_dims(weight, input.ndim) + 1) |
||||
|
||||
|
||||
# Attention |
||||
|
||||
class SelfAttention2d(ConditionedModule): |
||||
def __init__(self, c_in, n_head, norm, dropout_rate=0.): |
||||
super().__init__() |
||||
assert c_in % n_head == 0 |
||||
self.norm_in = norm(c_in) |
||||
self.n_head = n_head |
||||
self.qkv_proj = nn.Conv2d(c_in, c_in * 3, 1) |
||||
self.out_proj = nn.Conv2d(c_in, c_in, 1) |
||||
self.dropout = nn.Dropout(dropout_rate) |
||||
|
||||
def forward(self, input, cond): |
||||
n, c, h, w = input.shape |
||||
qkv = self.qkv_proj(self.norm_in(input, cond)) |
||||
qkv = qkv.view([n, self.n_head * 3, c // self.n_head, h * w]).transpose(2, 3) |
||||
q, k, v = qkv.chunk(3, dim=1) |
||||
scale = k.shape[3] ** -0.25 |
||||
att = ((q * scale) @ (k.transpose(2, 3) * scale)).softmax(3) |
||||
att = self.dropout(att) |
||||
y = (att @ v).transpose(2, 3).contiguous().view([n, c, h, w]) |
||||
return input + self.out_proj(y) |
||||
|
||||
|
||||
class CrossAttention2d(ConditionedModule): |
||||
def __init__(self, c_dec, c_enc, n_head, norm_dec, dropout_rate=0., |
||||
cond_key='cross', cond_key_padding='cross_padding'): |
||||
super().__init__() |
||||
assert c_dec % n_head == 0 |
||||
self.cond_key = cond_key |
||||
self.cond_key_padding = cond_key_padding |
||||
self.norm_enc = nn.LayerNorm(c_enc) |
||||
self.norm_dec = norm_dec(c_dec) |
||||
self.n_head = n_head |
||||
self.q_proj = nn.Conv2d(c_dec, c_dec, 1) |
||||
self.kv_proj = nn.Linear(c_enc, c_dec * 2) |
||||
self.out_proj = nn.Conv2d(c_dec, c_dec, 1) |
||||
self.dropout = nn.Dropout(dropout_rate) |
||||
|
||||
def forward(self, input, cond): |
||||
n, c, h, w = input.shape |
||||
q = self.q_proj(self.norm_dec(input, cond)) |
||||
q = q.view([n, self.n_head, c // self.n_head, h * w]).transpose(2, 3) |
||||
kv = self.kv_proj(self.norm_enc(cond[self.cond_key])) |
||||
kv = kv.view([n, -1, self.n_head * 2, c // self.n_head]).transpose(1, 2) |
||||
k, v = kv.chunk(2, dim=1) |
||||
scale = k.shape[3] ** -0.25 |
||||
att = ((q * scale) @ (k.transpose(2, 3) * scale)) |
||||
att = att - (cond[self.cond_key_padding][:, None, None, :]) * 10000 |
||||
att = att.softmax(3) |
||||
att = self.dropout(att) |
||||
y = (att @ v).transpose(2, 3) |
||||
y = y.contiguous().view([n, c, h, w]) |
||||
return input + self.out_proj(y) |
||||
|
||||
|
||||
# Downsampling/upsampling |
||||
|
||||
_kernels = { |
||||
'linear': |
||||
[1 / 8, 3 / 8, 3 / 8, 1 / 8], |
||||
'cubic': |
||||
[-0.01171875, -0.03515625, 0.11328125, 0.43359375, |
||||
0.43359375, 0.11328125, -0.03515625, -0.01171875], |
||||
'lanczos3': |
||||
[0.003689131001010537, 0.015056144446134567, -0.03399861603975296, |
||||
-0.066637322306633, 0.13550527393817902, 0.44638532400131226, |
||||
0.44638532400131226, 0.13550527393817902, -0.066637322306633, |
||||
-0.03399861603975296, 0.015056144446134567, 0.003689131001010537] |
||||
} |
||||
_kernels['bilinear'] = _kernels['linear'] |
||||
_kernels['bicubic'] = _kernels['cubic'] |
||||
|
||||
|
||||
class Downsample2d(nn.Module): |
||||
def __init__(self, kernel='linear', pad_mode='reflect'): |
||||
super().__init__() |
||||
self.pad_mode = pad_mode |
||||
kernel_1d = torch.tensor([_kernels[kernel]]) |
||||
self.pad = kernel_1d.shape[1] // 2 - 1 |
||||
self.register_buffer('kernel', kernel_1d.T @ kernel_1d) |
||||
|
||||
def forward(self, x): |
||||
x = F.pad(x, (self.pad,) * 4, self.pad_mode) |
||||
weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0], self.kernel.shape[1]]) |
||||
indices = torch.arange(x.shape[1], device=x.device) |
||||
weight[indices, indices] = self.kernel.to(weight) |
||||
return F.conv2d(x, weight, stride=2) |
||||
|
||||
|
||||
class Upsample2d(nn.Module): |
||||
def __init__(self, kernel='linear', pad_mode='reflect'): |
||||
super().__init__() |
||||
self.pad_mode = pad_mode |
||||
kernel_1d = torch.tensor([_kernels[kernel]]) * 2 |
||||
self.pad = kernel_1d.shape[1] // 2 - 1 |
||||
self.register_buffer('kernel', kernel_1d.T @ kernel_1d) |
||||
|
||||
def forward(self, x): |
||||
x = F.pad(x, ((self.pad + 1) // 2,) * 4, self.pad_mode) |
||||
weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0], self.kernel.shape[1]]) |
||||
indices = torch.arange(x.shape[1], device=x.device) |
||||
weight[indices, indices] = self.kernel.to(weight) |
||||
return F.conv_transpose2d(x, weight, stride=2, padding=self.pad * 2 + 1) |
||||
|
||||
|
||||
# Embeddings |
||||
|
||||
class FourierFeatures(nn.Module): |
||||
def __init__(self, in_features, out_features, std=1.): |
||||
super().__init__() |
||||
assert out_features % 2 == 0 |
||||
self.register_buffer('weight', torch.randn([out_features // 2, in_features]) * std) |
||||
|
||||
def forward(self, input): |
||||
f = 2 * math.pi * input @ self.weight.T |
||||
return torch.cat([f.cos(), f.sin()], dim=-1) |
||||
|
||||
|
||||
# U-Nets |
||||
|
||||
class UNet(ConditionedModule): |
||||
def __init__(self, d_blocks, u_blocks, skip_stages=0): |
||||
super().__init__() |
||||
self.d_blocks = nn.ModuleList(d_blocks) |
||||
self.u_blocks = nn.ModuleList(u_blocks) |
||||
self.skip_stages = skip_stages |
||||
|
||||
def forward(self, input, cond): |
||||
skips = [] |
||||
for block in self.d_blocks[self.skip_stages:]: |
||||
input = block(input, cond) |
||||
skips.append(input) |
||||
for i, (block, skip) in enumerate(zip(self.u_blocks, reversed(skips))): |
||||
input = block(input, cond, skip if i > 0 else None) |
||||
return input |
@ -1 +0,0 @@
|
||||
from .image_v1 import ImageDenoiserModelV1 |
@ -1,156 +0,0 @@
|
||||
import math |
||||
|
||||
import torch |
||||
from torch import nn |
||||
from torch.nn import functional as F |
||||
|
||||
from .. import layers, utils |
||||
|
||||
|
||||
def orthogonal_(module): |
||||
nn.init.orthogonal_(module.weight) |
||||
return module |
||||
|
||||
|
||||
class ResConvBlock(layers.ConditionedResidualBlock): |
||||
def __init__(self, feats_in, c_in, c_mid, c_out, group_size=32, dropout_rate=0.): |
||||
skip = None if c_in == c_out else orthogonal_(nn.Conv2d(c_in, c_out, 1, bias=False)) |
||||
super().__init__( |
||||
layers.AdaGN(feats_in, c_in, max(1, c_in // group_size)), |
||||
nn.GELU(), |
||||
nn.Conv2d(c_in, c_mid, 3, padding=1), |
||||
nn.Dropout2d(dropout_rate, inplace=True), |
||||
layers.AdaGN(feats_in, c_mid, max(1, c_mid // group_size)), |
||||
nn.GELU(), |
||||
nn.Conv2d(c_mid, c_out, 3, padding=1), |
||||
nn.Dropout2d(dropout_rate, inplace=True), |
||||
skip=skip) |
||||
|
||||
|
||||
class DBlock(layers.ConditionedSequential): |
||||
def __init__(self, n_layers, feats_in, c_in, c_mid, c_out, group_size=32, head_size=64, dropout_rate=0., downsample=False, self_attn=False, cross_attn=False, c_enc=0): |
||||
modules = [nn.Identity()] |
||||
for i in range(n_layers): |
||||
my_c_in = c_in if i == 0 else c_mid |
||||
my_c_out = c_mid if i < n_layers - 1 else c_out |
||||
modules.append(ResConvBlock(feats_in, my_c_in, c_mid, my_c_out, group_size, dropout_rate)) |
||||
if self_attn: |
||||
norm = lambda c_in: layers.AdaGN(feats_in, c_in, max(1, my_c_out // group_size)) |
||||
modules.append(layers.SelfAttention2d(my_c_out, max(1, my_c_out // head_size), norm, dropout_rate)) |
||||
if cross_attn: |
||||
norm = lambda c_in: layers.AdaGN(feats_in, c_in, max(1, my_c_out // group_size)) |
||||
modules.append(layers.CrossAttention2d(my_c_out, c_enc, max(1, my_c_out // head_size), norm, dropout_rate)) |
||||
super().__init__(*modules) |
||||
self.set_downsample(downsample) |
||||
|
||||
def set_downsample(self, downsample): |
||||
self[0] = layers.Downsample2d() if downsample else nn.Identity() |
||||
return self |
||||
|
||||
|
||||
class UBlock(layers.ConditionedSequential): |
||||
def __init__(self, n_layers, feats_in, c_in, c_mid, c_out, group_size=32, head_size=64, dropout_rate=0., upsample=False, self_attn=False, cross_attn=False, c_enc=0): |
||||
modules = [] |
||||
for i in range(n_layers): |
||||
my_c_in = c_in if i == 0 else c_mid |
||||
my_c_out = c_mid if i < n_layers - 1 else c_out |
||||
modules.append(ResConvBlock(feats_in, my_c_in, c_mid, my_c_out, group_size, dropout_rate)) |
||||
if self_attn: |
||||
norm = lambda c_in: layers.AdaGN(feats_in, c_in, max(1, my_c_out // group_size)) |
||||
modules.append(layers.SelfAttention2d(my_c_out, max(1, my_c_out // head_size), norm, dropout_rate)) |
||||
if cross_attn: |
||||
norm = lambda c_in: layers.AdaGN(feats_in, c_in, max(1, my_c_out // group_size)) |
||||
modules.append(layers.CrossAttention2d(my_c_out, c_enc, max(1, my_c_out // head_size), norm, dropout_rate)) |
||||
modules.append(nn.Identity()) |
||||
super().__init__(*modules) |
||||
self.set_upsample(upsample) |
||||
|
||||
def forward(self, input, cond, skip=None): |
||||
if skip is not None: |
||||
input = torch.cat([input, skip], dim=1) |
||||
return super().forward(input, cond) |
||||
|
||||
def set_upsample(self, upsample): |
||||
self[-1] = layers.Upsample2d() if upsample else nn.Identity() |
||||
return self |
||||
|
||||
|
||||
class MappingNet(nn.Sequential): |
||||
def __init__(self, feats_in, feats_out, n_layers=2): |
||||
layers = [] |
||||
for i in range(n_layers): |
||||
layers.append(orthogonal_(nn.Linear(feats_in if i == 0 else feats_out, feats_out))) |
||||
layers.append(nn.GELU()) |
||||
super().__init__(*layers) |
||||
|
||||
|
||||
class ImageDenoiserModelV1(nn.Module): |
||||
def __init__(self, c_in, feats_in, depths, channels, self_attn_depths, cross_attn_depths=None, mapping_cond_dim=0, unet_cond_dim=0, cross_cond_dim=0, dropout_rate=0., patch_size=1, skip_stages=0, has_variance=False): |
||||
super().__init__() |
||||
self.c_in = c_in |
||||
self.channels = channels |
||||
self.unet_cond_dim = unet_cond_dim |
||||
self.patch_size = patch_size |
||||
self.has_variance = has_variance |
||||
self.timestep_embed = layers.FourierFeatures(1, feats_in) |
||||
if mapping_cond_dim > 0: |
||||
self.mapping_cond = nn.Linear(mapping_cond_dim, feats_in, bias=False) |
||||
self.mapping = MappingNet(feats_in, feats_in) |
||||
self.proj_in = nn.Conv2d((c_in + unet_cond_dim) * self.patch_size ** 2, channels[max(0, skip_stages - 1)], 1) |
||||
self.proj_out = nn.Conv2d(channels[max(0, skip_stages - 1)], c_in * self.patch_size ** 2 + (1 if self.has_variance else 0), 1) |
||||
nn.init.zeros_(self.proj_out.weight) |
||||
nn.init.zeros_(self.proj_out.bias) |
||||
if cross_cond_dim == 0: |
||||
cross_attn_depths = [False] * len(self_attn_depths) |
||||
d_blocks, u_blocks = [], [] |
||||
for i in range(len(depths)): |
||||
my_c_in = channels[max(0, i - 1)] |
||||
d_blocks.append(DBlock(depths[i], feats_in, my_c_in, channels[i], channels[i], downsample=i > skip_stages, self_attn=self_attn_depths[i], cross_attn=cross_attn_depths[i], c_enc=cross_cond_dim, dropout_rate=dropout_rate)) |
||||
for i in range(len(depths)): |
||||
my_c_in = channels[i] * 2 if i < len(depths) - 1 else channels[i] |
||||
my_c_out = channels[max(0, i - 1)] |
||||
u_blocks.append(UBlock(depths[i], feats_in, my_c_in, channels[i], my_c_out, upsample=i > skip_stages, self_attn=self_attn_depths[i], cross_attn=cross_attn_depths[i], c_enc=cross_cond_dim, dropout_rate=dropout_rate)) |
||||
self.u_net = layers.UNet(d_blocks, reversed(u_blocks), skip_stages=skip_stages) |
||||
|
||||
def forward(self, input, sigma, mapping_cond=None, unet_cond=None, cross_cond=None, cross_cond_padding=None, return_variance=False): |
||||
c_noise = sigma.log() / 4 |
||||
timestep_embed = self.timestep_embed(utils.append_dims(c_noise, 2)) |
||||
mapping_cond_embed = torch.zeros_like(timestep_embed) if mapping_cond is None else self.mapping_cond(mapping_cond) |
||||
mapping_out = self.mapping(timestep_embed + mapping_cond_embed) |
||||
cond = {'cond': mapping_out} |
||||
if unet_cond is not None: |
||||
input = torch.cat([input, unet_cond], dim=1) |
||||
if cross_cond is not None: |
||||
cond['cross'] = cross_cond |
||||
cond['cross_padding'] = cross_cond_padding |
||||
if self.patch_size > 1: |
||||
input = F.pixel_unshuffle(input, self.patch_size) |
||||
input = self.proj_in(input) |
||||
input = self.u_net(input, cond) |
||||
input = self.proj_out(input) |
||||
if self.has_variance: |
||||
input, logvar = input[:, :-1], input[:, -1].flatten(1).mean(1) |
||||
if self.patch_size > 1: |
||||
input = F.pixel_shuffle(input, self.patch_size) |
||||
if self.has_variance and return_variance: |
||||
return input, logvar |
||||
return input |
||||
|
||||
def set_skip_stages(self, skip_stages): |
||||
self.proj_in = nn.Conv2d(self.proj_in.in_channels, self.channels[max(0, skip_stages - 1)], 1) |
||||
self.proj_out = nn.Conv2d(self.channels[max(0, skip_stages - 1)], self.proj_out.out_channels, 1) |
||||
nn.init.zeros_(self.proj_out.weight) |
||||
nn.init.zeros_(self.proj_out.bias) |
||||
self.u_net.skip_stages = skip_stages |
||||
for i, block in enumerate(self.u_net.d_blocks): |
||||
block.set_downsample(i > skip_stages) |
||||
for i, block in enumerate(reversed(self.u_net.u_blocks)): |
||||
block.set_upsample(i > skip_stages) |
||||
return self |
||||
|
||||
def set_patch_size(self, patch_size): |
||||
self.patch_size = patch_size |
||||
self.proj_in = nn.Conv2d((self.c_in + self.unet_cond_dim) * self.patch_size ** 2, self.channels[max(0, self.u_net.skip_stages - 1)], 1) |
||||
self.proj_out = nn.Conv2d(self.channels[max(0, self.u_net.skip_stages - 1)], self.c_in * self.patch_size ** 2 + (1 if self.has_variance else 0), 1) |
||||
nn.init.zeros_(self.proj_out.weight) |
||||
nn.init.zeros_(self.proj_out.bias) |
Loading…
Reference in new issue