Flaxデバッグ:ハッシュでパラメータ変更を追跡する
JAX/Flax NNXの訓練ループで損失が減少しない問題をデバッグするため、パラメータ配列のハッシュ値を比較する手法を紹介。77Mものパラメータを逐一確認する代わりに、numpyのtobytes()とPythonのhash()を使って微小な変化も検出。原因は@jax.jitではなく@nnx.jitを使うべきところを間違えていたことで、FlaxのNNX APIにおける非関数型的なインプレース更新とJITコンパイルの互換性について学べる内容。
背景メモ
- JAXはGoogleが開発した数値計算ライブラリで、関数型プログラミングのスタイル(入力から出力を計算し、状態を書き換えない)を特徴とする。Flaxはその上に構築されたニューラルネットワークフレームワークで、NNXはそのサブAPI(従来のLinenとは異なる新しいAPI)。Flax NNXは「in-place更新」というPyTorch風の書き方を採用しているが、それがJAXの関数型原則と相性が悪く、ハマりどころになっている。
- 著者はGPT-2レベルのLLMをスクラッチでJAX/Flax NNXで実装中。本記事では、トークン埋め込み層と出力線形層だけを直結した簡易モデルで訓練ループの動作確認をしている(本来の「次トークン予測」ではなく、入力をそのまま再現させるタスク)。
- 訓練が全く進まずロスが10.82で膠着——これはGPT-2の語彙サイズ50,257におけるランダム推測と同等のperplexity(約5万)に相当する値。つまりモデルは何も学習していない。
- 原因は、著者が訓練ステップ関数に @jax.jit(通常のJAX用デコレータ)を使っていたこと。Flax NNXのin-place更新を正しく機能させるには @nnx.jit を使わなければならない。@nnx.jit はモデルパラメータの状態を自動的に伝搬(propagation)するが、通常の @jax.jit はそれを行わないため、勾配が計算されてもパラメータに反映されていなかった。
- デバッグ手法:77Mものパラメータを逐一確認する代わりに、パラメータ配列をバイト列に変換してPythonのhash()にかけ、ハッシュ値が変化するかどうかで更新の有無を検出した。ハッシュが変わらなければ「更新されていない」と即座に断定できる。