給一般民眾 CME 章節

大型語言模型訓練

Stanford CME295 Lecture 4 - LLM Training

最後審閱 5 分鐘 16 張投影片 完整私人筆記
標籤13
Key Takeaways · 重點摘要
  • LLM 訓練沿用 transfer learning:先用大量資料預訓練,學到語言與程式碼的統計結構,再把預訓練權重微調到特定任務或助理行為。
  • 預訓練的核心目標是 next-token prediction;資料可包含 Common Crawl、Wikipedia、Reddit、GitHub、Stack Overflow 與多語言文字。GPT-3 約用 300 billion tokens;Llama 3 約用 15 trillion tokens。
  • FLOPs 是浮點運算量;FLOPS 是每秒浮點運算速度。LLM 預訓練常以約 $10^{25}$ FLOPs 量級描述,並大致受 token 數與參數數量共同影響。
  • Scaling laws 顯示,更多 compute、更多資料與更大模型通常改善 next-token prediction。Chinchilla law 的經驗關係指出,訓練 token 數約為參數數量 20 倍時,compute 使用較接近最適。
  • 預訓練成本高,常以數百萬至數千萬美元量級起跳,且會帶來時間與環境成本;base model 的知識只到訓練資料切斷日,形成 knowledge cutoff。
  • 訓練時 GPU 記憶體需存放權重、activations、gradients、optimizer states。資料平行化、ZeRO 與模型平行化用不同方式把資料、狀態或模型計算分散到多張 GPU。
  • FlashAttention 的重點不是近似 attention,而是用 tiling 把小區塊放到快速 SRAM 內完成計算,減少對 HBM 的讀寫;它也支援 backward pass 的 activation recomputation,通常同時節省記憶體與時間。
  • Quantization 降低數值精度以節省記憶體並提高硬體速度;mixed precision training 通常保留 FP32 權重與權重更新,用 FP16 執行 forward/backward。
  • SFT 以 input-output pair 監督模型,讓預訓練模型從「延續文字」轉成「依指令回應」。SFT 的 loss 通常從 output token 開始,不在固定 input 上計算。
  • Instruction tuning 是 SFT 的子類,資料包含故事、詩、清單、解釋、數學、證明、程式碼與 safety 行為;早期多由人工撰寫,近年也常由既有 LLM 生成後再經人工或 LLM 檢查。
  • SFT 資料量遠小於預訓練資料,但品質與 prompt distribution 很重要。GPT-3 instruction tuning 約 13K examples;Llama 3 約 10 million examples。
  • 模型評估不能只看單一分數。MMLU、GSM8K、程式碼與推理 benchmark 可量化能力,但若模型曾被訓練在 test task 上,跨模型比較會失真。
  • Chatbot Arena 以使用者兩兩偏好建立排名,但有 early comparison noise、可能被模型身分辨識干擾、使用者非專家、emoji 偏好差異與 safety refusal 被低估等問題。
  • Alignment 在本課定義為 pre-training 之後的 fine tuning 與 preference tuning;mid-training 則是新近出現、介於 pre-training 與 fine tuning 之間的階段。
  • LoRA 凍結預訓練權重 $W_0$,只訓練低秩矩陣 $B$ 與 $A$ 的乘積,用少量參數表達 task-specific update;rank $r$ 常很小,rank 4 是常見選擇。
  • QLoRA 將凍結的 base weights 量化為 NF4,並以 BF16 訓練 LoRA 權重;NF4 假設權重近似常態分布並用 quantile 切分,課中提到可帶來約 16 倍 VRAM 節省。

教學投影片 · Teaching Slides

16 張 · 每張一個重點
01

課程定位

  • 本講主題是 LLM training
  • 從 pretraining 講到 fine tuning
  • 接續前講的 MoE 與 decoding
  • 重點是訓練成本與效率
02

Transfer learning

  • 傳統模型為單一任務訓練
  • 語言任務共享文字理解能力
  • Pretrained model 可重複使用
  • Tuning 讓模型適應任務
03

Pretraining objective

  • 預訓練是最昂貴階段
  • 目標是 next-token prediction
  • 資料可涵蓋文字與程式碼
  • LLM 多為 decoder-only model
04

預訓練資料規模

  • Common Crawl 是常見來源
  • 可含 Wikipedia 與 Reddit
  • 程式碼來自 GitHub 等來源
  • GPT-3 約 300B tokens
  • Llama 3 約 15T tokens
05

FLOPs 與 FLOPS

  • FLOPs 是浮點運算總量
  • FLOPS 是每秒浮點運算
  • LLM 訓練約達 10^25 FLOPs
  • Compute 約受 tokens 與參數影響
06

Scaling laws

  • Compute 越多通常越好
  • 資料越多通常越好
  • 模型越大通常越好
  • 大模型較 sample efficient
  • Chinchilla 強調 tokens/params 平衡
07

預訓練限制

  • 成本可達數百萬美元以上
  • 也有時間與環境成本
  • Knowledge cutoff 限制新知
  • Knowledge editing 仍困難
  • 生成可能重現訓練內容
08

訓練記憶體壓力

  • Forward pass 需 activations
  • Backward pass 需 gradients
  • Adam 需 moment states
  • Context length 增加 attention 成本
  • 單張 H100 仍只有 80 GB
09

Data parallelism 與 ZeRO

  • DP 將 batch 分到多張 GPU
  • 每張 GPU 保留模型 copy
  • Gradients 需跨 GPU 平均
  • ZeRO 切分 states 與 parameters
  • 記憶體下降但溝通成本上升
10

Model parallelism

  • 模型平行化切分模型運算
  • Expert parallelism 切分 MoE experts
  • Tensor parallelism 切矩陣乘法
  • Pipeline parallelism 切 layers
  • 目標是降低單卡負擔
11

FlashAttention

  • HBM 大但慢,SRAM 小但快
  • Vanilla attention 反覆讀寫 HBM
  • Tiling 把 blocks 放進 SRAM
  • Block-wise softmax 仍是 exact
  • I/O 減少是核心收益
12

Recomputation

  • Backward pass 需 forward activations
  • FlashAttention 可快速重算
  • 可少存 activations
  • HBM read/write 接近 10 倍下降
  • 記憶體與 runtime 可同時改善
13

Quantization

  • 浮點數由 sign/exponent/mantissa 組成
  • FP16 記憶體約為 FP32 一半
  • 低 precision 可提高 GPU 速度
  • Mixed precision 保留 FP32 weights
  • Forward/backward 可用 FP16
14

SFT 與 instruction tuning

  • SFT 使用 input-output pairs
  • Input 不計入 loss
  • Loss 從 output tokens 開始
  • Instruction tuning 教模型回應指令
  • Safety 資料可訓練拒答
15

評估與 alignment

  • MMLU 評估多任務理解
  • GSM8K 評估數學推理
  • Training on test task 會干擾比較
  • Chatbot Arena 衡量使用者偏好
  • Alignment 含 fine tuning 與 preference tuning
16

LoRA 與 QLoRA

  • LoRA 凍結 base weights
  • 只訓練低秩矩陣 A 與 B
  • Rank 通常很小,rank 4 常見
  • QLoRA 將 W0 量化為 NF4
  • 課中提到約 16 倍 VRAM 節省