在本文中,您將學習如何使用 Pytorch FSDP 微調 LLAMA 2 70B 以及相關的最佳實踐。 在此過程中,我們將主要使用 Hugging Face Transformers、Accelerate 和 TRL 庫。 我們還將向您展示如何在 Slurm 中使用 Accelerate。
完全分片資料並行性 (FSDP) 是一種訓練正規化,其中優化器狀態、梯度和模型引數都跨裝置分片。 向前傳播時,每個 FSDP 單元執行 All Gather 以獲取完整權重,然後使用它們進行計算,並在計算後丟棄其他裝置的分片。 接下來是反向傳播,然後是損耗計算。 反向傳播時,每個 FSDP 單元執行 All Gather 操作以獲取完整權重,並執行計算以獲取本地批次的梯度。 這些梯度通過減少散點在裝置上進行平均和分片,以便每個裝置都可以更新其相應分片的引數。 有關 PyTorch FSDP 的詳細資訊,請參閱此部落格文章:使用 PyTorch 完全分片資料並行加速大型模型訓練。
FSDP 工作流。
節點數:2個,至少1個節點。
每個節點的 GPU 數:8
GPU型別:A100
GPU視訊記憶體:80GB
節點內互聯:nvlink
每個節點的記憶體:1TB
每個節點的 CPU 核心數:96
節點之間的互連:AWS 的 Elastic Fabric Adapter (EFA)。
在嘗試使用 FSDP 微調 LLAMA 2 70B 時,我們遇到了三個主要挑戰:
FSDP 在對模型進行分片之前載入整個預訓練模型。 這意味著節點內的每個程序(即 Rank)都會載入整個 LLAMA-70B 模型,因此需要 7048 GB 的 2TB CPU 記憶體,其中 4 是每個引數的位元組數,8 是每個節點的 GPU 數。 這會導致 CPU 記憶體不足,進而導致程序終止。 用full_state_dict
儲存完整的中間檢查點並將其解除安裝到秩為 0 的 CPU 記憶體需要花費大量時間,並且通常會導致 NCCL 超時錯誤,因為通訊庫需要無限期掛起才能完成儲存。 但是,完全關閉此選項並不是乙個好主意,因為在訓練結束時,我們需要儲存完整的模型狀態字典,而不是 fsdp 樣式分片的狀態字典。 我們需要提高速度並減少記憶體使用,以加快訓練速度並節省計算成本。 在下文中,我們將討論如何解決這些挑戰,並最終微調 70b 模型!
首先列出重現結果所需的所有資源:
庫:包含啟用 Flash Attention v2 的熱補丁。
FSDP 配置檔案:
slurm 啟動指令碼 -launch.slurm
型:meta-llama/llama-2-70b-chat-hf
資料集:Smangrul code-chat-assistant-v1(Lima 和 Guanaco 資料集的混合,已轉換為訓練所需的格式)。
首先按照以下步驟安裝 Flash Attention v2。 然後,安裝最新的 PyTorch Nightly (cuda 11)。8)。接下來,根據此檔案安裝其餘的相關軟體。 在本文中,我們將從 main 分支安裝