JAX backends and devices
A developer porting PyTorch code to JAX hit a GPU OOM error loading a 19 GiB dataset, because JAX defaults to the GPU backend. They solved it by using `jax.default_device(jax.devices("cpu")[0])` as a context manager to load data into RAM instead.