GCNet: Global Context Network(ICCV 2019)原理与代码解析
时间:2024-04-04 14:35:35 来源:网络cs 作者:胡椒 栏目:跨境风云 阅读:
paper:GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond
official implementaion:https://github.com/xvjiarui/GCNet
Third party implementation:https://github.com/open-mmlab/mmcv/blob/master/mmcv/cnn/bricks/context_block.py
存在的问题
通过捕获long-range dependency提取全局信息,对各种视觉任务都是很有帮助的。Non-local Network(介绍见https://blog.csdn.net/ooooocj/article/details/124573078)通过自注意力机制来解决这个问题。对于每个查询位置(query position),non-local network首先计算该位置和所有位置之间一个两两成对的关系,得到一个attention map。然后对attention map所有位置的权重加权求和得到汇总特征,每一个查询位置都得到一个汇总特征,将汇总特征与原始特征相加得到最终输出。
对于某个query position,non-local network计算的另一个位置与该位置的关系即一个权重值表示这个位置对query位置的重要程度。本文可视化attention map发现,不同的query位置其对应的attention map几乎一样,如下图所示
data:image/s3,"s3://crabby-images/842f1/842f1dd1ee682d83e5aca207f4d4215b44527abc" alt=""
non-local block可以表示为下式
data:image/s3,"s3://crabby-images/d75a1/d75a159409ba95145eb401ead9d9cbbb1d6826e7" alt=""
其中 \(i\) 是query position的索引,\(j\) 遍历所有位置,\(f(\mathbf{x}_{i},\mathbf{x}_{j})\) 表示位置 \(i\) 和 \(j\) 之间的关系,\(\mathcal{C}(\mathbf{x})\) 是归一化因子,\(W_{z}\) 和 \(W_{v}\) 是线性变换矩阵例如 \(1\times1\) 卷积。non-local block有多种不同的实例化方法,例如Gaussian、Embedded Gaussian、Dot product、Concat,下图(a)是Embedded Gaussian的结构。
data:image/s3,"s3://crabby-images/76b1b/76b1b03b3d59fce7723086ab7dc929ff84306b5d" alt=""
由于需要计算每个query位置的attention map,因此non-local block的时间和空间复杂度都是所有位置的平方关系。
下图是作者从COCO数据集中随机挑选的6张图片,并可视化出3个不同的query position即图中的红点与对应的query-specific attention map,可以看出对于不同的查询位置,它们的attention map几乎是相同的。
data:image/s3,"s3://crabby-images/96421/964219d08b022596d69347beb2b9db588b1d01c8" alt=""
为了进一步验证这一观察结果,作者又分析了不同的查询位置与全局上下文之间的距离。结果如下表所示。其中计算了三种向量之间的余弦距离cosine distance,分别为non-local block的输入、输出、以及查询的注意力图,对应表中的input、output、att。
data:image/s3,"s3://crabby-images/7c922/7c92275058e786dd7da2f17c9aa502021a3fdd83" alt=""
从表中可以看出,input列的余弦距离比较大表明non-local的输入特征可以在不同的位置进行区分,但output列的余弦距离非常小,表明non-local block建模的全局上下文特征对于不同的query position几乎是相同的。attention map上的距离非常小,也验证了可视化的观察结果。
尽管non-local block是打算针对每个位置计算全局上下文的,但训练后的全局上下文实际上是独立于查询位置的。因为没有必要为每个查询位置单独计算query-specific全局上下文。
本文的创新点
本文通过观察发现non-local block针对每个query position计算的attention map最终结果是独立于查询位置的,那么就没有必要针对每个查询位置计算了,因此提出计算一个通用的attention map并应用于输入feature map上的所有位置,大大减少了计算量的同时又没有导致性能的降低。此外,结合SE block,设计了一个新的Global Context (GC) block,既轻量又可以有效地建模全局上下文。GC Block结合了Non-local block和SE block的优点,基于GC Block设计的GCNet在多个任务上均超过了NLNet和SENet。
方法介绍
作者舍去了式(1)中的 \(W_{z}\) 即图(3)(a)中的query分支,得到下式
data:image/s3,"s3://crabby-images/80b41/80b4118fbca4c06c39511e08603e3bedb0b74e3d" alt=""
这里采用的是最常用的Embedded Gaussiian的实例化方式。简化后的non-local block如图(3)(b)所示。
为了进一步减少计算量,作者应用分配率将 \(W_{v}\) 移到attention pool的外面,如下
data:image/s3,"s3://crabby-images/49595/495956babb9065d86dda6a65fb55a4b7dac79d24" alt=""
这里简化后的non-local block如图4(b)所示。1x1卷积 \(W_{v}\) 的FLOPs从 \(\mathcal{O}(HWC^{2})\) 减小到 \(\mathcal{O}(C^{2})\)。
data:image/s3,"s3://crabby-images/f7000/f7000f3a3bbecf23a3f991af2e0a825c3e9e896a" alt=""
到目前为止简化的NL block中参数量最大的部分在transform module即图4(b)中的Transform部分,这里是一个1x1卷积但参数量为 \(C\cdot C\),当把Nl block应用到较深的层例如resnet中的 \(res_{5}\) 时,CxC=2028x2048,占据了整个block大部分的计算量。为了进一步减少计算量,作者借鉴了SE block的思想如图4(c),将图4(b)中的transform module换成了图4(d)中的bottleneck transform module,其中 \(r\) 是reduction ratio,这样参数量就从C·C变成了2·C·C/r,默认情况下r=16,因此参数量就减少为1/8。
实验结果
下表baseline是backbone为ResNet-50的Mask R-CNN在COCO数据集上的目标检测和实例分割的结果。将1个non-local block(NL)、1个simplified non-local block(SNL)、1个global context block(GC)插入到c4的最后一个residual block前,可以看出GC block获得的相似的性能但参数量更小。将GC block添加到所有的residual block中在参数量相似的情况下得到了更高的性能。
data:image/s3,"s3://crabby-images/fea5c/fea5cd509275a2582479006f2eadddb3ad4bb0d2" alt=""
代码解析
这里的代码是mmcv中的实现。
# Copyright (c) OpenMMLab. All rights reserved.from typing import Unionimport torchfrom torch import nnfrom ..utils import constant_init, kaiming_initfrom .registry import PLUGIN_LAYERSdef last_zero_init(m: Union[nn.Module, nn.Sequential]) -> None: if isinstance(m, nn.Sequential): constant_init(m[-1], val=0) else: constant_init(m, val=0)@PLUGIN_LAYERS.register_module()class ContextBlock(nn.Module): """ContextBlock module in GCNet. See 'GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond' (https://arxiv.org/abs/1904.11492) for details. Args: in_channels (int): Channels of the input feature map. ratio (float): Ratio of channels of transform bottleneck pooling_type (str): Pooling method for context modeling. Options are 'att' and 'avg', stand for attention pooling and average pooling respectively. Default: 'att'. fusion_types (Sequence[str]): Fusion method for feature fusion, Options are 'channels_add', 'channel_mul', stand for channelwise addition and multiplication respectively. Default: ('channel_add',) """ _abbr_ = 'context_block' def __init__(self, in_channels: int, ratio: float, pooling_type: str = 'att', fusion_types: tuple = ('channel_add', )): super().__init__() assert pooling_type in ['avg', 'att'] assert isinstance(fusion_types, (list, tuple)) valid_fusion_types = ['channel_add', 'channel_mul'] assert all([f in valid_fusion_types for f in fusion_types]) assert len(fusion_types) > 0, 'at least one fusion should be used' self.in_channels = in_channels self.ratio = ratio self.planes = int(in_channels * ratio) self.pooling_type = pooling_type self.fusion_types = fusion_types if pooling_type == 'att': self.conv_mask = nn.Conv2d(in_channels, 1, kernel_size=1) self.softmax = nn.Softmax(dim=2) else: self.avg_pool = nn.AdaptiveAvgPool2d(1) if 'channel_add' in fusion_types: self.channel_add_conv = nn.Sequential( nn.Conv2d(self.in_channels, self.planes, kernel_size=1), nn.LayerNorm([self.planes, 1, 1]), nn.ReLU(inplace=True), # yapf: disable nn.Conv2d(self.planes, self.in_channels, kernel_size=1)) else: self.channel_add_conv = None if 'channel_mul' in fusion_types: self.channel_mul_conv = nn.Sequential( nn.Conv2d(self.in_channels, self.planes, kernel_size=1), nn.LayerNorm([self.planes, 1, 1]), nn.ReLU(inplace=True), # yapf: disable nn.Conv2d(self.planes, self.in_channels, kernel_size=1)) else: self.channel_mul_conv = None self.reset_parameters() def reset_parameters(self): if self.pooling_type == 'att': kaiming_init(self.conv_mask, mode='fan_in') self.conv_mask.inited = True if self.channel_add_conv is not None: last_zero_init(self.channel_add_conv) if self.channel_mul_conv is not None: last_zero_init(self.channel_mul_conv) def spatial_pool(self, x: torch.Tensor) -> torch.Tensor: batch, channel, height, width = x.size() if self.pooling_type == 'att': input_x = x # [N, C, H * W] input_x = input_x.view(batch, channel, height * width) # [N, 1, C, H * W] input_x = input_x.unsqueeze(1) # [N, 1, H, W] context_mask = self.conv_mask(x) # [N, 1, H * W] context_mask = context_mask.view(batch, 1, height * width) # [N, 1, H * W] context_mask = self.softmax(context_mask) # [N, 1, H * W, 1] context_mask = context_mask.unsqueeze(-1) # [N, 1, C, 1] context = torch.matmul(input_x, context_mask) # [N, C, 1, 1] context = context.view(batch, channel, 1, 1) else: # [N, C, 1, 1] context = self.avg_pool(x) return context def forward(self, x: torch.Tensor) -> torch.Tensor: # [N, C, 1, 1] context = self.spatial_pool(x) out = x if self.channel_mul_conv is not None: # [N, C, 1, 1] channel_mul_term = torch.sigmoid(self.channel_mul_conv(context)) out = out * channel_mul_term if self.channel_add_conv is not None: # [N, C, 1, 1] channel_add_term = self.channel_add_conv(context) out = out + channel_add_term return out
阅读本书更多章节>>>>
本文链接:https://www.kjpai.cn/fengyun/2024-04-04/153954.html,文章来源:网络cs,作者:胡椒,版权归作者所有,如需转载请注明来源和作者,否则将追究法律责任!
下一篇:返回列表