mnist手写数字识别记录-使用keras分类

文章目录

    • 1、mnist介绍
      • 1、可视化
      • 2、各层介绍
    • 2、训练过程
      • 1、获取图像信息
      • 2.具体训练
    • 3、测试及反馈

本文是我b站上看的一个大佬讲的视频,讲的还是很通俗的,后面又自己查了点资料,加上了自己的一点理解而来的,仅代表本菜鸡理解,大佬勿喷!

b站原连接如下:https://www.bilibili.com/video/BV16g4y1z7Qu?spm_id_from=333.880.my_history.page.click

1、mnist介绍

1、可视化

首先可以看下这个网站,他是比较形象的表述了这个神经网络的这样的一个流程:https://www.cs.ryerson.ca/~aharley/vis/conv/
在这里插入图片描述
下面就可以写入数字然后查看这个各个层的显示效果了
在这里插入图片描述
点击这个层的每一个点,可以看到它由下面的层计算而来的关系
在这里插入图片描述

2、各层介绍

由上可见神经网络的训练就是各层之间形变化,最终收敛成一个分类的问题,这里以数字为例,就是最终判断是数字0-9的这样一个区间了

  • 首先还是卷积层

可以参考这个大佬写的文章图像卷积与滤波,在文中他很好的表示出了图像卷积带来的各种怪效应,另外还补充了很多新的知识点: https://blog.csdn.net/zouxy09/article/details/49080029

我在之前的opencv学习中也介绍过图像卷积相关的内容,简单的来说卷积就是把附近的一些像素点进行一些重新运算,得到新的像素矩阵,当然这里注意下面展示的是二维图,但是图像一般是多维的,可以理解就是我们前面讲的多个通道,比如RGB像素就是三个通道。
在这里插入图片描述

  • 池化(下采样层)

下采样层与卷积层比较类似,不过就是下采样层只取卷积核的最大或者平均值来计算,这就是他们的区别了,这个最大值和平均值称为最大值池化和平均值池化两个名称

一般我看网上的理解就是采用池化层可以提升神经网络的学习速度,防止过拟合等现象,所以如果能加的话还是加一加吧!

  • 全连接层

全连接层就是对前面找到的所有图像特征进行统一处理,因此这样的层一般最接近最终输出的一个层了,同时我们要知道就是当图像比较简单,但是特征比较单一,这个时候就直接使用全连接层就可以了,比如手写数字识别的这个实验,我们也可以就是看下面的图,全连接层包含了上一层的所有图像特征

在这里插入图片描述
常见的全连接层如下所示:
在这里插入图片描述

然后是损失函数什么的,因为神经网络训练的过程肯定是一个不断收敛的过程,所以就要用损失函数了,这里损失函数可以理解为就是你的老师,比如我们做错了事情,老师会告诉我们怎么去做,就是指明一个方向,损失函数也就是指明这个方向的作用了!

神经网络中,损失函数就是让我们自动去修改各层的权重,使得训练的准确度得到提高了,一般常见的好像是RELU,根据我查资料目前用的比较多的是这种的,有很多有点,缺点就在于这个函数比较容易造成一些0的输出,就是进行比较多的无效计算,就需要准备使用该函数作为损失函数的提前设置好一些参数!

2、训练过程

1、获取图像信息

这里我们用的是tensorflow+keras的方法进行的训练和测试,所以第一步还是安装环境,需要注意的就是这两个东西有个版本之间的对应,这里我找了一个对照表,我们要根据这个对照表来才行,不然装了也用不了,该表是我从一个大佬的博客上找到的,需要的对着这个来配置即可,这里的python版本不用注意,只要版本大于等于即可!
在这里插入图片描述
运行程序会有这个报错,这个报错是因为我没有配置对应的cuda还有cudnn的版本,没有配置就不能用GPU来训练了,这个不用在意,因为测试这个也就是几分钟的事情!
在这里插入图片描述
下面就可以开始查看训练集和测试集了
在这里插入图片描述
使用plt绘制出的图像效果如下所示:
在这里插入图片描述
我也把他打印出来的图像数据重新整理了下,结果如下,可以看看是不是有种似曾相识的感觉:
在这里插入图片描述
当然就是尾巴这里还带了个标签
在这里插入图片描述

2.具体训练

具体训练的流程我将一条条的进行解释:

  • 这里准备好配置我们需要训练的模型
    在这里插入图片描述

  • 这个首先是采用keras中的一种模型,下面这个Sequential是他里面的一种顺序模型,Sequential模型可以构建非常复杂的神经网络,包括全连接神经网络、卷积神经网络(CNN)、循环神经网络(RNN)、等等这里的Sequential可以理解为堆叠,通过堆叠许多层,构建出深度神经网络
    在这里插入图片描述

  • 下面就是我们对这个Sequential进行的堆叠了,units表示神经元数目,activation是激活函数,就是前面提到的损失函数
    在这里插入图片描述

  • 配置用于训练的模型,optimizer - 用来配置模型的优化器,loss - 用来配置模型的损失函数,metrics - 用来配置模型评价的方法,如accuracy、mse等。
    在这里插入图片描述

  • 这里开始正式训练网络,使用fit方法,可以显示损失loss还有精度acc,epochs是训练次数,表示训练20轮,batch_size是每次训练的数量大小,verbose是训练时输出的信息,2表示每个eoch输出一次信息。
    在这里插入图片描述

完整代码如下所示:

from keras.utils import to_categorical
from keras import models, layers, regularizers
from keras.optimizers import RMSprop
from keras.datasets import mnist
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np


(train_images, train_labels), (test_images, test_labels) = mnist.load_data()


train_images = train_images.reshape((60000, 28*28)).astype('float')
test_images = test_images.reshape((10000, 28*28)).astype('float')
train_labels = to_categorical(train_labels)
test_labels = to_categorical(test_labels)
network = models.Sequential()
network.add(layers.Dense(units=15, activation='relu', input_shape=(28*28, ),))
network.add(layers.Dense(units=10, activation='softmax'))

network.compile(optimizer=RMSprop(lr=0.001), loss='categorical_crossentropy', metrics=['accuracy'])

network.fit(train_images, train_labels, epochs=20, batch_size=128, verbose=2)

y_pre = network.predict(test_images[:5])
print(y_pre, test_labels[:5])
test_loss, test_accuracy = network.evaluate(test_images, test_labels)
print("test_loss:", test_loss, "    test_accuracy:", test_accuracy)

3、测试及反馈

最后我们也可以就是用测试集来测试模型的效果
在这里插入图片描述
根据上面提到的,会显示每轮训练时候的loss还有精度:
在这里插入图片描述
最终会打印这个模型的结果如下:
在这里插入图片描述
那么我们要优化这个模型,那不就是不断改变前面的那个堆叠的神经网络了!


http://www.niftyadmin.cn/n/980137.html

相关文章

堡垒机工作机制

网络***:之前介绍过、AAA采用C/S结构,AAA服务器上集中管理用户信息用户想访问网络资源,从而和网关建立连接,网关把用户的认证、授权、计费信息透传给radius服务器审计计费堡垒主机采用AAAA技术,为用户提供安全管理平台…

水比赛系列-HMI串口屏的使用

文章目录1、HMI串口屏介绍1、选型介绍2、开发工具3、新建工程2、HMI串口屏常用控件1、字库图片2、页面切换3、字符最大长度4、全局还是私有5、亮度调节和波特率6、变量7、定时器8、初始化事件3、串口屏数据交互1、串口发送数据2、模拟器仿真3、发送指令改变控件的值4、源码感觉…

1213: [视频]【计算几何】面积

1213: [视频]【计算几何】面积 时间限制: 1 Sec 内存限制: 128 MB 提交: 65 解决: 53 [提交][状态][讨论版] 题目描述 【题意】 在一个平面坐标系上随意画一条有n个点的封闭折线(按画线的顺序给出点的坐标),保证封闭折线的任意两条边都…

stm32-USB使用记录(一)

文章目录1、USB设备介绍2、虚拟串口进行数据收发1、在stm32F1上进行2、在stm32F4上进行3、大容量设备访问内部flash1、USB设备介绍 USB,即为通用串行总线,是一个外部总线标准,用于规范电脑与外部设备的连接和通讯。是应用在PC领域的接口技术…

1212: [视频]【计算几何】判断线段相交(跨立实验)

1212: [视频]【计算几何】判断线段相交(跨立实验) 时间限制: 1 Sec 内存限制: 128 MB 提交: 122 解决: 60 [提交][状态][讨论版] 题目描述 【题意】 有n条线段(编号为1~n),按1~n的顺序放在二维坐标系上&#xff…

zTree实现单独选中根节点中第一个节点

zTree实现单独选中根节点中第一个节点 1、实现源代码 <!DOCTYPE html> <html> <head><title>zTree实现基本树</title><meta http-equiv"content-type" content"text/html; charsetUTF-8"><link rel"styleshee…

stm32-USB使用记录(二)

文章目录1、使用外挂FLASH芯片模拟U盘2、使用sd卡模拟U盘前面的笔记中已经提到了就是可以通过STM32的USB外设来完成虚拟串口&#xff08;CDC&#xff09;还有大容量储存设备&#xff08;MSB&#xff09;的功能&#xff0c;但是对于单片机而言&#xff0c;内部的flsh总是不够的&…