1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54
| # LeNet-5(经典卷积神经网络)复现 import time from tensorflow.keras.datasets import mnist # import numpy as np import matplotlib.pyplot as plt from tensorflow.keras.models import Sequential from tensorflow.keras.layers import Dense from tensorflow.keras.optimizers import SGD from tensorflow.keras.utils import to_categorical from tensorflow.keras.layers import Conv2D # 二维卷积 from tensorflow.keras.layers import AveragePooling2D # 二维池化 from tensorflow.keras.layers import Flatten # 展平后,接入全连接层
(X_train, Y_train), (X_test, Y_test) = mnist.load_data()
X_train = X_train.reshape(60000, 28, 28, 1) / 255.0 X_test = X_test.reshape(10000, 28, 28, 1) / 255.0
Y_train = to_categorical(Y_train, 10) Y_test = to_categorical(Y_test, 10)
time_start = time.time()
model = Sequential() # filters表示过滤器(卷积核)数目,kernel_size表示卷积核大小,strides表示步长,padding:使用valid或same表示卷积的两种方式, model.add( Conv2D(filters=6, kernel_size=(5, 5), strides=(1, 1), input_shape=(28, 28, 1), padding='valid', activation='relu')) # 池化层,为2*2 model.add(AveragePooling2D(pool_size=(2, 2))) # 不用输入input_shape,Keras会自动计算输入 model.add(Conv2D(filters=16, kernel_size=(5, 5), strides=(1, 1), padding='valid', activation='relu')) model.add(AveragePooling2D(pool_size=(2, 2))) # 展平后送入全连接层Dense model.add(Flatten()) model.add(Dense(units=120, activation='relu')) model.add(Dense(units=84, activation='relu')) model.add(Dense(units=10, activation='softmax')) # 对于最终结果为多种类,一般采用softmax激活函数来进行激活,效果较好。 # 送入训练
# 采用多分类交叉熵代价函数categorical_crossentropy,效果较好;(之前采用的是均方误差函数) model.compile(loss='categorical_crossentropy', optimizer=SGD(lr=0.05), metrics=['accuracy']) # epochs表示训练次数,batch_size表示每次取出数据进行计算的数目 model.fit(X_train, Y_train, epochs=50, batch_size=1024)
time_end = time.time() # 评估测试表 print('time cost', time_end - time_start, 's') print(model.get_weights()) loss, accuracy = model.evaluate(X_test, Y_test) print("loss" + str(loss)) print("accuracy" + str(accuracy))
# 最终accuracy大概为98%
|