# ResLayer的中文意思是残差层,它可以解决深度网络的梯度消失问题
class ResLayer(nn.Module):
"Residual layer with `in_channels` inputs." # 有in_channels个输入的残差层
def __init__(self, in_channels: int): # 输入通道数
super().__init__() # 初始化父类
mid_channels = in_channels // 2 # 中间通道数
self.layer1 = BaseConv(
in_channels, mid_channels, ksize=1, stride=1, act="lrelu" # 1x1卷积,将通道数降低一半
)
self.layer2 = BaseConv(
mid_channels, in_channels, ksize=3, stride=1, act="lrelu" # 3x3卷积,将通道数恢复到原来的大小
)
def forward(self, x):
out = self.layer2(self.layer1(x)) # 1x1卷积,3x3卷积,通道数恢复到原来的大小
return x + out # shortcut y=y+x
# ResLayer的数据流图如下:
# x -> layer1(1x1) -> layer2(3x3) -> y
# |______________________________|
# 其中layer1是1x1卷积,layer2是3x3卷积
发表回复