0

0

如何在 PyTorch 中正确实现 K 折交叉验证

心靈之曲

心靈之曲

发布时间:2026-01-02 12:32:33

|

213人浏览过

|

来源于php中文网

原创

如何在 PyTorch 中正确实现 K 折交叉验证

k 折交叉验证要求每折使用不同的训练/验证数据划分,因此 dataloader 必须在每折内动态构建——不能复用外部定义的固定 dataloader;否则将失去交叉验证的意义。本文详解如何重构数据加载逻辑以支持 k 折验证。

在 PyTorch 中实现 K 折交叉验证(K-Fold Cross Validation)时,核心原则是:每一折(fold)必须对应一组独立、互斥的数据划分。这意味着 train_dataset 和 val_dataset 需在每次 fold 迭代中重新生成,进而构建对应的 DataLoader。你当前代码中将 get_train_utils() 和 get_val_utils() 定义在 fold 外部,本质上创建的是全局固定划分的 dataloader,这与 K 折验证的目标相悖——它无法评估模型在不同子集上的泛化能力。

✅ 正确做法:将数据划分与 dataloader 构建移入 fold 循环

你需要使用 torch.utils.data.Subset 或 sklearn.model_selection.KFold 配合原始完整数据集(如 torch.utils.data.Dataset 子类实例),在每折中生成新的子集,并据此构建 dataloader。以下是关键重构步骤:

CodeSquire
CodeSquire

AI代码编写助手,把你的想法变成代码

下载

1. 准备完整数据集(不划分)

# 在 main_worker 或主流程开头一次性加载完整数据集
full_dataset = YourCustomDataset(
    root_dir=opt.data_root,
    transform=...  # 基础预处理(不包含 fold 特定增强)
)

2. 使用 KFold 划分索引(推荐 sklearn)

from sklearn.model_selection import KFold

kf = KFold(n_splits=opt.n_folds, shuffle=True, random_state=42)
fold_results = []

for fold, (train_idx, val_idx) in enumerate(kf.split(full_dataset), 1):
    print(f"\n=== Starting Fold {fold}/{opt.n_folds} ===")

    # 创建 fold-specific 子集
    train_subset = torch.utils.data.Subset(full_dataset, train_idx)
    val_subset   = torch.utils.data.Subset(full_dataset, val_idx)

    # ✅ 每折独立构建 dataloader(含 fold-specific augmentations)
    train_loader = torch.utils.data.DataLoader(
        train_subset,
        batch_size=opt.batch_size,
        shuffle=True,
        num_workers=opt.n_threads,
        pin_memory=True,
        worker_init_fn=worker_init_fn
    )

    val_loader = torch.utils.data.DataLoader(
        val_subset,
        batch_size=opt.batch_size // opt.n_val_samples,
        shuffle=False,
        num_workers=opt.n_threads,
        pin_memory=True,
        worker_init_fn=worker_init_fn
    )

    # ✅ 每折独立初始化模型、优化器、调度器(避免参数污染)
    model = build_model(opt)  # 重置模型权重
    optimizer = torch.optim.Adam(model.parameters(), lr=opt.learning_rate)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer, milestones=opt.multistep_milestones, gamma=0.1
    )

    # ✅ 执行该 fold 的完整训练+验证循环
    best_val_acc = 0.0
    for epoch in range(1, opt.n_epochs + 1):
        train_epoch(epoch, train_loader, model, criterion, optimizer, 
                   opt.device, train_logger, tb_writer)

        val_acc = val_epoch(epoch, val_loader, model, criterion, 
                           opt.device, val_logger, tb_writer)

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            # 可选:保存本 fold 最佳模型
            torch.save(model.state_dict(), f"{opt.result_path}/best_fold_{fold}.pth")

    fold_results.append(best_val_acc)
    print(f"Fold {fold} best validation accuracy: {best_val_acc:.4f}")

⚠️ 关键注意事项:

  • 不要复用外部 dataloader:get_train_utils() 和 get_val_utils() 应被重构为接受 dataset 和 indices 参数的工厂函数,而非全局调用。
  • 模型需重置:每折必须初始化新模型(或严格 reset 权重),否则前一折的参数会污染后续 fold。
  • 日志与检查点隔离:为避免混淆,建议为每折创建独立日志目录(如 result_path/fold_1/),或使用 fold 标签区分 TensorBoard 曲线。
  • 分布式训练适配:若启用 DistributedSampler,需确保 train_sampler 基于当前 train_subset 构建,并在每个 epoch 调用 set_epoch()。
  • 数据增强一致性:训练增强可保留,但验证增强应保持确定性(如禁用随机裁剪)。

✅ 总结

K 折交叉验证不是“在固定数据上跑多次训练”,而是在 K 组不同数据划分上评估模型稳定性。因此,dataset → subset → dataloader 的链条必须在每折内完成。强行复用外部 dataloader 不仅技术上不可行(索引错位、采样冲突),更会彻底破坏交叉验证的统计意义。重构后,你将获得 K 个独立验证指标,最终取均值与标准差,这才是可信的模型性能评估。

相关专题

更多
什么是分布式
什么是分布式

分布式是一种计算和数据处理的方式,将计算任务或数据分散到多个计算机或节点中进行处理。本专题为大家提供分布式相关的文章、下载、课程内容,供大家免费下载体验。

319

2023.08.11

分布式和微服务的区别
分布式和微服务的区别

分布式和微服务的区别在定义和概念、设计思想、粒度和复杂性、服务边界和自治性、技术栈和部署方式等。本专题为大家提供分布式和微服务相关的文章、下载、课程内容,供大家免费下载体验。

229

2023.10.07

什么是分布式
什么是分布式

分布式是一种计算和数据处理的方式,将计算任务或数据分散到多个计算机或节点中进行处理。本专题为大家提供分布式相关的文章、下载、课程内容,供大家免费下载体验。

319

2023.08.11

分布式和微服务的区别
分布式和微服务的区别

分布式和微服务的区别在定义和概念、设计思想、粒度和复杂性、服务边界和自治性、技术栈和部署方式等。本专题为大家提供分布式和微服务相关的文章、下载、课程内容,供大家免费下载体验。

229

2023.10.07

pytorch是干嘛的
pytorch是干嘛的

pytorch是一个基于python的深度学习框架,提供以下主要功能:动态图计算,提供灵活性。强大的张量操作,实现高效处理。自动微分,简化梯度计算。预构建的神经网络模块,简化模型构建。各种优化器,用于性能优化。想了解更多pytorch的相关内容,可以阅读本专题下面的文章。

428

2024.05.29

Python AI机器学习PyTorch教程_Python怎么用PyTorch和TensorFlow做机器学习
Python AI机器学习PyTorch教程_Python怎么用PyTorch和TensorFlow做机器学习

PyTorch 是一种用于构建深度学习模型的功能完备框架,是一种通常用于图像识别和语言处理等应用程序的机器学习。 使用Python 编写,因此对于大多数机器学习开发者而言,学习和使用起来相对简单。 PyTorch 的独特之处在于,它完全支持GPU,并且使用反向模式自动微分技术,因此可以动态修改计算图形。

9

2025.12.22

php源码安装教程大全
php源码安装教程大全

本专题整合了php源码安装教程,阅读专题下面的文章了解更多详细内容。

65

2025.12.31

php网站源码教程大全
php网站源码教程大全

本专题整合了php网站源码相关教程,阅读专题下面的文章了解更多详细内容。

43

2025.12.31

视频文件格式
视频文件格式

本专题整合了视频文件格式相关内容,阅读专题下面的文章了解更多详细内容。

35

2025.12.31

热门下载

更多
网站特效
/
网站源码
/
网站素材
/
前端模板

精品课程

更多
相关推荐
/
热门推荐
/
最新课程
Java 教程
Java 教程

共578课时 | 40.7万人学习

国外Web开发全栈课程全集
国外Web开发全栈课程全集

共12课时 | 0.9万人学习

关于我们 免责申明 举报中心 意见反馈 讲师合作 广告合作 最新更新
php中文网:公益在线php培训,帮助PHP学习者快速成长!
关注服务号 技术交流群
PHP中文网订阅号
每天精选资源文章推送

Copyright 2014-2026 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号