神经网络学习笔记2:图像语义分割模型的选型与部署实践

在智能驾驶领域,常使用图像语义分割使控制器了解周围的环境。图像语义分割还能为视觉惯性组合导航生成其所需的语义化特征,具有广阔的应用前景。本文记录了一次完整的图像语义分割模型选型和部署的流程,以供参考。

在 Papers With Code 中挑选模型

Papers With Code 是一个社区驱动、公开透明的平台,专注于将最新的机器学习/人工智能研究论文与其开源代码实现、数据集与性能指标整合在一起。它提供论文摘要、GitHub 实现链接、基准榜单(如 SOTA leaderboard)及数据集和方法分类,帮助研究者和开发者快速了解、复现并比较技术细节与性能。

以这次需要的图像语义分割模型为例,在其 Browse State-of-the-Art 页面选择 Computer Vision 中的 Semantic Segmentation,可以发现很多 benchmarks 和在这些数据集中表现最好的模型。可以发现,其中较为流行的数据集为 Cityscapes,在 2024 年有 434 篇文章选择使用它作为数据集。它是一个专注于城市街景语义理解的大型图像数据集,包含来自德国等城市的高分辨率街景图像,主要用于语义分割、实例分割和场景理解等任务。数据集中提供了精细标注的像素级语义标签,广泛应用于自动驾驶与智能交通系统的研究中。

在 Cityscapes 数据集的详情页面,可以看到很多 benchmark,包括基于 Cityscapes 测试集和验证集的 Semantic Segmentation 测试、Panoptic Segmentation 测试、Real-Time Semantic Segmentation 测试等。测试中的每个条目都有论文链接和对应的开源 Github 链接,非常适合快速挑选并应用模型。

为了满足实时性需求,应该选择 Real-Time Semantic Segmentation 测试中表现较好的模型作为比较对象。注意到性能较好的模型为 PIDNet,它在 31.1 fps 的帧率下实现了 80.6% 的 mIoU,具备较强的识别能力,因此选择它进入下一阶段。

部署 PIDNet

在详情页面可以看到 PIDNet 的 (Github 仓库链接)[https://github.com/XuJiacong/PIDNet],在仓库的 README 中给出了不同尺寸的模型下载链接,以及详细的使用和复现流程。通过如下命令克隆仓库到本地:

1
git clone https://github.com/XuJiacong/PIDNet.git

接下来,为了将模型整合进现有的系统中,需要找到调用这个模型的方法。根据使用教程,在仓库的 tools 文件夹中找到了一些可以直接运行的文件。其中 train.py 是训练流程,此处不涉及。eval.py 用于验证模型在 Cityscapes 中的性能,包含了一些冗余代码。而 custom.py 用于对给出的单个图片进行语义分割,最为精简,可供参考。其代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
# ------------------------------------------------------------------------------
# Written by Jiacong Xu (jiacong.xu@tamu.edu)
# ------------------------------------------------------------------------------

import glob
import argparse
import cv2
import os
import numpy as np
import _init_paths
import models
import torch
import torch.nn.functional as F
from PIL import Image

mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]

color_map = [(128, 64,128),
(244, 35,232),
( 70, 70, 70),
(102,102,156),
(190,153,153),
(153,153,153),
(250,170, 30),
(220,220, 0),
(107,142, 35),
(152,251,152),
( 70,130,180),
(220, 20, 60),
(255, 0, 0),
( 0, 0,142),
( 0, 0, 70),
( 0, 60,100),
( 0, 80,100),
( 0, 0,230),
(119, 11, 32)]

def parse_args():
parser = argparse.ArgumentParser(description='Custom Input')

parser.add_argument('--a', help='pidnet-s, pidnet-m or pidnet-l', default='pidnet-l', type=str)
parser.add_argument('--c', help='cityscapes pretrained or not', type=bool, default=True)
parser.add_argument('--p', help='dir for pretrained model', default='../pretrained_models/cityscapes/PIDNet_L_Cityscapes_test.pt', type=str)
parser.add_argument('--r', help='root or dir for input images', default='../samples/', type=str)
parser.add_argument('--t', help='the format of input images (.jpg, .png, ...)', default='.png', type=str)

args = parser.parse_args()

return args

def input_transform(image):
image = image.astype(np.float32)[:, :, ::-1]
image = image / 255.0
image -= mean
image /= std
return image

def load_pretrained(model, pretrained):
pretrained_dict = torch.load(pretrained, map_location='cpu')
if 'state_dict' in pretrained_dict:
pretrained_dict = pretrained_dict['state_dict']
model_dict = model.state_dict()
pretrained_dict = {k[6:]: v for k, v in pretrained_dict.items() if (k[6:] in model_dict and v.shape == model_dict[k[6:]].shape)}
msg = 'Loaded {} parameters!'.format(len(pretrained_dict))
print('Attention!!!')
print(msg)
print('Over!!!')
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict, strict = False)

return model

if __name__ == '__main__':
args = parse_args()
images_list = glob.glob(args.r+'*'+args.t)
sv_path = args.r+'outputs/'

model = models.pidnet.get_pred_model(args.a, 19 if args.c else 11)
model = load_pretrained(model, args.p).cuda()
model.eval()
with torch.no_grad():
for img_path in images_list:
img_name = img_path.split("\\")[-1]
img = cv2.imread(os.path.join(args.r, img_name),
cv2.IMREAD_COLOR)
sv_img = np.zeros_like(img).astype(np.uint8)
img = input_transform(img)
img = img.transpose((2, 0, 1)).copy()
img = torch.from_numpy(img).unsqueeze(0).cuda()
pred = model(img)
pred = F.interpolate(pred, size=img.size()[-2:],
mode='bilinear', align_corners=True)
pred = torch.argmax(pred, dim=1).squeeze(0).cpu().numpy()

for i, color in enumerate(color_map):
for j in range(3):
sv_img[:,:,j][pred==i] = color_map[i][j]
sv_img = Image.fromarray(sv_img)

if not os.path.exists(sv_path):
os.mkdir(sv_path)
sv_img.save(sv_path+img_name)

程序首先根据指定的模型类型(如 pidnet-l)初始化模型,并通过 load_pretrained 函数加载预训练模型参数,该函数会读取指定路径下的预训练权重文件,并将符合形状要求的权重加载进模型。随后,从指定目录读取所有图像路径,对每张图像进行预处理:先通过 input_transform 函数进行归一化和标准化处理(按照 ImageNet 的 meanstd),然后将图像维度从 HWC 转换为 CHW 并封装成 PyTorch Tensor,再将其送入模型进行前向推理。推理结果通过 F.interpolate 进行尺寸恢复,然后用 torch.argmax 获取像素级的分类结果。

由此,可以剔除多余的代码,封装一个类。这个类在创建时可以完成模型权重和各种常量的定义,并提供一个分割函数。将 OpenCV 的图像传入后,返回一个分割并着色后的结果。代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import torch
import torch.nn.functional as F
import numpy as np
from PIL import Image
import models
import cv2

class PIDNetSegmenter:
def __init__(self,
model_type='pidnet-l',
pretrained_path='../pretrained_models/cityscapes/PIDNet_L_Cityscapes_test.pt',
use_cityscapes=True):
# 常量定义
self.mean = [0.485, 0.456, 0.406]
self.std = [0.229, 0.224, 0.225]
self.color_map = [
(128, 64, 128), (244, 35, 232), (70, 70, 70), (102, 102, 156),
(190, 153, 153), (153, 153, 153), (250, 170, 30), (220, 220, 0),
(107, 142, 35), (152, 251, 152), (70, 130, 180), (220, 20, 60),
(255, 0, 0), (0, 0, 142), (0, 0, 70), (0, 60, 100),
(0, 80, 100), (0, 0, 230), (119, 11, 32)
]

# 模型初始化与加载
self.model = models.pidnet.get_pred_model(model_type, 19 if use_cityscapes else 11)
self._load_pretrained(pretrained_path)
self.model.eval()
self.model.cuda()

def _load_pretrained(self, pretrained):
pretrained_dict = torch.load(pretrained, map_location='cpu')
if 'state_dict' in pretrained_dict:
pretrained_dict = pretrained_dict['state_dict']
model_dict = self.model.state_dict()
pretrained_dict = {
k[6:]: v for k, v in pretrained_dict.items()
if (k[6:] in model_dict and v.shape == model_dict[k[6:]].shape)
}
model_dict.update(pretrained_dict)
self.model.load_state_dict(model_dict, strict=False)

def _transform(self, image):
image = image.astype(np.float32)[:, :, ::-1] / 255.0
image -= self.mean
image /= self.std
image = image.transpose(2, 0, 1)
return torch.from_numpy(image).unsqueeze(0).cuda()

def segment(self, image: np.ndarray) -> np.ndarray:
"""
输入:OpenCV 格式的 BGR 图像
输出:分割着色后的 RGB 图像(np.ndarray)
"""
with torch.no_grad():
img_tensor = self._transform(image)
pred = self.model(img_tensor)
pred = F.interpolate(pred, size=image.shape[:2], mode='bilinear', align_corners=True)
pred = torch.argmax(pred, dim=1).squeeze(0).cpu().numpy()

sv_img = np.zeros_like(image).astype(np.uint8)
for i, color in enumerate(self.color_map):
for j in range(3):
sv_img[:, :, j][pred == i] = color[j]
return sv_img

然后,可以将其整合进 ROS 节点,这个节点监听相机的图片 topic,在相机拍摄图片后进行分割并且输出到自己的 topic 中,供后续节点进行处理。代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
#!/usr/bin/env python3
import rospy
from sensor_msgs.msg import Image
from cv_bridge import CvBridge
import numpy as np
import cv2

from pidnet_segmenter import PIDNetSegmenter

class SegmentationNode:
def __init__(self):
rospy.init_node('pidnet_segmentation_node', anonymous=True)

# 参数设定
self.input_topic = rospy.get_param('~input_topic', '/camera/image_raw')
self.output_topic = rospy.get_param('~output_topic', '/segmentation/image')

# OpenCV/ROS 图像桥
self.bridge = CvBridge()

# 初始化分割模型
self.segmenter = PIDNetSegmenter()

# ROS 发布者与订阅者
self.publisher = rospy.Publisher(self.output_topic, Image, queue_size=1)
self.subscriber = rospy.Subscriber(self.input_topic, Image, self.image_callback)

rospy.loginfo("PIDNet segmentation node started.")
rospy.spin()

def image_callback(self, msg):
try:
# 转换 ROS 图像为 OpenCV 图像 (BGR)
cv_image = self.bridge.imgmsg_to_cv2(msg, desired_encoding='bgr8')

# 执行分割
segmented_img = self.segmenter.segment(cv_image) # 输出为 RGB

# 转换回 ROS 图像消息
ros_image = self.bridge.cv2_to_imgmsg(segmented_img, encoding='rgb8')

# 发布结果
self.publisher.publish(ros_image)
except Exception as e:
rospy.logerr(f"Segmentation error: {e}")

if __name__ == '__main__':
try:
SegmentationNode()
except rospy.ROSInterruptException:
pass

至此,完成了图像语义分割模型的选型和部署。