|
|
@ -37,24 +37,24 @@ class HyperTile: |
|
|
|
def patch(self, model, tile_size, swap_size, max_depth, scale_depth): |
|
|
|
def patch(self, model, tile_size, swap_size, max_depth, scale_depth): |
|
|
|
model_channels = model.model.model_config.unet_config["model_channels"] |
|
|
|
model_channels = model.model.model_config.unet_config["model_channels"] |
|
|
|
|
|
|
|
|
|
|
|
apply_to = set() |
|
|
|
|
|
|
|
temp = model_channels |
|
|
|
|
|
|
|
for x in range(max_depth + 1): |
|
|
|
|
|
|
|
apply_to.add(temp) |
|
|
|
|
|
|
|
temp *= 2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
latent_tile_size = max(32, tile_size) // 8 |
|
|
|
latent_tile_size = max(32, tile_size) // 8 |
|
|
|
self.temp = None |
|
|
|
self.temp = None |
|
|
|
|
|
|
|
|
|
|
|
def hypertile_in(q, k, v, extra_options): |
|
|
|
def hypertile_in(q, k, v, extra_options): |
|
|
|
if q.shape[-1] in apply_to: |
|
|
|
model_chans = q.shape[-2] |
|
|
|
|
|
|
|
orig_shape = extra_options['original_shape'] |
|
|
|
|
|
|
|
apply_to = [] |
|
|
|
|
|
|
|
for i in range(max_depth + 1): |
|
|
|
|
|
|
|
apply_to.append((orig_shape[-2] / (2 ** i)) * (orig_shape[-1] / (2 ** i))) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if model_chans in apply_to: |
|
|
|
shape = extra_options["original_shape"] |
|
|
|
shape = extra_options["original_shape"] |
|
|
|
aspect_ratio = shape[-1] / shape[-2] |
|
|
|
aspect_ratio = shape[-1] / shape[-2] |
|
|
|
|
|
|
|
|
|
|
|
hw = q.size(1) |
|
|
|
hw = q.size(1) |
|
|
|
h, w = round(math.sqrt(hw * aspect_ratio)), round(math.sqrt(hw / aspect_ratio)) |
|
|
|
h, w = round(math.sqrt(hw * aspect_ratio)), round(math.sqrt(hw / aspect_ratio)) |
|
|
|
|
|
|
|
|
|
|
|
factor = 2**((q.shape[-1] // model_channels) - 1) if scale_depth else 1 |
|
|
|
factor = (2 ** apply_to.index(model_chans)) if scale_depth else 1 |
|
|
|
nh = random_divisor(h, latent_tile_size * factor, swap_size) |
|
|
|
nh = random_divisor(h, latent_tile_size * factor, swap_size) |
|
|
|
nw = random_divisor(w, latent_tile_size * factor, swap_size) |
|
|
|
nw = random_divisor(w, latent_tile_size * factor, swap_size) |
|
|
|
|
|
|
|
|
|
|
|