博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
Tensorflow神经网络框架(第三课 3-2MNIST数据集分类简单版本,手写数字识别)
阅读量:4299 次
发布时间:2019-05-27

本文共 4289 字,大约阅读时间需要 14 分钟。

3-2MNIST数据集分类简单版本,手写数字识别Last Checkpoint: 31 分钟前(unsaved changes)
Logout
In [1]:
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
In [2]:
mnist = input_data.read_data_sets("MNIST_data",one_hot=True)
#每个批次的大小
batch_size = 100
#计算一共有多少个批次
n_batch = mnist.train.num_examples // batch_size
#定义两个placeholder
x = tf.placeholder(tf.float32,[None,784])#样本
y = tf.placeholder(tf.float32,[None,10])#标签
#创建一个简单的神经网络,后面优化可以用多个隐藏层
W = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([10]))
prediction = tf.nn.softmax(tf.matmul(x,W)+b)
#定义二次代价函数
loss = tf.reduce_mean(tf.square(y-prediction))
#使用梯度下降法
train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss)
#初始化变量
init = tf.global_variables_initializer()
#结果存放在一个布尔型列表中
correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))
#求准确率
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
with tf.Session() as sess:
sess.run(init)
for epoch in range(21):
for batch in range(n_batch):
batch_xs,batch_ys = mnist.train.next_batch(batch_size)
sess.run(train_step, feed_dict = {
x:batch_xs,y:batch_ys})
 
acc = sess.run(accuracy,feed_dict={
x:mnist.test.images,y:mnist.test.labels})
print("Iter" + str(epoch)+ ",Testing Accuracy " + str(acc))
WARNING:tensorflow:From 
:1: read_data_sets (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.Instructions for updating:Please use alternatives such as official/mnist/dataset.py from tensorflow/models.WARNING:tensorflow:From D:\Program Files\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:260: maybe_download (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.Instructions for updating:Please write your own downloading logic.WARNING:tensorflow:From D:\Program Files\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:262: extract_images (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.Instructions for updating:Please use tf.data to implement this functionality.Extracting MNIST_data\train-images-idx3-ubyte.gzWARNING:tensorflow:From D:\Program Files\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:267: extract_labels (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.Instructions for updating:Please use tf.data to implement this functionality.Extracting MNIST_data\train-labels-idx1-ubyte.gzWARNING:tensorflow:From D:\Program Files\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:110: dense_to_one_hot (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.Instructions for updating:Please use tf.one_hot on tensors.Extracting MNIST_data\t10k-images-idx3-ubyte.gzExtracting MNIST_data\t10k-labels-idx1-ubyte.gzWARNING:tensorflow:From D:\Program Files\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:290: DataSet.__init__ (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.Instructions for updating:Please use alternatives such as official/mnist/dataset.py from tensorflow/models.Iter0,Testing Accuracy 0.8331Iter1,Testing Accuracy 0.8699Iter2,Testing Accuracy 0.8813Iter3,Testing Accuracy 0.8877Iter4,Testing Accuracy 0.8935Iter5,Testing Accuracy 0.8972Iter6,Testing Accuracy 0.9004Iter7,Testing Accuracy 0.9015Iter8,Testing Accuracy 0.9038Iter9,Testing Accuracy 0.9053Iter10,Testing Accuracy 0.906Iter11,Testing Accuracy 0.9074Iter12,Testing Accuracy 0.9088Iter13,Testing Accuracy 0.909Iter14,Testing Accuracy 0.9099Iter15,Testing Accuracy 0.9112Iter16,Testing Accuracy 0.9117Iter17,Testing Accuracy 0.9123Iter18,Testing Accuracy 0.9128Iter19,Testing Accuracy 0.9134Iter20,Testing Accuracy 0.9135
In [ ]:
#可以看到最后的预测正确率为:91.356,其实这个程序有许多地方可以进行优化:1增加隐含层,2迭代多次

转载地址:http://zosws.baihongyu.com/

你可能感兴趣的文章
<a4j:keeyAlive>的英文介绍
查看>>
关于list对象的转化问题
查看>>
VOPO对象介绍
查看>>
suse创建的虚拟机,修改ip地址
查看>>
linux的挂载的问题,重启后就挂载就没有了
查看>>
docker原始镜像启动容器并创建Apache服务器实现反向代理
查看>>
docker容器秒死的解决办法
查看>>
管理网&业务网的一些笔记
查看>>
openstack报错解决一
查看>>
openstack报错解决二
查看>>
linux source命令
查看>>
openstack报错解决三
查看>>
乙未年年终总结
查看>>
子网掩码
查看>>
第一天上班没精神
查看>>
启动eclipse报错:Failed to load the JNI shared library
查看>>
eclipse安装插件的两种方式在线和离线
查看>>
linux下源的相关笔记(suse)
查看>>
linux系统分区文件系统划分札记
查看>>
Linux(SUSE 12)安装Tomcat
查看>>