聯(lián)邦學習(Federated Learning)近幾年很火,本文擬用最通俗語言解釋聯(lián)邦學習的來龍去脈。
分布式學習是什么
其實聯(lián)邦學習沒有那么新穎,任何技術(shù)都是有一條主鏈的,并不是從石頭里蹦出來。聯(lián)邦學習的前身就是分布式學習。(并不是說聯(lián)邦學習就取代分布式學習,各有適合的場景)
那先簡單介紹一下分布式學習,就是你有一個模型要訓練,而你手頭上有多個CPU或GPU,這時你就想在多個計算節(jié)點上同時訓練這個模型,來加快訓練,這就是分布式學習。一般這些節(jié)點有一個老大哥,叫server;一群小弟,叫worker。那么很直觀的有兩招,同步(多個worker訓練完模型的一部分,同時與server通信。),異步(每個worker訓練完模型的一部分后就立刻可以與server通信,從而進入下一輪計算。)
左圖為同步,右圖為異步
同步(MapReduce)
異步(Parameter Server)
現(xiàn)在一般都是異步,因為同步的代價很大,而且會有木桶效應(yīng),算力強的節(jié)點會被算力弱的節(jié)點拖住。但是異步也有它的挑戰(zhàn),比如如果一個算力強的節(jié)點已經(jīng)算了7、8輪,一個算力弱的節(jié)點才計算完1輪,這時這個比較遲更新的梯度基于的是老參數(shù),對模型有可能是有害的,即需要設(shè)備算力比較均勻。
去中心化
還有一個方案就是去中心化,簡單來說就是沒有老大哥server了,又或者個個都是server。
題外話:并行計算和分布式學習其實就一個東西,嚴格點定義就是節(jié)點與server之間通信時是有線且在一個不大的區(qū)域內(nèi),就是并行計算;如果是無線通信,一般我們就叫它分布式計算。
聯(lián)邦學習是什么
聯(lián)邦學習是為了解決一些特殊場景下分布式學習,可以理解為分布式學習的一種,帶約束的分布式學習。
比如現(xiàn)在想基于用戶的手機上產(chǎn)生的數(shù)據(jù)來訓練模型,那么多臺手機,我們自然想到分布式學習,但是,現(xiàn)在手機上的數(shù)據(jù)因為隱私問題我們不希望傳到server,手機(worker)對數(shù)據(jù)有絕對的控制權(quán);同樣的場景也有像銀行、保險公司、醫(yī)院等有數(shù)據(jù)隱私的地方,它們想?yún)⑴c分布式學習,但是還想對數(shù)據(jù)有絕對控制權(quán)。
所以聯(lián)邦學習就是允許多個參與者協(xié)同訓練共享模型,同時保持各自數(shù)據(jù)的隱私和安全。
在聯(lián)邦學習中,數(shù)據(jù)不需要集中存儲或處理,而是保留在本地。模型的更新(如梯度或模型參數(shù))在本地計算,然后被發(fā)送到中央服務(wù)器進行聚合,最終形成更新后的全局模型。
分類
根據(jù)數(shù)據(jù)分布:
- 水平聯(lián)邦學習(Horizontal FL):不同參與者的數(shù)據(jù)在特征空間上相似,但在樣本空間上不同。
- 垂直聯(lián)邦學習(Vertical FL):不同參與者的數(shù)據(jù)在樣本空間上相似,但在特征空間上不同。
根據(jù)模型更新的方式分類:
- 同步更新:所有參與者同時進行本地更新,然后進行聚合。
- 異步更新:參與者在完成本地更新后即可發(fā)送更新,無需等待其他參與者。
聯(lián)邦學習與分布式學習區(qū)別(Challenge)
聯(lián)邦學習有它自己的特點:
- Worker對數(shù)據(jù)有絕對控制權(quán),不上傳到server。(分布式學習server有時會分發(fā)數(shù)據(jù)和要求worker上傳數(shù)據(jù))
- Worker不穩(wěn)定,且異構(gòu)性強。 比如手機有時會關(guān)機;手機和ipad不一樣,且即使都是手機,型號也不一樣。(分布式學習的worker默認是24小時工作,且設(shè)備差異不大)
- Worker與Server通信代價很大,大于計算的代價,比如手機有時離server很遠,通信帶寬小,延遲高等。
- 數(shù)據(jù)不平衡不均勻,不是獨立同分布的。(數(shù)據(jù)孤島)
個人理解,分布式學習:專職;聯(lián)邦學習:兼職(不穩(wěn)定,異構(gòu)性強)
前沿方向
- 主流的一個研究問題就是怎么減少節(jié)點的通信次數(shù)(因為通信代價大嘛,所以想communication-efficient)。經(jīng)典的方法為 Federated averaging。對于worker,它收到下發(fā)的權(quán)重w后,用本地data多次地更新w:算梯度g,然后梯度下降更新w,再算梯度g,再更新w。。多次計算后,將最終的w發(fā)往server。而server收到多個worker上發(fā)的w,對它們求個平均,搞定,再下發(fā)給worker。。整個過程很簡單,與分布式學習的區(qū)別就是邊緣節(jié)點做了大量的本地計算,以犧牲計算量為代價換取更少通信次數(shù)。(一定通信次數(shù)后,F(xiàn)ederated averaging的方法能使模型收斂更快。)(許多研究已經(jīng)表明聯(lián)邦學習的數(shù)據(jù)不需要獨立同分布)
- 隱私(Privacy):梯度只是對數(shù)據(jù)做了個變換,worker上傳的梯度多少帶有數(shù)據(jù)的一些特征,所以可以拿梯度反推數(shù)據(jù)(性別、種族、年齡、疾病。。) 直觀的解決方法:為上傳的梯度加噪聲(noise),但是效果不好,因為加了noise會影響模型的訓練,noise太小也沒有什么意義,所以這方向我感覺還沒有太有效的方法。
- 提高魯棒性:如果worker中有不穩(wěn)定的/異常的節(jié)點,它可能發(fā)送錯誤或有害的信息到server,影響模型訓練。比較直觀的方法,讓server對收到的梯度或權(quán)重進行驗證(但是因為worker的異構(gòu)性,梯度或權(quán)重總是有些區(qū)別,所以這個區(qū)別正常worker和異常worker驗證方法比較難設(shè)計)?;蚴亲宻erver用別的方法更新w,比如不用求平均,用求中位數(shù),等等。
- 與其他學習范式的結(jié)合:如聯(lián)邦學習與強化學習、遷移學習的結(jié)合,以解決更廣泛的問題。
- 去中心化聯(lián)邦學習:研究去中心化的方法,其中不依賴于中央服務(wù)器,而是通過點對點的方式進行模型的更新和聚合。