|
|
|
@ -846,6 +846,7 @@ class SwinIR(nn.Module):
|
|
|
|
|
num_in_ch = in_chans |
|
|
|
|
num_out_ch = in_chans |
|
|
|
|
supports_fp16 = True |
|
|
|
|
self.start_unshuffle = 1 |
|
|
|
|
|
|
|
|
|
self.model_arch = "SwinIR" |
|
|
|
|
self.sub_type = "SR" |
|
|
|
@ -874,6 +875,11 @@ class SwinIR(nn.Module):
|
|
|
|
|
else 64 |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
if "conv_first.1.weight" in self.state: |
|
|
|
|
self.state["conv_first.weight"] = self.state.pop("conv_first.1.weight") |
|
|
|
|
self.state["conv_first.bias"] = self.state.pop("conv_first.1.bias") |
|
|
|
|
self.start_unshuffle = round(math.sqrt(self.state["conv_first.weight"].shape[1] // 3)) |
|
|
|
|
|
|
|
|
|
num_in_ch = self.state["conv_first.weight"].shape[1] |
|
|
|
|
in_chans = num_in_ch |
|
|
|
|
if "conv_last.weight" in state_keys: |
|
|
|
@ -968,7 +974,7 @@ class SwinIR(nn.Module):
|
|
|
|
|
self.depths = depths |
|
|
|
|
self.window_size = window_size |
|
|
|
|
self.mlp_ratio = mlp_ratio |
|
|
|
|
self.scale = upscale |
|
|
|
|
self.scale = upscale / self.start_unshuffle |
|
|
|
|
self.upsampler = upsampler |
|
|
|
|
self.img_size = img_size |
|
|
|
|
self.img_range = img_range |
|
|
|
@ -1101,6 +1107,9 @@ class SwinIR(nn.Module):
|
|
|
|
|
self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) |
|
|
|
|
if self.upscale == 4: |
|
|
|
|
self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) |
|
|
|
|
elif self.upscale == 8: |
|
|
|
|
self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) |
|
|
|
|
self.conv_up3 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) |
|
|
|
|
self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) |
|
|
|
|
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) |
|
|
|
|
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) |
|
|
|
@ -1157,6 +1166,9 @@ class SwinIR(nn.Module):
|
|
|
|
|
self.mean = self.mean.type_as(x) |
|
|
|
|
x = (x - self.mean) * self.img_range |
|
|
|
|
|
|
|
|
|
if self.start_unshuffle > 1: |
|
|
|
|
x = torch.nn.functional.pixel_unshuffle(x, self.start_unshuffle) |
|
|
|
|
|
|
|
|
|
if self.upsampler == "pixelshuffle": |
|
|
|
|
# for classical SR |
|
|
|
|
x = self.conv_first(x) |
|
|
|
@ -1186,6 +1198,9 @@ class SwinIR(nn.Module):
|
|
|
|
|
) |
|
|
|
|
) |
|
|
|
|
) |
|
|
|
|
elif self.upscale == 8: |
|
|
|
|
x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) |
|
|
|
|
x = self.lrelu(self.conv_up3(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) |
|
|
|
|
x = self.conv_last(self.lrelu(self.conv_hr(x))) |
|
|
|
|
else: |
|
|
|
|
# for image denoising and JPEG compression artifact reduction |
|
|
|
|