[ge_arrow] Update to JAX and compare runtime#717
Conversation
|
📖 Netlify Preview Ready! Preview URL: https://pr-717--sunny-cactus-210e3e.netlify.app (b58cafe) 📚 Changed Lecture Pages: aiyagari, cake_eating_numerical, career, coleman_policy_iter, egm_policy_iter, ge_arrow, harrison_kreps, ifp, ifp_advanced, jv, lake_model, lqcontrol, mccall_correlated, mccall_fitted_vfi, mccall_model, mccall_model_with_separation, mccall_q, odu, optgrowth_fast, two_auctions, wald_friedman_2 |
|
📖 Netlify Preview Ready! Preview URL: https://pr-717--sunny-cactus-210e3e.netlify.app (bf55c41) 📚 Changed Lecture Pages: aiyagari, cake_eating_numerical, career, coleman_policy_iter, egm_policy_iter, ge_arrow, harrison_kreps, ifp, ifp_advanced, jv, lake_model, lqcontrol, mccall_correlated, mccall_fitted_vfi, mccall_model, mccall_model_with_separation, mccall_q, odu, optgrowth_fast, two_auctions, wald_friedman_2 |
|
📖 Netlify Preview Ready! Preview URL: https://pr-717--sunny-cactus-210e3e.netlify.app (253dea6) 📚 Changed Lecture Pages: aiyagari, cake_eating_numerical, career, coleman_policy_iter, egm_policy_iter, ge_arrow, harrison_kreps, ifp, ifp_advanced, jv, lake_model, lqcontrol, mccall_correlated, mccall_fitted_vfi, mccall_model, mccall_model_with_separation, mccall_q, odu, optgrowth_fast, two_auctions, wald_friedman_2 |
|
📖 Netlify Preview Ready! Preview URL: https://pr-717--sunny-cactus-210e3e.netlify.app (be07276) 📚 Changed Lecture Pages: aiyagari, cake_eating_numerical, career, coleman_policy_iter, egm_policy_iter, ge_arrow, harrison_kreps, ifp, ifp_advanced, jv, lake_model, lqcontrol, mccall_correlated, mccall_fitted_vfi, mccall_model, mccall_model_with_separation, mccall_q, odu, optgrowth_fast, two_auctions, wald_friedman_2 |
- Replace `@partial(jax.jit)` with `jax.jit` on the main function `compute_rc_model`. - Write a function to compute example 3 and add `jax.jit` decorator.
📖 Netlify Preview Ready!Preview URL: https://pr-717--sunny-cactus-210e3e.netlify.app Commit: 📚 Changed Lectures
Build Info
|
A Quantitative Evaluation System for JAX Rewrites of QuantEcon LecturesThis document defines a reusable, quantitative system for deciding whether rewriting a QuantEcon lecture's code (e.g., converting NumPy → JAX) actually improves the lecture. It was designed against the first such change, The guiding principle: these are teaching lectures first and programs second. A rewrite that makes the code faster or more "modern" but harder for a learner to read, or that silently changes the numbers, is not automatically an improvement. The system therefore weights pedagogy heavily and never treats "uses JAX" as a goal in itself — JAX must earn its place on each lecture. 1. The seven dimensions
Weights sum to 1.0. Readability (0.25) outranks efficiency (0.15) on purpose: the audience is learners, and most lecture models are tiny. Adjust the weights per lecture family if needed (e.g. a "performance" lecture could raise dimension 3), but record any change. Each dimension is scored 1–5 against the anchors below, then combined: Interpreting the total
2. Scoring anchors + worked high/low examplesFor each dimension, we give (a) the metric(s) that quantify it, (b) the 1–5 anchors, and (c) a HIGH-scoring and LOW-scoring example so reviewers agree on what "good" looks like. Dimensions 1, 2, 3, 6 carry numeric score thresholds (a measured number maps directly to 1–5); dimensions 4, 5, 7 are structural and scored against criteria + cited evidence. The numeric thresholds were calibrated against two real, measured end points: a HIGH case (the aiyagari Bellman pattern, 25× faster as-used) and a LOW case (the full Every example below is real code already in Dimension 1 — Correctness & numerical fidelity · weight 0.20Metrics (from
Anchors (numeric — keyed to
Dimension 2 — Readability & pedagogical clarity · weight 0.25Metrics (from
Anchors (numeric — keyed to Δ prerequisite-concepts vs the original and to docstring coverage; both from
(Use the worse of the two columns; the "&" column is the tie-breaker.)
Dimension 3 — Computational efficiency (as actually used) · weight 0.15Crucial rule: measure efficiency in the regime the lecture runs, not a hypothetical large-scale one. For JAX that means including trace+compile time whenever the lecture hits a new shape or Metrics (from
The metric that decides the score is the as-used speedup measured over the lecture's actual sequence of solver calls, at its actual problem sizes, in a fresh interpreter (so JAX's compiles count). Anchors (numeric)
Dimension 4 — Logic & design · weight 0.15Metrics: Anchors
Dimension 5 — Coding style & idiom · weight 0.10Metrics: PEP 8 / project-style conformance, and — for JAX — whether the code uses idiomatic JAX (vectorisation, Anchors
Dimension 6 — API ergonomics & reusability · weight 0.10Metrics: Anchors (numeric — keyed to
Dimension 7 — Maintainability & robustness · weight 0.05Metrics: testability (pure vs stateful), debuggability (can you step through it?), and "footguns" left for future editors. Anchors
3. Limitations
|
Evaluation Report —
|
| Dimension | Wt | Score | Weighted |
|---|---|---|---|
| Correctness & numerical fidelity | 0.20 | 3 | 0.60 |
| Readability & pedagogical clarity | 0.25 | 2 | 0.50 |
| Computational efficiency (as used) | 0.15 | 2 | 0.30 |
| Logic & design | 0.15 | 4 | 0.60 |
| Coding style & idiom | 0.10 | 3 | 0.30 |
| API ergonomics & reusability | 0.10 | 5 | 0.50 |
| Maintainability & robustness | 0.05 | 3 | 0.15 |
| Total | 1.00 | 2.95 |
What changed
Original (main) |
Rewrite (update_ge_arrow) |
|
|---|---|---|
| Library | NumPy | JAX (jnp, lax, jit) |
| Container | mutable class with methods |
immutable NamedTuple of results |
| Entry point | build object + 3 ordered method calls | one @jit function compute_rc_model |
| Loops | Python for (×6) |
jax.lax.fori_loop / lax.cond (0 Python loops) |
| Infinite-horizon flag | T=None |
T=0 |
| Notable | typo value_functionss; uses global P,n,K |
fixes both |
Evidence by dimension
1 · Correctness & numerical fidelity → 3/5
check_equivalence.py over all 11 example/initial-state combinations:
- Under float64: every object matches,
max|Δ| = 1.4e-14→ the rewrite's logic is identical. ✅ - As the lecture actually runs (float32 default, no
jax_enable_x64):ex2deviates by1.7e-4; several others ~1e-4. The published tables move in the 4th–5th significant figure. ❌ unflagged precision loss.
→ Correct economics, silently reduced precision. Score capped at 3.
2 · Readability & pedagogical clarity → 2/5
static_metrics.py:
| metric | old | new |
|---|---|---|
| prerequisite concepts | 7 | 13 |
| docstring coverage | 0.90 | 0.55 |
| code lines (model def) | 119 | 161 |
| sub-definitions | 10 | 22 |
| Python loops a reader parses | 6 | 0 (replaced by fori_loop closures) |
The pricing kernel — mathematically just fori_loops with .at[j].set(...) carries. For a lecture whose economies are 2×2, this is pure cognitive overhead. Biggest single driver of the negative verdict (and the heaviest-weighted dimension).
3 · Computational efficiency (as used) → 2/5
This was the stated motivation, so it matters that it is not achieved here.
Headline metric — replaying the entire lecture solver sequence once in a fresh process (as_used_total.py):
| NumPy total | JAX total | as-used speedup |
|---|---|---|
| 0.035 s | 1.56 s | 0.022× — i.e. ~45× slower |
Per-regime detail explaining why:
| Regime (n=2 unless noted) | NumPy | JAX | Result |
|---|---|---|---|
| First solve (cold, incl. compile) | 6.2 ms | 286 ms | 46× slower |
Recompile per new s0_idx/T |
— | 133 ms | each distinct call recompiles |
| Warm repeat | 0.032 ms | 0.022 ms | 1.4× faster (never used) |
| λ-sweep (100 pts), as run once | 1.8 ms | 300 ms cold | 170× slower |
| λ-sweep warm | — | 0.37 ms | 4.8× faster (never realized) |
Scaling crossover (benchmark.py): NumPy and JAX-warm are even near n≈10; JAX wins 2–6× for n = 25…200. The lecture never exceeds n=3. For calibration, the same machinery on the large, repeatedly-solved aiyagari pattern (bellman_bench.py) is 25× faster — a score-5 case. ge_arrow's 0.022× maps to score 2 (< 0.8×, but correct and fixable).
4 · Logic & design → 4/5
Genuine improvements, all verified in the diff:
- removes order-dependent stateful methods (old required
wealth_distribution → continuation_wealths → value_functionss); - removes reliance on module-level
P, n, K(a latent bug in the original); - fixes the
value_functionsstypo; - de-duplicates (
Rno longer recomputessum(Q)); returns one result object.
Minus one point: the pricing kernel is ported as an O(n²) scalar loop instead of a vectorised outer product.
5 · Coding style & idiom → 3/5
NamedTuple + pure function is clean. But two anti-idiomatic JAX choices: the nested-fori_loop pricing kernel (vectorisation was trivial) and jax.lax.cond(T==0, …) which traces both branches although T is already a static argument — a plain if would compile only the needed branch.
6 · API ergonomics & reusability → 5/5
statements_for_one_result: 4 → 1. compute_rc_model(s, P, ys, s0_idx=1, T=10) returns an immutable bundle; fully jit/vmap-composable. Clear win.
7 · Maintainability & robustness → 3/5
Purity aids unit testing, but jit + static_argnames + 3-deep closures hinder step-debugging, and the float32 default is a silent trap for future reuse.
Recommendation
The conversion is not yet a net improvement for this particular lecture. Two paths:
A. Keep NumPy for ge_arrow. The models are 2×2/3×3; NumPy is faster as-used, more readable, and matches the published numbers. Reserve JAX for lectures with large, repeated, fixed-shape computation.
B. If JAX is kept, fix these before re-scoring (each maps to a dimension):
- Vectorise the pricing kernel →
Q = β*(y[None,:]/y[:,None])**(-γ)*P(D2 readability, D3 efficiency, D5 idiom). - Enable float64:
jax.config.update("jax_enable_x64", True)so published numbers are preserved (D1, D7). - Reduce recompiles: avoid making
s0_idx/Tstatic, or vectorise overs0_idx, so the lecture doesn't pay a fresh compile per call (D3). - Restore docstrings on the nested helpers; replace
lax.condon a staticTwith a Pythonif(D2, D5).
Re-running run_all.py after these fixes would likely lift readability to ~3, efficiency to ~3, and the total above the 3.0 "merge after fixes" line.

Updated the
ge_arrow.mdto JAX and complemented the styling consistent with the operation manual.Key changes:
RecurCompetitiveclass as aNamedTuple.compute_rc_modelfunction. Inside this function, arguments of sub-functions can be written in the same way as the definitions in the theory part.jittedthe main computation function, and usedjax.lax.fori_loopto conduct loops.Update: Runtime Comparison Between
JAX (GPU),JAX (CPU), andNumPyMethodology: nearly the same as in #654
JAXversion uses the code in this PR, while theNumPyversion uses the code inmain.JAX (GPU)is measured using Google ColabT4 GPUruntime.qe.timeitover 1,000 iterations.Results:
JAX (CPU)>NumPy>JAX (GPU).