comfyanonymous
1 year ago
2 changed files with 85 additions and 1 deletions
@ -0,0 +1,83 @@
|
||||
#Taken from: https://github.com/tfernd/HyperTile/ |
||||
|
||||
import math |
||||
from einops import rearrange |
||||
import random |
||||
|
||||
def random_divisor(value: int, min_value: int, /, max_options: int = 1, counter = 0) -> int: |
||||
min_value = min(min_value, value) |
||||
|
||||
# All big divisors of value (inclusive) |
||||
divisors = [i for i in range(min_value, value + 1) if value % i == 0] |
||||
|
||||
ns = [value // i for i in divisors[:max_options]] # has at least 1 element |
||||
|
||||
random.seed(counter) |
||||
idx = random.randint(0, len(ns) - 1) |
||||
|
||||
return ns[idx] |
||||
|
||||
class HyperTile: |
||||
@classmethod |
||||
def INPUT_TYPES(s): |
||||
return {"required": { "model": ("MODEL",), |
||||
"tile_size": ("INT", {"default": 256, "min": 1, "max": 2048}), |
||||
"swap_size": ("INT", {"default": 2, "min": 1, "max": 128}), |
||||
"max_depth": ("INT", {"default": 0, "min": 0, "max": 10}), |
||||
"scale_depth": ("BOOLEAN", {"default": False}), |
||||
}} |
||||
RETURN_TYPES = ("MODEL",) |
||||
FUNCTION = "patch" |
||||
|
||||
CATEGORY = "_for_testing" |
||||
|
||||
def patch(self, model, tile_size, swap_size, max_depth, scale_depth): |
||||
model_channels = model.model.model_config.unet_config["model_channels"] |
||||
|
||||
apply_to = set() |
||||
temp = model_channels |
||||
for x in range(max_depth + 1): |
||||
apply_to.add(temp) |
||||
temp *= 2 |
||||
|
||||
latent_tile_size = max(32, tile_size) // 8 |
||||
self.temp = None |
||||
self.counter = 1 |
||||
|
||||
def hypertile_in(q, k, v, extra_options): |
||||
if q.shape[-1] in apply_to: |
||||
shape = extra_options["original_shape"] |
||||
aspect_ratio = shape[-1] / shape[-2] |
||||
|
||||
hw = q.size(1) |
||||
h, w = round(math.sqrt(hw * aspect_ratio)), round(math.sqrt(hw / aspect_ratio)) |
||||
|
||||
factor = 2**((q.shape[-1] // model_channels) - 1) if scale_depth else 1 |
||||
nh = random_divisor(h, latent_tile_size * factor, swap_size, self.counter) |
||||
self.counter += 1 |
||||
nw = random_divisor(w, latent_tile_size * factor, swap_size, self.counter) |
||||
self.counter += 1 |
||||
|
||||
if nh * nw > 1: |
||||
q = rearrange(q, "b (nh h nw w) c -> (b nh nw) (h w) c", h=h // nh, w=w // nw, nh=nh, nw=nw) |
||||
self.temp = (nh, nw, h, w) |
||||
return q, k, v |
||||
|
||||
return q, k, v |
||||
def hypertile_out(out, extra_options): |
||||
if self.temp is not None: |
||||
nh, nw, h, w = self.temp |
||||
self.temp = None |
||||
out = rearrange(out, "(b nh nw) hw c -> b nh nw hw c", nh=nh, nw=nw) |
||||
out = rearrange(out, "b nh nw (h w) c -> b (nh h nw w) c", h=h // nh, w=w // nw) |
||||
return out |
||||
|
||||
|
||||
m = model.clone() |
||||
m.set_model_attn1_patch(hypertile_in) |
||||
m.set_model_attn1_output_patch(hypertile_out) |
||||
return (m, ) |
||||
|
||||
NODE_CLASS_MAPPINGS = { |
||||
"HyperTile": HyperTile, |
||||
} |
Loading…
Reference in new issue