diff --git a/comfy/sd.py b/comfy/sd.py index d75bbd9a..5920ddde 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -2,6 +2,7 @@ import torch import contextlib import copy import inspect +import math from comfy import model_management from .ldm.util import instantiate_from_config @@ -1099,6 +1100,12 @@ class T2IAdapter(ControlBase): self.channels_in = channels_in self.control_input = None + def scale_image_to(self, width, height): + unshuffle_amount = self.t2i_model.unshuffle_amount + width = math.ceil(width / unshuffle_amount) * unshuffle_amount + height = math.ceil(height / unshuffle_amount) * unshuffle_amount + return width, height + def get_control(self, x_noisy, t, cond, batched_number): control_prev = None if self.previous_controlnet is not None: @@ -1116,7 +1123,8 @@ class T2IAdapter(ControlBase): del self.cond_hint self.control_input = None self.cond_hint = None - self.cond_hint = utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").float().to(self.device) + width, height = self.scale_image_to(x_noisy.shape[3] * 8, x_noisy.shape[2] * 8) + self.cond_hint = utils.common_upscale(self.cond_hint_original, width, height, 'nearest-exact', "center").float().to(self.device) if self.channels_in == 1 and self.cond_hint.shape[1] > 1: self.cond_hint = torch.mean(self.cond_hint, 1, keepdim=True) if x_noisy.shape[0] != self.cond_hint.shape[0]: diff --git a/comfy/t2i_adapter/adapter.py b/comfy/t2i_adapter/adapter.py index 000cf041..e9a606b1 100644 --- a/comfy/t2i_adapter/adapter.py +++ b/comfy/t2i_adapter/adapter.py @@ -103,17 +103,17 @@ class ResnetBlock(nn.Module): class Adapter(nn.Module): def __init__(self, channels=[320, 640, 1280, 1280], nums_rb=3, cin=64, ksize=3, sk=False, use_conv=True, xl=True): super(Adapter, self).__init__() - unshuffle = 8 + self.unshuffle_amount = 8 resblock_no_downsample = [] resblock_downsample = [3, 2, 1] self.xl = xl if self.xl: - unshuffle = 16 + self.unshuffle_amount = 16 resblock_no_downsample = [1] resblock_downsample = [2] - self.input_channels = cin // (unshuffle * unshuffle) - self.unshuffle = nn.PixelUnshuffle(unshuffle) + self.input_channels = cin // (self.unshuffle_amount * self.unshuffle_amount) + self.unshuffle = nn.PixelUnshuffle(self.unshuffle_amount) self.channels = channels self.nums_rb = nums_rb self.body = [] @@ -264,9 +264,9 @@ class extractor(nn.Module): class Adapter_light(nn.Module): def __init__(self, channels=[320, 640, 1280, 1280], nums_rb=3, cin=64): super(Adapter_light, self).__init__() - unshuffle = 8 - self.unshuffle = nn.PixelUnshuffle(unshuffle) - self.input_channels = cin // (unshuffle * unshuffle) + self.unshuffle_amount = 8 + self.unshuffle = nn.PixelUnshuffle(self.unshuffle_amount) + self.input_channels = cin // (self.unshuffle_amount * self.unshuffle_amount) self.channels = channels self.nums_rb = nums_rb self.body = []