Attention Gate in Attention UNet

对Attention UNet中使用的Attention Gate(AG)进行介绍和代码分析。

1 Attention UNet

原文链接

于2018年被收录于cvpr,在UNet的基础上,decoder的部分增加Attention Gate,以突出通过跳转连接的突出特征。文中介绍了2D和3D的实现方法。

Attention UNet整体结构如下:
Att-UNet

从原文中的实验结果来看,Attention UNet
都相较于U-Net在多个数据集(CT-150,TCIA Pancreas-CT,eg.)上有更优秀的分割性能,且从可视化结果来看,Attention UNet有更平滑且准确的边缘。

Segment Example

2 Attention Gate

下面来详细了解一下这个UNet中的Attention Gate。

作用:自适应地选择有用的特征,抑制无关的区域。

2.1 AG的结构与原理

AG

  • $x^l$:从encoder对应层skip来的特征图,通常是高分辨率的图像。
    • 可能包含冗余和无关信息
  • $g$:从上采样而来的低分辨率信号。
    • 解码器当前关注的区域,帮助引导关注重点

处理流程
假设:$x^l$(1,64,64,64),$g$(1,128,32,32)

通常g的空间尺寸更小,channel更大

  1. 将两个输入使用$1\times 1$卷积映射到统一维度
  2. 对齐空间尺寸(上采样gating信号):可以使用线性插值

权重大的特征在相加后更加明显

  1. 融合后是使用ReLU激活
  2. $1\times 1$卷积:通道降到1通道size:(1,1,32,32)
  3. 得到mask:sigmoid将值映射到0~1之间
  4. mask与$x^l$相乘,得到“重要区域”特征值更强,“不重要区域”被抑制的特征图。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
 x (1, 64, 64, 64)

[1×1 Conv] →────────────┐
│ │
(1, 32, 64, 64) │

Add + ReLU

[Upsample] ←── [1×1 Conv] ← g (1, 128, 32, 32)
(1, 32, 64, 64)


[1×1 Conv + Sigmoid]

Attention mask (1, 1, 64, 64)


Element-wise multiplication

Output: (1, 64, 64, 64)

2.2 Pytorch 实现AG(2D)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import torch
import torch.nn as nn
import torch.nn.functional as F

class AttentionGate2D(nn.Module):
def __init__(self, F_g, F_l, F_int):
"""
F_g: decoder 的通道数 (gating)
F_l: encoder 的通道数 (local feature x)
F_int: 中间通道数
"""
super(AttentionGate2D, self).__init__()

self.W_g = nn.Sequential(
nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True),
nn.BatchNorm2d(F_int)
)

self.W_x = nn.Sequential(
nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True),
nn.BatchNorm2d(F_int)
)

self.psi = nn.Sequential(
nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
nn.BatchNorm2d(1),
nn.Sigmoid()
)

self.relu = nn.ReLU(inplace=True)

def forward(self, x, g):
"""
x: 来自 encoder 的特征图 (B, F_l, H, W)
g: 来自 decoder 的 gating 信号 (B, F_g, H', W')
"""
g1 = self.W_g(g) # (B, F_int, H', W')
x1 = self.W_x(x) # (B, F_int, H, W)

# 上采样 gating 信号(如果大小不一致)
if g1.shape[-2:] != x1.shape[-2:]:
g1 = F.interpolate(g1, size=x1.shape[-2:], mode='bilinear', align_corners=True)

psi = self.relu(g1 + x1) # (B, F_int, H, W)
psi = self.psi(psi) # (B, 1, H, W)

return x * psi # (B, F_l, H, W) × (B, 1, H, W)

在上采样过程中:

1
2
3
4
5
x = self.up(x)                            # 上采样 decoder feature map
skip = self.attention_gate(skip, x) # 使用 AG 筛选 encoder feature
x = torch.cat([x, skip], dim=1) # 拼接经过 AG 的 skip connection
x = self.conv(x) # 双卷积融合
return x
作者

Zhou

发布于

2025-04-10

更新于

2025-04-11

许可协议

评论

+ + +