【霹雳吧啦】手把手带你入门语义分割の番外8:U-Net 源码讲解(PyTorch)—— 网络的搭建
时间:2024-04-27 10:25:32 来源:网络cs 作者:利杜鹃 栏目:防关联工具 阅读:
目录
前言
Preparation
一、U-Net 网络结构图
二、U-Net 网络源代码
1、DRIVE 数据集
2、U-Net 源代码的不同
(1)train.py
(2)train_and_val.py
(3)results.txt
3、predict.py 模型预测
4、unet.py 模型搭建
(1)DoubleConv 类
(2)Down 类
(3)Up 类
(4)OutConv 类
(5)UNet 类 ★
5、unet.py 的源代码
前言
文章性质:学习笔记 📖
视频教程:使用 Pytorch 搭建 U-Net 网络并基于 DRIVE 数据集训练(语义分割)- 1 网络的搭建
主要内容:根据 视频教程 中提供的 U-Net 源代码(PyTorch),对 DRIVE 文件夹结构和 predict.py、unet.py 文件进行具体讲解。
Preparation
源代码:https://github.com/WZMIAOMIAO/deep-learning-for-image-processing/tree/master/pytorch_segmentation/unet
├── src: 搭建U-Net模型代码
├── train_utils: 训练、验证以及多GPU训练相关模块
├── my_dataset.py: 自定义dataset用于读取DRIVE数据集(视网膜血管分割)
├── train.py: 以单GPU为例进行训练
├── train_multi_GPU.py: 针对使用多GPU的用户使用
├── predict.py: 简易的预测脚本,使用训练好的权重进行预测测试
└── compute_mean_std.py: 统计数据集各通道的均值和标准差
一、U-Net 网络结构图
原论文提供的 U-Net 网络结构图如下所示:
原论文中提供的 U-Net 网络结构所使用的卷积层会改变特征层的高和宽,而现在比较主流的方式是 不去改变输入特征层的高和宽 ,将转置卷积替换成简单的双线性插值进行上采样,所以霹雳吧啦重绘的 U-Net 网络结构图也是按照 双线性插值 进行绘制的,如下图所示:
二、U-Net 网络源代码
1、DRIVE 数据集
将 DRIVE 数据集下载下来后放在 unet 项目目录下,因为在 train.py 文件中将读取数据集的目录默认设置为当前目录,如下图所示:
我们再来简单看看 DRIVE 文件夹的结构,它主要分为 test 测试(验证)集和 training 训练集。
在 training 训练集文件夹中:
1st_manual 提供了人工分割好的标签图片images 提供了用于分割的原图片mask 提供了二值图片,白色部分是要分割的感兴趣的区域,mask 提供了类似于蒙版的效果在 test 测试(验证)集文件夹中:
1st_manual 提供了人工分割好的标签图片,用于精标准2nd_manual 提供了人工分割好的标签图片,用于分割做验证images 提供了用于分割的原图片mask 提供了二值图片,白色部分是要分割的感兴趣的区域,mask 提供了类似于蒙版的效果【彩蛋】为了方便王子公主们下载 DRIVE 数据集,微臣将官网的下载地址贴在这里:Introduction - Grand Challenge
【补充】当然也可以前往【Preparation】中提供的 GitHub 地址,在霹雳吧啦提供的网盘链接中下载。
2、U-Net 源代码的不同
(1)train.py
这里的 train.py 训练脚本和之前讲过的 FCN 源代码中的训练脚本类似,不同之处在于 create_model 创建模型部分:只需简单调用 U-Net ,传入相应的参数后创建即可,不需要载入预训练权重。
(2)train_and_val.py
在模型的训练和验证过程中,我们在 criterion 函数中引入了 dice_loss ,在 evaluate 函数中增加了 dice 指标,这个后面会进行详细的讲解。
(3)results.txt
训练完成后会生成 results 文本文件,保存了每轮训练的训练损失 train_loss、学习率 lr、Dice 系数 dice coefficient、全局正确率 global correct、平均行正确率 average row correct、交并比 IoU 和平均交并比 mean IoU 值。
3、predict.py 模型预测
在 predict.py 文件的 main 函数中,设置类别的数量为 1 ,不包含背景类别。接着依次设置了训练好的模型权重文件路径 weights_path、输入图像路径 img_path、感兴趣区域的掩模路径 roi_mask_path,并对这些路径文件进行断言,检查是否存在,不存在则输出对应的错误信息。
运行 predict.py 文件完成网络预测后,将生成 test_result.png 图片:
4、unet.py 模型搭建
这个 unet.py 是 U-Net 网络搭建部分。
(1)DoubleConv 类
因为在 UNet 网络结构中,卷积层基本是成对出现的,因此构造了 DoubleConv 类,in_channels 是输入特征层的 channel ,out_channels 是通过 DoubleConv 后的输出特征层的 channel ,mid_channels 是通过第一个卷积层后的输出特征图的 channel 。
【说明】因为这里的卷积层采用的是比较主流的方式,即不去改变特征图的高和宽,因此 padding 设置为 1 。
(2)Down 类
Down 类继承自 nn.Sequential 父类,该模块由一个下采样和两个卷积构成(Encoder):
(3)Up 类
Up 类继承自 nn.Module 父类,该模块由一个上采样和两个卷积构成(Decoder):
【说明1】在 __init__ 初始化函数中,传入 bilinear 表示是否使用双线性插值替代转置卷积,Up 类默认会使用双线性插值,故令 bilinear=True ,这里的 in_channels 对应 concat 拼接之后的 channels 或者说对应 Up 模块中第一个卷积的输入 channels 。
【说明2】在 forward 函数中, 绿 框部分的处理是为了确保要 concat 拼接的 x1 和 x2 的宽高相同,因为当最初输入的特征图的宽高不是 16 的整数倍时,在下采样后需要进行取整,再进行上采样后可能会出现 尺寸对不上 的问题。
(4)OutConv 类
OutConv 类继承自 nn.Module 父类,对应 U-Net 网络结构的最后一个卷积层,其卷积核个数为包含背景的分类类别个数 num_classes 。
(5)UNet 类 ★
UNet 类继承自 nn.Module 父类,传入参数包括:
in_channels 是输入特征图的通道数,彩色图片为 3 ,灰度图片为 1num_classes 是包含背景的分类类别个数bilinear 表示是否使用双线性插值法base_c 是基础通道数 channel【说明】factor 是一个因子,用于控制上采样过程中特征图的通道数,这个因子的值取决于是否使用双线性插值的 bilinear 标志:
当 bilinear=True 时,factor = 2,说明在上采样过程中,特征图的通道数减半,当 base_c 为 64 时,上采样后特征图的通道数为 32当 bilinear=False 时,factor = 1,说明在上采样过程中,特征图的通道数不变,当 base_c 为 64 时,上采样后特征图的通道数为 64这个 factor 因子的引入主要是为了在上采用过程中控制特征图的大小和复杂度,以适应不同的任务需求和计算资源限制。
这是 U-Net 的 forward 前向传播函数,它接受一个 torch.Tensor 类型的输入 x ,并返回一个字典类型的输出。
5、unet.py 的源代码
from typing import Dictimport torchimport torch.nn as nnimport torch.nn.functional as Fclass DoubleConv(nn.Sequential): def __init__(self, in_channels, out_channels, mid_channels=None): if mid_channels is None: mid_channels = out_channels super(DoubleConv, self).__init__( nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(mid_channels), nn.ReLU(inplace=True), nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True) )class Down(nn.Sequential): def __init__(self, in_channels, out_channels): super(Down, self).__init__( nn.MaxPool2d(2, stride=2), DoubleConv(in_channels, out_channels) )class Up(nn.Module): def __init__(self, in_channels, out_channels, bilinear=True): super(Up, self).__init__() if bilinear: self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) self.conv = DoubleConv(in_channels, out_channels, in_channels // 2) else: self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2) self.conv = DoubleConv(in_channels, out_channels) def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: x1 = self.up(x1) # [N, C, H, W] diff_y = x2.size()[2] - x1.size()[2] diff_x = x2.size()[3] - x1.size()[3] # padding_left, padding_right, padding_top, padding_bottom x1 = F.pad(x1, [diff_x // 2, diff_x - diff_x // 2, diff_y // 2, diff_y - diff_y // 2]) x = torch.cat([x2, x1], dim=1) x = self.conv(x) return xclass OutConv(nn.Sequential): def __init__(self, in_channels, num_classes): super(OutConv, self).__init__( nn.Conv2d(in_channels, num_classes, kernel_size=1) )class UNet(nn.Module): def __init__(self, in_channels: int = 1, num_classes: int = 2, bilinear: bool = True, base_c: int = 64): super(UNet, self).__init__() self.in_channels = in_channels self.num_classes = num_classes self.bilinear = bilinear self.in_conv = DoubleConv(in_channels, base_c) self.down1 = Down(base_c, base_c * 2) self.down2 = Down(base_c * 2, base_c * 4) self.down3 = Down(base_c * 4, base_c * 8) factor = 2 if bilinear else 1 self.down4 = Down(base_c * 8, base_c * 16 // factor) self.up1 = Up(base_c * 16, base_c * 8 // factor, bilinear) self.up2 = Up(base_c * 8, base_c * 4 // factor, bilinear) self.up3 = Up(base_c * 4, base_c * 2 // factor, bilinear) self.up4 = Up(base_c * 2, base_c, bilinear) self.out_conv = OutConv(base_c, num_classes) def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]: x1 = self.in_conv(x) x2 = self.down1(x1) x3 = self.down2(x2) x4 = self.down3(x3) x5 = self.down4(x4) x = self.up1(x5, x4) x = self.up2(x, x3) x = self.up3(x, x2) x = self.up4(x, x1) logits = self.out_conv(x) return {"out": logits}
本文链接:https://www.kjpai.cn/news/2024-04-27/162804.html,文章来源:网络cs,作者:利杜鹃,版权归作者所有,如需转载请注明来源和作者,否则将追究法律责任!
上一篇:操作系统安全:Linux安全审计,Linux日志详解
下一篇:返回列表