Using Safetensors with Flax
Using Safetensors with Flax requires a flat dictionary of string keys to arrays, not nested dicts. The author uses nnx.to_flat_state to flatten the model state, then converts it to dot-separated keys for save_file, and reverses the process for loading.