• 正文
    • RNN(Recurrent Neural Network)
    • LSTM(Long Short-Term Memory)
    • GRU(Gated Recurrent Unit)
  • 相關(guān)推薦
申請入駐 產(chǎn)業(yè)圖譜

大話循環(huán)神經(jīng)網(wǎng)絡(luò)RNN、LSTM、GRU

2024/12/20
5480
加入交流群
掃碼加入
獲取工程師必備禮包
參與熱點資訊討論

CNN主要處理圖像信息,主要應(yīng)用于計算機視覺領(lǐng)域。

RNN(recurrent neural network)主要就是處理序列數(shù)據(jù)(自然語言處理、語音識別、視頻分類、文本情感分析、翻譯),核心就是它能保持過去的記憶。但RNN有著梯度消失問題,專家之后接著改進為LSTM和GRU結(jié)構(gòu)。下面將用通俗的語言分別詳細介紹。

在這里插入圖片描述

機器學(xué)習(xí)深度學(xué)習(xí)不太熟的童鞋可以先康康這幾篇哦:

《無廢話的機器學(xué)習(xí)筆記》《一文極速理解深度學(xué)習(xí)》、《一文總結(jié)經(jīng)典卷積神經(jīng)網(wǎng)絡(luò)CNN模型》

RNN(Recurrent Neural Network)

RNN中的處理單元,中間綠色就是過去處理的結(jié)果,左邊第一幅圖就是正常的DNN,不會保存過去的結(jié)果,右邊的圖都有一個特點,輸出的結(jié)果(藍色)不僅取決于當(dāng)前的輸入,還取決于過去的輸入!不同的單元能賦予RNN不同的能力,如 多對一就能對一串文本進行分類,輸出離散值,比如根據(jù)你的言語判斷你今天高不高興。

在這里插入圖片描述

RNN中保存著過去的信息,輸出取決于現(xiàn)在與過去。如果大伙學(xué)過數(shù)電,這就是狀態(tài)機!這玩意跟觸發(fā)器很像。

在這里插入圖片描述

有個很重要的點:

這個權(quán)重fw沿時間維度是一致的,權(quán)值共享。就像CNN中一個卷積核在卷積過程中參數(shù)一致。所以CNN是沿著空間維度權(quán)值共享;RNN是沿著時間維度權(quán)值共享。

在這里插入圖片描述

具體來說有三個權(quán)重,過去與現(xiàn)在各一個權(quán)重,加起來再來一個權(quán)重。 它們都沿著時間維度權(quán)值共享。不然每個時間都不一樣權(quán)重,參數(shù)量會很恐怖。

在這里插入圖片描述

整體的計算圖(多對多):

每次的輸出y可以與標(biāo)簽值構(gòu)建損失函數(shù),這樣就跟之前DNN訓(xùn)練模型思想一樣,訓(xùn)練3套權(quán)重使損失函數(shù)不斷下降至滿意。

在這里插入圖片描述

反向傳播要沿時間反向傳回去(backpropagation through time,BPTT)

Forward through entire sequence to compute loss, then backward through entire sequence to compute gradient.

在這里插入圖片描述

這樣會有問題,就是一下子把全部序列弄進來求梯度,運算量非常大。實際我們會將大序列分成等長的小序列,分別處理:

在這里插入圖片描述

不同隱含層中不同的值負(fù)責(zé)的是語料庫中不同的特征,所以隱含狀態(tài)的個數(shù)越多,模型就越能捕獲文本的底層特征。

下面來看一個例子:字符級語言模型(由上文預(yù)測下文):

我想輸入hell,然后模型預(yù)測我會輸出o;或者我輸入h,模型輸出e,我再輸入e,模型輸出l…

首先對h,e,l,o進行獨熱編碼,然后構(gòu)建模型進行訓(xùn)練。

在這里插入圖片描述

在這里插入圖片描述

輸入莎士比亞的劇本,讓模型自己生成劇本,訓(xùn)練過程:

在這里插入圖片描述

輸入latex文本,讓模型自己生成內(nèi)容,公式寫得有模有樣的,就不知道對不對:

在這里插入圖片描述

當(dāng)然輸入代碼,模型也會輸出代碼。所以現(xiàn)在火熱的Chatgpt的本質(zhì)就是RNN。

對于圖像描述,專家會先用CNN對圖像進行特征抽取(編碼器),然后將特征再輸入RNN進行圖像描述(解碼器。

在這里插入圖片描述

還可以結(jié)合注意力機制(Image captioning with attention):

在這里插入圖片描述

普通堆疊的RNN一旦隱含層變多變深,反向傳播時就很容易出現(xiàn)梯度消失/爆炸。

子豪兄總結(jié)得非常好,以最簡單的三層網(wǎng)絡(luò)來看,對于輸出的O3可以列出損失函數(shù)L3,對L3進行求偏導(dǎo),分別對輸出權(quán)重w0,輸入權(quán)重wx,過去權(quán)重ws進行求導(dǎo)。我們發(fā)現(xiàn)對w0求偏導(dǎo)會很輕松。

但是,由于鏈?zhǔn)椒▌t(chain rule),對輸入權(quán)重wx和過去權(quán)重ws求偏導(dǎo)就會很痛苦。在表達式里,對于越是前面層的鏈?zhǔn)角髮?dǎo),乘積項越多,所以很容易梯度消失/爆炸,梯度消失占大多數(shù)。

在這里插入圖片描述

LSTM(Long Short-Term Memory)

長短時記憶神經(jīng)網(wǎng)絡(luò)(LSTM) 應(yīng)運而生!

LSTM既有長期記憶也有短期記憶,包括遺忘門、輸入門、輸出門、長期記憶單元。右圖紅色函數(shù)是sigmoid,藍色函數(shù)是tanh。

在這里插入圖片描述

C是長期記憶,h是短期記憶。

所以當(dāng)前輸出ht是由短期記憶產(chǎn)生的。

在這里插入圖片描述

我們看到長期記憶那條線是貫通的,且只有乘加操作。

在這里插入圖片描述

LSTM算法詳解:

下面幾個圖完美解釋了:

在這里插入圖片描述
在這里插入圖片描述
在這里插入圖片描述
在這里插入圖片描述

所以總共有四個權(quán)重:Wf、Wi、Wc、Wo,當(dāng)然還有它們對應(yīng)的偏置項。

整體過程可以概括為:遺忘、更新、輸出。(更新包括先選擇保留信息,再更新最新記憶。)

原論文中的圖也非常形象:

在這里插入圖片描述

在這里插入圖片描述

現(xiàn)在反向傳播求偏導(dǎo)就舒服了

在這里插入圖片描述
在這里插入圖片描述

在這里插入圖片描述

GRU(Gated Recurrent Unit)

GRU也能很好解決梯度消失問題,結(jié)構(gòu)簡單一點,主要就是重置門更新門。

在這里插入圖片描述

在這里插入圖片描述

GRU與LSTM對比:

  1. 參數(shù)數(shù)量:GRU的參數(shù)數(shù)量相對LSTM來說更少,因為它將LSTM中的輸入門、遺忘門和輸出門合并為了一個門控單元,從而減少了模型參數(shù)的數(shù)量。
    LSTM中有三個門控單元:輸入門、遺忘門和輸出門。每個門控單元都有自己的權(quán)重矩陣和偏置向量。這些門控單元負(fù)責(zé)控制歷史信息的流入和流出。
    GRU中只有兩個門控單元:更新門和重置門。它們共享一個權(quán)重矩陣和一個偏置向量。更新門控制當(dāng)前輸入和上一時刻的輸出對當(dāng)前時刻的輸出的影響,而重置門則控制上一時刻的輸出對當(dāng)前時刻的影響。
  2. 計算速度:由于參數(shù)數(shù)量更少,GRU的計算速度相對LSTM更快。
  3. 長序列建模:在處理長序列數(shù)據(jù)時,LSTM更加優(yōu)秀。由于LSTM中引入了一個長期記憶單元(Cell State),使得它可以更好地處理長序列中的梯度消失和梯度爆炸問題。

GRU適用于:

處理簡單序列數(shù)據(jù),如語言模型和文本生成等任務(wù)。
處理序列數(shù)據(jù)時需要快速訓(xùn)練和推斷的任務(wù),如實時語音識別、語音合成等。
對計算資源有限的場景,如嵌入式設(shè)備、移動設(shè)備等。

LSTM適用于:

處理復(fù)雜序列數(shù)據(jù),如長文本分類、機器翻譯、語音識別等任務(wù)。
處理需要長時依賴關(guān)系的序列數(shù)據(jù),如長文本、長語音等。
對準(zhǔn)確度要求較高的場景,如股票預(yù)測、醫(yī)學(xué)診斷等。

公式總結(jié):

在這里插入圖片描述

相關(guān)推薦