Flax debugging: making a hash of things
The author debugged a Flax NNX training loop where the loss was stuck at 10.82, indicating random guessing. By hashing the model parameters and comparing hashes across steps, they discovered the parameters weren't changing. The root cause was using @jax.jit instead of @nnx.jit, which is needed for proper in-place state propagation of parameter updates in NNX.
Background
- **JAX** is a high-performance numerical computing library from Google, designed for machine learning; it's functional (pure functions, no side effects) and compiles code with `jax.jit` for speed. **Flax NNX** is a newer, more PyTorch-like API built on top of JAX that lets you mutate model state "in place" rather than returning new copies.
- The author is building a tiny LLM from scratch for learning purposes. The training loop wasn't working because loss stayed constant, meaning no learning was happening.
- The bug: the author used the standard `@jax.jit` decorator on their training step, but Flax NNX's in-place parameter updates require its own `@nnx.jit` decorator, which automatically handles propagating mutated state back out of the compiled function. Using plain `@jax.jit` silently discarded the parameter updates.
- The debugging trick: with 77 million parameters, printing values is useless. Instead, the author hashed the raw byte representations of the parameter arrays using Python's `hash()` — even tiny changes produce completely different hashes, making it trivial to check if parameters are actually being updated.