稀疏自編碼器訓練
提供使用 SAELens 訓練和分析稀疏自編碼器(SAE)的指南,用於將神經網絡激活分解為可解釋的特徵。適用於在語言模型中發現可解釋特徵、分析超位置(superposition)或研究單義性(monosemantic)表示時。
技能元數據
| 來源 | 可選 — 使用 hermes skills install official/mlops/saelens 安裝 |
| 路徑 | optional-skills/mlops/saelens |
| 版本 | 1.0.0 |
| 作者 | Orchestra Research |
| 許可證 | MIT |
| 依賴項 | sae-lens>=6.0.0, transformer-lens>=2.0.0, torch>=2.0.0 |
| 標籤 | Sparse Autoencoders, SAE, Mechanistic Interpretability, Feature Discovery, Superposition |
參考:完整 SKILL.md
信息
以下是 Hermes 在觸發此技能時加載的完整技能定義。這是技能激活時代理看到的指令。
SAELens:用於機械可解釋性的稀疏自編碼器
SAELens 是用於訓練和分析稀疏自編碼器(SAE)的主要庫——這是一種將多義性(polysemantic)神經網絡激活分解為稀疏、可解釋特徵的技術。基於 Anthropic 在單義性方面的突破性研究。
GitHub: jbloomAus/SAELens (1,100+ stars)
問題:多義性與超位置
神經網絡中的單個神經元是多義的——它們在多個語義不同的上下文中激活。這是因為模型使用超位置來表示比其神經元數量更多的特徵,使得可解釋性變得困難。
SAE 通過以下方式解決此問題:將密集激活分解為稀疏的單義特徵——通常對於任何給定輸入,只有少量特徵被激活,且每個特徵對應一個可解釋的概念。
何時使用 SAELens
在需要執行以下操作時使用 SAELens:
- 發現模型激活中的可解釋特徵
- 理解模型學到了哪些概念
- 研究超位置和特徵幾何
- 執行基於特徵的引導(steering)或消融
- 分析與安全相關的特徵(欺騙、偏見、有害內容)
在以下情況考慮替代方案:
- 你需要基本的激活分析 → 直接使用 TransformerLens
- 你想要因果乾預實驗 → 使用 pyvene 或 TransformerLens
- 你需要生產環境引導 → 考慮直接激活工程
安裝
pip install sae-lens
要求:Python 3.10+, transformer-lens>=2.0.0
核心概念
SAE 學習的內容
SAE 經過訓練,通過稀疏瓶頸重建模型激活:
Input Activation → Encoder → Sparse Features → Decoder → Reconstructed Activation
(d_model) ↓ (d_sae >> d_model) ↓ (d_model)
sparsity reconstruction
penalty loss
損失函數:MSE(original, reconstructed) + L1_coefficient × L1(features)
關鍵驗證(Anthropic 研究)
在《Towards Monosemanticity》中,人類評估者發現 70% 的 SAE 特徵具有真正的可解釋性。發現的特徵包括:
- DNA 序列、法律語言、HTTP 請求
- 希伯來語文本、營養說明、代碼語法
- 情感、命名實體、語法結構
工作流 1:加載和分析預訓練 SAE
分步指南
from transformer_lens import HookedTransformer
from sae_lens import SAE
# 1. Load model and pre-trained SAE
model = HookedTransformer.from_pretrained("gpt2-small", device="cuda")
sae, cfg_dict, sparsity = SAE.from_pretrained(
release="gpt2-small-res-jb",
sae_id="blocks.8.hook_resid_pre",
device="cuda"
)
# 2. Get model activations
tokens = model.to_tokens("The capital of France is Paris")
_, cache = model.run_with_cache(tokens)
activations = cache["resid_pre", 8] # [batch, pos, d_model]
# 3. Encode to SAE features
sae_features = sae.encode(activations) # [batch, pos, d_sae]
print(f"Active features: {(sae_features > 0).sum()}")
# 4. Find top features for each position
for pos in range(tokens.shape[1]):
top_features = sae_features[0, pos].topk(5)
token = model.to_str_tokens(tokens[0, pos:pos+1])[0]
print(f"Token '{token}': features {top_features.indices.tolist()}")
# 5. Reconstruct activations
reconstructed = sae.decode(sae_features)
reconstruction_error = (activations - reconstructed).norm()
可用的預訓練 SAE
| 發佈版本 | 模型 | 層 |
|---|---|---|
gpt2-small-res-jb | GPT-2 Small | 多個殘差流 |
gemma-2b-res | Gemma 2B | 殘差流 |
| HuggingFace 上的各種版本 | 搜索標籤 saelens | 各種 |
檢查清單
- 使用 TransformerLens 加載模型
- 為目標層加載匹配的 SAE
- 將激活編碼為稀疏特徵
- 識別每個 token 激活最高的特徵
- 驗證重建質量
工作流 2:訓練自定義 SAE
分步指南
from sae_lens import SAE, LanguageModelSAERunnerConfig, SAETrainingRunner
# 1. Configure training
cfg = LanguageModelSAERunnerConfig(
# Model
model_name="gpt2-small",
hook_name="blocks.8.hook_resid_pre",
hook_layer=8,
d_in=768, # Model dimension
# SAE architecture
architecture="standard", # or "gated", "topk"
d_sae=768 * 8, # Expansion factor of 8
activation_fn="relu",
# Training
lr=4e-4,
l1_coefficient=8e-5, # Sparsity penalty
l1_warm_up_steps=1000,
train_batch_size_tokens=4096,
training_tokens=100_000_000,
# Data
dataset_path="monology/pile-uncopyrighted",
context_size=128,
# Logging
log_to_wandb=True,
wandb_project="sae-training",
# Checkpointing
checkpoint_path="checkpoints",
n_checkpoints=5,
)
# 2. Train
trainer = SAETrainingRunner(cfg)
sae = trainer.run()
# 3. Evaluate
print(f"L0 (avg active features): {trainer.metrics['l0']}")
print(f"CE Loss Recovered: {trainer.metrics['ce_loss_score']}")
關鍵超參數
| 參數 | 典型值 | 效果 |
|---|---|---|
d_sae | 4-16× d_model | 更多特徵,更高容量 |
l1_coefficient | 5e-5 到 1e-4 | 越高 = 越稀疏,準確度越低 |
lr | 1e-4 到 1e-3 | 標準優化器學習率 |
l1_warm_up_steps | 500-2000 | 防止早期特徵死亡 |
評估指標
| 指標 | 目標 | 含義 |
|---|---|---|
| L0 | 50-200 | 每個 token 的平均激活特徵數 |
| CE Loss Score | 80-95% | 相對於原始值的交叉熵恢復率 |
| Dead Features | <5% | 從未激活的特徵 |
| Explained Variance | >90% | 重建質量 |
檢查清單
- 選擇目標層和鉤子點(hook point)
- 設置擴展因子(d_sae = 4-16× d_model)
- 調整 L1 係數以獲得所需的稀疏度
- 啟用 L1 預熱以防止特徵死亡
- 在訓練期間監控指標(W&B)
- 驗證 L0 和 CE 損失恢復
- 檢查死亡特徵比例
工作流 3:特徵分析和引導
分析單個特徵
from transformer_lens import HookedTransformer
from sae_lens import SAE
import torch
model = HookedTransformer.from_pretrained("gpt2-small", device="cuda")
sae, _, _ = SAE.from_pretrained(
release="gpt2-small-res-jb",
sae_id="blocks.8.hook_resid_pre",
device="cuda"
)
# Find what activates a specific feature
feature_idx = 1234
test_texts = [
"The scientist conducted an experiment",
"I love chocolate cake",
"The code compiles successfully",
"Paris is beautiful in spring",
]
for text in test_texts:
tokens = model.to_tokens(text)
_, cache = model.run_with_cache(tokens)
features = sae.encode(cache["resid_pre", 8])
activation = features[0, :, feature_idx].max().item()
print(f"{activation:.3f}: {text}")
特徵引導
def steer_with_feature(model, sae, prompt, feature_idx, strength=5.0):
"""Add SAE feature direction to residual stream."""
tokens = model.to_tokens(prompt)
# Get feature direction from decoder
feature_direction = sae.W_dec[feature_idx] # [d_model]
def steering_hook(activation, hook):
# Add scaled feature direction at all positions
activation += strength * feature_direction
return activation
# Generate with steering
output = model.generate(
tokens,
max_new_tokens=50,
fwd_hooks=[("blocks.8.hook_resid_pre", steering_hook)]
)
return model.to_string(output[0])
特徵歸因
# Which features most affect a specific output?
tokens = model.to_tokens("The capital of France is")
_, cache = model.run_with_cache(tokens)
# Get features at final position
features = sae.encode(cache["resid_pre", 8])[0, -1] # [d_sae]
# Get logit attribution per feature
# Feature contribution = feature_activation × decoder_weight × unembedding
W_dec = sae.W_dec # [d_sae, d_model]
W_U = model.W_U # [d_model, vocab]
# Contribution to "Paris" logit
paris_token = model.to_single_token(" Paris")
feature_contributions = features * (W_dec @ W_U[:, paris_token])
top_features = feature_contributions.topk(10)
print("Top features for 'Paris' prediction:")
for idx, val in zip(top_features.indices, top_features.values):
print(f" Feature {idx.item()}: {val.item():.3f}")
常見問題與解決方案
問題:高死特徵比例
# WRONG: No warm-up, features die early
cfg = LanguageModelSAERunnerConfig(
l1_coefficient=1e-4,
l1_warm_up_steps=0, # Bad!
)
# RIGHT: Warm-up L1 penalty
cfg = LanguageModelSAERunnerConfig(
l1_coefficient=8e-5,
l1_warm_up_steps=1000, # Gradually increase
use_ghost_grads=True, # Revive dead features
)
問題:重建效果差(交叉熵恢復率低)
# Reduce sparsity penalty
cfg = LanguageModelSAERunnerConfig(
l1_coefficient=5e-5, # Lower = better reconstruction
d_sae=768 * 16, # More capacity
)
問題:特徵不可解釋
# Increase sparsity (higher L1)
cfg = LanguageModelSAERunnerConfig(
l1_coefficient=1e-4, # Higher = sparser, more interpretable
)
# Or use TopK architecture
cfg = LanguageModelSAERunnerConfig(
architecture="topk",
activation_fn_kwargs={"k": 50}, # Exactly 50 active features
)
問題:訓練期間出現內存錯誤
cfg = LanguageModelSAERunnerConfig(
train_batch_size_tokens=2048, # Reduce batch size
store_batch_size_prompts=4, # Fewer prompts in buffer
n_batches_in_buffer=8, # Smaller activation buffer
)
與 Neuronpedia 集成
在 neuronpedia.org 瀏覽預訓練的 SAE 特徵:
# Features are indexed by SAE ID
# Example: gpt2-small layer 8 feature 1234
# → neuronpedia.org/gpt2-small/8-res-jb/1234
關鍵類參考
| 類 | 用途 |
|---|---|
SAE | 稀疏自編碼器模型 |
LanguageModelSAERunnerConfig | 訓練配置 |
SAETrainingRunner | 訓練循環管理器 |
ActivationsStore | 激活值收集與批處理 |
HookedSAETransformer | TransformerLens + SAE 集成 |
參考文檔
有關詳細的 API 文檔、教程和高級用法,請參閱 references/ 文件夾:
| 文件 | 內容 |
|---|---|
| references/README.md | 概述和快速入門指南 |
| references/api.md | SAE、TrainingSAE、配置的完整 API 參考 |
| references/tutorials.md | 訓練、分析、 steering 的分步教程 |
外部資源
教程
論文
- Towards Monosemanticity - Anthropic (2023)
- Scaling Monosemanticity - Anthropic (2024)
- Sparse Autoencoders Find Highly Interpretable Features - Cunningham 等人 (ICLR 2024)
官方文檔
- SAELens 文檔
- Neuronpedia - 特徵瀏覽器
SAE 架構
| 架構 | 描述 | 用例 |
|---|---|---|
| Standard | ReLU + L1 懲罰 | 通用目的 |
| Gated | 學習門控機制 | 更好的稀疏性控制 |
| TopK | 恰好 K 個活躍特徵 | 一致的稀疏性 |
# TopK SAE (exactly 50 features active)
cfg = LanguageModelSAERunnerConfig(
architecture="topk",
activation_fn="topk",
activation_fn_kwargs={"k": 50},
)