0

0

Vision Transformer多标签分类:损失函数与评估策略深度解析

聖光之護

聖光之護

发布时间:2025-10-17 11:05:31

|

274人浏览过

|

来源于php中文网

原创

vision transformer多标签分类:损失函数与评估策略深度解析

本文旨在详细阐述如何将Vision Transformer(ViT)从单标签多分类任务转换为多标签分类任务,并重点介绍损失函数的选择与评估策略的调整。我们将探讨为何`CrossEntropyLoss`不适用于多标签场景,并深入讲解`BCEWithLogitsLoss`的使用方法,包括标签格式要求。此外,文章还将介绍多标签分类任务中常用的评估指标,如精确率、召回率、F1分数和mAP,并提供代码示例,确保读者能够顺利实现ViT在多标签环境下的训练与评估。

从单标签到多标签:核心概念转变

深度学习的图像分类任务中,单标签多分类(Single-label Multi-class Classification)是指每张图片只属于一个类别,模型需要从多个互斥的类别中预测出唯一正确的那个。而多标签分类(Multi-label Classification)则允许每张图片同时属于一个或多个类别,模型需要为每个类别独立地判断其是否存在于图片中。

这种任务性质的转变,要求我们对模型的输出层、损失函数以及评估策略进行相应的调整。对于Vision Transformer(ViT)而言,其特征提取部分通常保持不变,但最终的分类头和训练流程需要进行适配。

损失函数的选择与实现

在单标签多分类任务中,我们通常使用torch.nn.CrossEntropyLoss作为损失函数。它内部包含了Softmax激活函数和负对数似然损失,期望模型的输出是每个类别的Logits,并且这些Logits经过Softmax后会转化为概率分布,所有类别的概率和为1。

然而,在多标签分类任务中,由于图片可能同时属于多个类别,各个类别之间不再是互斥关系。因此,CrossEntropyLoss不再适用,因为它强制了类别之间的互斥性。

推荐的损失函数:torch.nn.BCEWithLogitsLoss

对于多标签分类任务,最常用且推荐的损失函数是torch.nn.BCEWithLogitsLoss。这个损失函数结合了Sigmoid激活函数和二元交叉熵损失(Binary Cross Entropy Loss)。

其主要优点包括:

  1. 独立处理每个类别: BCEWithLogitsLoss会对模型输出的每个Logit独立地计算二元交叉熵,这与多标签任务中各类别独立存在的特性相符。
  2. 数值稳定性: 它直接作用于模型的原始Logits输出,内部处理了Sigmoid激活,避免了先手动计算Sigmoid再计算交叉熵可能导致的数值溢出或下溢问题。

使用BCEWithLogitsLoss的注意事项:

  1. 模型输出: 模型的最终输出层应该是一个全连接层,输出维度等于类别的总数,且不应在其后接Softmax激活函数。例如,如果你的模型有7个类别,最终输出应为形状(batch_size, 7)的Logits张量。
  2. 标签格式: 标签(target)必须是与模型输出Logits形状相同的浮点型(torch.float)张量。它通常是一个“多热编码”(multi-hot encoding)向量,其中1表示该类别存在,0表示该类别不存在。例如,[0, 1, 1, 0, 0, 1, 0]表示第二个、第三个和第六个类别存在。

代码示例:替换损失函数

假设我们有一个ViT模型,其输出为pred(Logits),标签为labels(多热编码)。

import torch
import torch.nn as nn

# 假设模型输出的Logits,形状为 (batch_size, num_classes)
# 这里以 batch_size = 2, num_classes = 7 为例
logits = torch.randn(2, 7) # 模拟模型输出的原始Logits

# 假设对应的多标签,形状也为 (batch_size, num_classes)
# 注意:标签必须是浮点型 (torch.float)
labels = torch.tensor([
    [0, 1, 1, 0, 0, 1, 0], # 第一个样本的标签
    [1, 0, 1, 1, 0, 0, 0]  # 第二个样本的标签
]).float()

# 实例化 BCEWithLogitsLoss
loss_function = nn.BCEWithLogitsLoss()

# 计算损失
loss = loss_function(logits, labels)

print(f"Logits:\n{logits}")
print(f"Labels:\n{labels}")
print(f"Calculated Loss: {loss.item()}")

# 原始训练循环中的应用
# pred = model(images.to(device))
# loss = loss_function(pred, labels.to(device))
# loss.backward()
# optimizer.step()

多标签分类的评估策略

在单标签分类中,准确率(Accuracy)是最常用的评估指标。然而,在多标签分类中,仅仅计算准确率是不足够的,甚至可能产生误导。例如,如果一个模型总是预测所有类别都不存在,而实际只有少数类别存在,那么它的准确率可能很高(因为它正确预测了大量不存在的类别),但它对存在类别的识别能力却很差。

因此,我们需要采用更全面的指标来评估多标签分类模型的性能。

1. 从Logits到预测结果

智谱AI输入法
智谱AI输入法

智谱AI推出的AI语音输入法

下载

在计算评估指标之前,我们需要将模型的Logits输出转换为具体的类别预测。这通常通过对Logits应用Sigmoid函数,然后设定一个阈值(例如0.5)来完成。

# 假设 logits 是模型输出的Logits
# 例如:logits = torch.randn(batch_size, num_classes)

# 1. 应用Sigmoid函数将Logits转换为概率
probabilities = torch.sigmoid(logits)

# 2. 设定阈值,将概率转换为二元预测 (0或1)
threshold = 0.5
predictions = (probabilities > threshold).float()

print(f"Probabilities:\n{probabilities}")
print(f"Predictions (threshold={threshold}):\n{predictions}")

2. 常用评估指标

以下是多标签分类中常用的评估指标:

  • 精确率(Precision)、召回率(Recall)、F1分数(F1-score):

    • 精确率: 预测为正例的样本中,有多少是真正的正例。
    • 召回率: 实际为正例的样本中,有多少被模型预测为正例。
    • F1分数: 精确率和召回率的调和平均值,综合衡量模型的性能。
    • 这些指标可以针对每个类别独立计算(Per-class),也可以通过微平均(Micro-average)或宏平均(Macro-average)来汇总所有类别的结果。
      • Micro-average: 汇总所有类别的TP、FP、FN后再计算总体的Precision、Recall、F1。它更侧重于样本级别的性能,受样本数量较多的类别影响较大。
      • Macro-average: 先计算每个类别的Precision、Recall、F1,然后取这些值的平均。它给予每个类别相同的权重,不受类别样本数量不平衡的影响。
  • 平均精确率(Average Precision, AP)与平均精确率均值(mean Average Precision, mAP):

    • AP: 衡量单个类别在不同召回率下的精确率表现,通常通过计算PR曲线下面积获得。AP值越高,说明模型在该类别上的性能越好。
    • mAP: 对所有类别的AP值取平均,是衡量多标签分类模型整体性能的一个非常重要的指标,尤其在目标检测等领域广泛使用。
  • Jaccard Index (IoU) / Jaccard Similarity Score:

    • 衡量预测集合与真实标签集合的相似度,计算公式为交集大小除以并集大小。对于多标签分类,可以计算每个样本的预测标签集合与真实标签集合的Jaccard相似度,然后取平均。
  • Hamming Loss:

    • 衡量预测结果与真实标签不一致的标签比例。Hamming Loss越低越好。

3. 使用torchmetrics或scikit-learn进行评估

在PyTorch生态中,torchmetrics库提供了丰富的多标签评估指标。scikit-learn也是一个非常强大的工具,可以在CPU上方便地进行评估。

torchmetrics示例 (推荐用于PyTorch训练循环中):

import torch
from torchmetrics.classification import MultilabelF1Score, MultilabelAveragePrecision

# 假设真实标签和预测概率
# num_classes = 7
num_labels = 7
num_samples = 10
target_labels = torch.randint(0, 2, (num_samples, num_labels)).float() # 真实标签 (0或1)
predicted_probs = torch.rand(num_samples, num_labels) # 模型输出的概率 (经过Sigmoid)

# 或者直接使用Logits,让metrics内部处理Sigmoid
predicted_logits = torch.randn(num_samples, num_labels)


# 实例化F1分数,可以指定 average 方式 (e.g., 'micro', 'macro', 'weighted', 'none')
# MultilabelF1Score 期望输入是 (preds, target)
# preds: 概率 (float) 或 原始logits (float)
# target: 真实标签 (int 或 float, 0/1)
f1_score_micro = MultilabelF1Score(num_labels=num_labels, average='micro', validate_args=False)
f1_score_macro = MultilabelF1Score(num_labels=num_labels, average='macro', validate_args=False)

# 计算F1分数
# 注意:MultilabelF1Score 可以直接接收概率或logits,但通常建议给概率
f1_micro_val = f1_score_micro(predicted_probs, target_labels.long()) # target_labels需要是long类型对于F1Score
f1_macro_val = f1_score_macro(predicted_probs, target_labels.long())


print(f"Micro F1 Score: {f1_micro_val.item()}")
print(f"Macro F1 Score: {f1_macro_val.item()}")

# 实例化mAP
# MultilabelAveragePrecision 期望输入是 (preds, target)
# preds: 概率 (float)
# target: 真实标签 (int 或 float, 0/1)
map_metric = MultilabelAveragePrecision(num_labels=num_labels, validate_args=False)

# 计算mAP
map_val = map_metric(predicted_probs, target_labels.long()) # target_labels需要是long类型对于mAP

print(f"mAP: {map_val.item()}")

# 如果输入是logits,可以这样处理 (MultilabelF1Score 和 MultilabelAveragePrecision 默认不带sigmoid,需要手动处理或确保其内部处理了)
# 对于MultilabelF1Score和MultilabelAveragePrecision,当输入是概率时,通常需要手动将target转换为long
# 如果输入是logits,则需要确保metrics内部会执行sigmoid
# 更好的做法是,统一将模型输出转换为概率再传入metrics
probs_from_logits = torch.sigmoid(predicted_logits)
f1_micro_val_logits = f1_score_micro(probs_from_logits, target_labels.long())
map_val_logits = map_metric(probs_from_logits, target_labels.long())
print(f"Micro F1 Score (from logits): {f1_micro_val_logits.item()}")
print(f"mAP (from logits): {map_val_logits.item()}")

总结与注意事项

将ViT从单标签多分类转换为多标签分类,关键在于以下几点:

  1. 模型输出层: 确保模型的最终全连接层输出与类别数量相等的Logits,并且不带Softmax激活。
  2. 损失函数: 使用torch.nn.BCEWithLogitsLoss作为损失函数,它能独立处理每个类别的预测。
  3. 标签格式: 真实标签应为多热编码的浮点型张量,形状与模型输出的Logits相同。
  4. 评估指标: 采用适合多标签任务的评估指标,如Micro/Macro F1分数、mAP、Jaccard Index等,并结合torchmetrics或scikit-learn等库进行高效计算。
  5. 阈值选择: 在将概率转换为二元预测时,阈值的选择对最终的精确率和召回率有显著影响,可能需要通过验证集进行调优。
  6. 类别不平衡: 在多标签任务中,类别不平衡问题可能更复杂(例如,某些标签总是同时出现,某些标签非常稀有)。可以考虑使用加权BCE损失、Focal Loss或采样策略来缓解。

通过以上调整,您的Vision Transformer模型将能够有效地处理多标签图像分类任务。

相关专题

更多
css中float用法
css中float用法

css中float属性允许元素脱离文档流并沿其父元素边缘排列,用于创建并排列、对齐文本图像、浮动菜单边栏和重叠元素。想了解更多float的相关内容,可以阅读本专题下面的文章。

554

2024.04.28

C++中int、float和double的区别
C++中int、float和double的区别

本专题整合了c++中int和double的区别,阅读专题下面的文章了解更多详细内容。

95

2025.10.23

class在c语言中的意思
class在c语言中的意思

在C语言中,"class" 是一个关键字,用于定义一个类。想了解更多class的相关内容,可以阅读本专题下面的文章。

460

2024.01.03

python中class的含义
python中class的含义

本专题整合了python中class的相关内容,阅读专题下面的文章了解更多详细内容。

7

2025.12.06

golang map内存释放
golang map内存释放

本专题整合了golang map内存相关教程,阅读专题下面的文章了解更多相关内容。

73

2025.09.05

golang map相关教程
golang map相关教程

本专题整合了golang map相关教程,阅读专题下面的文章了解更多详细内容。

25

2025.11.16

golang map原理
golang map原理

本专题整合了golang map相关内容,阅读专题下面的文章了解更多详细内容。

37

2025.11.17

java判断map相关教程
java判断map相关教程

本专题整合了java判断map相关教程,阅读专题下面的文章了解更多详细内容。

32

2025.11.27

php源码安装教程大全
php源码安装教程大全

本专题整合了php源码安装教程,阅读专题下面的文章了解更多详细内容。

192

2025.12.31

热门下载

更多
网站特效
/
网站源码
/
网站素材
/
前端模板

精品课程

更多
相关推荐
/
热门推荐
/
最新课程
Git 教程
Git 教程

共21课时 | 2.4万人学习

Git版本控制工具
Git版本控制工具

共8课时 | 1.5万人学习

Git中文开发手册
Git中文开发手册

共0课时 | 0人学习

关于我们 免责申明 举报中心 意见反馈 讲师合作 广告合作 最新更新
php中文网:公益在线php培训,帮助PHP学习者快速成长!
关注服务号 技术交流群
PHP中文网订阅号
每天精选资源文章推送

Copyright 2014-2026 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号