下載本文檔
版權(quán)說明:本文檔由用戶提供并上傳,收益歸屬內(nèi)容提供方,若內(nèi)容存在侵權(quán),請進(jìn)行舉報或認(rèn)領(lǐng)
文檔簡介
1、搭建網(wǎng)絡(luò)思路:? mnist數(shù)據(jù)集下載與處理?搭建神經(jīng)網(wǎng)絡(luò)?損失函數(shù)與優(yōu)化函數(shù)的選擇?定義可視化函數(shù)(不是本文的重點)?模型訓(xùn)練以及結(jié)果圖像繪制?模型測試結(jié)果可視化與結(jié)果分析代碼實現(xiàn)1 .導(dǎo)入需要用到的安裝包與模塊:這里我么直接使用 Pytorch自帶的mnist數(shù)據(jù)集,所以要用到torchvision安裝包。from torch.utils import datafrom torchvision import datasets, transformsfrom torch.nn import Sequential, Conv2d, ReLU, MaxPool2d, Linear, CrossE
2、ntropyLossfrom torch.optim import Adamfrom torch.autograd import Variablefrom matplotlib import cm, pyplot as pltfrom sklearn.manifold import TSNE #將最后一層的輸出降維處理,方便結(jié)果的 可視化import torch.nn as nnimport torchimport os2 .mnist數(shù)據(jù)集下載與處理:直接下載下來的mnist數(shù)據(jù)集并不符合輸入給網(wǎng)絡(luò)的格式,所以我們需要對 數(shù)據(jù)集做一些處理,以滿足模型訓(xùn)練的要求。由于小編電腦配置較低,測試 數(shù)
3、據(jù)量為1800張圖片。LR = 0.01BATCH_SIZE = 32DOWNLOAD_MINIST = FalseEPOCHS = 1HAS_SK = True#創(chuàng)建數(shù)據(jù)集目錄if not os.listdir(./model_datas/MnistDatasets/):DOWNLOAD_MINIST= True#訓(xùn)練數(shù)據(jù)集處理train_data = datasets.MNIST(root=./model_datas/MnistDatasets/,峻據(jù)集需要保存的路徑train=True, # true:訓(xùn)練數(shù)據(jù) False:測試數(shù)據(jù)transform=transforms.ToTenso
4、r(), #等 PIL.Image 或 numpy.ndarray數(shù)據(jù)轉(zhuǎn)換為 形狀為 torch.FloatTensor (C x H x W),同時歸一化download=DOWNLOAD_MINIST # 下載數(shù)據(jù)集,如果有,就直接加載,如果 如果沒有,就去下載)#將輸入圖片的shap漿化為(50, 1,28, 28)data_loader = data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True,)test_data = datasets.MNIST(root=./model_datas/Mnist
5、Datasets/”,train=False, )test_x = torch.unsqueeze(test_data.test_data:1800, dim=1).type(torch.FloatTensor) / 255. #將測試數(shù)據(jù)的輸入shape由原泵的(10000, 28, 28)轉(zhuǎn)化為( 10000, 1, 28, 28)并且將輸出數(shù)據(jù)歸一化test_y = test_data.test_labels:1800 #1 取測試數(shù)據(jù)的標(biāo)簽值3 .搭建神經(jīng)向絡(luò):將模型數(shù)據(jù)預(yù)處理完之后,接下來才是我們這篇文章的重點,搭建一個卷積 神經(jīng)網(wǎng)絡(luò)。輸入層數(shù)據(jù)格式:32, , 1,28 , 28第
6、一卷積層:卷積層:shape 32 , 1 , 28 , 28 -32, 16 , 28 ,28這里的卷積層只是將原來的通道數(shù)由1變?yōu)?6,圖片大小沒有變化激活層:這里我們選擇 Relu作為激活函數(shù)對卷積層輸出的數(shù)據(jù)進(jìn)行去線性。池化層:shape 32, 16, 28, 28-32, 16, 14, 14池化層不改變通道數(shù),只改變圖片的大小第二卷積層:卷積層:shape 32, 16,14, 14 -32,32,14,14這里的卷積層只是將上一層池化層輸出的通道數(shù)由16變?yōu)?2,圖片大小沒有變化激活層:選擇Relu作為激活函數(shù)對卷積層輸出的數(shù)據(jù)進(jìn)行去線性。池化層:shape 32, 32, 1
7、4,14-32, 32, 7, 7池化層不改變圖片通道數(shù),只改變圖片的大小全連接層:數(shù)據(jù)降維處理:32,32,7,7-32, 32*7*7首先將第二卷積層池化層的輸出降維處理為全連接層可以接收的數(shù)據(jù)格式全連接層:shape 32, 32*7*7 - 32, 10將降維處理完的數(shù)據(jù)格式作為全連接層的輸入,使用一個32*7*7, 10 的權(quán)重,將結(jié)果輸出;具體的網(wǎng)絡(luò)結(jié)果如下圖所示:class CNN(nn.Module):def _init_(self):super(CNN, self)._init_()self.conv_1 = Sequential 第一卷積層Conv2d( in_channe
8、ls=1, out_channels=16, kernel_size=5, stride=1, padding=2),ReLU(),MaxP0012d(kernel_size=2)# 最大池化層 ).self.conv_2 = Sequential(# 第二卷積層Conv2d( in_channels=16, out_channels=32, kernel_size=5, stride=1, padding=2),ReLU(),MaxP0012d(kernel_size=2)# 最大池化層 ).self.out = Linear(32 * 7 * 7, 10) # 全連接層def forwar
9、d(self, x): conv_1 = self.conv_1(x) conv_2 = self.conv_2(conv_1) fcl_input = conv_2.view(conv_2.size(0), -1) fcl_output = self.out(fcl_input) return fcl_output, fcl_inputcnn = CNN()4 .損失函數(shù)與優(yōu)化函數(shù)搭建完神經(jīng)網(wǎng)絡(luò)之后,接下來就需要選擇目標(biāo)函數(shù)和優(yōu)化函數(shù)了。對于分類問題,一般使用交叉嫡(Cross-Entropy )作為損失進(jìn)行最小優(yōu) 而對于優(yōu)化函數(shù),這里選擇 Adam是在于經(jīng)過偏置校正后,每一次迭代學(xué)習(xí) 率都
10、有個確定范圍,使得參數(shù)比較平穩(wěn)。optimizer= Adam(cnn.parameters(), lr=LR)loss_func= CrossEntropyLoss()5 .定義可視化函數(shù):這里定義這個函數(shù)主要是為了將最終的分結(jié)果展示出來,方便理解。本文的 重點是神經(jīng)網(wǎng)絡(luò),對于結(jié)果展示代碼看不懂就直接忽視。def plot_with_labels(lowDWeights, labels):主要是將測試結(jié)果進(jìn)行可視化。:param lowDWeights:進(jìn)行降維處理后的測試集標(biāo)簽值:param labels:原測試集的標(biāo)簽值 plt.cla()X, Y = lowDWeights:, 0,
11、lowDWeights:, 1for x, y, s in zip(X, Y, labels):c = cm.rainbow(int(255 * s / 9)plt.text(x, y, s, backgroundcolor=c, fontsize=9)plt.xlim(X.min(), X.max()plt.ylim(Y.min(), Y.max()plt.title(Visualize last layer)plt.savefig(./model_datas/pictures/cnn.png)plt.pause(0.1)5 .模型訓(xùn)練以及結(jié)果圖像繪制:將之前處理好的mnist數(shù)據(jù)集輸入到搭
12、建好的神經(jīng)網(wǎng)絡(luò)里面進(jìn)行模型訓(xùn)練以及繪制測試結(jié)果的可視化。for epoch in range(EPOCHS):for step, (b_x, b_y) in enumerate(data_loader): batch_x, batch_y = Variable(b_x), Variable(b_y) pred_y = cnn(batch_x)0loss = loss_func(pred_y, batch_y)optimizer.zero_grad()loss.backward()optimizer.step()if step % 50 = 0:test_output, last_layer =
13、 cnn(Variable(test_x)test_y_pred = torch.max(test_output, 1)1.data.squeeze().numpy()accuracy = float(test_y_pred = test_y.numpy().astype(int).sum()/ float(test_y.size(0)print(epoch:, epoch, |step:, step, |train_loss:, loss.data.numpy(), ”|test_acuracy:%.2f % accuracy)if HAS_SK: #數(shù)據(jù)降維并且可視化tsne = TSNE
14、(perplexity=30, n_components=2, init=pca, n_iter=5000) plot_only = 500low_dim_embs = tsne.fit_transform(last_layer.data.numpy():plot_only, :) #將輸出數(shù)據(jù)陳線-labels = test_y.numpy():plot_onlyplot_with_labels(low_dim_embs, labels) #調(diào)用此函數(shù),將數(shù)據(jù)傳入并且展示出來plt.ioff()# print 10 predictions from test datatest_output,
15、 _ = cnn(test_x:10)pred_y = torch.max(test_output, 1)1.data.squeeze().numpy()print(pred_y, predict_number)print(test_y:10.numpy(), real number)6 .模型測試結(jié)果可視化與結(jié)果分析:為了更形象說明搭建的神經(jīng)網(wǎng)絡(luò)對mnist數(shù)據(jù)集分類結(jié)果準(zhǔn)確性,這里將每次訓(xùn)練與測試結(jié)果的對應(yīng)的準(zhǔn)確度和損失值的變化、分類結(jié)果的圖像兩 個方面來分析。tpULI i pI S L CJJ:T trdin_epoch:0step:1490Itrain_loss :0.1665232
16、 |test_acuracy:Q,97epoch:0Istep:145。Itrain_loss::L03144965 |test_acuracy:0.97epoch:e15tep:159G1 train_loss ::9109162366t est ac jracy ;0.97epoch:0Istep:1559Itrain,1 05S:9.03336678 1f est_actiracy :0.97epoch:0step:1000|train_loss:0t09505453test_acuracy;0.97epoch:0I step:1650I train_loss :0.QS195481 |test adj racyr 97epoch:0Istep:17091 trainloss:).0lb5379 i1 test_acuracy:,97圖一:損失與準(zhǔn)確度_Vhsualj2e ltft layerIHIIT圖二:測試結(jié)果的可視化圖 圖一我們可以看到,模型訓(xùn)練初期,訓(xùn)練數(shù)據(jù)的損失值和測試的準(zhǔn)確度都不能滿足我們的要求,
溫馨提示
- 1. 本站所有資源如無特殊說明,都需要本地電腦安裝OFFICE2007和PDF閱讀器。圖紙軟件為CAD,CAXA,PROE,UG,SolidWorks等.壓縮文件請下載最新的WinRAR軟件解壓。
- 2. 本站的文檔不包含任何第三方提供的附件圖紙等,如果需要附件,請聯(lián)系上傳者。文件的所有權(quán)益歸上傳用戶所有。
- 3. 本站RAR壓縮包中若帶圖紙,網(wǎng)頁內(nèi)容里面會有圖紙預(yù)覽,若沒有圖紙預(yù)覽就沒有圖紙。
- 4. 未經(jīng)權(quán)益所有人同意不得將文件中的內(nèi)容挪作商業(yè)或盈利用途。
- 5. 人人文庫網(wǎng)僅提供信息存儲空間,僅對用戶上傳內(nèi)容的表現(xiàn)方式做保護(hù)處理,對用戶上傳分享的文檔內(nèi)容本身不做任何修改或編輯,并不能對任何下載內(nèi)容負(fù)責(zé)。
- 6. 下載文件中如有侵權(quán)或不適當(dāng)內(nèi)容,請與我們聯(lián)系,我們立即糾正。
- 7. 本站不保證下載資源的準(zhǔn)確性、安全性和完整性, 同時也不承擔(dān)用戶因使用這些下載資源對自己和他人造成任何形式的傷害或損失。
最新文檔
- 2025道路運輸合同書范本
- 2025年東營貨運資格證題庫下載安裝
- 2025年廈門貨運從業(yè)資格模擬考試
- 2025年山東貨運從業(yè)資格證年考試題及答案
- 2025勞務(wù)派遣合同范本3
- 2025飛蟲綜合治理合同
- 2025購買公司股份合同
- 2025供電設(shè)備智能維護(hù)合同
- 水利實習(xí)報告范文
- 2025標(biāo)準(zhǔn)技術(shù)轉(zhuǎn)讓合同范文
- 2024年安徽省初中學(xué)業(yè)水平考試中考數(shù)學(xué)試卷(真題+答案)
- 2024年臨汾翼城縣就業(yè)困難高校畢業(yè)生公益性崗招考聘用70人重點基礎(chǔ)提升難、易點模擬試題(共500題)附帶答案詳解
- 護(hù)理中級職稱競聘
- 現(xiàn)代控制理論智慧樹知到期末考試答案章節(jié)答案2024年長安大學(xué)
- 國際公法學(xué)馬工程全套教學(xué)課件
- 數(shù)據(jù)安全重要數(shù)據(jù)風(fēng)險評估報告
- 汽車維修合伙利益分配協(xié)議書
- MOOC 普通地質(zhì)學(xué)-西南石油大學(xué) 中國大學(xué)慕課答案
- 醫(yī)療機(jī)構(gòu)感染預(yù)防與控制基本制度試題附有答案
- 生產(chǎn)部文員年終總結(jié)
- 半導(dǎo)體芯片知識講座
評論
0/150
提交評論