Three answers to one question: when does the graph exist?
Every framework needs a graph of the computation, for two non-negotiable reasons: autodiff must walk the chain of operations backward, and the compiler needs the whole chain to fuse and schedule it (see JIT compilation). The frameworks differ less in kernels — everyone ultimately calls cuDNN, cuBLAS, or XLA — than in when that graph comes into existence, and what the user must promise to make it exist. That single choice cascades into debuggability, performance, and what the code feels like to write.
print prints a symbolic node. Maximal optimisation, minimal ergonomics. Research users fled; TF2 capitulated to eager with tf.function retrofitted on top.
if. The price is that no complete graph exists ahead of time — the gap torch.compile now works to close from the eager side.
x.at[i].set(v) returns a new array), explicit RNG keys threaded by hand, parameters passed as arguments rather than living in modules, and traced branches that cannot depend on values (jax.lax.cond instead of if). Side effects do not fail loudly — a print inside jit fires once, at trace time, then never again.
Purity is what JAX charges; this is what it buys. Because a traced function is a closed mathematical object, transformations of it compose like functions do:
| Transform | Takes f and returns… | Replaces |
|---|---|---|
grad(f) | The gradient function, itself traceable | Autograd tape machinery |
jit(f) | f compiled whole by XLA | torch.compile, without graph breaks |
vmap(f) | f vectorised over a new batch axis | Hand-written batching, loops |
pmap / shard_map(f) | f running SPMD across devices, collectives inside | Much of the DDP/FSDP wrapper stack |
The composition is the point: jit(pmap(grad(f))) is per-example gradients, batched, differentiated, compiled, and distributed — one line, no framework machinery in sight. Things that are research projects in other ecosystems (per-sample gradients for DP, meta-gradients through training steps, ensembles via vmap over parameter stacks) are compositions in JAX. Sharding follows the same philosophy: annotate how arrays are laid out across the mesh, and the XLA partitioner derives the communication, rather than the user orchestrating it imperatively.
| PyTorch | JAX | TensorFlow | |
|---|---|---|---|
| Graph exists | During execution (autograd trail); ahead-of-time via compile | At trace time, on demand, whole | TF1: before. TF2: eager + tf.function retrace |
| State & params | Mutable modules, in-place ops | Pure functions; state threaded explicitly (pytrees) | Mutable tf.Variable in Keras objects |
| Debugging | Native Python, best in class | Good eager; inside jit, trace-time surprises | Historically the complaint that built PyTorch |
| Compiler story | torch.compile, partial graphs, guards | XLA-native, whole program, designed-in | XLA available, grafted on |
| Hardware | GPUs first; TPU support secondary | TPUs first-class; GPUs well supported | TPUs supported; the original TPU framework |
| Randomness | Global stateful RNG | Explicit keys (reproducible, verbose) | Global, with seeds |
| Ecosystem center | Research papers, HF models | Scaling shops, RL, scientific computing | Legacy production, mobile (TFLite) |
PyTorch is the default. The large majority of new research code, the Hugging Face ecosystem, and most open-weight model releases are PyTorch-first. Define-by-run won the argument: researcher iteration speed beat compiler convenience, and the compiler gap has since been narrowed from the eager side rather than the reverse.
JAX is the choice of TPU and large-scale shops. Google DeepMind, Anthropic, and a cluster of scaling-focused labs run on it, because whole-program XLA plus mesh sharding is a genuinely better substrate at thousand-chip scale, and because functional purity pays increasing dividends as systems grow (reproducibility, checkpointable state, no hidden mutation). Its research mindshare outside those shops remains a minority position.
TensorFlow is legacy production. Plenty still runs — established serving stacks, TFLite on mobile, older recommender systems — but new projects rarely start there, and Google's own center of gravity moved to JAX. Choosing it for new work in 2026 needs a specific reason (usually an existing deployment pipeline).