0

0

FRN——小样本学习SOTA模型

P粉084495128

P粉084495128

发布时间:2025-07-22 14:03:56

|

540人浏览过

|

来源于php中文网

原创

本文介绍CVPR2021论文提出的小样本学习模型FRN,其将分类问题归为特征重构问题,以闭合解形式从支持样本回归查询样本特征,性能与效率更优。文中展示了基于PaddlePaddle复现的FRN在mini-ImageNet上的精度,还介绍了数据集、环境依赖、快速开始步骤、代码结构及模型信息等内容。

☞☞☞AI 智能聊天, 问答助手, AI 智能搜索, 免费无限量使用 DeepSeek R1 模型☜☜☜

frn——小样本学习sota模型 - php中文网

FRN——小样本学习SOTA模型

一、论文概述

论文Few-Shot Classification with Feature Map Reconstruction Networks是顶会CVPR2021上发表的一种小样本学习经典方法。该方法在小样本学习的benchmark上依然具有最佳的性能指标,是该领域的重要方法。

FRN将小样本分类问题归结为潜在空间中的特征重构问题。作者认为,通过支持样本重构查询样本特征的能力,决定了查询样本的所属类别。作者在小样本学习中引入了一种新的机制,以闭合解的形式从支持样本特征直接向查询样本特征做回归,无需引入新的模块或者大规模的训练参数。上述方法得到的模型(FRN),相比先前的其他方法,无论在计算效率上还是性能表现上都更有优势。FRN在四个细粒度数据集上展现出实质性提升。在通用的粗粒度数据集mini-ImageNet和tiered-ImageNet上,也达到了SOTA指标。

下图展示了FRN的基本工作流程。

FRN——小样本学习SOTA模型 - php中文网        

二、复现精度

基于paddlepaddle深度学习框架,对文献算法进行复现后,本项目在mini-ImageNet上达到的测试精度,如下表所示。

task 本项目精度 参考文献精度
5-Way 1-Shot
66.45+-0.19
5-Way 5-Shot
82.83+-0.13

模型训练包括了两个过程,首先是模型预训练,按照典型分类网络的训练过程,将整个训练集送入backbone进行训练;然后是微调过程,按照episode training的训练范式,配置为20-Way 5-Shot方式进行微调训练。这两个训练过程的训练超参数设置如下:

(1)预训练过程

超参数名 设置值
lr 0.1
gamma 0.1
epoch 350
milestones 200 300
batch_size 512

(2)微调训练过程

超参数名 设置值
lr 1e-3
gamma 0.1
epoch 150
train_n_episode 1000
milestones 70 120
train_n_way 20
n_shot 5

三、数据集

miniImageNet数据集节选自ImageNet数据集。 DeepMind团队首次将miniImageNet数据集用于小样本学习研究,从此miniImageNet成为了元学习和小样本领域的基准数据集。 关于该数据集的介绍可以参考https://blog.csdn.net/wangkaidehao/article/details/105531837

miniImageNet是由Oriol Vinyals等在Matching Networks 中首次提出的,该文献是小样本分类任务的开山制作,也是本次复现论文关于该数据集的参考文献。在Matching Networks中, 作者提出对ImageNet中的类别和样本进行抽取(参见其Appendix B),形成了一个数据子集,将其命名为miniImageNet。 划分方法,作者仅给出了一个文本文件进行说明。 Vinyals在文中指明了miniImageNet图片尺寸为84x84。因此,后续小样本领域的研究者,均是基于原始图像,在代码中进行预处理, 将图像缩放到84x84的规格。

至于如何缩放到84x84,本领域研究者各有各的方法,通常与研究者的个人理解相关,但一般对实验结果影响不大。本次文献论文原文,未能给出 miniImageNet的具体实现方法,本项目即参考领域内较为通用的预处理方法进行处理。

  • 数据集大小:
    • miniImageNet包含100类共60000张彩色图片,其中每类有600个样本。 mini-imagenet一共有2.86GB
  • 数据格式:
|- miniImagenet|  |- images/|  |  |- n0153282900000005.jpg 
|  |  |- n0153282900000006.jpg|  |  |- …|  |- train.csv|  |- test.csv|  |- val.csv
       

数据集链接:miniImagenet

四、环境依赖

  • 硬件:

    • x86 cpu
    • NVIDIA GPU
  • 框架:

    • PaddlePaddle = 2.4
  • 其他依赖项:

    玻璃钢企业网站源码1.5
    玻璃钢企业网站源码1.5

    本程序源码为asp与acc编写,并没有花哨的界面与繁琐的功能,维护简单方便,只要你有一些点点asp的基础,二次开发易如反掌。 1.功能包括产品,新闻,留言簿,招聘,下载,...是大部分中小型的企业建站的首选。本程序是免费开源,只为大家学习之用。如果用于商业,版权问题概不负责。1.采用asp+access更加适合中小企业的网站模式。 2.网站页面div+css兼容目前所有主流浏览器,ie6+,Ch

    下载
    • numpy==1.19.3
    • tqdm==4.59.0
    • Pillow==8.3.1

五、快速开始

1、解压数据集和源代码:

!unzip -n -d ./data/ ./data/data105646/mini-imagenet-sxc.zip

In [ ]
%cd /home/aistudio/
!unzip -n -d ./data/ ./data/data105646/mini-imagenet-sxc.zip
   
In [ ]
%cd /home/aistudio/work/
!unzip -o frn.zip
   
In [ ]
# 生成json文件!cp write_miniImagenet_filelist.py /home/aistudio/data/mini-imagenet-sxc/
%cd /home/aistudio/data/mini-imagenet-sxc/
!python write_miniImagenet_filelist.py
   

2、执行以下命令启动预训练:

python pretrain.py --dataset mini_imagenet --data_path /home/aistudio/data/mini-imagenet-sxc --method stl_frn --lr 1e-1 --gamma 1e-1 --epoch 350 --milestones 200 300 --batch_size 512 --val_n_episode 600 --image_size 84 --model ResNet12 --n_shot 1 --n_query 15 --gpu
       

模型开始训练,运行完毕后,训练log和模型参数保存在./checkpoints/mini_imagenet/ResNet12_stl_frn_pretrain/目录下,分别是:

best_model.pdparams  # 最优模型参数文件output.log  # 训练LOG信息
       

训练完成后,可将上述文件手动保存到其他目录下,避免被后续训练操作覆盖。

In [ ]
%cd /home/aistudio/work
!python pretrain.py --dataset mini_imagenet --data_path /home/aistudio/data/mini-imagenet-sxc --method stl_frn --lr 1e-1 --gamma 1e-1 --epoch 350 --milestones 200 300 --batch_size 512 --val_n_episode 600 --image_size 84 --model ResNet12 --n_shot 1 --n_query 15   --gpu
   

3、执行以下命令启动微调训练:

python meta_train.py --dataset mini_imagenet --data_path /home/aistudio/data/mini-imagenet-sxc --method meta_frn --lr 1e-3 --gamma 1e-1 --epoch 150 --train_n_episode 1000 --val_n_episode 600 --milestones 70 120 --image_size 84 --model ResNet12 --train_n_way 20 --val_n_way 5 --n_shot 5 --n_query 15 --gpu --pretrain_path ./checkpoints/mini_imagenet/ResNet12_stl_frn_pretrain/best_model.pdparams
       

模型开始训练,运行完毕后,训练log和模型参数保存在./checkpoints/mini_imagenet/ResNet12_meta_frn_20way_5shot_metatrain/目录下,分别是:

best_model.pdparams  # 最优模型参数文件output.log  # 训练LOG信息
       

训练完成后,可将上述文件手动保存到其他目录下,避免被后续训练操作覆盖。

In [ ]
%cd /home/aistudio/work
!python meta_train.py --dataset mini_imagenet --data_path /home/aistudio/data/mini-imagenet-sxc --method meta_frn --lr 1e-3 --gamma 1e-1 --epoch 150 --train_n_episode 1000 --val_n_episode 600 --milestones 70 120 --image_size 84 --model ResNet12 --train_n_way 20 --val_n_way 5 --n_shot 5 --n_query 15 --gpu --pretrain_path ./checkpoints/mini_imagenet/ResNet12_stl_frn_pretrain/best_model.pdparams
   

4、执行以下命令进行评估

python test.py --dataset mini_imagenet --data_path /home/aistudio/data/mini-imagenet-sxc --model ResNet12 --method meta_frn --image_size 84 --gpu --n_shot 1 --model_path ./checkpoints/mini_imagenet/ResNet12_meta_frn_20way_5shot_metatrain/best_model.pdparams --test_task_nums 1 --test_n_episode 600
       

用于评估模型在小样本任务下的精度。

In [ ]
# 5-Way 1-Shot评估%cd /home/aistudio/work
!python test.py --dataset mini_imagenet --data_path /home/aistudio/data/mini-imagenet-sxc --model ResNet12 --method meta_frn --image_size 84 --gpu --n_shot 1 --model_path ./checkpoints/mini_imagenet/ResNet12_meta_frn_20way_5shot_metatrain/best_model.pdparams --test_task_nums 1 --test_n_episode 600
   
In [ ]
# 5-Way 5-Shot评估%cd /home/aistudio/work
!python test.py --dataset mini_imagenet --data_path /home/aistudio/data/mini-imagenet-sxc --model ResNet12 --method meta_frn --image_size 84 --gpu --n_shot 5 --model_path ./checkpoints/mini_imagenet/ResNet12_meta_frn_20way_5shot_metatrain/best_model.pdparams --test_task_nums 1 --test_n_episode 600
   

六、代码结构与详细说明

6.1 代码结构

├── data                               # 数据处理相关│   ├── datamgr.py                       # data manager模块│   ├── dataset.py                       # data set模块├── methods                             # 模型相关│   ├── FRN.py                          # FRN核心算法├── network                             # backbone│   ├── conv.py                         # Conv-4和Conv-6代码实现│   ├── resnet.py                        # ResNet-12代码实现├── scripts                             # 运行工程脚本│   ├── mini_imagenet                     
│   │   ├── run_frn                     
│   │   │   ├── run_frn_metatrain.sh         # 运行微调训练│   │   │   ├── run_frn_pretrain.sh          # 运行预训练│   │   │   ├── run_frn_test.sh            # 运行测试├── meta_train.py                         # 微调训练代码├── pretrain.py                          # 预训练代码├── test.py                             # 测试代码├── utils.py                            # 公共调用函数├── wirite_miniImagenet_filelist.py             # 生成mini-ImageNet数据json文件
   

6.2 参数说明

可以在 pretrain.py 中设置训练与评估相关参数,具体如下:

参数 默认值 说明
----batch_size 128 batch size
--lr 0.05 初始学习率
--wd 5e-4 weight decay超参
--gamma 0.1 lr_scheduler衰减系数
--milestones 80, 120 达到相应epoch后,lr_scheduler开始衰减
--epoch 150 遍历数据集的迭代轮数
--gpu True 是否使用GPU进行训练
--dataset mini_imagenet 指定训练数据集
--data_path '' 指定数据集的路径
--model ResNet-12 指定采用的backbone
--val meta 指定验证方式
--train_n_way 20 小样本训练类别数
--val_n_episode 600 验证时测试多少个episode
--val_n_way 5 小样本验证类别数
--n_shot 1 给定支持样本的个数
--n_query 15 指定查询样本的个数
--num_classes 64 指定base set类别总数
--save_freq 50 指定每隔多少个epoch保存一次模型参数
--seed 0 指定随机数种子
--resume '' 指定恢复训练时加载的中间参数文件路径

6.3 训练流程

可参考快速开始章节中的描述

训练输出

执行训练开始后,将得到类似如下的输出。每一轮epoch训练将会打印当前training loss、training acc、val loss、val acc以及训练kl散度。

Epoch 0 | Batch 0/150 | Loss 4.158544
best model! save...
val loss is 0.00, val acc is 37.46
model best acc is 37.46, best acc epoch is 0
This epoch use 7.61 minutes
train loss is 3.72, train acc is 10.84
Epoch 1 | Batch 0/150 | Loss 3.052964
val loss is 0.00, val acc is 37.46
model best acc is 37.46, best acc epoch is 0
This epoch use 3.73 minutes
train loss is 2.96, train acc is 25.28
Epoch 2 | Batch 0/150 | Loss 2.588413
val loss is 0.00, val acc is 37.46
model best acc is 37.46, best acc epoch is 0
This epoch use 3.71 minutes
train loss is 2.59, train acc is 33.27
...
   

6.4 测试流程

可参考快速开始章节中的描述

此时的输出为:


   

八、模型信息

训练完成后,模型和相关LOG保存在./results/5w1s和./results/5w5s目录下。

训练和测试日志保存在results目录下。

信息 说明
发布者 hrdwsong
时间 2023.03
框架版本 Paddle 2.4
应用场景 小样本学习
支持硬件 GPU、CPU
Aistudio地址 https://aistudio.baidu.com/aistudio/projectdetail/5723600?contributionType=1&sUid=527829&shared=1&ts=1678943299939

相关专题

更多
golang map内存释放
golang map内存释放

本专题整合了golang map内存相关教程,阅读专题下面的文章了解更多相关内容。

73

2025.09.05

golang map相关教程
golang map相关教程

本专题整合了golang map相关教程,阅读专题下面的文章了解更多详细内容。

23

2025.11.16

golang map原理
golang map原理

本专题整合了golang map相关内容,阅读专题下面的文章了解更多详细内容。

36

2025.11.17

java判断map相关教程
java判断map相关教程

本专题整合了java判断map相关教程,阅读专题下面的文章了解更多详细内容。

31

2025.11.27

页面置换算法
页面置换算法

页面置换算法是操作系统中用来决定在内存中哪些页面应该被换出以便为新的页面提供空间的算法。本专题为大家提供页面置换算法的相关文章,大家可以免费体验。

378

2023.08.14

页面置换算法
页面置换算法

页面置换算法是操作系统中用来决定在内存中哪些页面应该被换出以便为新的页面提供空间的算法。本专题为大家提供页面置换算法的相关文章,大家可以免费体验。

378

2023.08.14

http与https有哪些区别
http与https有哪些区别

http与https的区别:1、协议安全性;2、连接方式;3、证书管理;4、连接状态;5、端口号;6、资源消耗;7、兼容性。本专题为大家提供相关的文章、下载、课程内容,供大家免费下载体验。

1509

2024.08.16

苹果官网入口直接访问
苹果官网入口直接访问

苹果官网直接访问入口是https://www.apple.com/cn/,该页面具备0.8秒首屏渲染、HTTP/3与Brotli加速、WebP+AVIF双格式图片、免登录浏览全参数等特性。本专题为大家提供相关的文章、下载、课程内容,供大家免费下载体验。

115

2025.12.24

拼豆图纸在线生成器
拼豆图纸在线生成器

拼豆图纸生成器有PixelBeads在线版、BeadGen和“豆图快转”;推荐通过pixelbeads.online或搜索“beadgen free online”直达官网,避开需注册的诱导页面。本专题为大家提供相关的文章、下载、课程内容,供大家免费下载体验。

84

2025.12.24

热门下载

更多
网站特效
/
网站源码
/
网站素材
/
前端模板

精品课程

更多
相关推荐
/
热门推荐
/
最新课程
最新Python教程 从入门到精通
最新Python教程 从入门到精通

共4课时 | 0.6万人学习

Django 教程
Django 教程

共28课时 | 2.4万人学习

SciPy 教程
SciPy 教程

共10课时 | 0.9万人学习

关于我们 免责申明 举报中心 意见反馈 讲师合作 广告合作 最新更新
php中文网:公益在线php培训,帮助PHP学习者快速成长!
关注服务号 技术交流群
PHP中文网订阅号
每天精选资源文章推送

Copyright 2014-2025 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号