You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
455 lines
14 KiB
455 lines
14 KiB
# pylint: skip-file |
|
# ----------------------------------------------------------------------------------- |
|
# SCUNet: Practical Blind Denoising via Swin-Conv-UNet and Data Synthesis, https://arxiv.org/abs/2203.13278 |
|
# Zhang, Kai and Li, Yawei and Liang, Jingyun and Cao, Jiezhang and Zhang, Yulun and Tang, Hao and Timofte, Radu and Van Gool, Luc |
|
# ----------------------------------------------------------------------------------- |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from einops import rearrange |
|
from einops.layers.torch import Rearrange |
|
|
|
from .timm.drop import DropPath |
|
from .timm.weight_init import trunc_normal_ |
|
|
|
|
|
# Borrowed from https://github.com/cszn/SCUNet/blob/main/models/network_scunet.py |
|
class WMSA(nn.Module): |
|
"""Self-attention module in Swin Transformer""" |
|
|
|
def __init__(self, input_dim, output_dim, head_dim, window_size, type): |
|
super(WMSA, self).__init__() |
|
self.input_dim = input_dim |
|
self.output_dim = output_dim |
|
self.head_dim = head_dim |
|
self.scale = self.head_dim**-0.5 |
|
self.n_heads = input_dim // head_dim |
|
self.window_size = window_size |
|
self.type = type |
|
self.embedding_layer = nn.Linear(self.input_dim, 3 * self.input_dim, bias=True) |
|
|
|
self.relative_position_params = nn.Parameter( |
|
torch.zeros((2 * window_size - 1) * (2 * window_size - 1), self.n_heads) |
|
) |
|
# TODO recover |
|
# self.relative_position_params = nn.Parameter(torch.zeros(self.n_heads, 2 * window_size - 1, 2 * window_size -1)) |
|
self.relative_position_params = nn.Parameter( |
|
torch.zeros((2 * window_size - 1) * (2 * window_size - 1), self.n_heads) |
|
) |
|
|
|
self.linear = nn.Linear(self.input_dim, self.output_dim) |
|
|
|
trunc_normal_(self.relative_position_params, std=0.02) |
|
self.relative_position_params = torch.nn.Parameter( |
|
self.relative_position_params.view( |
|
2 * window_size - 1, 2 * window_size - 1, self.n_heads |
|
) |
|
.transpose(1, 2) |
|
.transpose(0, 1) |
|
) |
|
|
|
def generate_mask(self, h, w, p, shift): |
|
"""generating the mask of SW-MSA |
|
Args: |
|
shift: shift parameters in CyclicShift. |
|
Returns: |
|
attn_mask: should be (1 1 w p p), |
|
""" |
|
# supporting square. |
|
attn_mask = torch.zeros( |
|
h, |
|
w, |
|
p, |
|
p, |
|
p, |
|
p, |
|
dtype=torch.bool, |
|
device=self.relative_position_params.device, |
|
) |
|
if self.type == "W": |
|
return attn_mask |
|
|
|
s = p - shift |
|
attn_mask[-1, :, :s, :, s:, :] = True |
|
attn_mask[-1, :, s:, :, :s, :] = True |
|
attn_mask[:, -1, :, :s, :, s:] = True |
|
attn_mask[:, -1, :, s:, :, :s] = True |
|
attn_mask = rearrange( |
|
attn_mask, "w1 w2 p1 p2 p3 p4 -> 1 1 (w1 w2) (p1 p2) (p3 p4)" |
|
) |
|
return attn_mask |
|
|
|
def forward(self, x): |
|
"""Forward pass of Window Multi-head Self-attention module. |
|
Args: |
|
x: input tensor with shape of [b h w c]; |
|
attn_mask: attention mask, fill -inf where the value is True; |
|
Returns: |
|
output: tensor shape [b h w c] |
|
""" |
|
if self.type != "W": |
|
x = torch.roll( |
|
x, |
|
shifts=(-(self.window_size // 2), -(self.window_size // 2)), |
|
dims=(1, 2), |
|
) |
|
|
|
x = rearrange( |
|
x, |
|
"b (w1 p1) (w2 p2) c -> b w1 w2 p1 p2 c", |
|
p1=self.window_size, |
|
p2=self.window_size, |
|
) |
|
h_windows = x.size(1) |
|
w_windows = x.size(2) |
|
# square validation |
|
# assert h_windows == w_windows |
|
|
|
x = rearrange( |
|
x, |
|
"b w1 w2 p1 p2 c -> b (w1 w2) (p1 p2) c", |
|
p1=self.window_size, |
|
p2=self.window_size, |
|
) |
|
qkv = self.embedding_layer(x) |
|
q, k, v = rearrange( |
|
qkv, "b nw np (threeh c) -> threeh b nw np c", c=self.head_dim |
|
).chunk(3, dim=0) |
|
sim = torch.einsum("hbwpc,hbwqc->hbwpq", q, k) * self.scale |
|
# Adding learnable relative embedding |
|
sim = sim + rearrange(self.relative_embedding(), "h p q -> h 1 1 p q") |
|
# Using Attn Mask to distinguish different subwindows. |
|
if self.type != "W": |
|
attn_mask = self.generate_mask( |
|
h_windows, w_windows, self.window_size, shift=self.window_size // 2 |
|
) |
|
sim = sim.masked_fill_(attn_mask, float("-inf")) |
|
|
|
probs = nn.functional.softmax(sim, dim=-1) |
|
output = torch.einsum("hbwij,hbwjc->hbwic", probs, v) |
|
output = rearrange(output, "h b w p c -> b w p (h c)") |
|
output = self.linear(output) |
|
output = rearrange( |
|
output, |
|
"b (w1 w2) (p1 p2) c -> b (w1 p1) (w2 p2) c", |
|
w1=h_windows, |
|
p1=self.window_size, |
|
) |
|
|
|
if self.type != "W": |
|
output = torch.roll( |
|
output, |
|
shifts=(self.window_size // 2, self.window_size // 2), |
|
dims=(1, 2), |
|
) |
|
|
|
return output |
|
|
|
def relative_embedding(self): |
|
cord = torch.tensor( |
|
np.array( |
|
[ |
|
[i, j] |
|
for i in range(self.window_size) |
|
for j in range(self.window_size) |
|
] |
|
) |
|
) |
|
relation = cord[:, None, :] - cord[None, :, :] + self.window_size - 1 |
|
# negative is allowed |
|
return self.relative_position_params[ |
|
:, relation[:, :, 0].long(), relation[:, :, 1].long() |
|
] |
|
|
|
|
|
class Block(nn.Module): |
|
def __init__( |
|
self, |
|
input_dim, |
|
output_dim, |
|
head_dim, |
|
window_size, |
|
drop_path, |
|
type="W", |
|
input_resolution=None, |
|
): |
|
"""SwinTransformer Block""" |
|
super(Block, self).__init__() |
|
self.input_dim = input_dim |
|
self.output_dim = output_dim |
|
assert type in ["W", "SW"] |
|
self.type = type |
|
if input_resolution <= window_size: |
|
self.type = "W" |
|
|
|
self.ln1 = nn.LayerNorm(input_dim) |
|
self.msa = WMSA(input_dim, input_dim, head_dim, window_size, self.type) |
|
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() |
|
self.ln2 = nn.LayerNorm(input_dim) |
|
self.mlp = nn.Sequential( |
|
nn.Linear(input_dim, 4 * input_dim), |
|
nn.GELU(), |
|
nn.Linear(4 * input_dim, output_dim), |
|
) |
|
|
|
def forward(self, x): |
|
x = x + self.drop_path(self.msa(self.ln1(x))) |
|
x = x + self.drop_path(self.mlp(self.ln2(x))) |
|
return x |
|
|
|
|
|
class ConvTransBlock(nn.Module): |
|
def __init__( |
|
self, |
|
conv_dim, |
|
trans_dim, |
|
head_dim, |
|
window_size, |
|
drop_path, |
|
type="W", |
|
input_resolution=None, |
|
): |
|
"""SwinTransformer and Conv Block""" |
|
super(ConvTransBlock, self).__init__() |
|
self.conv_dim = conv_dim |
|
self.trans_dim = trans_dim |
|
self.head_dim = head_dim |
|
self.window_size = window_size |
|
self.drop_path = drop_path |
|
self.type = type |
|
self.input_resolution = input_resolution |
|
|
|
assert self.type in ["W", "SW"] |
|
if self.input_resolution <= self.window_size: |
|
self.type = "W" |
|
|
|
self.trans_block = Block( |
|
self.trans_dim, |
|
self.trans_dim, |
|
self.head_dim, |
|
self.window_size, |
|
self.drop_path, |
|
self.type, |
|
self.input_resolution, |
|
) |
|
self.conv1_1 = nn.Conv2d( |
|
self.conv_dim + self.trans_dim, |
|
self.conv_dim + self.trans_dim, |
|
1, |
|
1, |
|
0, |
|
bias=True, |
|
) |
|
self.conv1_2 = nn.Conv2d( |
|
self.conv_dim + self.trans_dim, |
|
self.conv_dim + self.trans_dim, |
|
1, |
|
1, |
|
0, |
|
bias=True, |
|
) |
|
|
|
self.conv_block = nn.Sequential( |
|
nn.Conv2d(self.conv_dim, self.conv_dim, 3, 1, 1, bias=False), |
|
nn.ReLU(True), |
|
nn.Conv2d(self.conv_dim, self.conv_dim, 3, 1, 1, bias=False), |
|
) |
|
|
|
def forward(self, x): |
|
conv_x, trans_x = torch.split( |
|
self.conv1_1(x), (self.conv_dim, self.trans_dim), dim=1 |
|
) |
|
conv_x = self.conv_block(conv_x) + conv_x |
|
trans_x = Rearrange("b c h w -> b h w c")(trans_x) |
|
trans_x = self.trans_block(trans_x) |
|
trans_x = Rearrange("b h w c -> b c h w")(trans_x) |
|
res = self.conv1_2(torch.cat((conv_x, trans_x), dim=1)) |
|
x = x + res |
|
|
|
return x |
|
|
|
|
|
class SCUNet(nn.Module): |
|
def __init__( |
|
self, |
|
state_dict, |
|
in_nc=3, |
|
config=[4, 4, 4, 4, 4, 4, 4], |
|
dim=64, |
|
drop_path_rate=0.0, |
|
input_resolution=256, |
|
): |
|
super(SCUNet, self).__init__() |
|
self.model_arch = "SCUNet" |
|
self.sub_type = "SR" |
|
|
|
self.num_filters: int = 0 |
|
|
|
self.state = state_dict |
|
self.config = config |
|
self.dim = dim |
|
self.head_dim = 32 |
|
self.window_size = 8 |
|
|
|
self.in_nc = in_nc |
|
self.out_nc = self.in_nc |
|
self.scale = 1 |
|
self.supports_fp16 = True |
|
|
|
# drop path rate for each layer |
|
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(config))] |
|
|
|
self.m_head = [nn.Conv2d(in_nc, dim, 3, 1, 1, bias=False)] |
|
|
|
begin = 0 |
|
self.m_down1 = [ |
|
ConvTransBlock( |
|
dim // 2, |
|
dim // 2, |
|
self.head_dim, |
|
self.window_size, |
|
dpr[i + begin], |
|
"W" if not i % 2 else "SW", |
|
input_resolution, |
|
) |
|
for i in range(config[0]) |
|
] + [nn.Conv2d(dim, 2 * dim, 2, 2, 0, bias=False)] |
|
|
|
begin += config[0] |
|
self.m_down2 = [ |
|
ConvTransBlock( |
|
dim, |
|
dim, |
|
self.head_dim, |
|
self.window_size, |
|
dpr[i + begin], |
|
"W" if not i % 2 else "SW", |
|
input_resolution // 2, |
|
) |
|
for i in range(config[1]) |
|
] + [nn.Conv2d(2 * dim, 4 * dim, 2, 2, 0, bias=False)] |
|
|
|
begin += config[1] |
|
self.m_down3 = [ |
|
ConvTransBlock( |
|
2 * dim, |
|
2 * dim, |
|
self.head_dim, |
|
self.window_size, |
|
dpr[i + begin], |
|
"W" if not i % 2 else "SW", |
|
input_resolution // 4, |
|
) |
|
for i in range(config[2]) |
|
] + [nn.Conv2d(4 * dim, 8 * dim, 2, 2, 0, bias=False)] |
|
|
|
begin += config[2] |
|
self.m_body = [ |
|
ConvTransBlock( |
|
4 * dim, |
|
4 * dim, |
|
self.head_dim, |
|
self.window_size, |
|
dpr[i + begin], |
|
"W" if not i % 2 else "SW", |
|
input_resolution // 8, |
|
) |
|
for i in range(config[3]) |
|
] |
|
|
|
begin += config[3] |
|
self.m_up3 = [ |
|
nn.ConvTranspose2d(8 * dim, 4 * dim, 2, 2, 0, bias=False), |
|
] + [ |
|
ConvTransBlock( |
|
2 * dim, |
|
2 * dim, |
|
self.head_dim, |
|
self.window_size, |
|
dpr[i + begin], |
|
"W" if not i % 2 else "SW", |
|
input_resolution // 4, |
|
) |
|
for i in range(config[4]) |
|
] |
|
|
|
begin += config[4] |
|
self.m_up2 = [ |
|
nn.ConvTranspose2d(4 * dim, 2 * dim, 2, 2, 0, bias=False), |
|
] + [ |
|
ConvTransBlock( |
|
dim, |
|
dim, |
|
self.head_dim, |
|
self.window_size, |
|
dpr[i + begin], |
|
"W" if not i % 2 else "SW", |
|
input_resolution // 2, |
|
) |
|
for i in range(config[5]) |
|
] |
|
|
|
begin += config[5] |
|
self.m_up1 = [ |
|
nn.ConvTranspose2d(2 * dim, dim, 2, 2, 0, bias=False), |
|
] + [ |
|
ConvTransBlock( |
|
dim // 2, |
|
dim // 2, |
|
self.head_dim, |
|
self.window_size, |
|
dpr[i + begin], |
|
"W" if not i % 2 else "SW", |
|
input_resolution, |
|
) |
|
for i in range(config[6]) |
|
] |
|
|
|
self.m_tail = [nn.Conv2d(dim, in_nc, 3, 1, 1, bias=False)] |
|
|
|
self.m_head = nn.Sequential(*self.m_head) |
|
self.m_down1 = nn.Sequential(*self.m_down1) |
|
self.m_down2 = nn.Sequential(*self.m_down2) |
|
self.m_down3 = nn.Sequential(*self.m_down3) |
|
self.m_body = nn.Sequential(*self.m_body) |
|
self.m_up3 = nn.Sequential(*self.m_up3) |
|
self.m_up2 = nn.Sequential(*self.m_up2) |
|
self.m_up1 = nn.Sequential(*self.m_up1) |
|
self.m_tail = nn.Sequential(*self.m_tail) |
|
# self.apply(self._init_weights) |
|
self.load_state_dict(state_dict, strict=True) |
|
|
|
def check_image_size(self, x): |
|
_, _, h, w = x.size() |
|
mod_pad_h = (64 - h % 64) % 64 |
|
mod_pad_w = (64 - w % 64) % 64 |
|
x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), "reflect") |
|
return x |
|
|
|
def forward(self, x0): |
|
h, w = x0.size()[-2:] |
|
x0 = self.check_image_size(x0) |
|
|
|
x1 = self.m_head(x0) |
|
x2 = self.m_down1(x1) |
|
x3 = self.m_down2(x2) |
|
x4 = self.m_down3(x3) |
|
x = self.m_body(x4) |
|
x = self.m_up3(x + x4) |
|
x = self.m_up2(x + x3) |
|
x = self.m_up1(x + x2) |
|
x = self.m_tail(x + x1) |
|
|
|
x = x[:, :, :h, :w] |
|
return x |
|
|
|
def _init_weights(self, m): |
|
if isinstance(m, nn.Linear): |
|
trunc_normal_(m.weight, std=0.02) |
|
if m.bias is not None: |
|
nn.init.constant_(m.bias, 0) |
|
elif isinstance(m, nn.LayerNorm): |
|
nn.init.constant_(m.bias, 0) |
|
nn.init.constant_(m.weight, 1.0)
|
|
|