FlaxでSafetensorsを使用する
PyTorchのLLMコードをJAX/Flaxに移植する際、モデルのチェックポイント保存にSafetensorsを使いたい場合の注意点と解決方法。SafetensorsのFlax/JAX APIは単純なフラットな辞書構造を期待するが、Flaxの`nnx.State.to_pure_dict`で得られる辞書はネスト構造のためそのまま渡すとエラーになる。代わりに`nnx.to_flat_state`を使ってフラット化し、ドット区切りのキーに変換してからSafetensorsに渡すことで正しく保存・読み込みが可能になる。