0

0

Pytorch入门之mnist分类实例

不言

不言

发布时间:2018-04-14 16:00:57

|

4839人浏览过

|

来源于php中文网

原创

这篇文章主要为大家详细介绍了pytorch入门之mnist分类实例,具有一定的参考价值,感兴趣的小伙伴们可以参考一下

本文实例为大家分享了Pytorch入门之mnist分类的具体代码,供大家参考,具体内容如下

#!/usr/bin/env python
# -*- coding: utf-8 -*-
__author__ = 'denny'
__time__ = '2017-9-9 9:03'

import torch
import torchvision
from torch.autograd import Variable
import torch.utils.data.dataloader as Data

train_data = torchvision.datasets.MNIST(
 './mnist', train=True, transform=torchvision.transforms.ToTensor(), download=True
)
test_data = torchvision.datasets.MNIST(
 './mnist', train=False, transform=torchvision.transforms.ToTensor()
)
print("train_data:", train_data.train_data.size())
print("train_labels:", train_data.train_labels.size())
print("test_data:", test_data.test_data.size())

train_loader = Data.DataLoader(dataset=train_data, batch_size=64, shuffle=True)
test_loader = Data.DataLoader(dataset=test_data, batch_size=64)


class Net(torch.nn.Module):
 def __init__(self):
 super(Net, self).__init__()
 self.conv1 = torch.nn.Sequential(
  torch.nn.Conv2d(1, 32, 3, 1, 1),
  torch.nn.ReLU(),
  torch.nn.MaxPool2d(2))
 self.conv2 = torch.nn.Sequential(
  torch.nn.Conv2d(32, 64, 3, 1, 1),
  torch.nn.ReLU(),
  torch.nn.MaxPool2d(2)
 )
 self.conv3 = torch.nn.Sequential(
  torch.nn.Conv2d(64, 64, 3, 1, 1),
  torch.nn.ReLU(),
  torch.nn.MaxPool2d(2)
 )
 self.dense = torch.nn.Sequential(
  torch.nn.Linear(64 * 3 * 3, 128),
  torch.nn.ReLU(),
  torch.nn.Linear(128, 10)
 )

 def forward(self, x):
 conv1_out = self.conv1(x)
 conv2_out = self.conv2(conv1_out)
 conv3_out = self.conv3(conv2_out)
 res = conv3_out.view(conv3_out.size(0), -1)
 out = self.dense(res)
 return out


model = Net()
print(model)

optimizer = torch.optim.Adam(model.parameters())
loss_func = torch.nn.CrossEntropyLoss()

for epoch in range(10):
 print('epoch {}'.format(epoch + 1))
 # training-----------------------------
 train_loss = 0.
 train_acc = 0.
 for batch_x, batch_y in train_loader:
 batch_x, batch_y = Variable(batch_x), Variable(batch_y)
 out = model(batch_x)
 loss = loss_func(out, batch_y)
 train_loss += loss.data[0]
 pred = torch.max(out, 1)[1]
 train_correct = (pred == batch_y).sum()
 train_acc += train_correct.data[0]
 optimizer.zero_grad()
 loss.backward()
 optimizer.step()
 print('Train Loss: {:.6f}, Acc: {:.6f}'.format(train_loss / (len(
 train_data)), train_acc / (len(train_data))))

 # evaluation--------------------------------
 model.eval()
 eval_loss = 0.
 eval_acc = 0.
 for batch_x, batch_y in test_loader:
 batch_x, batch_y = Variable(batch_x, volatile=True), Variable(batch_y, volatile=True)
 out = model(batch_x)
 loss = loss_func(out, batch_y)
 eval_loss += loss.data[0]
 pred = torch.max(out, 1)[1]
 num_correct = (pred == batch_y).sum()
 eval_acc += num_correct.data[0]
 print('Test Loss: {:.6f}, Acc: {:.6f}'.format(eval_loss / (len(
 test_data)), eval_acc / (len(test_data))))

相关推荐:

python如何读取二进制mnist实例详解

知了追踪
知了追踪

AI智能信息助手,智能追踪你的兴趣资讯

下载

一篇不错的Python入门教程_python

相关专题

更多
php源码安装教程大全
php源码安装教程大全

本专题整合了php源码安装教程,阅读专题下面的文章了解更多详细内容。

3

2025.12.31

php网站源码教程大全
php网站源码教程大全

本专题整合了php网站源码相关教程,阅读专题下面的文章了解更多详细内容。

1

2025.12.31

视频文件格式
视频文件格式

本专题整合了视频文件格式相关内容,阅读专题下面的文章了解更多详细内容。

4

2025.12.31

不受国内限制的浏览器大全
不受国内限制的浏览器大全

想找真正自由、无限制的上网体验?本合集精选2025年最开放、隐私强、访问无阻的浏览器App,涵盖Tor、Brave、Via、X浏览器、Mullvad等高自由度工具。支持自定义搜索引擎、广告拦截、隐身模式及全球网站无障碍访问,部分更具备防追踪、去谷歌化、双内核切换等高级功能。无论日常浏览、隐私保护还是突破地域限制,总有一款适合你!

6

2025.12.31

出现404解决方法大全
出现404解决方法大全

本专题整合了404错误解决方法大全,阅读专题下面的文章了解更多详细内容。

30

2025.12.31

html5怎么播放视频
html5怎么播放视频

想让网页流畅播放视频?本合集详解HTML5视频播放核心方法!涵盖<video>标签基础用法、多格式兼容(MP4/WebM/OGV)、自定义播放控件、响应式适配及常见浏览器兼容问题解决方案。无需插件,纯前端实现高清视频嵌入,助你快速打造现代化网页视频体验。

3

2025.12.31

关闭win10系统自动更新教程大全
关闭win10系统自动更新教程大全

本专题整合了关闭win10系统自动更新教程大全,阅读专题下面的文章了解更多详细内容。

2

2025.12.31

阻止电脑自动安装软件教程
阻止电脑自动安装软件教程

本专题整合了阻止电脑自动安装软件教程,阅读专题下面的文章了解更多详细教程。

3

2025.12.31

html5怎么使用
html5怎么使用

想快速上手HTML5开发?本合集为你整理最实用的HTML5使用指南!涵盖HTML5基础语法、主流框架(如Bootstrap、Vue、React)集成方法,以及无需安装、直接在线编辑运行的平台推荐(如CodePen、JSFiddle)。无论你是新手还是进阶开发者,都能轻松掌握HTML5网页制作、响应式布局与交互功能开发,零配置开启高效前端编程之旅!

2

2025.12.31

热门下载

更多
网站特效
/
网站源码
/
网站素材
/
前端模板

精品课程

更多
相关推荐
/
热门推荐
/
最新课程
Node.js 教程
Node.js 教程

共57课时 | 7.7万人学习

CSS3 教程
CSS3 教程

共18课时 | 4.1万人学习

Rust 教程
Rust 教程

共28课时 | 4万人学习

关于我们 免责申明 举报中心 意见反馈 讲师合作 广告合作 最新更新
php中文网:公益在线php培训,帮助PHP学习者快速成长!
关注服务号 技术交流群
PHP中文网订阅号
每天精选资源文章推送

Copyright 2014-2025 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号