來自凹面非寺廟量子位元的豐富色彩 | qbitai
懸空的心終於死了:
被尊為變形金剛挑戰者的曼巴已被ICLR正式拒絕。
在“最初被拒絕”後,在學術界引起軒然,並變成了“待定”狀態)。
但是,這個“頂級流”的受歡迎程度會受到怎樣的影響呢?
不,對它的新流行解讀(作者:傑克·庫克,牛津網際網絡研究所研究員,曾在麻省理工學院、英偉達、Microsoft工作過),才剛剛誕生,至今仍被網友點讚和收藏。
有些人甚至稱其為:
迄今為止最好的年度最佳(解釋)。
我們也不能錯過它。
以下是原文的精髓:
背景:S4 架構。
MAMBA 的架構基於 S4,這是一種狀態空間模型 (SSM) 架構。
主要思想如下:
在較高層次上,s4 學習如何通過中間狀態 h(t) 將輸入 x(t) 對映到輸出 y(t)。
在這裡,由於 SSM 旨在很好地處理連續資料,例如音訊、感測器資料和影象,因此 x、y、t 是 x 的函式。
S4 通過三個連續引數矩陣 A、B 和 C 將它們互連,形式為以下兩個方程(MAMBA** 中的 1a 和 1b)。
由於在實踐中,我們通常處理離散資料(如文字),因此這要求我們通過使用特殊的第四個引數δ將連續引數 a、b 和 c 轉換為離散引數和 c 來離散化 SSM。
離散化後,我們可以通過這兩個方程(曼巴**中的 2a 和 2b)來表示 SSM:。
這些方程形成遞迴,類似於我們在 RNN 網路中看到的。 在每個步驟 t 中,我們將上乙個時間步長 ht 1 的隱藏狀態與當前輸入 xt 相結合,以建立乙個新的隱藏狀態 ht。
下圖顯示了當句子中的下乙個單詞(我們在“我的名字是傑克”之後有“和”)時它是如何工作的。
基於此,我們基本上可以使用 S4 作為迴圈神經網路 RNN 來一次生成乙個代幣。
然而,S4 真正酷的地方在於,你也可以將它用作卷積神經網路 CNN。
在上面的例子中,當我們擴充套件前面的離散方程來嘗試計算 h3 時會發生什麼?
為簡單起見,我們假設 x 1 = 0。
計算完 h3 後,我們可以將其代入 y3 的方程中,以**下乙個單詞:
現在,請注意,y3 實際上可以計算為點積,其中正確的向量是我們的輸入 x:
由於引數 和 c 是常數,我們可以預先計算左向量並將其儲存為卷積核。 這為我們提供了一種使用卷積計算 y 的簡單方法,如以下兩個方程所示(曼巴**中的 3a 和 3b):
重要提示:這些迴圈和卷積形式(作者稱之為“RNN 模式”和“CNN 模式”)在數學上是等價的。
因此,s4 可以根據您希望它執行的操作進行變形,而輸出沒有任何區別。
當然,CNN模式更適合訓練,RNN模式更適合推理。
第乙個主要思想:可選性。
在本節中,我們將討論 MAMBA 引入的第乙個主要思想:可選性。 讓我們回顧一下定義 S4 離散形式的兩個方程:
請注意,在 S4 中,我們的離散引數和 c 是常數。 但是,MAMBA使這些引數因輸入而異。 因此,我們最終得到如下結果:
MAMBA 作者(GU 和 DAO)認為,選擇性或輸入依賴性對於許多任務都很重要。
本文的科普作者認為,由於 S4 沒有選擇性,因此它被迫以完全相同的方式處理輸入的所有部分。
然而,當我們面對乙個句子時,其中一些詞不可避免地比其他詞更重要。
例如,“我想點乙個漢堡包。“這句話。
如果沒有選擇性,S4 將在每個單詞上花費相同的“精力”:
但是,如果它是乙個試圖對句子的意圖進行分類的模型,它可能希望更多地“關注”秩序、漢堡包,而不是想要。
如下圖所示,通過使模型引數成為輸入的函式,MAMBA可以“關注”輸入中對手頭任務更重要的部分。
然而,選擇性給我們帶來了乙個問題。 讓我們回想一下之前計算的卷積核。
在 S4 中,我們可以預先計算核心,儲存它,然後將其乘以輸入 x。
這很好,因為離散引數和 c 是常數。 但同樣,在 MAMBA 中,這些矩陣會根據輸入而變化! 因此,我們無法預先計算,也無法使用 CNN 模式來訓練我們的模型。 如果我們想要選擇性,我們必須在 RNN 模式下進行訓練。 這樣做的方法是刪除等式 3b 以獲得“戲劇性效果”。
但這給 MAMBA 的作者帶來了乙個問題:RNN 模式的訓練速度非常慢。
假設我們正在使用 1000 個令牌的序列來訓練我們的模型:
CNN 本質上是計算其核心和輸入向量之間的點積,並且這些計算可以併行執行。 相比之下,RNN 需要按順序更新其隱藏狀態 1000 次。
這導致MAMBA的作者提出了他們的第二個好主意。
第二個主要思想:無卷積的快速訓練。
MAMBA 可以在 RNN 模式下非常非常快速地訓練。
在某些時候,它們的遞迴與掃瞄演算法(也稱為字首總和、字首總和)非常相似。
要計算字首總和,我們需要獲取乙個輸入陣列 [x1,x2,...xn] 並返回乙個輸出陣列,其中每個元素是項及其前置元素的總和。
換句話說,輸出的第乙個元素將是 x1,第二個元素將是 [x1+[x2,依此類推。 舉個例子:
現在讓我們畫出在 RNN 模式下更新 MAMBA 隱藏狀態的過程。
等等......如果我們必須將字首 sum 形式化,我們可以將其寫成以下等式:
該方程形成遞迴:在每一步中,我們通過將先前儲存的值新增到當前輸入中來計算新值。 現在,我們再來看看更新後的曼巴隱藏狀態迴圈。
這兩個方程式真的非常相似,是的!
最酷的部分來了:雖然字首的計算在本質上似乎是連續的,但我們實際上有高效的並行演算法來完成這項任務!
在下圖中,我們可以看到並行字首和演算法的執行情況,其中每條垂直線代表陣列中的乙個專案。
花點時間複習一下這個演算法:
選擇任意垂直線,從頂部開始,然後向下移動,將每個新增項追溯到陣列的前幾項。 當您到達底部時,您應該會在行的左側看到所有專案的總和。
例如,在開頭將第乙個元素新增到第二個元素後,陣列的第三個元素在結尾處接收第二個元素的附加值。 因此,當並行掃瞄完成時,第三個元素包含第乙個元素。
第乙個、第二個和第三個元素的總和。
如果我們在沒有並行性的單個執行緒中執行演算法,則所需的時間將比僅按順序新增值所需的時間更長。 但 GPU 擁有大量處理器,可以進行高度平行計算。 因此,我們可以在大約 o(logn) 時間內計算出這個字首和/或掃瞄操作!
因此,MAMBA的作者意識到,如果他們想在RNN模式下進行有效的訓練,他們可能能夠使用並行掃瞄。
但是由於目前沒有 pytorch 的掃瞄實現,因此 mamba 的作者自己編寫了乙個 - 但是,結果並不好。
在上圖中,您可以看到他們基於 pytorch 的掃瞄實現(綠色)總是比 FlashAttention-2(藍色)慢,後者是可用的“精確注意力”的最快實現。
儘管當序列長度為 128000 個令牌時,掃瞄似乎在執行時趕上了,但它仍然耗盡了記憶體。
為了使MAMBA實用,它需要更快。 這讓 MAMBA 的作者看到了 DAO 之前在 FlashAttention 上的工作,從而解決了這個問題。
由於篇幅所限,我們在原文中省略了flashattention的原理介紹(評論:flashattention),感興趣的朋友可以檢視原文flashattention原文**,或者我們之前的原理介紹文章之一。
back to mamba
同樣,基於之前的比較圖表。
事實證明,如果在計算掃瞄時採用相同的記憶體感知平鋪方法,則可以大大加快速度。
通過此優化,MAMBA(紅色)現在在所有序列長度上都比 FlashAttention-2(藍色)快。
這些結果表明,MAMBA在速度方面是實用的,執行速度甚至比最快的Transformer還要快。 但這與語言建模有什麼關係嗎?
MAMBA的作者在涉及語言、基因組學和音訊的一系列序列建模任務中評估了MAMBA。
結果看起來很酷:MAMBA在對人類基因組計畫的DNA和Piano**資料集的音訊進行建模時,已經建立了最先進的效能。
然而,讓很多人興奮的是語言任務的結果。 很多關於曼巴的討論都集中在下圖上:
我們可以看到模型大小向右增加,語言建模效能進一步向下提高。
這意味著最好的模型應該在左邊:小(因此速度快)並且非常擅長建模語言。
由於 MAMBA 的作者是學者,負擔不起讓數千個 GPU 來訓練 GPT-4 大小的模型,因此該實驗是通過訓練一堆較小的模型(大約 125M 比 1)。3b 引數)。
如上圖所示,結果看起來非常有希望。 與其他類似規模的模型相比,MAMBA似乎是建模語言的最佳選擇。
為什麼被“拒絕了兩次”。
在寫作的最後,本文作者再次表達了對曼巴被拒絕的遺憾:
我真的認為 Mamba 以一種非常獨特和有趣的方式在語言建模方面進行了創新。 不幸的是,一些審稿人不同意。
從最近的拒絕來看,審稿人拒絕的原因之一與“兩個重要的基準評估”有關。
首先是缺乏LRA(Long Range Arena)評估,這是公認的長序列建模基準。
其次,僅僅將混淆評估作為主要評價指標是不夠的,因為低混淆不一定與發電效能呈正相關。
最後的總體思路是增加額外的實驗。
對於這個結果,也有網友再次評論:
這只能說明一篇文章是否被會議接受,與其對社群的價值貢獻無關。 因為前者很容易依靠極少數人的判斷。
事實上,當談到公認的好**會過去的事實時,MAMBA真的不是第乙個。
大約十年前,Word2vec 也被 ICLR “醜陋地拒絕”,但去年,它還贏得了 Neurips 時間測試。
你認為時間會“證明”曼巴的合理性嗎?
原文解讀:參考鏈結:[1]。