You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
61 lines
1.8 KiB
61 lines
1.8 KiB
2 years ago
|
#!/usr/bin/env python3
|
||
|
# -*- coding:utf-8 -*-
|
||
|
#############################################################
|
||
|
# File: OSAG.py
|
||
|
# Created Date: Tuesday April 28th 2022
|
||
|
# Author: Chen Xuanhong
|
||
|
# Email: chenxuanhongzju@outlook.com
|
||
|
# Last Modified: Sunday, 23rd April 2023 3:08:49 pm
|
||
|
# Modified By: Chen Xuanhong
|
||
|
# Copyright (c) 2020 Shanghai Jiao Tong University
|
||
|
#############################################################
|
||
|
|
||
|
|
||
|
import torch.nn as nn
|
||
|
|
||
|
from .esa import ESA
|
||
|
from .OSA import OSA_Block
|
||
|
|
||
|
|
||
|
class OSAG(nn.Module):
|
||
|
def __init__(
|
||
|
self,
|
||
|
channel_num=64,
|
||
|
bias=True,
|
||
|
block_num=4,
|
||
|
ffn_bias=False,
|
||
|
window_size=0,
|
||
|
pe=False,
|
||
|
):
|
||
|
super(OSAG, self).__init__()
|
||
|
|
||
|
# print("window_size: %d" % (window_size))
|
||
|
# print("with_pe", pe)
|
||
|
# print("ffn_bias: %d" % (ffn_bias))
|
||
|
|
||
|
# block_script_name = kwargs.get("block_script_name", "OSA")
|
||
|
# block_class_name = kwargs.get("block_class_name", "OSA_Block")
|
||
|
|
||
|
# script_name = "." + block_script_name
|
||
|
# package = __import__(script_name, fromlist=True)
|
||
|
block_class = OSA_Block # getattr(package, block_class_name)
|
||
|
group_list = []
|
||
|
for _ in range(block_num):
|
||
|
temp_res = block_class(
|
||
|
channel_num,
|
||
|
bias,
|
||
|
ffn_bias=ffn_bias,
|
||
|
window_size=window_size,
|
||
|
with_pe=pe,
|
||
|
)
|
||
|
group_list.append(temp_res)
|
||
|
group_list.append(nn.Conv2d(channel_num, channel_num, 1, 1, 0, bias=bias))
|
||
|
self.residual_layer = nn.Sequential(*group_list)
|
||
|
esa_channel = max(channel_num // 4, 16)
|
||
|
self.esa = ESA(esa_channel, channel_num)
|
||
|
|
||
|
def forward(self, x):
|
||
|
out = self.residual_layer(x)
|
||
|
out = out + x
|
||
|
return self.esa(out)
|