From 8be46438be1c848e01e4085f54ae997e2e918771 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 7 Sep 2023 03:31:43 -0400 Subject: [PATCH] Support DiffBIR SwinIR models. --- .../chainner_models/architecture/SwinIR.py | 17 ++++++++++++++++- comfy_extras/nodes_upscale_model.py | 2 ++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/comfy_extras/chainner_models/architecture/SwinIR.py b/comfy_extras/chainner_models/architecture/SwinIR.py index 1abf450b..439dcbcb 100644 --- a/comfy_extras/chainner_models/architecture/SwinIR.py +++ b/comfy_extras/chainner_models/architecture/SwinIR.py @@ -846,6 +846,7 @@ class SwinIR(nn.Module): num_in_ch = in_chans num_out_ch = in_chans supports_fp16 = True + self.start_unshuffle = 1 self.model_arch = "SwinIR" self.sub_type = "SR" @@ -874,6 +875,11 @@ class SwinIR(nn.Module): else 64 ) + if "conv_first.1.weight" in self.state: + self.state["conv_first.weight"] = self.state.pop("conv_first.1.weight") + self.state["conv_first.bias"] = self.state.pop("conv_first.1.bias") + self.start_unshuffle = round(math.sqrt(self.state["conv_first.weight"].shape[1] // 3)) + num_in_ch = self.state["conv_first.weight"].shape[1] in_chans = num_in_ch if "conv_last.weight" in state_keys: @@ -968,7 +974,7 @@ class SwinIR(nn.Module): self.depths = depths self.window_size = window_size self.mlp_ratio = mlp_ratio - self.scale = upscale + self.scale = upscale / self.start_unshuffle self.upsampler = upsampler self.img_size = img_size self.img_range = img_range @@ -1101,6 +1107,9 @@ class SwinIR(nn.Module): self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) if self.upscale == 4: self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + elif self.upscale == 8: + self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_up3 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) @@ -1157,6 +1166,9 @@ class SwinIR(nn.Module): self.mean = self.mean.type_as(x) x = (x - self.mean) * self.img_range + if self.start_unshuffle > 1: + x = torch.nn.functional.pixel_unshuffle(x, self.start_unshuffle) + if self.upsampler == "pixelshuffle": # for classical SR x = self.conv_first(x) @@ -1186,6 +1198,9 @@ class SwinIR(nn.Module): ) ) ) + elif self.upscale == 8: + x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) + x = self.lrelu(self.conv_up3(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) x = self.conv_last(self.lrelu(self.conv_hr(x))) else: # for image denoising and JPEG compression artifact reduction diff --git a/comfy_extras/nodes_upscale_model.py b/comfy_extras/nodes_upscale_model.py index abd182e6..2b5e49a5 100644 --- a/comfy_extras/nodes_upscale_model.py +++ b/comfy_extras/nodes_upscale_model.py @@ -18,6 +18,8 @@ class UpscaleModelLoader: def load_model(self, model_name): model_path = folder_paths.get_full_path("upscale_models", model_name) sd = comfy.utils.load_torch_file(model_path, safe_load=True) + if "module.layers.0.residual_group.blocks.0.norm1.weight" in sd: + sd = comfy.utils.state_dict_prefix_replace(sd, {"module.":""}) out = model_loading.load_state_dict(sd).eval() return (out, )