阅读( 202 )次

   评论(0)条

wudixx

Softmax分类器学习笔记(二)


最近晚上都有时间,每天晚上都坚持听课,还好老师讲的速度刚刚好,我能跟着一边听讲一边把代码敲进来实现。

有疑问的地方反复听了几遍,get到了老师的意思。不说了,直接上我敲的代码。


import sys
import argparse
import os
import tensorflow as tf
import tensorflow.examples.tutorials.mnist.input_data as input_data

os.environ['TF_CPP_MIN_LOG-LEVEL'] = '2'


def main(_):
    # 开始计算图
    with tf.Graph().as_default():
        # 输入占位符:
        with tf.name_scope('Input'):
            X = tf.placeholder(tf.float32, shape=[None, 784], name='X')
            Y_true = tf.placeholder(tf.float32, shape=[None, 10], name='Y_True')
        # Inference:前向预测
        with tf.name_scope('Inference'):
            # 参数
            W = tf.Variable(tf.zeros([784, 10]), name='Weight')
            b = tf.Variable(tf.zeros([10]), name='Bias')
            logits = tf.add(tf.matmul(X, W), b)
            # softmax把logits变成概率分布
            with tf.name_scope('Softmax'):
                Y_pred = tf.nn.softmax(logits=logits)

        # Loss:定义损失节点
        with tf.name_scope('Loss'):
            TrainLoss = tf.reduce_mean(
                -tf.reduce_sum(Y_true * tf.log(Y_pred), axis=1))
        # 定义训练节点
        with tf.name_scope('Train'):
            # Optimizer
            Optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01)
            # Train:
            TrainStep = Optimizer.minimize(TrainLoss)
        #定义评估节点:
        with tf.name_scope('Evaluate'):
            correct_prediction = tf.equal(tf.argmax(Y_pred,1), tf.argmax(Y_true,1))
            accuray = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
        # 初始化:
        InitOp = tf.global_variables_initializer()
        # save graph
        writer = tf.summary.FileWriter(logdir='logs/minst_softmax', graph=tf.get_default_graph())
        writer.close()
        print('开始运行计算图')
        # 加载数据
        mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)

        # 声明一个交互式会话
        sess = tf.InteractiveSession()
        # 初始化所有变量:W,b
        sess.run(InitOp)
        #按批次训练
        for step in range(10000):
            batch_xs, batch_ys = mnist.train.next_batch(100)
            _,train_loss = sess.run([TrainStep, TrainLoss],
                                    feed_dict={X: batch_xs, Y_true: batch_ys})
            print("train step: ", step, ", train_loss:", train_loss)
        accuray_score = sess.run(accuray,feed_dict={X:mnist.test.images,
                                                    Y_true:mnist.test.labels})
        print("模型正确率:",accuray_score)
#调用main()函数
if __name__ == '__main__':
     parser = argparse.ArgumentParser()
     parser.add_argument('--data_dir', type=str,
                         default='MNIST_data/',
                         help='数据存放路径')
     FLAGS, unparsed = parser.parse_known_args()
     tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)


 收藏 (0)  打赏  点赞 (0)

 ©2017 studyai.com 版权所有

关于我们