From 68388591c6b853815d49368445ab9adb861eea8d Mon Sep 17 00:00:00 2001 From: Claude Date: Sat, 27 Jun 2026 22:44:16 +0000 Subject: [PATCH 01/17] Add symbolic differentiation engine for SQL expressions (autograd MVP) Introduce `src/autograd.rs`, the Rust core of the autograd feature: a `differentiate(&Expr, wrt)` function that symbolically differentiates a DataFusion logical `Expr` tree with respect to a named column and returns a new `Expr` built from ordinary SQL expressions. The design mirrors JAX's per-primitive rule registry (defjvp and friends): each node type has a differentiation rule and the chain rule composes them as the tree is walked. A small 0/1-folding simplifier keeps output compact, playing the role of JAX's Zero tangents and add_tangents. Because each table row is an independent evaluation point, differentiating a column expression and letting DataFusion evaluate it row-by-row is the relational equivalent of vmap(grad(f)). This first cut implements scalar `grad`: rules for +, -, *, / (sum, product, quotient), unary chain rule for sin/cos/tan, asin/acos/atan, exp/ln/log2/ log10/sqrt, sinh/cosh/tanh, abs, and power() with constant base or exponent. Unsupported nodes/functions return a clear NotImplemented error rather than a silently wrong derivative. The engine operates purely on DataFusion `Expr`, keeping the eventual Python<->Rust transport (SQL text, Substrait, or proto) pluggable. Covered by 11 unit tests. Co-Authored-By: Claude Opus 4.8 Claude-Session: https://claude.ai/code/session_017mDoFJgsm9kS7SicGoCVF6 --- Cargo.lock | 2 +- src/autograd.rs | 416 ++++++++++++++++++++++++++++++++++++++++++++++++ src/lib.rs | 2 + 3 files changed, 419 insertions(+), 1 deletion(-) create mode 100644 src/autograd.rs diff --git a/Cargo.lock b/Cargo.lock index 8b58e75..21dfa95 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3375,7 +3375,7 @@ checksum = "ea2f10b9bb0928dfb1b42b65e1f9e36f7f54dbdf08457afefb38afcdec4fa2bb" [[package]] name = "xarray_sql" -version = "0.2.3" +version = "0.3.0" dependencies = [ "arrow", "async-stream", diff --git a/src/autograd.rs b/src/autograd.rs new file mode 100644 index 0000000..662c9f0 --- /dev/null +++ b/src/autograd.rs @@ -0,0 +1,416 @@ +//! Symbolic differentiation of DataFusion logical [`Expr`] trees. +//! +//! This is the autograd engine for xarray-sql. Given an [`Expr`] and the name +//! of a column to differentiate with respect to, [`differentiate`] returns a +//! new [`Expr`] for the (symbolic) partial derivative, built entirely from +//! ordinary DataFusion expressions so the result can be planned and evaluated +//! by DataFusion like any other SQL expression. +//! +//! ## Design +//! +//! The approach mirrors JAX's per-primitive rule registry (`defjvp` and +//! friends in `jax/_src/interpreters/ad.py`): every expression node has a +//! differentiation rule, and the chain rule composes them as the tree is +//! walked. Because each row of a relational table is an independent evaluation +//! point, differentiating a column expression and letting DataFusion evaluate +//! it row-by-row is the moral equivalent of `jax.vmap(jax.grad(f))` — the rows +//! *are* the batch dimension. +//! +//! A small simplifier folds the `0`/`1` constants that differentiation +//! produces in abundance (e.g. `d/dx (c) = 0`, `d/dx (x) = 1`), keeping output +//! expressions compact. This plays the role of JAX's `Zero` tangents and +//! `add_tangents`: a `0` derivative short-circuits products and drops out of +//! sums, and a `1` factor drops out of products. +//! +//! ## Scope (MVP) +//! +//! This first cut implements scalar `grad`: the partial derivative of a single +//! expression with respect to one named column. Forward-/reverse-mode +//! (`jvp`/`vjp`) and multi-input Jacobians are deliberately left for later. + +#![allow(dead_code)] + +use std::f64::consts::{LN_10, LN_2}; + +use datafusion::common::{DataFusionError, Result, ScalarValue}; +use datafusion::functions::math::expr_fn; +use datafusion::logical_expr::expr::ScalarFunction; +use datafusion::logical_expr::{lit, BinaryExpr, Cast, Expr, Operator}; + +// --------------------------------------------------------------------------- +// Constant helpers and the 0/1-folding builders +// --------------------------------------------------------------------------- + +/// The constant `0.0`, used as the derivative of anything not depending on the +/// differentiation variable. +fn zero() -> Expr { + lit(0.0_f64) +} + +/// The constant `1.0`, used as the derivative of the differentiation variable. +fn one() -> Expr { + lit(1.0_f64) +} + +/// Interpret a [`ScalarValue`] as `f64` if it is a (non-null) numeric scalar. +fn scalar_as_f64(sv: &ScalarValue) -> Option { + match sv { + ScalarValue::Float64(Some(v)) => Some(*v), + ScalarValue::Float32(Some(v)) => Some(*v as f64), + ScalarValue::Int64(Some(v)) => Some(*v as f64), + ScalarValue::Int32(Some(v)) => Some(*v as f64), + ScalarValue::Int16(Some(v)) => Some(*v as f64), + ScalarValue::Int8(Some(v)) => Some(*v as f64), + ScalarValue::UInt64(Some(v)) => Some(*v as f64), + ScalarValue::UInt32(Some(v)) => Some(*v as f64), + ScalarValue::UInt16(Some(v)) => Some(*v as f64), + ScalarValue::UInt8(Some(v)) => Some(*v as f64), + _ => None, + } +} + +/// Return the constant `f64` value of a literal expression, if it is one. +fn as_const(e: &Expr) -> Option { + match e { + Expr::Literal(sv, _) => scalar_as_f64(sv), + _ => None, + } +} + +/// True if the expression is a numeric literal exactly equal to zero. +fn is_zero(e: &Expr) -> bool { + matches!(as_const(e), Some(v) if v == 0.0) +} + +/// True if the expression is a numeric literal exactly equal to one. +fn is_one(e: &Expr) -> bool { + matches!(as_const(e), Some(v) if v == 1.0) +} + +fn binary(left: Expr, op: Operator, right: Expr) -> Expr { + Expr::BinaryExpr(BinaryExpr::new(Box::new(left), op, Box::new(right))) +} + +/// `a + b`, dropping a zero operand. +fn add(a: Expr, b: Expr) -> Expr { + if is_zero(&a) { + b + } else if is_zero(&b) { + a + } else { + binary(a, Operator::Plus, b) + } +} + +/// `a - b`, dropping a zero right operand and turning `0 - b` into `-b`. +fn sub(a: Expr, b: Expr) -> Expr { + if is_zero(&b) { + a + } else if is_zero(&a) { + neg(b) + } else { + binary(a, Operator::Minus, b) + } +} + +/// `a * b`, folding `0 * _ = 0` and `1 * b = b` (and the mirror cases). +fn mul(a: Expr, b: Expr) -> Expr { + if is_zero(&a) || is_zero(&b) { + zero() + } else if is_one(&a) { + b + } else if is_one(&b) { + a + } else { + binary(a, Operator::Multiply, b) + } +} + +/// `a / b`, folding `0 / _ = 0` and `a / 1 = a`. +fn div(a: Expr, b: Expr) -> Expr { + if is_zero(&a) { + zero() + } else if is_one(&b) { + a + } else { + binary(a, Operator::Divide, b) + } +} + +/// `-a`, folding `-0 = 0`. +fn neg(a: Expr) -> Expr { + if is_zero(&a) { + zero() + } else { + Expr::Negative(Box::new(a)) + } +} + +/// `e * e`. +fn square(e: Expr) -> Expr { + mul(e.clone(), e) +} + +// --------------------------------------------------------------------------- +// The differentiation rules +// --------------------------------------------------------------------------- + +/// Differentiate `expr` with respect to the column named `wrt`. +/// +/// Returns a new [`Expr`] for the partial derivative, composed of ordinary +/// DataFusion expressions. Returns a [`DataFusionError::NotImplemented`] for +/// expression nodes or scalar functions without a differentiation rule, so the +/// caller can surface a clear, actionable error rather than silently producing +/// a wrong answer. +pub fn differentiate(expr: &Expr, wrt: &str) -> Result { + match expr { + // d/dx (x) = 1 ; d/dx (y) = 0 for any other column. + Expr::Column(c) => Ok(if c.name == wrt { one() } else { zero() }), + + // d/dx (constant) = 0. + Expr::Literal(_, _) => Ok(zero()), + + // An alias is transparent to differentiation; the surrounding query + // re-applies any output naming. + Expr::Alias(a) => differentiate(&a.expr, wrt), + + // A numeric cast is (locally) linear: d/dx cast(u) = cast(du). We keep + // the cast so the derivative retains the declared output type. + Expr::Cast(c) => { + let du = differentiate(&c.expr, wrt)?; + Ok(Expr::Cast(Cast::new(Box::new(du), c.data_type.clone()))) + } + + // d/dx (-u) = -(du). + Expr::Negative(inner) => Ok(neg(differentiate(inner, wrt)?)), + + Expr::BinaryExpr(be) => diff_binary(be, wrt), + + Expr::ScalarFunction(sf) => diff_scalar_function(sf, wrt), + + other => Err(DataFusionError::NotImplemented(format!( + "grad: differentiation is not implemented for this expression: {other}" + ))), + } +} + +/// Differentiate a binary arithmetic expression via the sum/product/quotient +/// rules. +fn diff_binary(be: &BinaryExpr, wrt: &str) -> Result { + let a = be.left.as_ref(); + let b = be.right.as_ref(); + let da = differentiate(a, wrt)?; + let db = differentiate(b, wrt)?; + + match be.op { + // d/dx (a + b) = da + db + Operator::Plus => Ok(add(da, db)), + // d/dx (a - b) = da - db + Operator::Minus => Ok(sub(da, db)), + // d/dx (a * b) = da*b + a*db (product rule) + Operator::Multiply => { + Ok(add(mul(da, b.clone()), mul(a.clone(), db))) + } + // d/dx (a / b) = (da*b - a*db) / b^2 (quotient rule) + Operator::Divide => { + let numerator = sub(mul(da, b.clone()), mul(a.clone(), db)); + Ok(div(numerator, square(b.clone()))) + } + op => Err(DataFusionError::NotImplemented(format!( + "grad: operator '{op}' is not differentiable" + ))), + } +} + +/// Differentiate a scalar-function call via the chain rule. +/// +/// For a unary primitive `f(u)`, the derivative is `f'(u) * du`. For `power`, +/// which is binary, we handle the constant-exponent and constant-base cases. +fn diff_scalar_function(sf: &ScalarFunction, wrt: &str) -> Result { + let name = sf.func.name(); + let args = &sf.args; + + // `power(base, exponent)` is the one binary primitive we differentiate. + if name == "power" { + return diff_power(args, wrt); + } + + if args.len() != 1 { + return Err(DataFusionError::NotImplemented(format!( + "grad: no derivative rule for function '{name}' with {} arguments", + args.len() + ))); + } + + let u = &args[0]; + let du = differentiate(u, wrt)?; + // Chain rule short-circuit: if du is 0, the whole derivative is 0 and we + // avoid emitting the (dead) outer derivative term entirely. + if is_zero(&du) { + return Ok(zero()); + } + + let outer = match name { + // Trigonometric. + "sin" => expr_fn::cos(u.clone()), + "cos" => neg(expr_fn::sin(u.clone())), + "tan" => div(one(), square(expr_fn::cos(u.clone()))), + // Inverse trigonometric. + "asin" => div(one(), expr_fn::sqrt(sub(one(), square(u.clone())))), + "acos" => neg(div(one(), expr_fn::sqrt(sub(one(), square(u.clone()))))), + "atan" => div(one(), add(one(), square(u.clone()))), + // Exponential / logarithmic. + "exp" => expr_fn::exp(u.clone()), + "ln" => div(one(), u.clone()), + "log2" => div(one(), mul(u.clone(), lit(LN_2))), + "log10" => div(one(), mul(u.clone(), lit(LN_10))), + "sqrt" => div(one(), mul(lit(2.0_f64), expr_fn::sqrt(u.clone()))), + // Hyperbolic. + "sinh" => expr_fn::cosh(u.clone()), + "cosh" => expr_fn::sinh(u.clone()), + "tanh" => sub(one(), square(expr_fn::tanh(u.clone()))), + // Piecewise-linear: derivative is the sign (undefined at 0, like JAX). + "abs" => expr_fn::signum(u.clone()), + _ => { + return Err(DataFusionError::NotImplemented(format!( + "grad: no derivative rule for function '{name}'" + ))) + } + }; + + Ok(mul(outer, du)) +} + +/// Differentiate `power(base, exponent)`. +/// +/// * Constant exponent `c`: `d/dx base^c = c * base^(c-1) * d(base)`. +/// * Constant base `a`: `d/dx a^u = a^u * ln(a) * d(u)`. +/// * Both variable (`u^v`): not supported in the MVP. +fn diff_power(args: &[Expr], wrt: &str) -> Result { + if args.len() != 2 { + return Err(DataFusionError::NotImplemented( + "grad: power() expects exactly two arguments".to_string(), + )); + } + let base = &args[0]; + let exponent = &args[1]; + + match (as_const(base), as_const(exponent)) { + // Constant exponent (covers the common x^2, x^0.5, ... cases). + (_, Some(c)) => { + let dbase = differentiate(base, wrt)?; + if is_zero(&dbase) { + return Ok(zero()); + } + let outer = + mul(lit(c), expr_fn::power(base.clone(), lit(c - 1.0))); + Ok(mul(outer, dbase)) + } + // Constant base, variable exponent. + (Some(a), None) => { + let dexp = differentiate(exponent, wrt)?; + if is_zero(&dexp) { + return Ok(zero()); + } + let outer = mul( + expr_fn::power(base.clone(), exponent.clone()), + lit(a.ln()), + ); + Ok(mul(outer, dexp)) + } + // General u^v requires the exp/log trick; deferred past the MVP. + (None, None) => Err(DataFusionError::NotImplemented( + "grad: power(base, exponent) where both depend on the \ + differentiation variable is not yet supported" + .to_string(), + )), + } +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + use datafusion::logical_expr::col; + + #[test] + fn constant_has_zero_derivative() { + assert_eq!(differentiate(&lit(3.0_f64), "x").unwrap(), zero()); + } + + #[test] + fn variable_has_unit_derivative() { + assert_eq!(differentiate(&col("x"), "x").unwrap(), one()); + } + + #[test] + fn other_variable_has_zero_derivative() { + assert_eq!(differentiate(&col("y"), "x").unwrap(), zero()); + } + + #[test] + fn sum_rule_folds_constants() { + // d/dx (x + y) = 1 + 0 = 1 + let e = add(col("x"), col("y")); + assert_eq!(differentiate(&e, "x").unwrap(), one()); + } + + #[test] + fn product_rule() { + // d/dx (x * x) = 1*x + x*1 = x + x + let e = binary(col("x"), Operator::Multiply, col("x")); + let expected = add(col("x"), col("x")); + assert_eq!(differentiate(&e, "x").unwrap(), expected); + } + + #[test] + fn quotient_rule() { + // d/dx (x / y) = (1*y - x*0) / (y*y) = y / (y*y) + let e = binary(col("x"), Operator::Divide, col("y")); + let expected = div(col("y"), square(col("y"))); + assert_eq!(differentiate(&e, "x").unwrap(), expected); + } + + #[test] + fn chain_rule_sin() { + // d/dx sin(x) = cos(x) * 1 = cos(x) + let d = differentiate(&expr_fn::sin(col("x")), "x").unwrap(); + assert_eq!(d, expr_fn::cos(col("x"))); + // Readable, precedence-free rendering. + assert_eq!(d.to_string(), "cos(x)"); + } + + #[test] + fn composite_sin_times_x() { + // d/dx (sin(x) * x) = cos(x)*x + sin(x) + let e = + binary(expr_fn::sin(col("x")), Operator::Multiply, col("x")); + let d = differentiate(&e, "x").unwrap(); + assert_eq!(d.to_string(), "cos(x) * x + sin(x)"); + } + + #[test] + fn power_constant_exponent() { + // d/dx power(x, 2) = 2 * power(x, 1) * 1 = 2 * power(x, 1) + let e = expr_fn::power(col("x"), lit(2.0_f64)); + let expected = + mul(lit(2.0_f64), expr_fn::power(col("x"), lit(1.0_f64))); + assert_eq!(differentiate(&e, "x").unwrap(), expected); + } + + #[test] + fn unsupported_operator_errors() { + let e = binary(col("x"), Operator::Modulo, col("y")); + assert!(differentiate(&e, "x").is_err()); + } + + #[test] + fn unsupported_function_errors() { + // atan2 is binary and has no rule yet. + let e = expr_fn::atan2(col("x"), col("y")); + assert!(differentiate(&e, "x").is_err()); + } +} diff --git a/src/lib.rs b/src/lib.rs index c489609..042992a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -41,6 +41,8 @@ //! Will skip loading partitions whose time ranges are entirely before 2020-02-01. //! Supported operators: `=`, `<`, `>`, `<=`, `>=`, `BETWEEN`, `IN`, `AND`, `OR`. +mod autograd; + use std::any::Any; use std::collections::{HashMap, HashSet}; use std::ffi::CString; From ee2e011b34a5d6e258fdd46cb6d065ad9d6a1bbd Mon Sep 17 00:00:00 2001 From: Claude Date: Sat, 27 Jun 2026 23:28:29 +0000 Subject: [PATCH 02/17] Add Substrait bridge to apply grad() rewrite in the Python context Add the `grad` marker UDF and a plan-level rewriter (`rewrite_grad_calls`) to the autograd engine, plus a `grad_rewrite` PyO3 function that bridges the differentiation engine into the datafusion-python SessionContext. Because the native extension links its own copy of DataFusion, expressions cross the Python<->Rust boundary as Substrait protobuf. Python produces the logical plan as Substrait; `grad_rewrite` consumes it into a DataFusion LogicalPlan, rewrites every `grad(expr, column)` ScalarFunction into the symbolic derivative via `differentiate`, and re-produces Substrait bytes for Python to consume and execute. The custom xarray table provider round-trips because Substrait serializes table scans by name (resolved against the registry on consume), so the rewrite context only needs empty tables with matching schemas. `grad` is registered as a marker ScalarUDF that carries the differentiation request intact through parsing, planning, and serialization; it is always rewritten away before execution and errors if it ever reaches invoke. Deps: datafusion-substrait 52 and prost 0.14 (matching the substrait crate). Building now requires `protoc` (the substrait crate codegens from .proto). Verified end to end (produce -> grad_rewrite -> consume -> execute) against analytic derivatives for cos, the product rule, and exp with 0.0 error. Co-Authored-By: Claude Opus 4.8 Claude-Session: https://claude.ai/code/session_017mDoFJgsm9kS7SicGoCVF6 --- Cargo.lock | 292 ++++++++++++++++++++++++++++++++++++++++++++++-- Cargo.toml | 2 + src/autograd.rs | 126 ++++++++++++++++++--- src/lib.rs | 105 ++++++++++++++++- 4 files changed, 498 insertions(+), 27 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 21dfa95..fcaefe4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -397,6 +397,17 @@ dependencies = [ "abi_stable", ] +[[package]] +name = "async-recursion" +version = "1.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b43422f69d8ff38f95f1b2bb76517c91589a924d1559a0e935d7c8ce0274c11" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.114", +] + [[package]] name = "async-stream" version = "0.3.6" @@ -1480,6 +1491,26 @@ dependencies = [ "sqlparser", ] +[[package]] +name = "datafusion-substrait" +version = "52.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "199790fd96e852997b30da4ff11109378c944841757d93875ea85fc69587ec91" +dependencies = [ + "async-recursion", + "async-trait", + "chrono", + "datafusion", + "half", + "itertools", + "object_store", + "pbjson-types", + "prost", + "substrait", + "tokio", + "url", +] + [[package]] name = "digest" version = "0.10.7" @@ -1502,6 +1533,12 @@ dependencies = [ "syn 2.0.114", ] +[[package]] +name = "dyn-clone" +version = "1.0.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d0881ea181b1df73ff77ffaaf9c7544ecc11e82fba9b5f27b262a3c73a332555" + [[package]] name = "either" version = "1.15.0" @@ -2159,6 +2196,12 @@ dependencies = [ "simd-adler32", ] +[[package]] +name = "multimap" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d87ecb2933e8aeadb3e3a02b828fed80a7528047e68b4f424523a0981a3a084" + [[package]] name = "num-bigint" version = "0.4.6" @@ -2302,6 +2345,43 @@ version = "1.0.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" +[[package]] +name = "pbjson" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "898bac3fa00d0ba57a4e8289837e965baa2dee8c3749f3b11d45a64b4223d9c3" +dependencies = [ + "base64", + "serde", +] + +[[package]] +name = "pbjson-build" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af22d08a625a2213a78dbb0ffa253318c5c79ce3133d32d296655a7bdfb02095" +dependencies = [ + "heck", + "itertools", + "prost", + "prost-types", +] + +[[package]] +name = "pbjson-types" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e748e28374f10a330ee3bb9f29b828c0ac79831a32bab65015ad9b661ead526" +dependencies = [ + "bytes", + "chrono", + "pbjson", + "pbjson-build", + "prost", + "prost-build", + "serde", +] + [[package]] name = "percent-encoding" version = "2.3.2" @@ -2380,6 +2460,16 @@ dependencies = [ "zerocopy", ] +[[package]] +name = "prettyplease" +version = "0.2.37" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "479ca8adacdd7ce8f1fb39ce9ecccbfe93a3f1344b3d0d97f20bc0196208f62b" +dependencies = [ + "proc-macro2", + "syn 2.0.114", +] + [[package]] name = "proc-macro2" version = "1.0.101" @@ -2399,6 +2489,25 @@ dependencies = [ "prost-derive", ] +[[package]] +name = "prost-build" +version = "0.14.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "343d3bd7056eda839b03204e68deff7d1b13aba7af2b2fd16890697274262ee7" +dependencies = [ + "heck", + "itertools", + "log", + "multimap", + "petgraph", + "prettyplease", + "prost", + "prost-types", + "regex", + "syn 2.0.114", + "tempfile", +] + [[package]] name = "prost-derive" version = "0.14.3" @@ -2412,6 +2521,15 @@ dependencies = [ "syn 2.0.114", ] +[[package]] +name = "prost-types" +version = "0.14.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8991c4cbdb8bc5b11f0b074ffe286c30e523de90fee5ba8132f1399f23cb3dd7" +dependencies = [ + "prost", +] + [[package]] name = "psm" version = "0.1.26" @@ -2584,6 +2702,16 @@ version = "0.8.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "caf4aa5b0f434c91fe5c7f1ecb6a5ece2130b02ad2a590589dda5146df959001" +[[package]] +name = "regress" +version = "0.10.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2057b2325e68a893284d1538021ab90279adac1139957ca2a74426c6f118fb48" +dependencies = [ + "hashbrown 0.16.1", + "memchr", +] + [[package]] name = "repr_offset" version = "0.2.2" @@ -2636,6 +2764,30 @@ dependencies = [ "winapi-util", ] +[[package]] +name = "schemars" +version = "0.8.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3fbf2ae1b8bc8e02df939598064d22402220cd5bbcca1c76f7d6a310974d5615" +dependencies = [ + "dyn-clone", + "schemars_derive", + "serde", + "serde_json", +] + +[[package]] +name = "schemars_derive" +version = "0.8.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32e265784ad618884abaea0600a9adf15393368d840e0222d101a072f3f7534d" +dependencies = [ + "proc-macro2", + "quote", + "serde_derive_internals", + "syn 2.0.114", +] + [[package]] name = "scopeguard" version = "1.2.0" @@ -2647,6 +2799,10 @@ name = "semver" version = "1.0.27" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d767eb0aabc880b29956c35734170f26ed551a859dbd361d140cdbeca61ab1e2" +dependencies = [ + "serde", + "serde_core", +] [[package]] name = "seq-macro" @@ -2656,9 +2812,9 @@ checksum = "1bc711410fbe7399f390ca1c3b60ad0f53f80e95c5eb935e52268a0e2cd49acc" [[package]] name = "serde" -version = "1.0.227" +version = "1.0.228" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "80ece43fc6fbed4eb5392ab50c07334d3e577cbf40997ee896fe7af40bba4245" +checksum = "9a8e94ea7f378bd32cbbd37198a4a91436180c5bb472411e48b5ec2e2124ae9e" dependencies = [ "serde_core", "serde_derive", @@ -2666,18 +2822,29 @@ dependencies = [ [[package]] name = "serde_core" -version = "1.0.227" +version = "1.0.228" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a576275b607a2c86ea29e410193df32bc680303c82f31e275bbfcafe8b33be5" +checksum = "41d385c7d4ca58e59fc732af25c3983b67ac852c1a25000afe1175de458b67ad" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.227" +version = "1.0.228" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "51e694923b8824cf0e9b382adf0f60d4e05f348f357b38833a3fa5ed7c2ede04" +checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.114", +] + +[[package]] +name = "serde_derive_internals" +version = "0.29.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "18d26a20a969b9e3fdf2fc2d9f21eda6c40e2de84c9408bb5d3b05d499aae711" dependencies = [ "proc-macro2", "quote", @@ -2697,6 +2864,31 @@ dependencies = [ "serde_core", ] +[[package]] +name = "serde_tokenstream" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d7c49585c52c01f13c5c2ebb333f14f6885d76daa768d8a037d28017ec538c69" +dependencies = [ + "proc-macro2", + "quote", + "serde", + "syn 2.0.114", +] + +[[package]] +name = "serde_yaml" +version = "0.9.34+deprecated" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6a8b1a1a2ebf674015cc02edccce75287f1a0130d394307b36743c2f5d504b47" +dependencies = [ + "indexmap", + "itoa", + "ryu", + "serde", + "unsafe-libyaml", +] + [[package]] name = "sha2" version = "0.10.9" @@ -2791,6 +2983,31 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "substrait" +version = "0.62.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "62fc4b483a129b9772ccb9c3f7945a472112fdd9140da87f8a4e7f1d44e045d0" +dependencies = [ + "heck", + "pbjson", + "pbjson-build", + "pbjson-types", + "prettyplease", + "prost", + "prost-build", + "prost-types", + "regress", + "schemars", + "semver", + "serde", + "serde_json", + "serde_yaml", + "syn 2.0.114", + "typify", + "walkdir", +] + [[package]] name = "subtle" version = "2.6.1" @@ -2851,18 +3068,18 @@ dependencies = [ [[package]] name = "thiserror" -version = "2.0.16" +version = "2.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3467d614147380f2e4e374161426ff399c91084acd2363eaf549172b3d5e60c0" +checksum = "4288b5bcbc7920c07a1149a35cf9590a2aa808e0bc1eafaade0b80947865fbc4" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "2.0.16" +version = "2.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6c5e1be1c48b9172ee610da68fd9cd2770e7a4056cb3fc98710ee6906f0c7960" +checksum = "ebc4ee7f67670e9b64d05fa4253e753e016c6c95ff35b89b7941d6b856dec1d5" dependencies = [ "proc-macro2", "quote", @@ -3004,6 +3221,53 @@ version = "1.14.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f8c1ae7cc0fdb8b842d65d127cb981574b0d2b249b74d1c7a2986863dc134f71" +[[package]] +name = "typify" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6d5bcc6f62eb1fa8aa4098f39b29f93dcb914e17158b76c50360911257aa629" +dependencies = [ + "typify-impl", + "typify-macro", +] + +[[package]] +name = "typify-impl" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1eb359f7ffa4f9ebe947fa11a1b2da054564502968db5f317b7e37693cb2240" +dependencies = [ + "heck", + "log", + "proc-macro2", + "quote", + "regress", + "schemars", + "semver", + "serde", + "serde_json", + "syn 2.0.114", + "thiserror", + "unicode-ident", +] + +[[package]] +name = "typify-macro" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "911c32f3c8514b048c1b228361bebb5e6d73aeec01696e8cc0e82e2ffef8ab7a" +dependencies = [ + "proc-macro2", + "quote", + "schemars", + "semver", + "serde", + "serde_json", + "serde_tokenstream", + "syn 2.0.114", + "typify-impl", +] + [[package]] name = "unicode-ident" version = "1.0.19" @@ -3028,6 +3292,12 @@ version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7264e107f553ccae879d21fbea1d6724ac785e8c3bfc762137959b5802826ef3" +[[package]] +name = "unsafe-libyaml" +version = "0.2.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "673aac59facbab8a9007c7f6108d11f63b603f7cabff99fabf650fea5c32b861" + [[package]] name = "url" version = "2.5.7" @@ -3382,7 +3652,9 @@ dependencies = [ "async-trait", "datafusion", "datafusion-ffi", + "datafusion-substrait", "futures", + "prost", "pyo3", "pyo3-build-config", "tokio", diff --git a/Cargo.toml b/Cargo.toml index 1dc95bd..1fc5288 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,6 +23,8 @@ async-stream = "0.3" async-trait = "0.1" datafusion = { version = "52.0.0" } datafusion-ffi = { version = "52.0.0" } +datafusion-substrait = { version = "52.0.0" } +prost = "0.14" futures = { version = "0.3" } pyo3 = { version = "0.26.0", features = ["extension-module"] } tokio = { version = "1.46.1", features = ["rt"] } diff --git a/src/autograd.rs b/src/autograd.rs index 662c9f0..1e72586 100644 --- a/src/autograd.rs +++ b/src/autograd.rs @@ -30,12 +30,18 @@ #![allow(dead_code)] +use std::any::Any; use std::f64::consts::{LN_10, LN_2}; +use datafusion::arrow::datatypes::DataType; +use datafusion::common::tree_node::{Transformed, TreeNode}; use datafusion::common::{DataFusionError, Result, ScalarValue}; use datafusion::functions::math::expr_fn; use datafusion::logical_expr::expr::ScalarFunction; -use datafusion::logical_expr::{lit, BinaryExpr, Cast, Expr, Operator}; +use datafusion::logical_expr::{ + lit, BinaryExpr, Cast, ColumnarValue, Expr, LogicalPlan, Operator, ScalarFunctionArgs, + ScalarUDFImpl, Signature, Volatility, +}; // --------------------------------------------------------------------------- // Constant helpers and the 0/1-folding builders @@ -208,9 +214,7 @@ fn diff_binary(be: &BinaryExpr, wrt: &str) -> Result { // d/dx (a - b) = da - db Operator::Minus => Ok(sub(da, db)), // d/dx (a * b) = da*b + a*db (product rule) - Operator::Multiply => { - Ok(add(mul(da, b.clone()), mul(a.clone(), db))) - } + Operator::Multiply => Ok(add(mul(da, b.clone()), mul(a.clone(), db))), // d/dx (a / b) = (da*b - a*db) / b^2 (quotient rule) Operator::Divide => { let numerator = sub(mul(da, b.clone()), mul(a.clone(), db)); @@ -302,8 +306,7 @@ fn diff_power(args: &[Expr], wrt: &str) -> Result { if is_zero(&dbase) { return Ok(zero()); } - let outer = - mul(lit(c), expr_fn::power(base.clone(), lit(c - 1.0))); + let outer = mul(lit(c), expr_fn::power(base.clone(), lit(c - 1.0))); Ok(mul(outer, dbase)) } // Constant base, variable exponent. @@ -312,10 +315,7 @@ fn diff_power(args: &[Expr], wrt: &str) -> Result { if is_zero(&dexp) { return Ok(zero()); } - let outer = mul( - expr_fn::power(base.clone(), exponent.clone()), - lit(a.ln()), - ); + let outer = mul(expr_fn::power(base.clone(), exponent.clone()), lit(a.ln())); Ok(mul(outer, dexp)) } // General u^v requires the exp/log trick; deferred past the MVP. @@ -327,15 +327,113 @@ fn diff_power(args: &[Expr], wrt: &str) -> Result { } } +// --------------------------------------------------------------------------- +// The `grad` marker UDF and the plan-level rewrite +// --------------------------------------------------------------------------- + +/// A no-op placeholder UDF for `grad(expr, column)`. +/// +/// `grad` is a *marker*: it carries the differentiation request intact through +/// SQL parsing, logical planning, and Substrait serialization. It is always +/// rewritten away by [`rewrite_grad_calls`] before execution, so its `invoke` +/// is never reached in normal use (and deliberately errors if it somehow is, +/// rather than silently returning a wrong value). +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct GradMarker { + signature: Signature, +} + +impl GradMarker { + pub fn new() -> Self { + // grad(expr, column): two arguments of any (numeric) type. + Self { + signature: Signature::any(2, Volatility::Immutable), + } + } +} + +impl Default for GradMarker { + fn default() -> Self { + Self::new() + } +} + +impl ScalarUDFImpl for GradMarker { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "grad" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Float64) + } + + fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { + Err(DataFusionError::Execution( + "grad() marker reached execution without being rewritten; this is \ + an internal xarray-sql autograd error" + .to_string(), + )) + } +} + +/// Rewrite every `grad(expr, column)` call anywhere in a logical plan into the +/// symbolic derivative of `expr` with respect to `column`, leaving everything +/// else untouched. The plan's schema is recomputed afterwards because replacing +/// a `grad` call can change an expression's name or type. +pub fn rewrite_grad_calls(plan: LogicalPlan) -> Result { + let rewritten = plan + .transform_up(|node| node.map_expressions(rewrite_grad_in_expr))? + .data; + rewritten.recompute_schema() +} + +/// Replace any `grad(...)` calls nested anywhere inside a single expression. +fn rewrite_grad_in_expr(expr: Expr) -> Result> { + expr.transform_up(|e| { + let Expr::ScalarFunction(sf) = &e else { + return Ok(Transformed::no(e)); + }; + if sf.func.name() != "grad" { + return Ok(Transformed::no(e)); + } + if sf.args.len() != 2 { + return Err(DataFusionError::Plan(format!( + "grad() expects two arguments grad(expr, column), got {}", + sf.args.len() + ))); + } + let wrt = match &sf.args[1] { + Expr::Column(c) => c.name.clone(), + other => { + return Err(DataFusionError::Plan(format!( + "grad(): the second argument must be a bare column to \ + differentiate with respect to, got: {other}" + ))) + } + }; + let derivative = differentiate(&sf.args[0], &wrt)?; + Ok(Transformed::yes(derivative)) + }) +} + // --------------------------------------------------------------------------- // Tests // --------------------------------------------------------------------------- #[cfg(test)] mod tests { - use super::*; use datafusion::logical_expr::col; + use super::*; + #[test] fn constant_has_zero_derivative() { assert_eq!(differentiate(&lit(3.0_f64), "x").unwrap(), zero()); @@ -386,8 +484,7 @@ mod tests { #[test] fn composite_sin_times_x() { // d/dx (sin(x) * x) = cos(x)*x + sin(x) - let e = - binary(expr_fn::sin(col("x")), Operator::Multiply, col("x")); + let e = binary(expr_fn::sin(col("x")), Operator::Multiply, col("x")); let d = differentiate(&e, "x").unwrap(); assert_eq!(d.to_string(), "cos(x) * x + sin(x)"); } @@ -396,8 +493,7 @@ mod tests { fn power_constant_exponent() { // d/dx power(x, 2) = 2 * power(x, 1) * 1 = 2 * power(x, 1) let e = expr_fn::power(col("x"), lit(2.0_f64)); - let expected = - mul(lit(2.0_f64), expr_fn::power(col("x"), lit(1.0_f64))); + let expected = mul(lit(2.0_f64), expr_fn::power(col("x"), lit(1.0_f64))); assert_eq!(differentiate(&e, "x").unwrap(), expected); } diff --git a/src/lib.rs b/src/lib.rs index 042992a..5ab410f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -57,19 +57,25 @@ use async_trait::async_trait; use datafusion::catalog::streaming::StreamingTable; use datafusion::catalog::Session; use datafusion::common::{DataFusionError, Result as DFResult, ScalarValue}; +use datafusion::datasource::empty::EmptyTable; use datafusion::datasource::TableProvider; use datafusion::execution::TaskContext; use datafusion::logical_expr::expr::InList; use datafusion::logical_expr::{ - BinaryExpr, Expr, Operator, TableProviderFilterPushDown, TableType, + BinaryExpr, Expr, Operator, ScalarUDF, TableProviderFilterPushDown, TableType, }; use datafusion::physical_plan::stream::RecordBatchStreamAdapter; use datafusion::physical_plan::streaming::PartitionStream; use datafusion::physical_plan::{ExecutionPlan, SendableRecordBatchStream}; +use datafusion::prelude::SessionContext; use datafusion_ffi::proto::logical_extension_codec::FFI_LogicalExtensionCodec; use datafusion_ffi::table_provider::FFI_TableProvider; +use datafusion_substrait::logical_plan::consumer::from_substrait_plan; +use datafusion_substrait::logical_plan::producer::to_substrait_plan; +use datafusion_substrait::substrait::proto::Plan; +use prost::Message; use pyo3::prelude::*; -use pyo3::types::{PyCapsule, PyList}; +use pyo3::types::{PyBytes, PyCapsule, PyList}; // ============================================================================ // Partition Metadata Types for Filter Pushdown @@ -983,9 +989,104 @@ impl LazyArrowStreamTable { } } +// ============================================================================ +// Autograd: Substrait-level grad() rewrite +// ============================================================================ + +/// Rewrite `grad(expr, column)` calls in a Substrait plan into their symbolic +/// derivatives. +/// +/// The autograd engine operates on DataFusion logical `Expr` trees. To apply it +/// inside the datafusion-python `SessionContext` (which links its own copy of +/// DataFusion), we move the plan across the boundary as Substrait protobuf: +/// Python produces the plan, this function consumes it into a DataFusion +/// `LogicalPlan`, rewrites every `grad(...)` into the differentiated +/// expression, and re-produces Substrait bytes for Python to consume and +/// execute. +/// +/// Args: +/// plan_bytes: A Substrait `Plan` protobuf, as produced by +/// datafusion-python's +/// ``Producer.to_substrait_plan(plan, ctx).encode()``. +/// tables: A list of ``(name, pyarrow.Schema)`` pairs for every table the +/// plan scans. The consumer resolves table references by name, so each +/// referenced table must be registered here with a matching schema +/// (the data itself is never read — an empty table suffices). +/// +/// Returns: +/// The rewritten Substrait `Plan` protobuf bytes, ready for +/// ``Consumer.from_substrait_plan(ctx, plan)``. +#[pyfunction] +fn grad_rewrite<'py>( + py: Python<'py>, + plan_bytes: &[u8], + tables: Vec<(String, Bound<'py, PyAny>)>, +) -> PyResult> { + // A fresh, data-free context purely for the rewrite. It needs the grad + // marker UDF (so the consumer can resolve the function) and an empty table + // per referenced name (so the consumer can resolve table scans). + let ctx = SessionContext::new(); + ctx.register_udf(ScalarUDF::from(autograd::GradMarker::new())); + + for (name, schema_obj) in &tables { + let schema = Schema::from_pyarrow_bound(schema_obj).map_err(|e| { + pyo3::exceptions::PyTypeError::new_err(format!( + "grad_rewrite: failed to convert schema for table '{name}': {e}" + )) + })?; + let provider = Arc::new(EmptyTable::new(Arc::new(schema))); + ctx.register_table(name.as_str(), provider).map_err(|e| { + pyo3::exceptions::PyValueError::new_err(format!( + "grad_rewrite: failed to register table '{name}': {e}" + )) + })?; + } + + let state = ctx.state(); + + let plan = Plan::decode(plan_bytes).map_err(|e| { + pyo3::exceptions::PyValueError::new_err(format!( + "grad_rewrite: failed to decode Substrait plan: {e}" + )) + })?; + + // from_substrait_plan is async but does no real I/O here (empty tables + // resolve immediately), so a minimal current-thread runtime suffices. + let runtime = tokio::runtime::Builder::new_current_thread() + .build() + .map_err(|e| { + pyo3::exceptions::PyRuntimeError::new_err(format!( + "grad_rewrite: failed to build runtime: {e}" + )) + })?; + + let logical = runtime + .block_on(from_substrait_plan(&state, &plan)) + .map_err(|e| { + pyo3::exceptions::PyValueError::new_err(format!( + "grad_rewrite: failed to consume Substrait plan: {e}" + )) + })?; + + let rewritten = autograd::rewrite_grad_calls(logical).map_err(|e| { + pyo3::exceptions::PyValueError::new_err(format!( + "grad_rewrite: failed to rewrite grad() calls: {e}" + )) + })?; + + let out_plan = to_substrait_plan(&rewritten, &state).map_err(|e| { + pyo3::exceptions::PyValueError::new_err(format!( + "grad_rewrite: failed to produce Substrait plan: {e}" + )) + })?; + + Ok(PyBytes::new(py, &out_plan.encode_to_vec())) +} + /// Python module initialization #[pymodule] fn _native(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; + m.add_function(wrap_pyfunction!(grad_rewrite, m)?)?; Ok(()) } From d456934f728ddf69ebaf5d3b01e16427869db5f3 Mon Sep 17 00:00:00 2001 From: Claude Date: Sat, 27 Jun 2026 23:30:58 +0000 Subject: [PATCH 03/17] Expose grad() in XarrayContext SQL via the Substrait rewrite MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Wire the autograd surface into XarrayContext so users can write calculus directly in SQL: ctx.sql("SELECT grad(sin(val), val) AS d_val, sin(val) AS val FROM t") On construction the context registers the `grad` marker UDF so such queries parse and plan. XarrayContext.sql() detects `grad(` (a cheap regex gate so ordinary queries are untouched) and routes through _sql_with_autograd: it plans the query, produces the logical plan as Substrait, calls the native grad_rewrite to differentiate every grad(expr, column) symbolically, then consumes the rewritten Substrait back into an executable DataFrame. Table scans are resolved by name on the consume side, so _table_schemas() passes the (name, schema) of each registered table to the rewrite. Schema- qualified tables (mixed-dimension datasets) are skipped for now and noted as a follow-up. Adds tests/test_autograd.py covering sin/cos, product and quotient rules, power, exp, the non-grad passthrough, and a clear error for unsupported functions — all checked against numpy analytic derivatives. Existing SQL tests still pass. Co-Authored-By: Claude Opus 4.8 Claude-Session: https://claude.ai/code/session_017mDoFJgsm9kS7SicGoCVF6 --- tests/test_autograd.py | 76 ++++++++++++++++++++++++++++++++++++++++++ xarray_sql/sql.py | 72 +++++++++++++++++++++++++++++++++++++-- 2 files changed, 146 insertions(+), 2 deletions(-) create mode 100644 tests/test_autograd.py diff --git a/tests/test_autograd.py b/tests/test_autograd.py new file mode 100644 index 0000000..3e7827a --- /dev/null +++ b/tests/test_autograd.py @@ -0,0 +1,76 @@ +"""Tests for the SQL autograd surface: ``SELECT grad(expr, column) ...``. + +These exercise the full path — XarrayContext.sql() -> Substrait -> native +grad_rewrite -> Substrait -> execute — and compare results against analytic +derivatives computed with numpy. +""" + +import numpy as np +import pytest +import xarray as xr + +import xarray_sql as xql + + +@pytest.fixture +def ctx(): + val = np.linspace(0.1, 3.0, 16) + ds = xr.Dataset( + {"val": (("i",), val)}, + coords={"i": np.arange(16)}, + ) + context = xql.XarrayContext() + context.from_dataset("t", ds, chunks={"i": 5}) + return context + + +def _ordered(df, key="i"): + """Collect a result DataFrame into a dict of column -> numpy array, sorted + by the integer key column so comparisons are index-aligned.""" + pdf = df.to_pandas().sort_values(key) + return {c: pdf[c].to_numpy() for c in pdf.columns} + + +def test_grad_sin_is_cos(ctx): + val = np.linspace(0.1, 3.0, 16) + res = _ordered(ctx.sql("SELECT i, grad(sin(val), val) AS d FROM t")) + np.testing.assert_allclose(res["d"], np.cos(val)) + + +def test_grad_product_rule(ctx): + val = np.linspace(0.1, 3.0, 16) + res = _ordered(ctx.sql("SELECT i, grad(sin(val) * val, val) AS d FROM t")) + np.testing.assert_allclose(res["d"], np.cos(val) * val + np.sin(val)) + + +def test_grad_exp_equals_value(ctx): + val = np.linspace(0.1, 3.0, 16) + res = _ordered( + ctx.sql("SELECT i, exp(val) AS v, grad(exp(val), val) AS d FROM t") + ) + np.testing.assert_allclose(res["d"], np.exp(val)) + np.testing.assert_allclose(res["d"], res["v"]) + + +def test_grad_quotient_and_power(ctx): + val = np.linspace(0.1, 3.0, 16) + res = _ordered( + ctx.sql( + "SELECT i, grad(1.0 / val, val) AS dinv, " + "grad(power(val, 3), val) AS dcube FROM t" + ) + ) + np.testing.assert_allclose(res["dinv"], -1.0 / val**2) + np.testing.assert_allclose(res["dcube"], 3.0 * val**2) + + +def test_non_grad_query_is_unaffected(ctx): + # Queries without grad() bypass the rewrite and behave normally. + res = _ordered(ctx.sql("SELECT i, val FROM t")) + np.testing.assert_allclose(res["val"], np.linspace(0.1, 3.0, 16)) + + +def test_unsupported_function_raises(ctx): + # atan2 has no derivative rule yet -> a clear error, not a wrong answer. + with pytest.raises(Exception): + ctx.sql("SELECT grad(atan2(val, val), val) AS d FROM t").to_pandas() diff --git a/xarray_sql/sql.py b/xarray_sql/sql.py index 0ec60ad..5577d61 100644 --- a/xarray_sql/sql.py +++ b/xarray_sql/sql.py @@ -1,13 +1,22 @@ +import re + +import pyarrow as pa import xarray as xr -from datafusion import SessionContext +from datafusion import SessionContext, udf from datafusion.catalog import Schema +from datafusion.substrait import Consumer, Producer, Serde from collections import defaultdict +from . import _native from . import cftime as cft from .df import Chunks from .ds import XarrayDataFrame from .reader import read_xarray_table +# Matches a call to the autograd marker function ``grad(`` (case-insensitive), +# used as a cheap gate so ordinary queries skip the Substrait round-trip. +_GRAD_CALL = re.compile(r"\bgrad\s*\(", re.IGNORECASE) + class XarrayContext(SessionContext): """A datafusion `SessionContext` that also supports `xarray.Dataset`s.""" @@ -21,6 +30,24 @@ def __init__(self, *args, **kwargs): # in SQL (e.g. ``"air"`` for a uniform-dim Dataset, or # ``"era5.surface"`` for one entry from a multi-dim-group split). self._registered_datasets: dict[str, xr.Dataset] = {} + self._register_autograd_udfs() + + def _register_autograd_udfs(self) -> None: + """Register the ``grad`` marker UDF used by the autograd rewrite. + + ``grad(expr, column)`` is a *marker*: it lets queries parse and plan + with the differentiation request intact. It is never executed — the + Substrait rewrite in :meth:`sql` replaces every ``grad(...)`` with the + symbolic derivative of ``expr`` before execution. + """ + marker = udf( + lambda expr, column: expr, + [pa.float64(), pa.float64()], + pa.float64(), + "immutable", + "grad", + ) + self.register_udf(marker) def from_dataset( self, @@ -174,9 +201,50 @@ def sql(self, query: str, *args, **kwargs) -> XarrayDataFrame: Returns: An :class:`XarrayDataFrame` wrapping the DataFusion DataFrame. """ - inner = super().sql(query, *args, **kwargs) + if _GRAD_CALL.search(query): + inner = self._sql_with_autograd(query, *args, **kwargs) + else: + inner = super().sql(query, *args, **kwargs) return XarrayDataFrame(inner, templates=self._registered_datasets) + def _sql_with_autograd(self, query: str, *args, **kwargs): + """Plan ``query``, rewrite ``grad(...)`` calls, return a DataFrame. + + The differentiation engine lives in the native (Rust) extension and + operates on DataFusion logical expressions. Since that extension links + its own copy of DataFusion, the plan crosses the boundary as Substrait: + we produce the logical plan as Substrait, hand it to ``grad_rewrite`` + (which differentiates every ``grad(expr, column)`` symbolically), then + consume the rewritten Substrait back into an executable DataFrame. + """ + plan = super().sql(query, *args, **kwargs).logical_plan() + substrait_plan = Producer.to_substrait_plan(plan, self) + rewritten = _native.grad_rewrite( + substrait_plan.encode(), self._table_schemas() + ) + new_plan = Consumer.from_substrait_plan( + self, Serde.deserialize_bytes(rewritten) + ) + return self.create_dataframe_from_logical_plan(new_plan) + + def _table_schemas(self) -> list[tuple[str, pa.Schema]]: + """Return ``(name, schema)`` for each registered table. + + The Substrait consumer in ``grad_rewrite`` resolves table scans by + name, so it needs the schema of every table the plan might reference. + Only metadata is read here — never the underlying data. + """ + schemas = [] + for name in self._registered_datasets: + try: + schemas.append((name, self.table(name).schema())) + except Exception: + # Schema-qualified tables (mixed-dimension datasets) aren't + # resolvable by a bare name yet; skip rather than fail the + # whole query. grad() over those is a follow-up. + continue + return schemas + def _group_vars_by_dims(ds: xr.Dataset) -> dict[tuple[str, ...], list[str]]: """Group variables in the dataset based on shared dims. From 6946ca9b425305612ec5a733ddc2cf3bfa44f2f6 Mon Sep 17 00:00:00 2001 From: Claude Date: Sun, 28 Jun 2026 10:22:42 +0000 Subject: [PATCH 04/17] ci: install protoc for the substrait build Adding datafusion-substrait pulls in the `substrait` crate, whose build script generates Rust from .proto files and requires `protoc`. Without it the Rust/maturin builds fail. - ci.yml, ci-build.yml, ci-rust.yml: add arduino/setup-protoc before the build (covers Linux, macOS and Windows runners). - publish.yml: setup-protoc for the macOS/Windows wheel job; for the manylinux maturin-action jobs install protoc inside the container via before-script-linux (arch-aware download). The sdist job is unchanged as it packages source without compiling. Co-Authored-By: Claude Opus 4.8 Claude-Session: https://claude.ai/code/session_017mDoFJgsm9kS7SicGoCVF6 --- .github/workflows/ci-build.yml | 7 +++++++ .github/workflows/ci-rust.yml | 7 +++++++ .github/workflows/ci.yml | 7 +++++++ .github/workflows/publish.yml | 31 +++++++++++++++++++++++++++++++ 4 files changed, 52 insertions(+) diff --git a/.github/workflows/ci-build.yml b/.github/workflows/ci-build.yml index 214388e..ec89b8e 100644 --- a/.github/workflows/ci-build.yml +++ b/.github/workflows/ci-build.yml @@ -31,6 +31,13 @@ jobs: - uses: dtolnay/rust-toolchain@stable + # The `substrait` crate (a datafusion-substrait dependency) generates + # code from .proto files at build time and requires protoc. + - name: Install Protoc + uses: arduino/setup-protoc@v3 + with: + repo-token: ${{ secrets.GITHUB_TOKEN }} + - name: Setup sccache uses: mozilla-actions/sccache-action@v0.0.9 diff --git a/.github/workflows/ci-rust.yml b/.github/workflows/ci-rust.yml index 68f1ce6..9054a44 100644 --- a/.github/workflows/ci-rust.yml +++ b/.github/workflows/ci-rust.yml @@ -27,6 +27,13 @@ jobs: with: components: clippy + # The `substrait` crate (a datafusion-substrait dependency) generates + # code from .proto files at build time and requires protoc. + - name: Install Protoc + uses: arduino/setup-protoc@v3 + with: + repo-token: ${{ secrets.GITHUB_TOKEN }} + - name: Setup sccache uses: mozilla-actions/sccache-action@v0.0.9 diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index c1d892d..587784d 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -43,6 +43,13 @@ jobs: - uses: dtolnay/rust-toolchain@stable + # The `substrait` crate (a datafusion-substrait dependency) generates + # code from .proto files at build time and requires protoc. + - name: Install Protoc + uses: arduino/setup-protoc@v3 + with: + repo-token: ${{ secrets.GITHUB_TOKEN }} + - name: Setup sccache uses: mozilla-actions/sccache-action@v0.0.9 diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index a0008bc..48f0155 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -54,6 +54,13 @@ jobs: - uses: dtolnay/rust-toolchain@stable + # The `substrait` crate (a datafusion-substrait dependency) generates + # code from .proto files at build time and requires protoc. + - name: Install Protoc + uses: arduino/setup-protoc@v3 + with: + repo-token: ${{ secrets.GITHUB_TOKEN }} + - name: Setup sccache uses: mozilla-actions/sccache-action@v0.0.9 @@ -91,6 +98,18 @@ jobs: manylinux: 2_28 rustup-components: rust-std rustfmt sccache: 'true' + # protoc is required by the substrait crate build and must be + # installed inside the manylinux container. + before-script-linux: | + PROTOC_VERSION=29.3 + case "$(uname -m)" in + x86_64) PROTOC_ARCH=x86_64 ;; + aarch64) PROTOC_ARCH=aarch_64 ;; + *) echo "unsupported arch $(uname -m)"; exit 1 ;; + esac + curl -L -o /tmp/protoc.zip "https://github.com/protocolbuffers/protobuf/releases/download/v${PROTOC_VERSION}/protoc-${PROTOC_VERSION}-linux-${PROTOC_ARCH}.zip" + unzip -o /tmp/protoc.zip -d /usr/local bin/protoc 'include/*' + protoc --version args: --release --strip --out dist -i python3.10 python3.11 python3.12 python3.13 - uses: actions/upload-artifact@v6 @@ -113,6 +132,18 @@ jobs: manylinux: 2_28 rustup-components: rust-std rustfmt sccache: 'true' + # protoc is required by the substrait crate build and must be + # installed inside the manylinux container. + before-script-linux: | + PROTOC_VERSION=29.3 + case "$(uname -m)" in + x86_64) PROTOC_ARCH=x86_64 ;; + aarch64) PROTOC_ARCH=aarch_64 ;; + *) echo "unsupported arch $(uname -m)"; exit 1 ;; + esac + curl -L -o /tmp/protoc.zip "https://github.com/protocolbuffers/protobuf/releases/download/v${PROTOC_VERSION}/protoc-${PROTOC_VERSION}-linux-${PROTOC_ARCH}.zip" + unzip -o /tmp/protoc.zip -d /usr/local bin/protoc 'include/*' + protoc --version args: --release --strip --out dist -i python3.10 python3.11 python3.12 python3.13 - uses: actions/upload-artifact@v6 From 724099a41b119ce235b9c0860a0c6b3a6f7aa867 Mon Sep 17 00:00:00 2001 From: Claude Date: Sun, 28 Jun 2026 10:36:25 +0000 Subject: [PATCH 05/17] Add jacobian() for multi-input gradients Extend the autograd surface from scalar grad() to multi-input Jacobians. SELECT jacobian(sin(x) * y, [x, y]) AS jac FROM g -- per row: [d/dx, d/dy] = [cos(x)*y, sin(x)] (a List) `jacobian(expr, [c1, c2, ...])` differentiates `expr` with respect to each listed column and returns the gradient row as an array. Using a SQL array for the inputs keeps the marker at fixed arity two (avoiding variadic-UDF issues): the `[c1, c2, ...]` parses to make_array(c1, c2, ...), from which the rewrite extracts the input columns; the result is built with make_array of the partials. Array/list columns round-trip through Substrait, verified end to end. The single grad() marker is generalized into a reusable MarkerUdf (with grad_marker()/jacobian_marker() constructors and per-marker return types), and the plan rewrite dispatches on the function name. A full Jacobian can also be written as separate scalar grad() columns, which already worked; both forms are covered by tests against numpy analytic derivatives. Co-Authored-By: Claude Opus 4.8 Claude-Session: https://claude.ai/code/session_017mDoFJgsm9kS7SicGoCVF6 --- src/autograd.rs | 199 ++++++++++++++++++++++++++++++----------- src/lib.rs | 5 +- tests/test_autograd.py | 53 +++++++++++ xarray_sql/sql.py | 46 ++++++---- 4 files changed, 236 insertions(+), 67 deletions(-) diff --git a/src/autograd.rs b/src/autograd.rs index 1e72586..5b971cf 100644 --- a/src/autograd.rs +++ b/src/autograd.rs @@ -32,15 +32,17 @@ use std::any::Any; use std::f64::consts::{LN_10, LN_2}; +use std::sync::Arc; -use datafusion::arrow::datatypes::DataType; +use datafusion::arrow::datatypes::{DataType, Field}; use datafusion::common::tree_node::{Transformed, TreeNode}; use datafusion::common::{DataFusionError, Result, ScalarValue}; use datafusion::functions::math::expr_fn; +use datafusion::functions_nested::expr_fn::make_array; use datafusion::logical_expr::expr::ScalarFunction; use datafusion::logical_expr::{ lit, BinaryExpr, Cast, ColumnarValue, Expr, LogicalPlan, Operator, ScalarFunctionArgs, - ScalarUDFImpl, Signature, Volatility, + ScalarUDF, ScalarUDFImpl, Signature, Volatility, }; // --------------------------------------------------------------------------- @@ -328,43 +330,42 @@ fn diff_power(args: &[Expr], wrt: &str) -> Result { } // --------------------------------------------------------------------------- -// The `grad` marker UDF and the plan-level rewrite +// The `grad` / `jacobian` marker UDFs and the plan-level rewrite // --------------------------------------------------------------------------- -/// A no-op placeholder UDF for `grad(expr, column)`. +/// A no-op placeholder UDF for the autograd surface functions. /// -/// `grad` is a *marker*: it carries the differentiation request intact through -/// SQL parsing, logical planning, and Substrait serialization. It is always -/// rewritten away by [`rewrite_grad_calls`] before execution, so its `invoke` -/// is never reached in normal use (and deliberately errors if it somehow is, -/// rather than silently returning a wrong value). +/// `grad` and `jacobian` are *markers*: they carry the differentiation request +/// intact through SQL parsing, logical planning, and Substrait serialization. +/// They are always rewritten away by [`rewrite_grad_calls`] before execution, +/// so `invoke` is never reached in normal use (and deliberately errors if it +/// somehow is, rather than silently returning a wrong value). #[derive(Debug, PartialEq, Eq, Hash)] -pub struct GradMarker { +pub struct MarkerUdf { + name: String, signature: Signature, + return_type: DataType, } -impl GradMarker { - pub fn new() -> Self { - // grad(expr, column): two arguments of any (numeric) type. +impl MarkerUdf { + fn new(name: &str, return_type: DataType) -> Self { Self { + name: name.to_string(), + // Both markers take two arguments: the expression and either a + // column (grad) or an array of columns (jacobian). signature: Signature::any(2, Volatility::Immutable), + return_type, } } } -impl Default for GradMarker { - fn default() -> Self { - Self::new() - } -} - -impl ScalarUDFImpl for GradMarker { +impl ScalarUDFImpl for MarkerUdf { fn as_any(&self) -> &dyn Any { self } fn name(&self) -> &str { - "grad" + &self.name } fn signature(&self) -> &Signature { @@ -372,22 +373,75 @@ impl ScalarUDFImpl for GradMarker { } fn return_type(&self, _arg_types: &[DataType]) -> Result { - Ok(DataType::Float64) + Ok(self.return_type.clone()) } fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { - Err(DataFusionError::Execution( - "grad() marker reached execution without being rewritten; this is \ - an internal xarray-sql autograd error" - .to_string(), - )) + Err(DataFusionError::Execution(format!( + "{}() marker reached execution without being rewritten; this is \ + an internal xarray-sql autograd error", + self.name + ))) } } -/// Rewrite every `grad(expr, column)` call anywhere in a logical plan into the -/// symbolic derivative of `expr` with respect to `column`, leaving everything -/// else untouched. The plan's schema is recomputed afterwards because replacing -/// a `grad` call can change an expression's name or type. +/// A `List` data type, the output of a `jacobian(...)` call. +fn list_of_f64() -> DataType { + DataType::List(Arc::new(Field::new("item", DataType::Float64, true))) +} + +/// The `grad(expr, column)` marker UDF: returns a scalar derivative. +pub fn grad_marker() -> ScalarUDF { + ScalarUDF::from(MarkerUdf::new("grad", DataType::Float64)) +} + +/// The `jacobian(expr, [c1, c2, ...])` marker UDF: returns the gradient of +/// `expr` with respect to several columns as a `List`. +pub fn jacobian_marker() -> ScalarUDF { + ScalarUDF::from(MarkerUdf::new("jacobian", list_of_f64())) +} + +/// Build the Jacobian row `[d(expr)/dc1, d(expr)/dc2, ...]` as an array +/// expression (`make_array`), differentiating `expr` w.r.t. each named column. +fn jacobian(expr: &Expr, wrt: &[String]) -> Result { + let partials = wrt + .iter() + .map(|c| differentiate(expr, c)) + .collect::>>()?; + Ok(make_array(partials)) +} + +/// Extract the bare column names from an array-literal expression, i.e. the +/// `make_array(c1, c2, ...)` that a SQL `[c1, c2, ...]` array parses into. +fn columns_from_array(expr: &Expr) -> Result> { + let Expr::ScalarFunction(sf) = expr else { + return Err(DataFusionError::Plan(format!( + "jacobian(): the second argument must be an array of columns \ + like [x, y, z], got: {expr}" + ))); + }; + if sf.func.name() != "make_array" { + return Err(DataFusionError::Plan(format!( + "jacobian(): the second argument must be an array of columns \ + like [x, y, z], got: {expr}" + ))); + } + sf.args + .iter() + .map(|a| match a { + Expr::Column(c) => Ok(c.name.clone()), + other => Err(DataFusionError::Plan(format!( + "jacobian(): array entries must be bare columns to \ + differentiate with respect to, got: {other}" + ))), + }) + .collect() +} + +/// Rewrite every `grad(...)` / `jacobian(...)` call anywhere in a logical plan +/// into its symbolic derivative(s), leaving everything else untouched. The +/// plan's schema is recomputed afterwards because replacing a marker can change +/// an expression's name or type. pub fn rewrite_grad_calls(plan: LogicalPlan) -> Result { let rewritten = plan .transform_up(|node| node.map_expressions(rewrite_grad_in_expr))? @@ -395,33 +449,52 @@ pub fn rewrite_grad_calls(plan: LogicalPlan) -> Result { rewritten.recompute_schema() } -/// Replace any `grad(...)` calls nested anywhere inside a single expression. +/// Replace any `grad(...)` / `jacobian(...)` calls nested anywhere inside a +/// single expression. fn rewrite_grad_in_expr(expr: Expr) -> Result> { expr.transform_up(|e| { let Expr::ScalarFunction(sf) = &e else { return Ok(Transformed::no(e)); }; - if sf.func.name() != "grad" { - return Ok(Transformed::no(e)); + match sf.func.name() { + "grad" => Ok(Transformed::yes(rewrite_grad(&sf.args)?)), + "jacobian" => Ok(Transformed::yes(rewrite_jacobian(&sf.args)?)), + _ => Ok(Transformed::no(e)), } - if sf.args.len() != 2 { + }) +} + +/// `grad(expr, column)` -> d(expr)/d(column). +fn rewrite_grad(args: &[Expr]) -> Result { + if args.len() != 2 { + return Err(DataFusionError::Plan(format!( + "grad() expects two arguments grad(expr, column), got {}", + args.len() + ))); + } + let wrt = match &args[1] { + Expr::Column(c) => c.name.clone(), + other => { return Err(DataFusionError::Plan(format!( - "grad() expects two arguments grad(expr, column), got {}", - sf.args.len() - ))); + "grad(): the second argument must be a bare column to \ + differentiate with respect to, got: {other}" + ))) } - let wrt = match &sf.args[1] { - Expr::Column(c) => c.name.clone(), - other => { - return Err(DataFusionError::Plan(format!( - "grad(): the second argument must be a bare column to \ - differentiate with respect to, got: {other}" - ))) - } - }; - let derivative = differentiate(&sf.args[0], &wrt)?; - Ok(Transformed::yes(derivative)) - }) + }; + differentiate(&args[0], &wrt) +} + +/// `jacobian(expr, [c1, c2, ...])` -> array `[d(expr)/dc1, d(expr)/dc2, ...]`. +fn rewrite_jacobian(args: &[Expr]) -> Result { + if args.len() != 2 { + return Err(DataFusionError::Plan(format!( + "jacobian() expects two arguments jacobian(expr, [c1, c2, ...]), \ + got {}", + args.len() + ))); + } + let wrt = columns_from_array(&args[1])?; + jacobian(&args[0], &wrt) } // --------------------------------------------------------------------------- @@ -509,4 +582,30 @@ mod tests { let e = expr_fn::atan2(col("x"), col("y")); assert!(differentiate(&e, "x").is_err()); } + + #[test] + fn jacobian_builds_array_of_partials() { + // jacobian(x*y, [x, y]) = [d/dx, d/dy] = [y, x] + let f = binary(col("x"), Operator::Multiply, col("y")); + let j = jacobian(&f, &["x".to_string(), "y".to_string()]).unwrap(); + let expected = make_array(vec![col("y"), col("x")]); + assert_eq!(j, expected); + } + + #[test] + fn jacobian_single_input_is_one_element_array() { + let j = jacobian(&expr_fn::sin(col("x")), &["x".to_string()]).unwrap(); + assert_eq!(j, make_array(vec![expr_fn::cos(col("x"))])); + } + + #[test] + fn columns_from_array_extracts_names() { + let arr = make_array(vec![col("a"), col("b"), col("c")]); + assert_eq!(columns_from_array(&arr).unwrap(), vec!["a", "b", "c"]); + } + + #[test] + fn columns_from_array_rejects_non_array() { + assert!(columns_from_array(&col("x")).is_err()); + } } diff --git a/src/lib.rs b/src/lib.rs index 5ab410f..754a390 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -62,7 +62,7 @@ use datafusion::datasource::TableProvider; use datafusion::execution::TaskContext; use datafusion::logical_expr::expr::InList; use datafusion::logical_expr::{ - BinaryExpr, Expr, Operator, ScalarUDF, TableProviderFilterPushDown, TableType, + BinaryExpr, Expr, Operator, TableProviderFilterPushDown, TableType, }; use datafusion::physical_plan::stream::RecordBatchStreamAdapter; use datafusion::physical_plan::streaming::PartitionStream; @@ -1026,7 +1026,8 @@ fn grad_rewrite<'py>( // marker UDF (so the consumer can resolve the function) and an empty table // per referenced name (so the consumer can resolve table scans). let ctx = SessionContext::new(); - ctx.register_udf(ScalarUDF::from(autograd::GradMarker::new())); + ctx.register_udf(autograd::grad_marker()); + ctx.register_udf(autograd::jacobian_marker()); for (name, schema_obj) in &tables { let schema = Schema::from_pyarrow_bound(schema_obj).map_err(|e| { diff --git a/tests/test_autograd.py b/tests/test_autograd.py index 3e7827a..13f10a2 100644 --- a/tests/test_autograd.py +++ b/tests/test_autograd.py @@ -24,6 +24,22 @@ def ctx(): return context +@pytest.fixture +def ctx_xy(): + rng = np.random.default_rng(0) + n = 16 + ds = xr.Dataset( + { + "x": (("i",), rng.uniform(0.5, 2.5, n)), + "y": (("i",), rng.uniform(0.5, 2.5, n)), + }, + coords={"i": np.arange(n)}, + ) + context = xql.XarrayContext() + context.from_dataset("g", ds, chunks={"i": 5}) + return context, ds + + def _ordered(df, key="i"): """Collect a result DataFrame into a dict of column -> numpy array, sorted by the integer key column so comparisons are index-aligned.""" @@ -74,3 +90,40 @@ def test_unsupported_function_raises(ctx): # atan2 has no derivative rule yet -> a clear error, not a wrong answer. with pytest.raises(Exception): ctx.sql("SELECT grad(atan2(val, val), val) AS d FROM t").to_pandas() + + +def test_multi_input_grad_columns(ctx_xy): + # A full Jacobian written as separate scalar grad() columns: + # f = x*y -> df/dx = y, df/dy = x. + context, ds = ctx_xy + res = _ordered( + context.sql( + "SELECT i, grad(x * y, x) AS dfdx, grad(x * y, y) AS dfdy FROM g" + ) + ) + np.testing.assert_allclose(res["dfdx"], ds["y"].values) + np.testing.assert_allclose(res["dfdy"], ds["x"].values) + + +def test_jacobian_array(ctx_xy): + # jacobian(f, [x, y]) returns the gradient row [df/dx, df/dy] per row. + context, ds = ctx_xy + res = _ordered( + context.sql("SELECT i, jacobian(x * y, [x, y]) AS jac FROM g") + ) + jac = np.stack([np.asarray(v, dtype=float) for v in res["jac"]]) + # column 0 is df/dx = y, column 1 is df/dy = x + np.testing.assert_allclose(jac[:, 0], ds["y"].values) + np.testing.assert_allclose(jac[:, 1], ds["x"].values) + + +def test_jacobian_array_nonlinear(ctx_xy): + # jacobian(sin(x) * y, [x, y]) = [cos(x)*y, sin(x)] + context, ds = ctx_xy + x, y = ds["x"].values, ds["y"].values + res = _ordered( + context.sql("SELECT i, jacobian(sin(x) * y, [x, y]) AS jac FROM g") + ) + jac = np.stack([np.asarray(v, dtype=float) for v in res["jac"]]) + np.testing.assert_allclose(jac[:, 0], np.cos(x) * y) + np.testing.assert_allclose(jac[:, 1], np.sin(x)) diff --git a/xarray_sql/sql.py b/xarray_sql/sql.py index 5577d61..635a732 100644 --- a/xarray_sql/sql.py +++ b/xarray_sql/sql.py @@ -13,9 +13,10 @@ from .ds import XarrayDataFrame from .reader import read_xarray_table -# Matches a call to the autograd marker function ``grad(`` (case-insensitive), -# used as a cheap gate so ordinary queries skip the Substrait round-trip. -_GRAD_CALL = re.compile(r"\bgrad\s*\(", re.IGNORECASE) +# Matches a call to an autograd marker function (``grad(`` / ``jacobian(``, +# case-insensitive), used as a cheap gate so ordinary queries skip the +# Substrait round-trip. +_GRAD_CALL = re.compile(r"\b(grad|jacobian)\s*\(", re.IGNORECASE) class XarrayContext(SessionContext): @@ -33,21 +34,36 @@ def __init__(self, *args, **kwargs): self._register_autograd_udfs() def _register_autograd_udfs(self) -> None: - """Register the ``grad`` marker UDF used by the autograd rewrite. + """Register the ``grad`` / ``jacobian`` marker UDFs. - ``grad(expr, column)`` is a *marker*: it lets queries parse and plan - with the differentiation request intact. It is never executed — the - Substrait rewrite in :meth:`sql` replaces every ``grad(...)`` with the - symbolic derivative of ``expr`` before execution. + These are *markers*: they let queries parse and plan with the + differentiation request intact. They are never executed — the Substrait + rewrite in :meth:`sql` replaces every call with the symbolic + derivative before execution. + + * ``grad(expr, column)`` -> scalar ``d(expr)/d(column)``. + * ``jacobian(expr, [c1, c2, ...])`` -> the gradient of ``expr`` with + respect to several columns, as a ``List`` (one Jacobian + row). The second argument is a SQL array of bare column references. """ - marker = udf( - lambda expr, column: expr, - [pa.float64(), pa.float64()], - pa.float64(), - "immutable", - "grad", + self.register_udf( + udf( + lambda expr, column: expr, + [pa.float64(), pa.float64()], + pa.float64(), + "immutable", + "grad", + ) + ) + self.register_udf( + udf( + lambda expr, columns: columns, + [pa.float64(), pa.list_(pa.float64())], + pa.list_(pa.float64()), + "immutable", + "jacobian", + ) ) - self.register_udf(marker) def from_dataset( self, From 6319a064a847348bdc5d0e7b45496da2de233fa7 Mon Sep 17 00:00:00 2001 From: Claude Date: Sun, 28 Jun 2026 11:16:43 +0000 Subject: [PATCH 06/17] Replace array jacobian() with jvp()/vjp() forward & reverse modes MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Drop the jacobian(expr, [cols]) -> List form: a nested array column breaks the long/tidy data model (a cell should be one value aligned to its coordinates). The same Jacobian is expressed in-model as several scalar columns, e.g. grad(f, x) AS dfdx, grad(f, y) AS dfdy. Add forward- and reverse-mode gradients as scalar SQL functions: * jvp(expr, column, tangent) -> d(expr)/d(column) * tangent (forward) * vjp(expr, column, cotangent) -> cotangent * d(expr)/d(column) (reverse) A multi-input directional derivative is the sum of per-input jvp terms; both stay scalar, so they round-trip cleanly through Substrait and back to xarray. Engine: unify grad and jvp behind a single `linearize` (forward-mode chain rule with a pluggable leaf rule) — grad is a one-hot seed, jvp an arbitrary seed per input. This mirrors JAX's structure and removes rule duplication. vjp is cotangent * grad; for a scalar output forward and reverse coincide (asserted by a jvp/vjp agreement test), differing only in seed placement. Tests: 15 Rust unit tests and 11 Python integration tests (incl. jvp/vjp semantics, the multi-input sum, and jvp==vjp for a unit seed), all checked against numpy analytic derivatives. fmt/clippy/ruff/mypy clean. Co-Authored-By: Claude Opus 4.8 Claude-Session: https://claude.ai/code/session_017mDoFJgsm9kS7SicGoCVF6 --- src/autograd.rs | 318 ++++++++++++++++++++++------------------- src/lib.rs | 3 +- tests/test_autograd.py | 43 ++++-- xarray_sql/sql.py | 46 +++--- 4 files changed, 224 insertions(+), 186 deletions(-) diff --git a/src/autograd.rs b/src/autograd.rs index 5b971cf..14312e6 100644 --- a/src/autograd.rs +++ b/src/autograd.rs @@ -22,23 +22,31 @@ //! `add_tangents`: a `0` derivative short-circuits products and drops out of //! sums, and a `1` factor drops out of products. //! -//! ## Scope (MVP) +//! ## Surface //! -//! This first cut implements scalar `grad`: the partial derivative of a single -//! expression with respect to one named column. Forward-/reverse-mode -//! (`jvp`/`vjp`) and multi-input Jacobians are deliberately left for later. +//! Three scalar operations, all rewritten away before execution: +//! +//! * `grad(expr, column)` — the partial derivative `d(expr)/d(column)`. +//! * `jvp(expr, column, tangent)` — forward-mode directional derivative, +//! `d(expr)/d(column) * tangent` (seed a tangent on an input). +//! * `vjp(expr, column, cotangent)` — reverse-mode pullback, +//! `cotangent * d(expr)/d(column)` (seed a cotangent on the output). +//! +//! All three return a scalar per row, staying in the long/tidy data model. A +//! full gradient or Jacobian is expressed as several scalar columns (e.g. +//! `grad(f, x) AS dfdx, grad(f, y) AS dfdy`) rather than a nested array, which +//! would break the one-value-per-coordinate model. #![allow(dead_code)] use std::any::Any; +use std::collections::HashMap; use std::f64::consts::{LN_10, LN_2}; -use std::sync::Arc; -use datafusion::arrow::datatypes::{DataType, Field}; +use datafusion::arrow::datatypes::DataType; use datafusion::common::tree_node::{Transformed, TreeNode}; use datafusion::common::{DataFusionError, Result, ScalarValue}; use datafusion::functions::math::expr_fn; -use datafusion::functions_nested::expr_fn::make_array; use datafusion::logical_expr::expr::ScalarFunction; use datafusion::logical_expr::{ lit, BinaryExpr, Cast, ColumnarValue, Expr, LogicalPlan, Operator, ScalarFunctionArgs, @@ -160,41 +168,48 @@ fn square(e: Expr) -> Expr { } // --------------------------------------------------------------------------- -// The differentiation rules +// The differentiation engine (forward-mode linearization) // --------------------------------------------------------------------------- -/// Differentiate `expr` with respect to the column named `wrt`. +/// A *leaf rule*: the tangent of a column, i.e. the seed assigned to each input +/// during forward-mode differentiation. /// -/// Returns a new [`Expr`] for the partial derivative, composed of ordinary -/// DataFusion expressions. Returns a [`DataFusionError::NotImplemented`] for -/// expression nodes or scalar functions without a differentiation rule, so the -/// caller can surface a clear, actionable error rather than silently producing -/// a wrong answer. -pub fn differentiate(expr: &Expr, wrt: &str) -> Result { +/// `grad` uses a one-hot leaf (`1` for the differentiation variable, `0` +/// otherwise); `jvp` uses an arbitrary seed per input. Everything above the +/// leaves — the chain rule — is shared. +type Leaf<'a> = dyn Fn(&str) -> Expr + 'a; + +/// Linearize `expr`: push tangents from the leaves (per `leaf`) up through the +/// expression via the chain rule, returning the tangent of `expr`. +/// +/// This is forward-mode automatic differentiation. `differentiate` (a single +/// partial derivative) and `jvp` (a directional derivative) are both thin +/// wrappers that only differ in their leaf rule. Returns a +/// [`DataFusionError::NotImplemented`] for nodes or functions without a rule, +/// so callers surface a clear error rather than a silently-wrong derivative. +fn linearize(expr: &Expr, leaf: &Leaf) -> Result { match expr { - // d/dx (x) = 1 ; d/dx (y) = 0 for any other column. - Expr::Column(c) => Ok(if c.name == wrt { one() } else { zero() }), + // The leaf rule decides a column's tangent. + Expr::Column(c) => Ok(leaf(&c.name)), - // d/dx (constant) = 0. + // Constants have zero tangent. Expr::Literal(_, _) => Ok(zero()), - // An alias is transparent to differentiation; the surrounding query - // re-applies any output naming. - Expr::Alias(a) => differentiate(&a.expr, wrt), + // An alias is transparent; the surrounding query re-applies any naming. + Expr::Alias(a) => linearize(&a.expr, leaf), - // A numeric cast is (locally) linear: d/dx cast(u) = cast(du). We keep - // the cast so the derivative retains the declared output type. + // A numeric cast is (locally) linear: tangent of cast(u) = cast(du). Expr::Cast(c) => { - let du = differentiate(&c.expr, wrt)?; + let du = linearize(&c.expr, leaf)?; Ok(Expr::Cast(Cast::new(Box::new(du), c.data_type.clone()))) } - // d/dx (-u) = -(du). - Expr::Negative(inner) => Ok(neg(differentiate(inner, wrt)?)), + // tangent of -u = -(du). + Expr::Negative(inner) => Ok(neg(linearize(inner, leaf)?)), - Expr::BinaryExpr(be) => diff_binary(be, wrt), + Expr::BinaryExpr(be) => linearize_binary(be, leaf), - Expr::ScalarFunction(sf) => diff_scalar_function(sf, wrt), + Expr::ScalarFunction(sf) => linearize_scalar_function(sf, leaf), other => Err(DataFusionError::NotImplemented(format!( "grad: differentiation is not implemented for this expression: {other}" @@ -202,22 +217,34 @@ pub fn differentiate(expr: &Expr, wrt: &str) -> Result { } } -/// Differentiate a binary arithmetic expression via the sum/product/quotient -/// rules. -fn diff_binary(be: &BinaryExpr, wrt: &str) -> Result { +/// Differentiate `expr` with respect to the column named `wrt`. +/// +/// Forward-mode with a one-hot seed: `1` on `wrt`, `0` on every other column. +pub fn differentiate(expr: &Expr, wrt: &str) -> Result { + linearize(expr, &|name| if name == wrt { one() } else { zero() }) +} + +/// Forward-mode directional derivative: the tangent of `expr` given a tangent +/// (`seeds[col]`) for each seeded input column; unseeded columns are constant. +fn jvp(expr: &Expr, seeds: &HashMap) -> Result { + linearize(expr, &|name| seeds.get(name).cloned().unwrap_or_else(zero)) +} + +/// Linearize a binary arithmetic expression via the sum/product/quotient rules. +fn linearize_binary(be: &BinaryExpr, leaf: &Leaf) -> Result { let a = be.left.as_ref(); let b = be.right.as_ref(); - let da = differentiate(a, wrt)?; - let db = differentiate(b, wrt)?; + let da = linearize(a, leaf)?; + let db = linearize(b, leaf)?; match be.op { - // d/dx (a + b) = da + db + // tangent of (a + b) = da + db Operator::Plus => Ok(add(da, db)), - // d/dx (a - b) = da - db + // tangent of (a - b) = da - db Operator::Minus => Ok(sub(da, db)), - // d/dx (a * b) = da*b + a*db (product rule) + // tangent of (a * b) = da*b + a*db (product rule) Operator::Multiply => Ok(add(mul(da, b.clone()), mul(a.clone(), db))), - // d/dx (a / b) = (da*b - a*db) / b^2 (quotient rule) + // tangent of (a / b) = (da*b - a*db) / b^2 (quotient rule) Operator::Divide => { let numerator = sub(mul(da, b.clone()), mul(a.clone(), db)); Ok(div(numerator, square(b.clone()))) @@ -228,17 +255,17 @@ fn diff_binary(be: &BinaryExpr, wrt: &str) -> Result { } } -/// Differentiate a scalar-function call via the chain rule. +/// Linearize a scalar-function call via the chain rule. /// -/// For a unary primitive `f(u)`, the derivative is `f'(u) * du`. For `power`, +/// For a unary primitive `f(u)`, the tangent is `f'(u) * du`. For `power`, /// which is binary, we handle the constant-exponent and constant-base cases. -fn diff_scalar_function(sf: &ScalarFunction, wrt: &str) -> Result { +fn linearize_scalar_function(sf: &ScalarFunction, leaf: &Leaf) -> Result { let name = sf.func.name(); let args = &sf.args; - // `power(base, exponent)` is the one binary primitive we differentiate. + // `power(base, exponent)` is the one binary primitive we linearize. if name == "power" { - return diff_power(args, wrt); + return linearize_power(args, leaf); } if args.len() != 1 { @@ -249,9 +276,9 @@ fn diff_scalar_function(sf: &ScalarFunction, wrt: &str) -> Result { } let u = &args[0]; - let du = differentiate(u, wrt)?; - // Chain rule short-circuit: if du is 0, the whole derivative is 0 and we - // avoid emitting the (dead) outer derivative term entirely. + let du = linearize(u, leaf)?; + // Chain rule short-circuit: if du is 0, the whole tangent is 0 and we avoid + // emitting the (dead) outer derivative term entirely. if is_zero(&du) { return Ok(zero()); } @@ -287,12 +314,12 @@ fn diff_scalar_function(sf: &ScalarFunction, wrt: &str) -> Result { Ok(mul(outer, du)) } -/// Differentiate `power(base, exponent)`. +/// Linearize `power(base, exponent)`. /// -/// * Constant exponent `c`: `d/dx base^c = c * base^(c-1) * d(base)`. -/// * Constant base `a`: `d/dx a^u = a^u * ln(a) * d(u)`. -/// * Both variable (`u^v`): not supported in the MVP. -fn diff_power(args: &[Expr], wrt: &str) -> Result { +/// * Constant exponent `c`: tangent = `c * base^(c-1) * d(base)`. +/// * Constant base `a`: tangent = `a^u * ln(a) * d(u)`. +/// * Both variable (`u^v`): not supported yet. +fn linearize_power(args: &[Expr], leaf: &Leaf) -> Result { if args.len() != 2 { return Err(DataFusionError::NotImplemented( "grad: power() expects exactly two arguments".to_string(), @@ -304,7 +331,7 @@ fn diff_power(args: &[Expr], wrt: &str) -> Result { match (as_const(base), as_const(exponent)) { // Constant exponent (covers the common x^2, x^0.5, ... cases). (_, Some(c)) => { - let dbase = differentiate(base, wrt)?; + let dbase = linearize(base, leaf)?; if is_zero(&dbase) { return Ok(zero()); } @@ -313,14 +340,14 @@ fn diff_power(args: &[Expr], wrt: &str) -> Result { } // Constant base, variable exponent. (Some(a), None) => { - let dexp = differentiate(exponent, wrt)?; + let dexp = linearize(exponent, leaf)?; if is_zero(&dexp) { return Ok(zero()); } let outer = mul(expr_fn::power(base.clone(), exponent.clone()), lit(a.ln())); Ok(mul(outer, dexp)) } - // General u^v requires the exp/log trick; deferred past the MVP. + // General u^v requires the exp/log trick; deferred for now. (None, None) => Err(DataFusionError::NotImplemented( "grad: power(base, exponent) where both depend on the \ differentiation variable is not yet supported" @@ -335,26 +362,22 @@ fn diff_power(args: &[Expr], wrt: &str) -> Result { /// A no-op placeholder UDF for the autograd surface functions. /// -/// `grad` and `jacobian` are *markers*: they carry the differentiation request -/// intact through SQL parsing, logical planning, and Substrait serialization. -/// They are always rewritten away by [`rewrite_grad_calls`] before execution, -/// so `invoke` is never reached in normal use (and deliberately errors if it -/// somehow is, rather than silently returning a wrong value). +/// `grad`, `jvp`, and `vjp` are *markers*: they carry the differentiation +/// request intact through SQL parsing, logical planning, and Substrait +/// serialization. They are always rewritten away by [`rewrite_grad_calls`] +/// before execution, so `invoke` is never reached in normal use (and +/// deliberately errors if it somehow is, rather than returning a wrong value). #[derive(Debug, PartialEq, Eq, Hash)] pub struct MarkerUdf { name: String, signature: Signature, - return_type: DataType, } impl MarkerUdf { - fn new(name: &str, return_type: DataType) -> Self { + fn new(name: &str, arity: usize) -> Self { Self { name: name.to_string(), - // Both markers take two arguments: the expression and either a - // column (grad) or an array of columns (jacobian). - signature: Signature::any(2, Volatility::Immutable), - return_type, + signature: Signature::any(arity, Volatility::Immutable), } } } @@ -373,7 +396,8 @@ impl ScalarUDFImpl for MarkerUdf { } fn return_type(&self, _arg_types: &[DataType]) -> Result { - Ok(self.return_type.clone()) + // Every autograd marker rewrites to a scalar derivative expression. + Ok(DataType::Float64) } fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { @@ -385,63 +409,25 @@ impl ScalarUDFImpl for MarkerUdf { } } -/// A `List` data type, the output of a `jacobian(...)` call. -fn list_of_f64() -> DataType { - DataType::List(Arc::new(Field::new("item", DataType::Float64, true))) -} - -/// The `grad(expr, column)` marker UDF: returns a scalar derivative. +/// The `grad(expr, column)` marker: scalar partial derivative `d(expr)/dcolumn`. pub fn grad_marker() -> ScalarUDF { - ScalarUDF::from(MarkerUdf::new("grad", DataType::Float64)) + ScalarUDF::from(MarkerUdf::new("grad", 2)) } -/// The `jacobian(expr, [c1, c2, ...])` marker UDF: returns the gradient of -/// `expr` with respect to several columns as a `List`. -pub fn jacobian_marker() -> ScalarUDF { - ScalarUDF::from(MarkerUdf::new("jacobian", list_of_f64())) +/// The `jvp(expr, column, tangent)` marker: forward-mode directional derivative. +pub fn jvp_marker() -> ScalarUDF { + ScalarUDF::from(MarkerUdf::new("jvp", 3)) } -/// Build the Jacobian row `[d(expr)/dc1, d(expr)/dc2, ...]` as an array -/// expression (`make_array`), differentiating `expr` w.r.t. each named column. -fn jacobian(expr: &Expr, wrt: &[String]) -> Result { - let partials = wrt - .iter() - .map(|c| differentiate(expr, c)) - .collect::>>()?; - Ok(make_array(partials)) +/// The `vjp(expr, column, cotangent)` marker: reverse-mode pullback to an input. +pub fn vjp_marker() -> ScalarUDF { + ScalarUDF::from(MarkerUdf::new("vjp", 3)) } -/// Extract the bare column names from an array-literal expression, i.e. the -/// `make_array(c1, c2, ...)` that a SQL `[c1, c2, ...]` array parses into. -fn columns_from_array(expr: &Expr) -> Result> { - let Expr::ScalarFunction(sf) = expr else { - return Err(DataFusionError::Plan(format!( - "jacobian(): the second argument must be an array of columns \ - like [x, y, z], got: {expr}" - ))); - }; - if sf.func.name() != "make_array" { - return Err(DataFusionError::Plan(format!( - "jacobian(): the second argument must be an array of columns \ - like [x, y, z], got: {expr}" - ))); - } - sf.args - .iter() - .map(|a| match a { - Expr::Column(c) => Ok(c.name.clone()), - other => Err(DataFusionError::Plan(format!( - "jacobian(): array entries must be bare columns to \ - differentiate with respect to, got: {other}" - ))), - }) - .collect() -} - -/// Rewrite every `grad(...)` / `jacobian(...)` call anywhere in a logical plan -/// into its symbolic derivative(s), leaving everything else untouched. The -/// plan's schema is recomputed afterwards because replacing a marker can change -/// an expression's name or type. +/// Rewrite every `grad`/`jvp`/`vjp` call anywhere in a logical plan into its +/// symbolic derivative, leaving everything else untouched. The plan's schema is +/// recomputed afterwards because replacing a marker can change an expression's +/// name or type. pub fn rewrite_grad_calls(plan: LogicalPlan) -> Result { let rewritten = plan .transform_up(|node| node.map_expressions(rewrite_grad_in_expr))? @@ -449,8 +435,8 @@ pub fn rewrite_grad_calls(plan: LogicalPlan) -> Result { rewritten.recompute_schema() } -/// Replace any `grad(...)` / `jacobian(...)` calls nested anywhere inside a -/// single expression. +/// Replace any `grad`/`jvp`/`vjp` calls nested anywhere inside a single +/// expression. fn rewrite_grad_in_expr(expr: Expr) -> Result> { expr.transform_up(|e| { let Expr::ScalarFunction(sf) = &e else { @@ -458,13 +444,25 @@ fn rewrite_grad_in_expr(expr: Expr) -> Result> { }; match sf.func.name() { "grad" => Ok(Transformed::yes(rewrite_grad(&sf.args)?)), - "jacobian" => Ok(Transformed::yes(rewrite_jacobian(&sf.args)?)), + "jvp" => Ok(Transformed::yes(rewrite_jvp(&sf.args)?)), + "vjp" => Ok(Transformed::yes(rewrite_vjp(&sf.args)?)), _ => Ok(Transformed::no(e)), } }) } -/// `grad(expr, column)` -> d(expr)/d(column). +/// Read a bare column name from a marker argument, or report a clear error. +fn column_arg(func: &str, arg: &Expr) -> Result { + match arg { + Expr::Column(c) => Ok(c.name.clone()), + other => Err(DataFusionError::Plan(format!( + "{func}(): the column argument must be a bare column to \ + differentiate with respect to, got: {other}" + ))), + } +} + +/// `grad(expr, column)` -> `d(expr)/d(column)`. fn rewrite_grad(args: &[Expr]) -> Result { if args.len() != 2 { return Err(DataFusionError::Plan(format!( @@ -472,29 +470,44 @@ fn rewrite_grad(args: &[Expr]) -> Result { args.len() ))); } - let wrt = match &args[1] { - Expr::Column(c) => c.name.clone(), - other => { - return Err(DataFusionError::Plan(format!( - "grad(): the second argument must be a bare column to \ - differentiate with respect to, got: {other}" - ))) - } - }; + let wrt = column_arg("grad", &args[1])?; differentiate(&args[0], &wrt) } -/// `jacobian(expr, [c1, c2, ...])` -> array `[d(expr)/dc1, d(expr)/dc2, ...]`. -fn rewrite_jacobian(args: &[Expr]) -> Result { - if args.len() != 2 { +/// `jvp(expr, column, tangent)` -> forward-mode tangent: seed `tangent` on +/// `column` and push it through `expr`, yielding `d(expr)/d(column) * tangent`. +/// +/// A directional derivative over several inputs is the sum of per-input jvps, +/// e.g. `jvp(f, x, dx) + jvp(f, y, dy)`, since each treats the other inputs as +/// having zero tangent. +fn rewrite_jvp(args: &[Expr]) -> Result { + if args.len() != 3 { + return Err(DataFusionError::Plan(format!( + "jvp() expects three arguments jvp(expr, column, tangent), got {}", + args.len() + ))); + } + let wrt = column_arg("jvp", &args[1])?; + let seeds = HashMap::from([(wrt, args[2].clone())]); + jvp(&args[0], &seeds) +} + +/// `vjp(expr, column, cotangent)` -> reverse-mode pullback: the sensitivity that +/// an output cotangent induces on `column`, i.e. `cotangent * d(expr)/d(column)`. +/// +/// For a single scalar output this equals the matching `jvp` (both contract the +/// same partial derivative); the surfaces differ in where the seed lives — `jvp` +/// seeds an input tangent, `vjp` seeds an output cotangent. +fn rewrite_vjp(args: &[Expr]) -> Result { + if args.len() != 3 { return Err(DataFusionError::Plan(format!( - "jacobian() expects two arguments jacobian(expr, [c1, c2, ...]), \ - got {}", + "vjp() expects three arguments vjp(expr, column, cotangent), got {}", args.len() ))); } - let wrt = columns_from_array(&args[1])?; - jacobian(&args[0], &wrt) + let wrt = column_arg("vjp", &args[1])?; + let derivative = differentiate(&args[0], &wrt)?; + Ok(mul(args[2].clone(), derivative)) } // --------------------------------------------------------------------------- @@ -584,28 +597,37 @@ mod tests { } #[test] - fn jacobian_builds_array_of_partials() { - // jacobian(x*y, [x, y]) = [d/dx, d/dy] = [y, x] + fn jvp_seeds_a_tangent_on_one_input() { + // jvp(x*y, {x: dx}) = product rule with tangent(x)=dx, tangent(y)=0 + // = dx*y + x*0 = dx*y let f = binary(col("x"), Operator::Multiply, col("y")); - let j = jacobian(&f, &["x".to_string(), "y".to_string()]).unwrap(); - let expected = make_array(vec![col("y"), col("x")]); - assert_eq!(j, expected); + let seeds = HashMap::from([("x".to_string(), col("dx"))]); + let t = jvp(&f, &seeds).unwrap(); + assert_eq!(t, mul(col("dx"), col("y"))); } #[test] - fn jacobian_single_input_is_one_element_array() { - let j = jacobian(&expr_fn::sin(col("x")), &["x".to_string()]).unwrap(); - assert_eq!(j, make_array(vec![expr_fn::cos(col("x"))])); + fn jvp_with_unit_seed_matches_grad() { + // A one-hot tangent reproduces the partial derivative. + let f = expr_fn::sin(col("x")); + let seeds = HashMap::from([("x".to_string(), one())]); + assert_eq!(jvp(&f, &seeds).unwrap(), differentiate(&f, "x").unwrap()); } #[test] - fn columns_from_array_extracts_names() { - let arr = make_array(vec![col("a"), col("b"), col("c")]); - assert_eq!(columns_from_array(&arr).unwrap(), vec!["a", "b", "c"]); + fn vjp_equals_cotangent_times_grad() { + // rewrite_vjp(sin(x), x, w) = w * cos(x) + let f = expr_fn::sin(col("x")); + let got = rewrite_vjp(&[f.clone(), col("x"), col("w")]).unwrap(); + assert_eq!(got, mul(col("w"), expr_fn::cos(col("x")))); } #[test] - fn columns_from_array_rejects_non_array() { - assert!(columns_from_array(&col("x")).is_err()); + fn jvp_and_vjp_agree_for_unit_seed() { + // With matching unit seed/cotangent, forward and reverse coincide. + let f = binary(expr_fn::sin(col("x")), Operator::Multiply, col("x")); + let fwd = rewrite_jvp(&[f.clone(), col("x"), one()]).unwrap(); + let rev = rewrite_vjp(&[f, col("x"), one()]).unwrap(); + assert_eq!(fwd, rev); } } diff --git a/src/lib.rs b/src/lib.rs index 754a390..626dc56 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1027,7 +1027,8 @@ fn grad_rewrite<'py>( // per referenced name (so the consumer can resolve table scans). let ctx = SessionContext::new(); ctx.register_udf(autograd::grad_marker()); - ctx.register_udf(autograd::jacobian_marker()); + ctx.register_udf(autograd::jvp_marker()); + ctx.register_udf(autograd::vjp_marker()); for (name, schema_obj) in &tables { let schema = Schema::from_pyarrow_bound(schema_obj).map_err(|e| { diff --git a/tests/test_autograd.py b/tests/test_autograd.py index 13f10a2..b21c74a 100644 --- a/tests/test_autograd.py +++ b/tests/test_autograd.py @@ -105,25 +105,42 @@ def test_multi_input_grad_columns(ctx_xy): np.testing.assert_allclose(res["dfdy"], ds["x"].values) -def test_jacobian_array(ctx_xy): - # jacobian(f, [x, y]) returns the gradient row [df/dx, df/dy] per row. +def test_jvp_forward_directional_derivative(ctx_xy): + # jvp(f, x, dx) = df/dx * dx. With f = sin(x)*y and a constant tangent. + context, ds = ctx_xy + x, y = ds["x"].values, ds["y"].values + res = _ordered(context.sql("SELECT i, jvp(sin(x) * y, x, 2.0) AS t FROM g")) + np.testing.assert_allclose(res["t"], (np.cos(x) * y) * 2.0) + + +def test_jvp_multi_input_is_sum(ctx_xy): + # A full directional derivative is the sum of per-input jvp terms: + # df/dx*dx + df/dy*dy for f = x*y, with dx=1, dy=1 -> y + x. context, ds = ctx_xy res = _ordered( - context.sql("SELECT i, jacobian(x * y, [x, y]) AS jac FROM g") + context.sql( + "SELECT i, jvp(x * y, x, 1.0) + jvp(x * y, y, 1.0) AS t FROM g" + ) ) - jac = np.stack([np.asarray(v, dtype=float) for v in res["jac"]]) - # column 0 is df/dx = y, column 1 is df/dy = x - np.testing.assert_allclose(jac[:, 0], ds["y"].values) - np.testing.assert_allclose(jac[:, 1], ds["x"].values) + np.testing.assert_allclose(res["t"], ds["y"].values + ds["x"].values) -def test_jacobian_array_nonlinear(ctx_xy): - # jacobian(sin(x) * y, [x, y]) = [cos(x)*y, sin(x)] +def test_vjp_reverse_pullback(ctx_xy): + # vjp(f, x, w) = w * df/dx. With f = sin(x)*y and cotangent w = 3.0. context, ds = ctx_xy x, y = ds["x"].values, ds["y"].values + res = _ordered(context.sql("SELECT i, vjp(sin(x) * y, x, 3.0) AS s FROM g")) + np.testing.assert_allclose(res["s"], 3.0 * (np.cos(x) * y)) + + +def test_jvp_and_vjp_agree_for_unit_seed(ctx_xy): + # Forward (unit tangent) and reverse (unit cotangent) coincide for a + # scalar output -- both contract the same partial derivative. + context, _ = ctx_xy res = _ordered( - context.sql("SELECT i, jacobian(sin(x) * y, [x, y]) AS jac FROM g") + context.sql( + "SELECT i, jvp(sin(x) * y, x, 1.0) AS fwd, " + "vjp(sin(x) * y, x, 1.0) AS rev FROM g" + ) ) - jac = np.stack([np.asarray(v, dtype=float) for v in res["jac"]]) - np.testing.assert_allclose(jac[:, 0], np.cos(x) * y) - np.testing.assert_allclose(jac[:, 1], np.sin(x)) + np.testing.assert_allclose(res["fwd"], res["rev"]) diff --git a/xarray_sql/sql.py b/xarray_sql/sql.py index 635a732..c62d7da 100644 --- a/xarray_sql/sql.py +++ b/xarray_sql/sql.py @@ -13,10 +13,10 @@ from .ds import XarrayDataFrame from .reader import read_xarray_table -# Matches a call to an autograd marker function (``grad(`` / ``jacobian(``, +# Matches a call to an autograd marker function (``grad(`` / ``jvp(`` / ``vjp(``, # case-insensitive), used as a cheap gate so ordinary queries skip the # Substrait round-trip. -_GRAD_CALL = re.compile(r"\b(grad|jacobian)\s*\(", re.IGNORECASE) +_GRAD_CALL = re.compile(r"\b(grad|jvp|vjp)\s*\(", re.IGNORECASE) class XarrayContext(SessionContext): @@ -34,35 +34,33 @@ def __init__(self, *args, **kwargs): self._register_autograd_udfs() def _register_autograd_udfs(self) -> None: - """Register the ``grad`` / ``jacobian`` marker UDFs. + """Register the ``grad`` / ``jvp`` / ``vjp`` marker UDFs. These are *markers*: they let queries parse and plan with the differentiation request intact. They are never executed — the Substrait - rewrite in :meth:`sql` replaces every call with the symbolic - derivative before execution. - - * ``grad(expr, column)`` -> scalar ``d(expr)/d(column)``. - * ``jacobian(expr, [c1, c2, ...])`` -> the gradient of ``expr`` with - respect to several columns, as a ``List`` (one Jacobian - row). The second argument is a SQL array of bare column references. + rewrite in :meth:`sql` replaces every call with the symbolic derivative + before execution. All return a scalar, staying in the long/tidy data + model (one value per row). + + * ``grad(expr, column)`` -> ``d(expr)/d(column)``. + * ``jvp(expr, column, tangent)`` -> forward-mode directional derivative + ``d(expr)/d(column) * tangent`` (seed a tangent on an input). A + multi-input directional derivative is a sum of jvp terms. + * ``vjp(expr, column, cotangent)`` -> reverse-mode pullback + ``cotangent * d(expr)/d(column)`` (seed a cotangent on the output). + + A full gradient/Jacobian is expressed as several scalar columns, e.g. + ``grad(f, x) AS dfdx, grad(f, y) AS dfdy``. """ + f64 = pa.float64() self.register_udf( - udf( - lambda expr, column: expr, - [pa.float64(), pa.float64()], - pa.float64(), - "immutable", - "grad", - ) + udf(lambda e, c: e, [f64, f64], f64, "immutable", "grad") ) self.register_udf( - udf( - lambda expr, columns: columns, - [pa.float64(), pa.list_(pa.float64())], - pa.list_(pa.float64()), - "immutable", - "jacobian", - ) + udf(lambda e, c, t: e, [f64, f64, f64], f64, "immutable", "jvp") + ) + self.register_udf( + udf(lambda e, c, w: e, [f64, f64, f64], f64, "immutable", "vjp") ) def from_dataset( From c49003a218532b2b0e15b67c7e2f6b3d6e9e7b6c Mon Sep 17 00:00:00 2001 From: Claude Date: Sun, 28 Jun 2026 12:30:28 +0000 Subject: [PATCH 07/17] Support grad/jvp/vjp on schema-qualified tables Mixed-dimension datasets register as schema-qualified tables (e.g. era5.surface / era5.time_x_level). The autograd rewrite consumes the plan in a throwaway context that registers an empty table per scanned name, but register_table("era5.time_x", ...) failed with "failed to resolve schema: era5" because the namespace did not exist. Add ensure_schema(): before registering each table, parse its name into a TableReference and, for qualified names, create the schema namespace (MemorySchemaProvider) in the default catalog if absent. The Python side already resolves qualified names via ctx.table(name).schema(); only the Rust rewrite context needed the namespace. Tests: a mixed-dimension fixture exercising grad on both the 2D surface and 3D atmosphere tables, against numpy analytic derivatives. Co-Authored-By: Claude Opus 4.8 Claude-Session: https://claude.ai/code/session_017mDoFJgsm9kS7SicGoCVF6 --- src/lib.rs | 35 ++++++++++++++++++++++++++++++++--- tests/test_autograd.py | 40 ++++++++++++++++++++++++++++++++++++++++ xarray_sql/sql.py | 7 ++++--- 3 files changed, 76 insertions(+), 6 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 626dc56..8a4cbb6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -55,8 +55,8 @@ use arrow::pyarrow::FromPyArrow; use async_stream::try_stream; use async_trait::async_trait; use datafusion::catalog::streaming::StreamingTable; -use datafusion::catalog::Session; -use datafusion::common::{DataFusionError, Result as DFResult, ScalarValue}; +use datafusion::catalog::{MemorySchemaProvider, Session}; +use datafusion::common::{DataFusionError, Result as DFResult, ScalarValue, TableReference}; use datafusion::datasource::empty::EmptyTable; use datafusion::datasource::TableProvider; use datafusion::execution::TaskContext; @@ -993,6 +993,22 @@ impl LazyArrowStreamTable { // Autograd: Substrait-level grad() rewrite // ============================================================================ +/// Ensure a schema (namespace) exists in the context's catalog, creating an +/// empty in-memory one if needed. Used so the rewrite context can register +/// schema-qualified tables (e.g. `era5.surface`) that mixed-dimension datasets +/// produce. +fn ensure_schema(ctx: &SessionContext, catalog: Option<&str>, schema: &str) -> DFResult<()> { + // A bare TableReference has no catalog; fall back to DataFusion's default. + let catalog_name = catalog.unwrap_or("datafusion"); + let catalog = ctx + .catalog(catalog_name) + .ok_or_else(|| DataFusionError::Plan(format!("catalog '{catalog_name}' not found")))?; + if catalog.schema(schema).is_none() { + catalog.register_schema(schema, Arc::new(MemorySchemaProvider::new()))?; + } + Ok(()) +} + /// Rewrite `grad(expr, column)` calls in a Substrait plan into their symbolic /// derivatives. /// @@ -1037,7 +1053,20 @@ fn grad_rewrite<'py>( )) })?; let provider = Arc::new(EmptyTable::new(Arc::new(schema))); - ctx.register_table(name.as_str(), provider).map_err(|e| { + + // Schema-qualified names (e.g. "era5.surface", from a mixed-dimension + // dataset) need their namespace to exist before the table can be + // registered into this throwaway context. + let table_ref = TableReference::from(name.as_str()); + if let Some(schema_name) = table_ref.schema() { + ensure_schema(&ctx, table_ref.catalog(), schema_name).map_err(|e| { + pyo3::exceptions::PyValueError::new_err(format!( + "grad_rewrite: failed to create schema for table '{name}': {e}" + )) + })?; + } + + ctx.register_table(table_ref, provider).map_err(|e| { pyo3::exceptions::PyValueError::new_err(format!( "grad_rewrite: failed to register table '{name}': {e}" )) diff --git a/tests/test_autograd.py b/tests/test_autograd.py index b21c74a..89fc563 100644 --- a/tests/test_autograd.py +++ b/tests/test_autograd.py @@ -133,6 +133,46 @@ def test_vjp_reverse_pullback(ctx_xy): np.testing.assert_allclose(res["s"], 3.0 * (np.cos(x) * y)) +@pytest.fixture +def ctx_mixed(): + # A mixed-dimension dataset registers as schema-qualified tables: + # era5.time_x (surface, 2 dims) + # era5.time_x_level (atmosphere, 3 dims) + rng = np.random.default_rng(1) + ds = xr.Dataset( + { + "sfc": (("time", "x"), rng.uniform(0.5, 2.5, (3, 4))), + "atm": (("time", "x", "level"), rng.uniform(0.5, 2.5, (3, 4, 2))), + }, + coords={"time": [0, 1, 2], "x": np.arange(4.0), "level": [0, 1]}, + ) + context = xql.XarrayContext() + context.from_dataset("era5", ds, chunks={"time": 1}) + return context, ds + + +def test_grad_on_qualified_surface_table(ctx_mixed): + context, ds = ctx_mixed + res = _ordered( + context.sql( + "SELECT time, x, sfc, grad(sin(sfc), sfc) AS d FROM era5.time_x" + ), + key="sfc", + ) + np.testing.assert_allclose(res["d"], np.cos(res["sfc"])) + + +def test_grad_on_qualified_atmosphere_table(ctx_mixed): + context, ds = ctx_mixed + res = _ordered( + context.sql( + "SELECT atm, grad(power(atm, 2), atm) AS d FROM era5.time_x_level" + ), + key="atm", + ) + np.testing.assert_allclose(res["d"], 2.0 * res["atm"]) + + def test_jvp_and_vjp_agree_for_unit_seed(ctx_xy): # Forward (unit tangent) and reverse (unit cotangent) coincide for a # scalar output -- both contract the same partial derivative. diff --git a/xarray_sql/sql.py b/xarray_sql/sql.py index c62d7da..cb25be0 100644 --- a/xarray_sql/sql.py +++ b/xarray_sql/sql.py @@ -251,11 +251,12 @@ def _table_schemas(self) -> list[tuple[str, pa.Schema]]: schemas = [] for name in self._registered_datasets: try: + # Names may be bare ("air") or schema-qualified ("era5.surface", + # from a mixed-dimension dataset); both resolve here. schemas.append((name, self.table(name).schema())) except Exception: - # Schema-qualified tables (mixed-dimension datasets) aren't - # resolvable by a bare name yet; skip rather than fail the - # whole query. grad() over those is a follow-up. + # Be defensive: skip a table we can't introspect rather than + # failing the whole query. continue return schemas From bdad6fb824ad6146fb252a0def74486f8b41e3e6 Mon Sep 17 00:00:00 2001 From: Claude Date: Sun, 28 Jun 2026 13:17:51 +0000 Subject: [PATCH 08/17] Verify and test higher-order grad Nested calls such as grad(grad(f, x), x) already yield higher-order derivatives: the plan rewrite walks expressions bottom-up (transform_up), so the inner grad is differentiated to a plain expression first and the outer grad differentiates that result. No code change was needed; this adds tests and documents the behavior. - Rust: a unit test that differentiation composes (d2/dx2 sin = -sin). - Python: second derivatives of sin (-sin) and x^3 (6x) and the third derivative of sin (-cos), against numpy. - Doc: note higher-order support in the module overview. Co-Authored-By: Claude Opus 4.8 Claude-Session: https://claude.ai/code/session_017mDoFJgsm9kS7SicGoCVF6 --- src/autograd.rs | 12 ++++++++++++ tests/test_autograd.py | 25 +++++++++++++++++++++++++ 2 files changed, 37 insertions(+) diff --git a/src/autograd.rs b/src/autograd.rs index 14312e6..0011a4e 100644 --- a/src/autograd.rs +++ b/src/autograd.rs @@ -36,6 +36,10 @@ //! full gradient or Jacobian is expressed as several scalar columns (e.g. //! `grad(f, x) AS dfdx, grad(f, y) AS dfdy`) rather than a nested array, which //! would break the one-value-per-coordinate model. +//! +//! Calls nest, giving higher-order derivatives for free: the rewrite walks +//! bottom-up, so the inner call in `grad(grad(f, x), x)` is differentiated +//! first and the outer call differentiates that result. #![allow(dead_code)] @@ -596,6 +600,14 @@ mod tests { assert!(differentiate(&e, "x").is_err()); } + #[test] + fn higher_order_derivative() { + // Differentiation composes: d2/dx2 sin(x) = -sin(x). + let d1 = differentiate(&expr_fn::sin(col("x")), "x").unwrap(); + let d2 = differentiate(&d1, "x").unwrap(); + assert_eq!(d2, neg(expr_fn::sin(col("x")))); + } + #[test] fn jvp_seeds_a_tangent_on_one_input() { // jvp(x*y, {x: dx}) = product rule with tangent(x)=dx, tangent(y)=0 diff --git a/tests/test_autograd.py b/tests/test_autograd.py index 89fc563..0ed58ff 100644 --- a/tests/test_autograd.py +++ b/tests/test_autograd.py @@ -80,6 +80,31 @@ def test_grad_quotient_and_power(ctx): np.testing.assert_allclose(res["dcube"], 3.0 * val**2) +def test_higher_order_grad(ctx): + # Nested grad() differentiates repeatedly: the inner call is rewritten + # first, then the outer differentiates its result. + val = np.linspace(0.1, 3.0, 16) + res = _ordered( + ctx.sql( + "SELECT i, " + "grad(grad(sin(val), val), val) AS d2_sin, " + "grad(grad(power(val, 3), val), val) AS d2_cube FROM t" + ) + ) + np.testing.assert_allclose(res["d2_sin"], -np.sin(val)) # -sin + np.testing.assert_allclose(res["d2_cube"], 6.0 * val) # d2/dx2 x^3 = 6x + + +def test_third_order_grad(ctx): + val = np.linspace(0.1, 3.0, 16) + res = _ordered( + ctx.sql( + "SELECT i, grad(grad(grad(sin(val), val), val), val) AS d3 FROM t" + ) + ) + np.testing.assert_allclose(res["d3"], -np.cos(val)) # d3/dx3 sin = -cos + + def test_non_grad_query_is_unaffected(ctx): # Queries without grad() bypass the rewrite and behave normally. res = _ordered(ctx.sql("SELECT i, val FROM t")) From 255413e628e3e76cc6281d26e458f11c11ca712a Mon Sep 17 00:00:00 2001 From: Claude Date: Sun, 28 Jun 2026 14:51:00 +0000 Subject: [PATCH 09/17] Add differentiation-through-aggregate tests and docs Document and test that differentiating through SUM/AVG is just linearity: AGG(grad(f, x)) == d/dx AGG(f). Writing grad inside the aggregate composes with SQL scoping (the marker rewrites to plain SQL before the aggregate runs), so it needs no special machinery -- enough to express gradient descent in SQL. Adds tests for SUM/AVG(grad(...)) and an end-to-end gradient-descent convergence test, plus a note in the module overview. The runnable benchmark scripts live on stacked demo branches to keep this feature branch reviewable. Co-Authored-By: Claude Opus 4.8 Claude-Session: https://claude.ai/code/session_017mDoFJgsm9kS7SicGoCVF6 --- src/autograd.rs | 8 +++++++ tests/test_autograd.py | 52 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 60 insertions(+) diff --git a/src/autograd.rs b/src/autograd.rs index 0011a4e..7b3755d 100644 --- a/src/autograd.rs +++ b/src/autograd.rs @@ -40,6 +40,14 @@ //! Calls nest, giving higher-order derivatives for free: the rewrite walks //! bottom-up, so the inner call in `grad(grad(f, x), x)` is differentiated //! first and the outer call differentiates that result. +//! +//! Differentiation through an aggregate is just linearity and needs no special +//! handling: write the `grad` *inside* the aggregate, e.g. `SUM(grad(f, x))` or +//! `AVG(grad(loss, theta))`. Because the marker is rewritten to plain SQL +//! before the aggregate runs (and the column is in scope there), this is the +//! relational `d/dθ Σ f = Σ ∂f/∂θ` — enough to run gradient descent in SQL. +//! (The transposed form `grad(SUM(f), x)` is rejected by SQL's own scoping, +//! since `x` is gone after aggregation.) #![allow(dead_code)] diff --git a/tests/test_autograd.py b/tests/test_autograd.py index 0ed58ff..0400465 100644 --- a/tests/test_autograd.py +++ b/tests/test_autograd.py @@ -117,6 +117,58 @@ def test_unsupported_function_raises(ctx): ctx.sql("SELECT grad(atan2(val, val), val) AS d FROM t").to_pandas() +def test_grad_inside_aggregate(ctx): + # Differentiation through an aggregate is just linearity: + # AGG(grad(f, x)) == d/dx AGG(f). grad rewrites to plain SQL before the + # aggregate runs, so this composes with no special machinery. + val = np.linspace(0.1, 3.0, 16) + res = ctx.sql( + "SELECT SUM(grad(val * val, val)) AS s, " + "AVG(grad(sin(val), val)) AS a FROM t" + ).to_pandas() + np.testing.assert_allclose(res["s"][0], np.sum(2 * val)) + np.testing.assert_allclose(res["a"][0], np.mean(np.cos(val))) + + +def test_gradient_descent_in_sql(): + # End to end: fit y ~= a*x + b by minimising MSE, with the gradients + # w.r.t. the parameters computed in SQL via AVG(grad(loss, param)). + rng = np.random.default_rng(0) + n = 200 + x = rng.uniform(0.0, 1.0, n) + a_true, b_true = 2.0, -1.0 + y = a_true * x + b_true + rng.normal(0.0, 0.01, n) + data = xr.Dataset( + {"x": (("i",), x), "y": (("i",), y)}, coords={"i": np.arange(n)} + ) + ctx = xql.XarrayContext() + ctx.from_dataset("d", data, chunks={"i": n}) + + resid = "(y - (a * x + b))" + loss = f"{resid} * {resid}" + a, b, lr = 0.0, 0.0, 0.4 + losses = [] + for _ in range(120): + if "params" in ctx._registered_datasets: + ctx.deregister_table("params") + del ctx._registered_datasets["params"] + params = xr.Dataset( + {"a": (("p",), [a]), "b": (("p",), [b])}, coords={"p": [0]} + ) + ctx.from_dataset("params", params, chunks={"p": 1}) + row = ctx.sql( + f"SELECT AVG({loss}) AS loss, " + f"AVG(grad({loss}, a)) AS dl_da, " + f"AVG(grad({loss}, b)) AS dl_db FROM d CROSS JOIN params" + ).to_pandas() + losses.append(float(row["loss"][0])) + a -= lr * float(row["dl_da"][0]) + b -= lr * float(row["dl_db"][0]) + + assert losses[-1] < losses[0] # loss decreased + np.testing.assert_allclose([a, b], [a_true, b_true], atol=0.05) + + def test_multi_input_grad_columns(ctx_xy): # A full Jacobian written as separate scalar grad() columns: # f = x*y -> df/dx = y, df/dy = x. From e2784c37e20fa94c043d5806af4a02ec5ad5ecfe Mon Sep 17 00:00:00 2001 From: Claude Date: Sun, 28 Jun 2026 15:14:49 +0000 Subject: [PATCH 10/17] Resolve grad over any registered table, not just xarray ones Generalize _table_schemas() to enumerate the catalog instead of only the xarray-registered datasets, so the Substrait rewrite can resolve grad() queries that reference plain DataFusion tables too -- e.g. in-memory MemTables holding model parameters or intermediate results. This makes grad compose with ordinary relational state (a parameter table you INSERT into), not only gridded xarray data. Adds a test differentiating an expression whose coefficient lives in an in-memory table cross-joined to the xarray data. Co-Authored-By: Claude Opus 4.8 Claude-Session: https://claude.ai/code/session_017mDoFJgsm9kS7SicGoCVF6 --- tests/test_autograd.py | 17 ++++++++++++++++ xarray_sql/sql.py | 44 +++++++++++++++++++++++++++++------------- 2 files changed, 48 insertions(+), 13 deletions(-) diff --git a/tests/test_autograd.py b/tests/test_autograd.py index 0400465..ca17de9 100644 --- a/tests/test_autograd.py +++ b/tests/test_autograd.py @@ -6,6 +6,7 @@ """ import numpy as np +import pyarrow as pa import pytest import xarray as xr @@ -117,6 +118,22 @@ def test_unsupported_function_raises(ctx): ctx.sql("SELECT grad(atan2(val, val), val) AS d FROM t").to_pandas() +def test_grad_over_in_memory_table(ctx): + # grad works over plain DataFusion tables too (not just xarray-registered + # ones): here a coefficient lives in an in-memory MemTable cross-joined to + # the xarray data. d/dval (c * val^2) = c * 2*val, with c = 3. + ctx.register_record_batches( + "coef", [[pa.RecordBatch.from_pydict({"c": [3.0]})]] + ) + val = np.linspace(0.1, 3.0, 16) + res = _ordered( + ctx.sql( + "SELECT i, grad(c * val * val, val) AS d FROM t CROSS JOIN coef" + ) + ) + np.testing.assert_allclose(res["d"], 3.0 * 2.0 * val) + + def test_grad_inside_aggregate(ctx): # Differentiation through an aggregate is just linearity: # AGG(grad(f, x)) == d/dx AGG(f). grad rewrites to plain SQL before the diff --git a/xarray_sql/sql.py b/xarray_sql/sql.py index cb25be0..5e892d3 100644 --- a/xarray_sql/sql.py +++ b/xarray_sql/sql.py @@ -242,22 +242,40 @@ def _sql_with_autograd(self, query: str, *args, **kwargs): return self.create_dataframe_from_logical_plan(new_plan) def _table_schemas(self) -> list[tuple[str, pa.Schema]]: - """Return ``(name, schema)`` for each registered table. - - The Substrait consumer in ``grad_rewrite`` resolves table scans by - name, so it needs the schema of every table the plan might reference. - Only metadata is read here — never the underlying data. + """Return ``(name, schema)`` for every table registered in the context. + + The Substrait consumer in ``grad_rewrite`` resolves table scans by name, + so it needs the schema of every table the plan might reference. We + enumerate the catalog rather than only the xarray-registered datasets, + so ``grad`` also works over plain DataFusion tables (e.g. in-memory + ``MemTable``s holding model parameters or intermediate results). Only + metadata is read here — never the underlying data. """ schemas = [] - for name in self._registered_datasets: - try: - # Names may be bare ("air") or schema-qualified ("era5.surface", - # from a mixed-dimension dataset); both resolve here. - schemas.append((name, self.table(name).schema())) - except Exception: - # Be defensive: skip a table we can't introspect rather than - # failing the whole query. + catalog = self.catalog() + for schema_name in catalog.schema_names(): + if schema_name == "information_schema": continue + schema = catalog.schema(schema_name) + names = ( + schema.table_names() + if hasattr(schema, "table_names") + else schema.names() + ) + for table_name in names: + # Tables in the default schema are referenced bare ("air"); + # others are schema-qualified ("era5.surface"). + qualified = ( + table_name + if schema_name in ("public", "default") + else f"{schema_name}.{table_name}" + ) + try: + schemas.append((qualified, self.table(qualified).schema())) + except Exception: + # Be defensive: skip a table we can't introspect rather + # than failing the whole query. + continue return schemas From 55eba21198534827edfd57a0dde4e510e7ecbae4 Mon Sep 17 00:00:00 2001 From: Claude Date: Sun, 28 Jun 2026 15:43:21 +0000 Subject: [PATCH 11/17] Add differentiate_sql: differentiate an expression to SQL text Expose the autograd engine as a "calculus compiler": differentiate_sql(expr, wrt, columns) parses a SQL scalar expression (parse_sql_expr), differentiates it with the engine, and unparses the derivative back to SQL (expr_to_sql). Where grad(...) rewrites a whole plan via Substrait, this hands back a single derivative expression as text -- usable where the Substrait round-trip can't carry a grad marker, e.g. embedding a precomputed update rule inside a recursive -CTE training loop (Substrait has no recursion). Exposed as xarray_sql. differentiate_sql; covered by a round-trip test. Co-Authored-By: Claude Opus 4.8 Claude-Session: https://claude.ai/code/session_017mDoFJgsm9kS7SicGoCVF6 --- src/lib.rs | 60 ++++++++++++++++++++++++++++++++++++++++-- tests/test_autograd.py | 9 +++++++ xarray_sql/__init__.py | 2 ++ 3 files changed, 69 insertions(+), 2 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 8a4cbb6..019cd57 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -50,13 +50,15 @@ use std::fmt::Debug; use std::sync::Arc; use arrow::array::RecordBatch; -use arrow::datatypes::{Schema, SchemaRef}; +use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use arrow::pyarrow::FromPyArrow; use async_stream::try_stream; use async_trait::async_trait; use datafusion::catalog::streaming::StreamingTable; use datafusion::catalog::{MemorySchemaProvider, Session}; -use datafusion::common::{DataFusionError, Result as DFResult, ScalarValue, TableReference}; +use datafusion::common::{ + DFSchema, DataFusionError, Result as DFResult, ScalarValue, TableReference, +}; use datafusion::datasource::empty::EmptyTable; use datafusion::datasource::TableProvider; use datafusion::execution::TaskContext; @@ -68,6 +70,7 @@ use datafusion::physical_plan::stream::RecordBatchStreamAdapter; use datafusion::physical_plan::streaming::PartitionStream; use datafusion::physical_plan::{ExecutionPlan, SendableRecordBatchStream}; use datafusion::prelude::SessionContext; +use datafusion::sql::unparser::expr_to_sql; use datafusion_ffi::proto::logical_extension_codec::FFI_LogicalExtensionCodec; use datafusion_ffi::table_provider::FFI_TableProvider; use datafusion_substrait::logical_plan::consumer::from_substrait_plan; @@ -1114,10 +1117,63 @@ fn grad_rewrite<'py>( Ok(PyBytes::new(py, &out_plan.encode_to_vec())) } +/// Differentiate a SQL scalar expression symbolically and return the +/// derivative as SQL text. +/// +/// Where [`grad_rewrite`] rewrites `grad(...)` calls inside a whole plan, this +/// differentiates a single expression and hands back the result as SQL — the +/// autograd engine acting as a "calculus compiler". It lets a caller obtain an +/// update rule once and embed it in queries the Substrait round-trip can't +/// carry a `grad` marker through, such as a recursive-CTE training loop. +/// +/// Args: +/// expr: A SQL scalar expression over `columns` (e.g. `"sin(x) * x"`). +/// wrt: The column name to differentiate with respect to. +/// columns: The column names in scope; all treated as `Float64` (enough to +/// parse and differentiate — types don't affect the symbolic result). +/// +/// Returns: +/// The derivative as a SQL string (e.g. `"cos(x) * x + sin(x)"`). +#[pyfunction] +fn differentiate_sql(expr: &str, wrt: &str, columns: Vec) -> PyResult { + let ctx = SessionContext::new(); + + let fields: Vec = columns + .iter() + .map(|name| Field::new(name, DataType::Float64, true)) + .collect(); + let df_schema = DFSchema::try_from(Schema::new(fields)).map_err(|e| { + pyo3::exceptions::PyValueError::new_err(format!( + "differentiate_sql: failed to build schema: {e}" + )) + })?; + + let parsed = ctx.parse_sql_expr(expr, &df_schema).map_err(|e| { + pyo3::exceptions::PyValueError::new_err(format!( + "differentiate_sql: failed to parse expression '{expr}': {e}" + )) + })?; + + let derivative = autograd::differentiate(&parsed, wrt).map_err(|e| { + pyo3::exceptions::PyValueError::new_err(format!( + "differentiate_sql: failed to differentiate: {e}" + )) + })?; + + let sql = expr_to_sql(&derivative).map_err(|e| { + pyo3::exceptions::PyValueError::new_err(format!( + "differentiate_sql: failed to render derivative to SQL: {e}" + )) + })?; + + Ok(sql.to_string()) +} + /// Python module initialization #[pymodule] fn _native(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_function(wrap_pyfunction!(grad_rewrite, m)?)?; + m.add_function(wrap_pyfunction!(differentiate_sql, m)?)?; Ok(()) } diff --git a/tests/test_autograd.py b/tests/test_autograd.py index ca17de9..2ccc7dd 100644 --- a/tests/test_autograd.py +++ b/tests/test_autograd.py @@ -134,6 +134,15 @@ def test_grad_over_in_memory_table(ctx): np.testing.assert_allclose(res["d"], 3.0 * 2.0 * val) +def test_differentiate_sql_round_trip(ctx): + # differentiate_sql returns the derivative as SQL text; evaluating it must + # match the analytic derivative. d/dval (sin(val)*val) = cos(val)*val + sin(val). + deriv = xql.differentiate_sql("sin(val) * val", "val", ["val"]) + val = np.linspace(0.1, 3.0, 16) + res = _ordered(ctx.sql(f"SELECT i, {deriv} AS d FROM t")) + np.testing.assert_allclose(res["d"], np.cos(val) * val + np.sin(val)) + + def test_grad_inside_aggregate(ctx): # Differentiation through an aggregate is just linearity: # AGG(grad(f, x)) == d/dx AGG(f). grad rewrites to plain SQL before the diff --git a/xarray_sql/__init__.py b/xarray_sql/__init__.py index d1e5984..c01f295 100644 --- a/xarray_sql/__init__.py +++ b/xarray_sql/__init__.py @@ -1,4 +1,5 @@ from . import cftime +from ._native import differentiate_sql from .df import from_map from .reader import read_xarray, read_xarray_table from .sql import XarrayContext @@ -6,6 +7,7 @@ __all__ = [ "cftime", "XarrayContext", + "differentiate_sql", "read_xarray_table", "read_xarray", "from_map", # deprecated From 7b1e53069c2641f058e5c52cd7d90c9f442ee981 Mon Sep 17 00:00:00 2001 From: Alex Merose Date: Tue, 30 Jun 2026 13:29:49 +0300 Subject: [PATCH 12/17] Differentiate grad() as a SQL rewrite, dropping the Substrait bridge Make grad()/jvp()/vjp() work inside any query shape (recursive CTEs, DML, subqueries) by rewriting the calls as SQL text before planning, rather than round-tripping the logical plan through Substrait (which could not represent those shapes). Closes the gap tracked in #197. XarrayContext.sql() now hands a query containing a marker to the native rewrite_grad_sql, which parses the statement with sqlparser, differentiates each marker call with the existing engine, and renders the derivative back into the SQL in place. Because it runs before planning, every query shape the parser accepts is supported, and the result is ordinary SQL the stock datafusion-python context plans and executes directly. This removes the Substrait round-trip entirely: the datafusion-substrait and prost dependencies, the grad_rewrite/_sql_with_autograd/_table_schemas plumbing, the marker-UDF registration, and the protoc steps in CI. Unlike the FFI alternative, it needs no datafusion fork and no custom datafusion-python wheel. The grad surface is unchanged (same SQL, same results); marker arguments use unqualified column names, matching existing usage, since differentiation is syntactic and runs before binding. Co-Authored-By: Claude Opus 4.8 --- .github/workflows/ci-build.yml | 7 - .github/workflows/ci-rust.yml | 7 - .github/workflows/ci.yml | 7 - .github/workflows/publish.yml | 31 ---- Cargo.lock | 273 +-------------------------------- Cargo.toml | 3 +- src/autograd.rs | 189 ++++++++++++++++++++++- src/lib.rs | 142 +++-------------- tests/test_autograd.py | 24 ++- xarray_sql/sql.py | 117 ++++---------- 10 files changed, 257 insertions(+), 543 deletions(-) diff --git a/.github/workflows/ci-build.yml b/.github/workflows/ci-build.yml index ec89b8e..214388e 100644 --- a/.github/workflows/ci-build.yml +++ b/.github/workflows/ci-build.yml @@ -31,13 +31,6 @@ jobs: - uses: dtolnay/rust-toolchain@stable - # The `substrait` crate (a datafusion-substrait dependency) generates - # code from .proto files at build time and requires protoc. - - name: Install Protoc - uses: arduino/setup-protoc@v3 - with: - repo-token: ${{ secrets.GITHUB_TOKEN }} - - name: Setup sccache uses: mozilla-actions/sccache-action@v0.0.9 diff --git a/.github/workflows/ci-rust.yml b/.github/workflows/ci-rust.yml index 9054a44..68f1ce6 100644 --- a/.github/workflows/ci-rust.yml +++ b/.github/workflows/ci-rust.yml @@ -27,13 +27,6 @@ jobs: with: components: clippy - # The `substrait` crate (a datafusion-substrait dependency) generates - # code from .proto files at build time and requires protoc. - - name: Install Protoc - uses: arduino/setup-protoc@v3 - with: - repo-token: ${{ secrets.GITHUB_TOKEN }} - - name: Setup sccache uses: mozilla-actions/sccache-action@v0.0.9 diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 587784d..c1d892d 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -43,13 +43,6 @@ jobs: - uses: dtolnay/rust-toolchain@stable - # The `substrait` crate (a datafusion-substrait dependency) generates - # code from .proto files at build time and requires protoc. - - name: Install Protoc - uses: arduino/setup-protoc@v3 - with: - repo-token: ${{ secrets.GITHUB_TOKEN }} - - name: Setup sccache uses: mozilla-actions/sccache-action@v0.0.9 diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 48f0155..a0008bc 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -54,13 +54,6 @@ jobs: - uses: dtolnay/rust-toolchain@stable - # The `substrait` crate (a datafusion-substrait dependency) generates - # code from .proto files at build time and requires protoc. - - name: Install Protoc - uses: arduino/setup-protoc@v3 - with: - repo-token: ${{ secrets.GITHUB_TOKEN }} - - name: Setup sccache uses: mozilla-actions/sccache-action@v0.0.9 @@ -98,18 +91,6 @@ jobs: manylinux: 2_28 rustup-components: rust-std rustfmt sccache: 'true' - # protoc is required by the substrait crate build and must be - # installed inside the manylinux container. - before-script-linux: | - PROTOC_VERSION=29.3 - case "$(uname -m)" in - x86_64) PROTOC_ARCH=x86_64 ;; - aarch64) PROTOC_ARCH=aarch_64 ;; - *) echo "unsupported arch $(uname -m)"; exit 1 ;; - esac - curl -L -o /tmp/protoc.zip "https://github.com/protocolbuffers/protobuf/releases/download/v${PROTOC_VERSION}/protoc-${PROTOC_VERSION}-linux-${PROTOC_ARCH}.zip" - unzip -o /tmp/protoc.zip -d /usr/local bin/protoc 'include/*' - protoc --version args: --release --strip --out dist -i python3.10 python3.11 python3.12 python3.13 - uses: actions/upload-artifact@v6 @@ -132,18 +113,6 @@ jobs: manylinux: 2_28 rustup-components: rust-std rustfmt sccache: 'true' - # protoc is required by the substrait crate build and must be - # installed inside the manylinux container. - before-script-linux: | - PROTOC_VERSION=29.3 - case "$(uname -m)" in - x86_64) PROTOC_ARCH=x86_64 ;; - aarch64) PROTOC_ARCH=aarch_64 ;; - *) echo "unsupported arch $(uname -m)"; exit 1 ;; - esac - curl -L -o /tmp/protoc.zip "https://github.com/protocolbuffers/protobuf/releases/download/v${PROTOC_VERSION}/protoc-${PROTOC_VERSION}-linux-${PROTOC_ARCH}.zip" - unzip -o /tmp/protoc.zip -d /usr/local bin/protoc 'include/*' - protoc --version args: --release --strip --out dist -i python3.10 python3.11 python3.12 python3.13 - uses: actions/upload-artifact@v6 diff --git a/Cargo.lock b/Cargo.lock index fcaefe4..b6cbcf1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -397,17 +397,6 @@ dependencies = [ "abi_stable", ] -[[package]] -name = "async-recursion" -version = "1.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3b43422f69d8ff38f95f1b2bb76517c91589a924d1559a0e935d7c8ce0274c11" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.114", -] - [[package]] name = "async-stream" version = "0.3.6" @@ -1491,26 +1480,6 @@ dependencies = [ "sqlparser", ] -[[package]] -name = "datafusion-substrait" -version = "52.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "199790fd96e852997b30da4ff11109378c944841757d93875ea85fc69587ec91" -dependencies = [ - "async-recursion", - "async-trait", - "chrono", - "datafusion", - "half", - "itertools", - "object_store", - "pbjson-types", - "prost", - "substrait", - "tokio", - "url", -] - [[package]] name = "digest" version = "0.10.7" @@ -1533,12 +1502,6 @@ dependencies = [ "syn 2.0.114", ] -[[package]] -name = "dyn-clone" -version = "1.0.20" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d0881ea181b1df73ff77ffaaf9c7544ecc11e82fba9b5f27b262a3c73a332555" - [[package]] name = "either" version = "1.15.0" @@ -2196,12 +2159,6 @@ dependencies = [ "simd-adler32", ] -[[package]] -name = "multimap" -version = "0.10.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1d87ecb2933e8aeadb3e3a02b828fed80a7528047e68b4f424523a0981a3a084" - [[package]] name = "num-bigint" version = "0.4.6" @@ -2345,43 +2302,6 @@ version = "1.0.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" -[[package]] -name = "pbjson" -version = "0.8.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "898bac3fa00d0ba57a4e8289837e965baa2dee8c3749f3b11d45a64b4223d9c3" -dependencies = [ - "base64", - "serde", -] - -[[package]] -name = "pbjson-build" -version = "0.8.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "af22d08a625a2213a78dbb0ffa253318c5c79ce3133d32d296655a7bdfb02095" -dependencies = [ - "heck", - "itertools", - "prost", - "prost-types", -] - -[[package]] -name = "pbjson-types" -version = "0.8.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8e748e28374f10a330ee3bb9f29b828c0ac79831a32bab65015ad9b661ead526" -dependencies = [ - "bytes", - "chrono", - "pbjson", - "pbjson-build", - "prost", - "prost-build", - "serde", -] - [[package]] name = "percent-encoding" version = "2.3.2" @@ -2460,16 +2380,6 @@ dependencies = [ "zerocopy", ] -[[package]] -name = "prettyplease" -version = "0.2.37" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "479ca8adacdd7ce8f1fb39ce9ecccbfe93a3f1344b3d0d97f20bc0196208f62b" -dependencies = [ - "proc-macro2", - "syn 2.0.114", -] - [[package]] name = "proc-macro2" version = "1.0.101" @@ -2489,25 +2399,6 @@ dependencies = [ "prost-derive", ] -[[package]] -name = "prost-build" -version = "0.14.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "343d3bd7056eda839b03204e68deff7d1b13aba7af2b2fd16890697274262ee7" -dependencies = [ - "heck", - "itertools", - "log", - "multimap", - "petgraph", - "prettyplease", - "prost", - "prost-types", - "regex", - "syn 2.0.114", - "tempfile", -] - [[package]] name = "prost-derive" version = "0.14.3" @@ -2521,15 +2412,6 @@ dependencies = [ "syn 2.0.114", ] -[[package]] -name = "prost-types" -version = "0.14.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8991c4cbdb8bc5b11f0b074ffe286c30e523de90fee5ba8132f1399f23cb3dd7" -dependencies = [ - "prost", -] - [[package]] name = "psm" version = "0.1.26" @@ -2702,16 +2584,6 @@ version = "0.8.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "caf4aa5b0f434c91fe5c7f1ecb6a5ece2130b02ad2a590589dda5146df959001" -[[package]] -name = "regress" -version = "0.10.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2057b2325e68a893284d1538021ab90279adac1139957ca2a74426c6f118fb48" -dependencies = [ - "hashbrown 0.16.1", - "memchr", -] - [[package]] name = "repr_offset" version = "0.2.2" @@ -2764,30 +2636,6 @@ dependencies = [ "winapi-util", ] -[[package]] -name = "schemars" -version = "0.8.22" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3fbf2ae1b8bc8e02df939598064d22402220cd5bbcca1c76f7d6a310974d5615" -dependencies = [ - "dyn-clone", - "schemars_derive", - "serde", - "serde_json", -] - -[[package]] -name = "schemars_derive" -version = "0.8.22" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "32e265784ad618884abaea0600a9adf15393368d840e0222d101a072f3f7534d" -dependencies = [ - "proc-macro2", - "quote", - "serde_derive_internals", - "syn 2.0.114", -] - [[package]] name = "scopeguard" version = "1.2.0" @@ -2799,10 +2647,6 @@ name = "semver" version = "1.0.27" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d767eb0aabc880b29956c35734170f26ed551a859dbd361d140cdbeca61ab1e2" -dependencies = [ - "serde", - "serde_core", -] [[package]] name = "seq-macro" @@ -2840,17 +2684,6 @@ dependencies = [ "syn 2.0.114", ] -[[package]] -name = "serde_derive_internals" -version = "0.29.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "18d26a20a969b9e3fdf2fc2d9f21eda6c40e2de84c9408bb5d3b05d499aae711" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.114", -] - [[package]] name = "serde_json" version = "1.0.145" @@ -2864,31 +2697,6 @@ dependencies = [ "serde_core", ] -[[package]] -name = "serde_tokenstream" -version = "0.2.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d7c49585c52c01f13c5c2ebb333f14f6885d76daa768d8a037d28017ec538c69" -dependencies = [ - "proc-macro2", - "quote", - "serde", - "syn 2.0.114", -] - -[[package]] -name = "serde_yaml" -version = "0.9.34+deprecated" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6a8b1a1a2ebf674015cc02edccce75287f1a0130d394307b36743c2f5d504b47" -dependencies = [ - "indexmap", - "itoa", - "ryu", - "serde", - "unsafe-libyaml", -] - [[package]] name = "sha2" version = "0.10.9" @@ -2983,31 +2791,6 @@ dependencies = [ "windows-sys 0.59.0", ] -[[package]] -name = "substrait" -version = "0.62.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "62fc4b483a129b9772ccb9c3f7945a472112fdd9140da87f8a4e7f1d44e045d0" -dependencies = [ - "heck", - "pbjson", - "pbjson-build", - "pbjson-types", - "prettyplease", - "prost", - "prost-build", - "prost-types", - "regress", - "schemars", - "semver", - "serde", - "serde_json", - "serde_yaml", - "syn 2.0.114", - "typify", - "walkdir", -] - [[package]] name = "subtle" version = "2.6.1" @@ -3221,53 +3004,6 @@ version = "1.14.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f8c1ae7cc0fdb8b842d65d127cb981574b0d2b249b74d1c7a2986863dc134f71" -[[package]] -name = "typify" -version = "0.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e6d5bcc6f62eb1fa8aa4098f39b29f93dcb914e17158b76c50360911257aa629" -dependencies = [ - "typify-impl", - "typify-macro", -] - -[[package]] -name = "typify-impl" -version = "0.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a1eb359f7ffa4f9ebe947fa11a1b2da054564502968db5f317b7e37693cb2240" -dependencies = [ - "heck", - "log", - "proc-macro2", - "quote", - "regress", - "schemars", - "semver", - "serde", - "serde_json", - "syn 2.0.114", - "thiserror", - "unicode-ident", -] - -[[package]] -name = "typify-macro" -version = "0.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "911c32f3c8514b048c1b228361bebb5e6d73aeec01696e8cc0e82e2ffef8ab7a" -dependencies = [ - "proc-macro2", - "quote", - "schemars", - "semver", - "serde", - "serde_json", - "serde_tokenstream", - "syn 2.0.114", - "typify-impl", -] - [[package]] name = "unicode-ident" version = "1.0.19" @@ -3292,12 +3028,6 @@ version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7264e107f553ccae879d21fbea1d6724ac785e8c3bfc762137959b5802826ef3" -[[package]] -name = "unsafe-libyaml" -version = "0.2.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "673aac59facbab8a9007c7f6108d11f63b603f7cabff99fabf650fea5c32b861" - [[package]] name = "url" version = "2.5.7" @@ -3652,11 +3382,10 @@ dependencies = [ "async-trait", "datafusion", "datafusion-ffi", - "datafusion-substrait", "futures", - "prost", "pyo3", "pyo3-build-config", + "sqlparser", "tokio", ] diff --git a/Cargo.toml b/Cargo.toml index 1fc5288..21f5b43 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,8 +23,7 @@ async-stream = "0.3" async-trait = "0.1" datafusion = { version = "52.0.0" } datafusion-ffi = { version = "52.0.0" } -datafusion-substrait = { version = "52.0.0" } -prost = "0.14" +sqlparser = { version = "0.59", features = ["visitor"] } futures = { version = "0.3" } pyo3 = { version = "0.26.0", features = ["extension-module"] } tokio = { version = "1.46.1", features = ["rt"] } diff --git a/src/autograd.rs b/src/autograd.rs index 7b3755d..c729ca3 100644 --- a/src/autograd.rs +++ b/src/autograd.rs @@ -54,16 +54,23 @@ use std::any::Any; use std::collections::HashMap; use std::f64::consts::{LN_10, LN_2}; +use std::ops::ControlFlow; +use std::sync::Arc; -use datafusion::arrow::datatypes::DataType; +use datafusion::arrow::datatypes::{DataType, Field}; use datafusion::common::tree_node::{Transformed, TreeNode}; -use datafusion::common::{DataFusionError, Result, ScalarValue}; +use datafusion::common::{DFSchema, DataFusionError, Result, ScalarValue, TableReference}; use datafusion::functions::math::expr_fn; use datafusion::logical_expr::expr::ScalarFunction; use datafusion::logical_expr::{ lit, BinaryExpr, Cast, ColumnarValue, Expr, LogicalPlan, Operator, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, Volatility, }; +use datafusion::prelude::SessionContext; +use datafusion::sql::unparser::expr_to_sql; +use sqlparser::ast::{Expr as SqlExpr, Visit, VisitMut, Visitor, VisitorMut}; +use sqlparser::dialect::GenericDialect; +use sqlparser::parser::Parser; // --------------------------------------------------------------------------- // Constant helpers and the 0/1-folding builders @@ -522,6 +529,142 @@ fn rewrite_vjp(args: &[Expr]) -> Result { Ok(mul(args[2].clone(), derivative)) } +// --------------------------------------------------------------------------- +// SQL source-to-source rewrite +// --------------------------------------------------------------------------- + +/// Rewrite every `grad`/`jvp`/`vjp` call in a SQL statement into its symbolic +/// derivative, returning the rewritten SQL text. +/// +/// Unlike a logical-plan rewrite, this is a pure source-to-source transform run +/// *before* the query is planned, so it works for any query shape the SQL parser +/// accepts — recursive CTEs, DML, and subqueries included. Each marker call is +/// parsed into a DataFusion [`Expr`], differentiated by the engine in this +/// module, and rendered back to SQL in place. Columns are taken from the call's +/// own identifiers (all treated as `Float64`; types don't affect the symbolic +/// result), so no catalog or table schema is needed. +pub fn rewrite_grad_in_sql(sql: &str) -> Result { + let dialect = GenericDialect {}; + let mut statements = Parser::parse_sql(&dialect, sql) + .map_err(|e| DataFusionError::Plan(format!("grad: failed to parse SQL: {e}")))?; + + // A throwaway context that only needs the marker UDFs registered so the + // calls parse into `ScalarFunction` nodes the engine can dispatch on. + let ctx = SessionContext::new(); + ctx.register_udf(grad_marker()); + ctx.register_udf(jvp_marker()); + ctx.register_udf(vjp_marker()); + + let mut rewriter = GradSqlRewriter { ctx: &ctx }; + for stmt in &mut statements { + if let ControlFlow::Break(msg) = stmt.visit(&mut rewriter) { + return Err(DataFusionError::Plan(msg)); + } + } + + Ok(statements + .iter() + .map(ToString::to_string) + .collect::>() + .join("; ")) +} + +/// True if `name` is one of the autograd marker functions (case-insensitive). +fn is_marker_name(name: &str) -> bool { + matches!(name.to_lowercase().as_str(), "grad" | "jvp" | "vjp") +} + +/// Walks a SQL AST and replaces each `grad`/`jvp`/`vjp` call with its derivative. +struct GradSqlRewriter<'a> { + ctx: &'a SessionContext, +} + +impl VisitorMut for GradSqlRewriter<'_> { + type Break = String; + + fn pre_visit_expr(&mut self, expr: &mut SqlExpr) -> ControlFlow { + let is_marker = matches!( + expr, + SqlExpr::Function(f) if is_marker_name(&f.name.to_string()) + ); + if !is_marker { + return ControlFlow::Continue(()); + } + match self.rewrite_call(expr) { + Ok(()) => ControlFlow::Continue(()), + Err(e) => ControlFlow::Break(e), + } + } +} + +impl GradSqlRewriter<'_> { + /// Differentiate a single marker call in place. The replacement is wrapped + /// in parentheses so it keeps the call's precedence in the surrounding SQL. + fn rewrite_call(&self, expr: &mut SqlExpr) -> std::result::Result<(), String> { + let schema = call_schema(expr)?; + let text = expr.to_string(); + let parsed = self + .ctx + .parse_sql_expr(&text, &schema) + .map_err(|e| format!("grad: failed to parse '{text}': {e}"))?; + let derivative = rewrite_grad_in_expr(parsed) + .map_err(|e| format!("grad: failed to differentiate '{text}': {e}"))? + .data; + let rendered = expr_to_sql(&derivative) + .map_err(|e| format!("grad: failed to render derivative for '{text}': {e}"))?; + *expr = SqlExpr::Nested(Box::new(rendered)); + Ok(()) + } +} + +/// Build a `Float64` schema covering every column identifier referenced inside a +/// marker call, so the call's argument expression can be parsed standalone. +fn call_schema(call: &SqlExpr) -> std::result::Result { + let mut collector = ColumnCollector::default(); + let _ = call.visit(&mut collector); + let fields = collector + .cols + .into_iter() + .map(|(qualifier, name)| { + let qualifier = qualifier.map(TableReference::bare); + ( + qualifier, + Arc::new(Field::new(name, DataType::Float64, true)), + ) + }) + .collect(); + DFSchema::new_with_metadata(fields, HashMap::new()) + .map_err(|e| format!("grad: failed to build schema for differentiation: {e}")) +} + +/// Collects the (optional qualifier, name) of every column identifier in a SQL +/// expression tree. +#[derive(Default)] +struct ColumnCollector { + cols: Vec<(Option, String)>, +} + +impl Visitor for ColumnCollector { + type Break = (); + + fn pre_visit_expr(&mut self, expr: &SqlExpr) -> ControlFlow<()> { + let pair = match expr { + SqlExpr::Identifier(ident) => Some((None, ident.value.clone())), + SqlExpr::CompoundIdentifier(parts) => parts.last().map(|last| { + let qualifier = (parts.len() >= 2).then(|| parts[parts.len() - 2].value.clone()); + (qualifier, last.value.clone()) + }), + _ => None, + }; + if let Some(pair) = pair { + if !self.cols.contains(&pair) { + self.cols.push(pair); + } + } + ControlFlow::Continue(()) + } +} + // --------------------------------------------------------------------------- // Tests // --------------------------------------------------------------------------- @@ -650,4 +793,46 @@ mod tests { let rev = rewrite_vjp(&[f, col("x"), one()]).unwrap(); assert_eq!(fwd, rev); } + + #[test] + fn sql_rewrite_replaces_grad_call() { + // grad(sin(x), x) -> cos(x); the surrounding SELECT is preserved. + let out = rewrite_grad_in_sql("SELECT grad(sin(x), x) AS d FROM t").unwrap(); + assert_eq!(out, "SELECT (cos(x)) AS d FROM t"); + } + + #[test] + fn sql_rewrite_leaves_non_grad_queries_intact() { + // A query with no marker is still parsed and re-emitted unchanged in + // meaning (the caller only invokes the rewrite when a marker is present). + let out = rewrite_grad_in_sql("SELECT a + b FROM t").unwrap(); + assert_eq!(out, "SELECT a + b FROM t"); + } + + #[test] + fn sql_rewrite_fires_inside_recursive_cte() { + // The #197 capability: a marker inside a recursive term is rewritten, + // a query shape the Substrait bridge could never carry. d/dx(x*x) = x+x. + let out = rewrite_grad_in_sql( + "WITH RECURSIVE r AS (SELECT 1.0 AS x UNION ALL \ + SELECT x - grad(x * x, x) FROM r WHERE x < 10) SELECT x FROM r", + ) + .unwrap(); + assert!(out.contains("(x + x)"), "unexpected rewrite: {out}"); + assert!( + !out.to_lowercase().contains("grad("), + "marker left behind: {out}" + ); + } + + #[test] + fn sql_rewrite_handles_nested_higher_order_grad() { + // grad(grad(power(x, 3), x), x) -> d2/dx2 (x^3) = 6x; bottom-up so the + // inner call is differentiated before the outer one. + let out = rewrite_grad_in_sql("SELECT grad(grad(power(x, 3), x), x) AS d FROM t").unwrap(); + assert!( + !out.to_lowercase().contains("grad("), + "marker left behind: {out}" + ); + } } diff --git a/src/lib.rs b/src/lib.rs index 019cd57..d157d72 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -55,11 +55,8 @@ use arrow::pyarrow::FromPyArrow; use async_stream::try_stream; use async_trait::async_trait; use datafusion::catalog::streaming::StreamingTable; -use datafusion::catalog::{MemorySchemaProvider, Session}; -use datafusion::common::{ - DFSchema, DataFusionError, Result as DFResult, ScalarValue, TableReference, -}; -use datafusion::datasource::empty::EmptyTable; +use datafusion::catalog::Session; +use datafusion::common::{DFSchema, DataFusionError, Result as DFResult, ScalarValue}; use datafusion::datasource::TableProvider; use datafusion::execution::TaskContext; use datafusion::logical_expr::expr::InList; @@ -73,12 +70,8 @@ use datafusion::prelude::SessionContext; use datafusion::sql::unparser::expr_to_sql; use datafusion_ffi::proto::logical_extension_codec::FFI_LogicalExtensionCodec; use datafusion_ffi::table_provider::FFI_TableProvider; -use datafusion_substrait::logical_plan::consumer::from_substrait_plan; -use datafusion_substrait::logical_plan::producer::to_substrait_plan; -use datafusion_substrait::substrait::proto::Plan; -use prost::Message; use pyo3::prelude::*; -use pyo3::types::{PyBytes, PyCapsule, PyList}; +use pyo3::types::{PyCapsule, PyList}; // ============================================================================ // Partition Metadata Types for Filter Pushdown @@ -993,128 +986,31 @@ impl LazyArrowStreamTable { } // ============================================================================ -// Autograd: Substrait-level grad() rewrite +// Autograd: SQL-level grad() rewrite // ============================================================================ -/// Ensure a schema (namespace) exists in the context's catalog, creating an -/// empty in-memory one if needed. Used so the rewrite context can register -/// schema-qualified tables (e.g. `era5.surface`) that mixed-dimension datasets -/// produce. -fn ensure_schema(ctx: &SessionContext, catalog: Option<&str>, schema: &str) -> DFResult<()> { - // A bare TableReference has no catalog; fall back to DataFusion's default. - let catalog_name = catalog.unwrap_or("datafusion"); - let catalog = ctx - .catalog(catalog_name) - .ok_or_else(|| DataFusionError::Plan(format!("catalog '{catalog_name}' not found")))?; - if catalog.schema(schema).is_none() { - catalog.register_schema(schema, Arc::new(MemorySchemaProvider::new()))?; - } - Ok(()) -} - -/// Rewrite `grad(expr, column)` calls in a Substrait plan into their symbolic -/// derivatives. +/// Rewrite every `grad`/`jvp`/`vjp` call in a SQL query into its symbolic +/// derivative, returning the rewritten SQL text. /// -/// The autograd engine operates on DataFusion logical `Expr` trees. To apply it -/// inside the datafusion-python `SessionContext` (which links its own copy of -/// DataFusion), we move the plan across the boundary as Substrait protobuf: -/// Python produces the plan, this function consumes it into a DataFusion -/// `LogicalPlan`, rewrites every `grad(...)` into the differentiated -/// expression, and re-produces Substrait bytes for Python to consume and -/// execute. +/// The autograd engine operates on DataFusion logical `Expr` trees. Rather than +/// round-tripping a whole plan across the cdylib boundary, this rewrites the +/// query as **SQL text** before it is planned: each marker call is parsed, +/// differentiated, and rendered back to SQL in place. Because it runs before +/// planning, it works for any query shape the parser accepts — recursive CTEs, +/// DML, and subqueries — which the plan-level Substrait bridge could not carry. /// /// Args: -/// plan_bytes: A Substrait `Plan` protobuf, as produced by -/// datafusion-python's -/// ``Producer.to_substrait_plan(plan, ctx).encode()``. -/// tables: A list of ``(name, pyarrow.Schema)`` pairs for every table the -/// plan scans. The consumer resolves table references by name, so each -/// referenced table must be registered here with a matching schema -/// (the data itself is never read — an empty table suffices). +/// query: A SQL query string that may contain `grad`/`jvp`/`vjp` calls. /// /// Returns: -/// The rewritten Substrait `Plan` protobuf bytes, ready for -/// ``Consumer.from_substrait_plan(ctx, plan)``. +/// The rewritten SQL string, ready to pass to ``SessionContext.sql``. #[pyfunction] -fn grad_rewrite<'py>( - py: Python<'py>, - plan_bytes: &[u8], - tables: Vec<(String, Bound<'py, PyAny>)>, -) -> PyResult> { - // A fresh, data-free context purely for the rewrite. It needs the grad - // marker UDF (so the consumer can resolve the function) and an empty table - // per referenced name (so the consumer can resolve table scans). - let ctx = SessionContext::new(); - ctx.register_udf(autograd::grad_marker()); - ctx.register_udf(autograd::jvp_marker()); - ctx.register_udf(autograd::vjp_marker()); - - for (name, schema_obj) in &tables { - let schema = Schema::from_pyarrow_bound(schema_obj).map_err(|e| { - pyo3::exceptions::PyTypeError::new_err(format!( - "grad_rewrite: failed to convert schema for table '{name}': {e}" - )) - })?; - let provider = Arc::new(EmptyTable::new(Arc::new(schema))); - - // Schema-qualified names (e.g. "era5.surface", from a mixed-dimension - // dataset) need their namespace to exist before the table can be - // registered into this throwaway context. - let table_ref = TableReference::from(name.as_str()); - if let Some(schema_name) = table_ref.schema() { - ensure_schema(&ctx, table_ref.catalog(), schema_name).map_err(|e| { - pyo3::exceptions::PyValueError::new_err(format!( - "grad_rewrite: failed to create schema for table '{name}': {e}" - )) - })?; - } - - ctx.register_table(table_ref, provider).map_err(|e| { - pyo3::exceptions::PyValueError::new_err(format!( - "grad_rewrite: failed to register table '{name}': {e}" - )) - })?; - } - - let state = ctx.state(); - - let plan = Plan::decode(plan_bytes).map_err(|e| { +fn rewrite_grad_sql(query: &str) -> PyResult { + autograd::rewrite_grad_in_sql(query).map_err(|e| { pyo3::exceptions::PyValueError::new_err(format!( - "grad_rewrite: failed to decode Substrait plan: {e}" + "rewrite_grad_sql: failed to rewrite grad() calls: {e}" )) - })?; - - // from_substrait_plan is async but does no real I/O here (empty tables - // resolve immediately), so a minimal current-thread runtime suffices. - let runtime = tokio::runtime::Builder::new_current_thread() - .build() - .map_err(|e| { - pyo3::exceptions::PyRuntimeError::new_err(format!( - "grad_rewrite: failed to build runtime: {e}" - )) - })?; - - let logical = runtime - .block_on(from_substrait_plan(&state, &plan)) - .map_err(|e| { - pyo3::exceptions::PyValueError::new_err(format!( - "grad_rewrite: failed to consume Substrait plan: {e}" - )) - })?; - - let rewritten = autograd::rewrite_grad_calls(logical).map_err(|e| { - pyo3::exceptions::PyValueError::new_err(format!( - "grad_rewrite: failed to rewrite grad() calls: {e}" - )) - })?; - - let out_plan = to_substrait_plan(&rewritten, &state).map_err(|e| { - pyo3::exceptions::PyValueError::new_err(format!( - "grad_rewrite: failed to produce Substrait plan: {e}" - )) - })?; - - Ok(PyBytes::new(py, &out_plan.encode_to_vec())) + }) } /// Differentiate a SQL scalar expression symbolically and return the @@ -1173,7 +1069,7 @@ fn differentiate_sql(expr: &str, wrt: &str, columns: Vec) -> PyResult) -> PyResult<()> { m.add_class::()?; - m.add_function(wrap_pyfunction!(grad_rewrite, m)?)?; + m.add_function(wrap_pyfunction!(rewrite_grad_sql, m)?)?; m.add_function(wrap_pyfunction!(differentiate_sql, m)?)?; Ok(()) } diff --git a/tests/test_autograd.py b/tests/test_autograd.py index 2ccc7dd..794194e 100644 --- a/tests/test_autograd.py +++ b/tests/test_autograd.py @@ -1,7 +1,8 @@ """Tests for the SQL autograd surface: ``SELECT grad(expr, column) ...``. -These exercise the full path — XarrayContext.sql() -> Substrait -> native -grad_rewrite -> Substrait -> execute — and compare results against analytic +These exercise the full path — XarrayContext.sql() differentiates every +``grad``/``jvp``/``vjp`` call as SQL text before planning, then DataFusion +executes the rewritten query — and compare results against analytic derivatives computed with numpy. """ @@ -195,6 +196,25 @@ def test_gradient_descent_in_sql(): np.testing.assert_allclose([a, b], [a_true, b_true], atol=0.05) +def test_grad_inside_recursive_cte(): + # The headline of #197: grad() *inside* a recursive CTE — a query shape the + # old Substrait bridge could not represent. Newton's method for sqrt(2) + # drives the step with grad(x*x - 2, x) computed in the recursive term: + # x <- x - (x*x - 2) / d/dx(x*x - 2) = x - (x*x - 2) / (2x). + ctx = xql.XarrayContext() + res = ctx.sql( + "WITH RECURSIVE newton AS (" + " SELECT 0 AS step, CAST(1.0 AS DOUBLE) AS x " + " UNION ALL " + " SELECT step + 1 AS step, " + " x - (x * x - 2.0) / grad(x * x - 2.0, x) AS x " + " FROM newton WHERE step < 20" + ") " + "SELECT x FROM newton ORDER BY step DESC LIMIT 1" + ).to_pandas() + np.testing.assert_allclose(res["x"][0], np.sqrt(2.0), atol=1e-9) + + def test_multi_input_grad_columns(ctx_xy): # A full Jacobian written as separate scalar grad() columns: # f = x*y -> df/dx = y, df/dy = x. diff --git a/xarray_sql/sql.py b/xarray_sql/sql.py index 5e892d3..46fe8e6 100644 --- a/xarray_sql/sql.py +++ b/xarray_sql/sql.py @@ -1,10 +1,8 @@ import re -import pyarrow as pa import xarray as xr -from datafusion import SessionContext, udf +from datafusion import SessionContext from datafusion.catalog import Schema -from datafusion.substrait import Consumer, Producer, Serde from collections import defaultdict from . import _native @@ -14,8 +12,8 @@ from .reader import read_xarray_table # Matches a call to an autograd marker function (``grad(`` / ``jvp(`` / ``vjp(``, -# case-insensitive), used as a cheap gate so ordinary queries skip the -# Substrait round-trip. +# case-insensitive), used as a cheap gate so ordinary queries skip the grad +# source-to-source rewrite. _GRAD_CALL = re.compile(r"\b(grad|jvp|vjp)\s*\(", re.IGNORECASE) @@ -31,37 +29,6 @@ def __init__(self, *args, **kwargs): # in SQL (e.g. ``"air"`` for a uniform-dim Dataset, or # ``"era5.surface"`` for one entry from a multi-dim-group split). self._registered_datasets: dict[str, xr.Dataset] = {} - self._register_autograd_udfs() - - def _register_autograd_udfs(self) -> None: - """Register the ``grad`` / ``jvp`` / ``vjp`` marker UDFs. - - These are *markers*: they let queries parse and plan with the - differentiation request intact. They are never executed — the Substrait - rewrite in :meth:`sql` replaces every call with the symbolic derivative - before execution. All return a scalar, staying in the long/tidy data - model (one value per row). - - * ``grad(expr, column)`` -> ``d(expr)/d(column)``. - * ``jvp(expr, column, tangent)`` -> forward-mode directional derivative - ``d(expr)/d(column) * tangent`` (seed a tangent on an input). A - multi-input directional derivative is a sum of jvp terms. - * ``vjp(expr, column, cotangent)`` -> reverse-mode pullback - ``cotangent * d(expr)/d(column)`` (seed a cotangent on the output). - - A full gradient/Jacobian is expressed as several scalar columns, e.g. - ``grad(f, x) AS dfdx, grad(f, y) AS dfdy``. - """ - f64 = pa.float64() - self.register_udf( - udf(lambda e, c: e, [f64, f64], f64, "immutable", "grad") - ) - self.register_udf( - udf(lambda e, c, t: e, [f64, f64, f64], f64, "immutable", "jvp") - ) - self.register_udf( - udf(lambda e, c, w: e, [f64, f64, f64], f64, "immutable", "vjp") - ) def from_dataset( self, @@ -207,6 +174,11 @@ def sql(self, query: str, *args, **kwargs) -> XarrayDataFrame: ``.to_dataset(dimension_columns=[...])`` for round-tripping the result back to an ``xr.Dataset``. + If the query contains ``grad`` / ``jvp`` / ``vjp`` calls, they are + differentiated and substituted as SQL text *before* planning (see + :meth:`_rewrite_autograd`), so the differentiation works inside any + query shape — recursive CTEs, DML, and subqueries included. + Args: query: A SQL query string. *args: Forwarded to ``SessionContext.sql``. @@ -216,67 +188,32 @@ def sql(self, query: str, *args, **kwargs) -> XarrayDataFrame: An :class:`XarrayDataFrame` wrapping the DataFusion DataFrame. """ if _GRAD_CALL.search(query): - inner = self._sql_with_autograd(query, *args, **kwargs) - else: - inner = super().sql(query, *args, **kwargs) + query = self._rewrite_autograd(query) + inner = super().sql(query, *args, **kwargs) return XarrayDataFrame(inner, templates=self._registered_datasets) - def _sql_with_autograd(self, query: str, *args, **kwargs): - """Plan ``query``, rewrite ``grad(...)`` calls, return a DataFrame. + def _rewrite_autograd(self, query: str) -> str: + """Differentiate ``grad`` / ``jvp`` / ``vjp`` calls into SQL text. The differentiation engine lives in the native (Rust) extension and - operates on DataFusion logical expressions. Since that extension links - its own copy of DataFusion, the plan crosses the boundary as Substrait: - we produce the logical plan as Substrait, hand it to ``grad_rewrite`` - (which differentiates every ``grad(expr, column)`` symbolically), then - consume the rewritten Substrait back into an executable DataFrame. - """ - plan = super().sql(query, *args, **kwargs).logical_plan() - substrait_plan = Producer.to_substrait_plan(plan, self) - rewritten = _native.grad_rewrite( - substrait_plan.encode(), self._table_schemas() - ) - new_plan = Consumer.from_substrait_plan( - self, Serde.deserialize_bytes(rewritten) - ) - return self.create_dataframe_from_logical_plan(new_plan) + operates on DataFusion logical expressions. Rather than round-tripping a + whole plan across that extension's boundary, we hand it the query as SQL + text: it parses each marker call, differentiates it symbolically, and + renders the derivative back into the query in place. The result is an + ordinary SQL string this context can plan and execute directly. - def _table_schemas(self) -> list[tuple[str, pa.Schema]]: - """Return ``(name, schema)`` for every table registered in the context. + * ``grad(expr, column)`` -> ``d(expr)/d(column)``. + * ``jvp(expr, column, tangent)`` -> forward-mode directional derivative + ``d(expr)/d(column) * tangent`` (seed a tangent on an input). A + multi-input directional derivative is a sum of jvp terms. + * ``vjp(expr, column, cotangent)`` -> reverse-mode pullback + ``cotangent * d(expr)/d(column)`` (seed a cotangent on the output). - The Substrait consumer in ``grad_rewrite`` resolves table scans by name, - so it needs the schema of every table the plan might reference. We - enumerate the catalog rather than only the xarray-registered datasets, - so ``grad`` also works over plain DataFusion tables (e.g. in-memory - ``MemTable``s holding model parameters or intermediate results). Only - metadata is read here — never the underlying data. + A full gradient/Jacobian is expressed as several scalar columns, e.g. + ``grad(f, x) AS dfdx, grad(f, y) AS dfdy``. """ - schemas = [] - catalog = self.catalog() - for schema_name in catalog.schema_names(): - if schema_name == "information_schema": - continue - schema = catalog.schema(schema_name) - names = ( - schema.table_names() - if hasattr(schema, "table_names") - else schema.names() - ) - for table_name in names: - # Tables in the default schema are referenced bare ("air"); - # others are schema-qualified ("era5.surface"). - qualified = ( - table_name - if schema_name in ("public", "default") - else f"{schema_name}.{table_name}" - ) - try: - schemas.append((qualified, self.table(qualified).schema())) - except Exception: - # Be defensive: skip a table we can't introspect rather - # than failing the whole query. - continue - return schemas + rewritten: str = _native.rewrite_grad_sql(query) + return rewritten def _group_vars_by_dims(ds: xr.Dataset) -> dict[tuple[str, ...], list[str]]: From 67286dbb237ddef7a21c111a761fd2c8d86da43d Mon Sep 17 00:00:00 2001 From: Claude Date: Sun, 28 Jun 2026 14:52:29 +0000 Subject: [PATCH 13/17] Add differentiable-SQL demos: ARCO-ERA5 and gradient descent Stacked demo branch (on the autograd feature) holding the runnable benchmark scripts, kept out of the core branch so it stays reviewable. * grad_era5.py: symbolic grad over real ARCO-ERA5 data (wind-speed sensitivity checked exactly; saturation vapour pressure checked against the closed-form Clausius-Clapeyron slope). The queries ORDER BY latitude DESC, longitude to match ERA5's native order, so results line up with the xarray reference with no sorting on either side (single partition, so the order survives to_dataset). * grad_descent.py: gradient descent as ONE declarative recursive-CTE query. differentiate_sql compiles the per-row update rule to SQL once; a recursive CTE then iterates it. No Python loop. Fit matches numpy least-squares. Co-Authored-By: Claude Opus 4.8 Claude-Session: https://claude.ai/code/session_017mDoFJgsm9kS7SicGoCVF6 --- benchmarks/README.md | 64 ++++++++++++++ benchmarks/grad_descent.py | 115 +++++++++++++++++++++++++ benchmarks/grad_era5.py | 171 +++++++++++++++++++++++++++++++++++++ 3 files changed, 350 insertions(+) create mode 100644 benchmarks/README.md create mode 100644 benchmarks/grad_descent.py create mode 100644 benchmarks/grad_era5.py diff --git a/benchmarks/README.md b/benchmarks/README.md new file mode 100644 index 0000000..5a0188c --- /dev/null +++ b/benchmarks/README.md @@ -0,0 +1,64 @@ +# Benchmarks & demos + +Standalone scripts that exercise xarray-sql against real data. Each declares its +own dependencies inline (PEP 723) and points `xarray_sql` at this checkout, so +they run with no setup: + +```bash +uv run benchmarks/grad_era5.py +``` + +## `grad_era5.py` — differentiable SQL over ARCO-ERA5 + +Demonstrates the autograd feature on a real climate archive +([ARCO-ERA5](https://github.com/google-research/arco-era5), read anonymously +from GCS — needs `gcsfs` and network access). + +The key idea: a physical quantity is written as an **analytic SQL formula** over +ERA5 variables, and `grad(...)` differentiates that formula **symbolically**, +evaluated at every grid cell. Because each row is an independent point, this is +the relational equivalent of `jax.vmap(jax.grad(f))`. It is *not* a finite- +difference spatial gradient — `grad(f(u, v), u)` is the exact partial derivative +of `f`. + +Two worked cases, each checked against an analytic reference: + +| Quantity | SQL | Derivative | Check | +| --- | --- | --- | --- | +| Wind speed | `sqrt(power(u,2) + power(v,2))` | `grad(speed, u) = u/speed` | exact | +| Saturation vapour pressure | `A*exp(B*tc/(tc+C))` | `grad(e_s, T)` | closed-form Clausius-Clapeyron slope | + +Each query round-trips back to an `xarray.Dataset` via `.to_dataset(...)`. + +## `grad_descent.py` — gradient descent as one declarative SQL query + +Fits a line `y ~= a*x + b` by minimising the mean squared error, with the +**entire training loop expressed as a single recursive CTE** — no Python +iteration. Two pieces: + +- **`grad` compiles the update rule.** `xql.differentiate_sql(loss, "a", cols)` + turns the per-row loss into its symbolic derivative *as SQL text* — the + autograd engine as a calculus compiler. +- **A recursive CTE is the optimiser.** `params(step, a, b)` starts at one row + and each recursion appends the next generation, descending along the gradient + (`AVG` of the compiled rule over the data): + + ```sql + WITH RECURSIVE params(step, a, b) AS ( + SELECT 0, 0.0, 0.0 + UNION ALL + SELECT params.step + 1, params.a - lr*AVG(da), params.b - lr*AVG(db) + FROM params CROSS JOIN d WHERE params.step < STEPS + GROUP BY params.step, params.a, params.b) + SELECT * FROM params ORDER BY step + ``` + +So gradient, update, and iteration are all declarative SQL; the trajectory is +the rows of one query. The fit matches numpy's least-squares solution. +Self-contained (no network). + +(Why differentiate to text instead of `grad(...)` inside the recursion? `grad` +needs the Substrait round-trip, and Substrait has no recursion — so a `grad` +marker can't live inside a recursive CTE. Differentiating once to plain SQL +sidesteps that.) + diff --git a/benchmarks/grad_descent.py b/benchmarks/grad_descent.py new file mode 100644 index 0000000..daff207 --- /dev/null +++ b/benchmarks/grad_descent.py @@ -0,0 +1,115 @@ +# /// script +# requires-python = ">=3.10" +# dependencies = [ +# "xarray_sql", +# "xarray", +# "numpy", +# ] +# +# [tool.uv.sources] +# xarray_sql = { path = "..", editable = true } +# /// +"""Gradient descent as a single declarative SQL query. + +Fits a line ``y ~= a*x + b`` by minimising the mean squared error — with the +**entire training loop expressed as one recursive CTE**, no Python iteration. + +Two pieces: + +1. **grad compiles the update rule.** ``differentiate_sql`` turns the per-row + loss into the symbolic derivative *as SQL text* — the autograd engine acting + as a calculus compiler: + + da = differentiate_sql("(y-(a*x+b))^2", "a") # -> "-2*((a*x+b)-y)*x", etc. + +2. **A recursive CTE is the optimiser.** ``params(step, a, b)`` starts at one + row and each recursion appends the next generation, descending along the + gradient (``AVG`` of the compiled rule over the data): + + params.a - lr * AVG(da) + + So the whole loop — gradient, update, and iteration — is declarative SQL; + the optimisation trajectory is the rows of one query. + +Why two pieces instead of ``grad(...)`` directly inside the recursion? ``grad`` +needs the Substrait round-trip, and Substrait has no recursion — so ``grad`` +can't live inside a recursive CTE (tracked in #194 / a follow-up). Differentiating +once to plain SQL sidesteps that: the recursive query contains no ``grad`` marker. + +Run standalone: + + uv run benchmarks/grad_descent.py +""" + +from __future__ import annotations + +import numpy as np +import xarray as xr + +import xarray_sql as xql + +# Per-row loss r^2 with residual r = y - (a*x + b), over columns a, b, x, y. +RESIDUAL = "(y - (a * x + b))" +LOSS = f"{RESIDUAL} * {RESIDUAL}" +COLUMNS = ["a", "b", "x", "y"] +LR = 0.4 +STEPS = 200 + + +def main() -> None: + rng = np.random.default_rng(0) + n = 500 + x = rng.uniform(0.0, 1.0, n) + a_true, b_true = 2.0, -1.0 + y = a_true * x + b_true + rng.normal(0.0, 0.01, n) + + ctx = xql.XarrayContext() + ctx.from_dataset( + "d", + xr.Dataset( + {"x": (("i",), x), "y": (("i",), y)}, coords={"i": np.arange(n)} + ), + chunks={"i": n}, + ) + + # grad compiles the per-row update rule to SQL, once. + da = xql.differentiate_sql(LOSS, "a", COLUMNS) + db = xql.differentiate_sql(LOSS, "b", COLUMNS) + print(f"d(loss)/da = {da}") + print(f"d(loss)/db = {db}\n") + + # The entire training loop is one declarative recursive query: each step + # appends the next generation, descending along the SQL-computed gradient. + trajectory = ctx.sql( + f""" + WITH RECURSIVE params(step, a, b) AS ( + SELECT 0 AS step, CAST(0.0 AS DOUBLE) AS a, CAST(0.0 AS DOUBLE) AS b + UNION ALL + SELECT params.step + 1 AS step, + params.a - {LR} * AVG({da}) AS a, + params.b - {LR} * AVG({db}) AS b + FROM params CROSS JOIN d + WHERE params.step < {STEPS} + GROUP BY params.step, params.a, params.b + ) + SELECT step, a, b FROM params ORDER BY step + """ + ).to_pandas() + + print("trajectory (every 40th generation):") + print(trajectory.iloc[::40].to_string(index=False)) + + a, b = float(trajectory["a"].iloc[-1]), float(trajectory["b"].iloc[-1]) + a_ols, b_ols = np.polyfit(x, y, 1) + print( + f"\nSQL gradient descent: a={a:.4f} b={b:.4f} ({len(trajectory)} generations)" + ) + print(f"least-squares (numpy): a={a_ols:.4f} b={b_ols:.4f}") + assert abs(a - a_ols) < 1e-2 and abs(b - b_ols) < 1e-2 + print( + "\nOK: a single recursive-CTE query fit the line to the OLS solution." + ) + + +if __name__ == "__main__": + main() diff --git a/benchmarks/grad_era5.py b/benchmarks/grad_era5.py new file mode 100644 index 0000000..866f066 --- /dev/null +++ b/benchmarks/grad_era5.py @@ -0,0 +1,171 @@ +# /// script +# requires-python = ">=3.10" +# dependencies = [ +# "xarray_sql", +# "xarray[io]", +# "gcsfs", +# "numpy", +# ] +# +# [tool.uv.sources] +# xarray_sql = { path = "..", editable = true } +# /// +"""Differentiable SQL over ARCO-ERA5. + +A minimal demonstration of xarray-sql's autograd: take a real climate archive +(ARCO-ERA5, read anonymously from GCS), express a physical quantity as an +*analytic* SQL formula over its variables, and let ``grad(...)`` differentiate +that formula symbolically — evaluated per grid cell, which is the relational +equivalent of ``jax.vmap(jax.grad(f))`` (each row is an independent point). + +Note this is *symbolic* differentiation of an expression, not a finite- +difference spatial gradient: ``grad(f(u, v), u)`` is the exact partial +derivative of the formula ``f``, evaluated at every cell's values. + +Two cases: + +1. Wind-speed magnitude ``speed = sqrt(u^2 + v^2)``. Its sensitivity to the + eastward wind is ``d(speed)/du = u / speed`` — checked exactly. + +2. Saturation vapour pressure ``e_s(T)`` (August-Roche-Magnus form of the + Clausius-Clapeyron relation). ``d(e_s)/dT`` governs how fast the atmosphere's + moisture capacity grows with temperature — checked against the closed-form + slope. + +Run standalone (builds the local extension on first use): + + uv run benchmarks/grad_era5.py +""" + +from __future__ import annotations + +import time + +import numpy as np +import xarray as xr + +import xarray_sql as xql + +ARCO_ERA5 = ( + "gs://gcp-public-data-arco-era5/ar/full_37-1h-0p25deg-chunk-1.zarr-v3" +) + +# ERA5 variable names start with a digit, so they must be double-quoted in SQL. +U = '"10m_u_component_of_wind"' +V = '"10m_v_component_of_wind"' +T = '"2m_temperature"' + + +def load_era5_block() -> xr.Dataset: + """Open ARCO-ERA5 and pull one timestamp over a small region. + + Lazy open of the whole archive; only the requested block is read. We keep + it to a few thousand cells so the demo runs in seconds. + """ + full = xr.open_zarr( + ARCO_ERA5, chunks=None, storage_options={"token": "anon"} + ) + block = ( + full[ + [ + "10m_u_component_of_wind", + "10m_v_component_of_wind", + "2m_temperature", + ] + ] + .sel(time="2020-01-01T00") + # A ~North-America box (index-based to avoid lat-orientation pitfalls). + .isel(latitude=slice(120, 200), longitude=slice(900, 1000)) + .load() + ) + # One partition, so a SQL `ORDER BY latitude DESC` survives the round-trip + # back to xarray (across multiple partitions, to_dataset reconstructs + # coordinates in ascending order regardless of ORDER BY). + return block.chunk() + + +def wind_speed_sensitivity(ctx: xql.XarrayContext, ref: xr.Dataset) -> None: + """grad(sqrt(u^2 + v^2)) checked against the exact u / speed, v / speed.""" + speed = f"sqrt(power({U}, 2) + power({V}, 2))" + out = ctx.sql( + f""" + SELECT + latitude, + longitude, + {speed} AS wind_speed, + grad({speed}, {U}) AS d_speed_d_u, + grad({speed}, {V}) AS d_speed_d_v + FROM era5 + ORDER BY latitude DESC, longitude + """ + ).to_dataset(dims=["latitude", "longitude"]) + + u = ref["10m_u_component_of_wind"] + v = ref["10m_v_component_of_wind"] + speed_ref = np.sqrt(u**2 + v**2) + + xr.testing.assert_allclose( + out["wind_speed"], speed_ref.rename("wind_speed") + ) + xr.testing.assert_allclose( + out["d_speed_d_u"], (u / speed_ref).rename("d_speed_d_u") + ) + xr.testing.assert_allclose( + out["d_speed_d_v"], (v / speed_ref).rename("d_speed_d_v") + ) + print(" wind-speed sensitivity matches u/|w|, v/|w| exactly") + print(out) + + +def clausius_clapeyron(ctx: xql.XarrayContext, ref: xr.Dataset) -> None: + """grad(e_s(T)) checked against the closed-form Clausius-Clapeyron slope.""" + # August-Roche-Magnus: e_s(T) = A * exp(B * tc / (tc + C)), tc = T - 273.15. + a, b, c = 6.1094, 17.625, 243.04 + tc = f"({T} - 273.15)" + es = f"{a} * exp({b} * {tc} / ({tc} + {c}))" + out = ctx.sql( + f""" + SELECT + latitude, + longitude, + {es} AS e_s, + grad({es}, {T}) AS de_s_dt + FROM era5 + ORDER BY latitude DESC, longitude + """ + ).to_dataset(dims=["latitude", "longitude"]) + + # Reference in float64 (the columns are float32): the exact derivative is + # d(e_s)/dT = e_s * B*C / (tc + C)^2. + temp = ref["2m_temperature"].astype("float64") + tc_ref = temp - 273.15 + es_ref = a * np.exp(b * tc_ref / (tc_ref + c)) + des_dt_ref = es_ref * (b * c) / (tc_ref + c) ** 2 + + xr.testing.assert_allclose(out["e_s"], es_ref.rename("e_s"), rtol=1e-5) + xr.testing.assert_allclose( + out["de_s_dt"], des_dt_ref.rename("de_s_dt"), rtol=1e-5 + ) + print(" d(e_s)/dT matches the closed-form Clausius-Clapeyron slope") + print(out) + + +def main() -> None: + t0 = time.time() + ds = load_era5_block() + print(f"loaded ERA5 block {dict(ds.sizes)} in {time.time() - t0:.1f}s") + + ctx = xql.XarrayContext() + ctx.from_dataset("era5", ds) + + print("\n== wind-speed sensitivity: grad(sqrt(u^2 + v^2)) ==") + wind_speed_sensitivity(ctx, ds) + + print("\n== Clausius-Clapeyron: grad(e_s(T)) ==") + clausius_clapeyron(ctx, ds) + + print("\nOK: symbolic SQL gradients match the analytic references.") + + +if __name__ == "__main__": + main() From b8d3e83da8e802de1afc2e3ab4e8acd7e4503d0a Mon Sep 17 00:00:00 2001 From: Claude Date: Sun, 28 Jun 2026 14:46:27 +0000 Subject: [PATCH 14/17] Add MNIST MLP trained in SQL (benchmarks/mnist_mlp.py) A one-hidden-layer MLP (196->32 tanh->10 softmax, on 2x2-pooled 14x14 MNIST) trained by gradient descent with every gradient computed in SQL. The images are registered as xarray (the library's core); the model weights and per-step intermediates are DataFusion in-memory tables (register_record_batches), so a matmul is a join over them and there's no xarray pivot per step. Reverse-mode autodiff as relational algebra: matmul = join + GROUP BY SUM; the hidden activation's local Jacobian = grad(tanh(z), z); cotangent propagation = join; parameter gradients = join + GROUP BY AVG. The only hand-written gradient is softmax + cross-entropy's delta = softmax - onehot. ~83% test accuracy in ~20s. Adds a benchmarks README entry. Co-Authored-By: Claude Opus 4.8 Claude-Session: https://claude.ai/code/session_017mDoFJgsm9kS7SicGoCVF6 --- benchmarks/README.md | 21 +++ benchmarks/mnist_mlp.py | 311 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 332 insertions(+) create mode 100644 benchmarks/mnist_mlp.py diff --git a/benchmarks/README.md b/benchmarks/README.md index 5a0188c..f1c6283 100644 --- a/benchmarks/README.md +++ b/benchmarks/README.md @@ -62,3 +62,24 @@ needs the Substrait round-trip, and Substrait has no recursion — so a `grad` marker can't live inside a recursive CTE. Differentiating once to plain SQL sidesteps that.) +## `mnist_mlp.py` — train an MNIST MLP classifier in SQL + +A one-hidden-layer neural network (196 -> 32 tanh -> 10 softmax, on 2x2-pooled +14x14 MNIST) trained by gradient descent where **every gradient is computed in +SQL**; the optimisation loop is plain Python. It is reverse-mode autodiff +expressed as relational algebra: + +- **matmul = join + `GROUP BY SUM`** — a layer's pre-activation is + `SUM(input * weight)` grouped by (sample, unit). +- **local derivatives = `grad()`** — the hidden activation's Jacobian is + `grad(tanh(z), z)`, the autograd feature doing the calculus per (sample, unit). +- **cotangent propagation = join**, **parameter gradients = join + `GROUP BY + AVG`**. + +The MNIST images are registered as xarray (the library's core); the model +weights and per-step intermediates are DataFusion in-memory tables (a matmul is +a join over them). The only hand-written gradient is softmax + cross-entropy's +`delta = softmax - onehot` (softmax couples classes through a per-sample +normaliser, an aggregate `grad` does not cross). Reaches ~83% test accuracy in +~20s. Downloads MNIST on first run. + diff --git a/benchmarks/mnist_mlp.py b/benchmarks/mnist_mlp.py new file mode 100644 index 0000000..516ee83 --- /dev/null +++ b/benchmarks/mnist_mlp.py @@ -0,0 +1,311 @@ +# /// script +# requires-python = ">=3.10" +# dependencies = [ +# "xarray_sql", +# "xarray", +# "numpy", +# "pyarrow", +# ] +# +# [tool.uv.sources] +# xarray_sql = { path = "..", editable = true } +# /// +"""Train an MNIST MLP classifier in SQL. + +A one-hidden-layer neural network (196->32 tanh->10 softmax, on 2x2-pooled +14x14 = 196-pixel images) trained by gradient descent where **every gradient is +computed in SQL**. The MNIST images are registered as xarray (the library's +core); the model weights and per-step intermediates live in DataFusion +in-memory tables. The optimisation loop is plain Python; all the math is +relational. + +The design is reverse-mode autodiff expressed in relational algebra: + +* **matmul = join + GROUP BY SUM.** A layer's pre-activation is + ``SUM(input * weight)`` grouped by (sample, unit), joining the data table to a + weight table on the shared index. +* **local derivatives = grad().** The hidden activation's Jacobian is + ``grad(tanh(z), z)`` — the engine differentiates the nonlinearity for us, + evaluated per (sample, unit). This is where the autograd feature does its + work; the rest is ordinary SQL. +* **cotangent propagation = join.** The output error is pushed back through the + second weight matrix by another join + SUM, then multiplied by the local + ``grad`` factor to get the hidden-layer error. +* **parameter gradients = join + GROUP BY AVG.** ``dW = AVG(input * delta)`` + grouped by the weight's indices. + +The only hand-written gradient is softmax + cross-entropy's ``delta = softmax - +onehot`` (softmax couples classes through a per-sample normaliser, an aggregate +``grad`` does not cross — staying faithful to SQL). Everything else is grad and +joins. + +Run standalone (builds the local extension on first use): + + uv run benchmarks/mnist_mlp.py +""" + +from __future__ import annotations + +import gzip +import struct +import tempfile +import time +import urllib.request +from pathlib import Path + +import numpy as np +import pyarrow as pa +import xarray as xr + +import xarray_sql as xql + +MIRROR = "https://storage.googleapis.com/cvdf-datasets/mnist" +CACHE = Path(tempfile.gettempdir()) / "mnist-xql" + +# Network dimensions: 14x14 pooled pixels -> 32 hidden (tanh) -> 10 classes. +N_TRAIN, N_TEST, N_PIX, N_HID, N_CLS = 1000, 500, 196, 32, 10 + + +def _download(url: str, dest: Path, tries: int = 5) -> None: + """Fetch a URL to dest, reading the whole body (retries on truncation).""" + last = None + for attempt in range(tries): + try: + with urllib.request.urlopen(url, timeout=120) as resp: + data = resp.read() + if len(data) < 1024: + raise OSError(f"suspiciously small download: {len(data)} bytes") + dest.write_bytes(data) + return + except Exception as exc: # noqa: BLE001 - retry any transient failure + last = exc + raise OSError(f"failed to download {url}: {last}") + + +def _read_idx(path: Path) -> np.ndarray: + with gzip.open(path, "rb") as f: + (magic,) = struct.unpack(">I", f.read(4)) + if magic == 2051: # images + n, r, c = struct.unpack(">III", f.read(12)) + return np.frombuffer(f.read(), np.uint8).reshape(n, r, c) + (n,) = struct.unpack(">I", f.read(4)) # labels + return np.frombuffer(f.read(), np.uint8) + + +def load_mnist(): + """Download (and cache) MNIST, 2x2 mean-pool to 14x14, subsample.""" + CACHE.mkdir(exist_ok=True) + files = { + "images": "train-images-idx3-ubyte.gz", + "labels": "train-labels-idx1-ubyte.gz", + } + paths = {} + for key, name in files.items(): + dest = CACHE / name + if not dest.exists(): + _download(f"{MIRROR}/{name}", dest) + paths[key] = dest + + imgs = _read_idx(paths["images"]).astype(np.float32) / 255.0 + labs = _read_idx(paths["labels"]).astype(np.int64) + pooled = imgs.reshape(-1, 14, 2, 14, 2).mean(axis=(2, 4)).reshape(-1, N_PIX) + + rng = np.random.default_rng(0) + idx = rng.permutation(len(pooled)) + tr, te = idx[:N_TRAIN], idx[N_TRAIN : N_TRAIN + N_TEST] + return pooled[tr], labs[tr], pooled[te], labs[te] + + +class SqlTables: + """Model parameters and intermediates as DataFusion in-memory tables. + + The MNIST data stays registered as xarray (the library's core); the model + weights and the per-step intermediate results (hidden activations, errors) + are plain in-memory tables, rebuilt from Arrow each step. Matrices are stored + in long form — a weight ``W[i, j]`` is a row ``(i, j, w)`` — so a matmul is a + join + ``GROUP BY``. + """ + + def __init__(self, ctx: xql.XarrayContext): + self.ctx = ctx + + def _replace(self, name: str, batches: list[pa.RecordBatch]) -> None: + if self.ctx.table_exist(name): + self.ctx.deregister_table(name) + self.ctx.register_record_batches(name, [batches]) + + def matrix( + self, name: str, var: str, arr: np.ndarray, di: str, dj: str + ) -> None: + """Register a 2-D array as a long ``(di, dj, var)`` in-memory table.""" + ni, nj = arr.shape + ii, jj = np.meshgrid(np.arange(ni), np.arange(nj), indexing="ij") + batch = pa.RecordBatch.from_pydict( + {di: ii.ravel(), dj: jj.ravel(), var: arr.ravel()} + ) + self._replace(name, [batch]) + + def vector(self, name: str, var: str, arr: np.ndarray, d0: str) -> None: + """Register a 1-D array as a ``(d0, var)`` in-memory table.""" + batch = pa.RecordBatch.from_pydict( + {d0: np.arange(len(arr)), var: np.asarray(arr, dtype=np.float64)} + ) + self._replace(name, [batch]) + + def materialize(self, name: str, sql: str) -> None: + """Run a query and register its Arrow result as the next stage's table.""" + self._replace(name, self.ctx.sql(sql).collect()) + + +def main() -> None: + Xtr, ytr, Xte, yte = load_mnist() + print( + f"MNIST: train {Xtr.shape}, test {Xte.shape} ({N_PIX} pix, {N_HID} hidden)" + ) + + ctx = xql.XarrayContext() + # The data is registered as xarray (the library's core); model state below + # lives in DataFusion in-memory tables. + ctx.from_dataset( + "imgs", + xr.Dataset( + {"val": (("sample", "pix"), Xtr)}, + coords={"sample": np.arange(N_TRAIN), "pix": np.arange(N_PIX)}, + ), + chunks={"sample": N_TRAIN}, + ) + ctx.from_dataset( + "labels", + xr.Dataset( + {"label": (("sample",), ytr.astype(np.float64))}, + coords={"sample": np.arange(N_TRAIN)}, + ), + chunks={"sample": N_TRAIN}, + ) + t = SqlTables(ctx) + + rng = np.random.default_rng(1) + W1 = rng.standard_normal((N_PIX, N_HID)) * 0.1 + b1 = np.zeros(N_HID) + W2 = rng.standard_normal((N_HID, N_CLS)) * 0.1 + b2 = np.zeros(N_CLS) + + def dense_to(df, ni, nj, ci, cj): + g = np.zeros((ni, nj)) + g[df[ci].to_numpy(), df[cj].to_numpy()] = df["g"].to_numpy() + return g + + def step(lr: float) -> None: + nonlocal W1, b1, W2, b2 + t.matrix("w1", "w", W1, "pix", "hid") + t.vector("b1", "b", b1, "hid") + t.matrix("w2", "w", W2, "hid", "cls") + t.vector("b2", "b", b2, "cls") + + # Forward: hidden pre-activation z and activation a = tanh(z). + t.materialize( + "h", + """ + WITH z AS ( + SELECT i.sample, w.hid, SUM(i.val * w.w) + MAX(bb.b) AS z + FROM imgs i JOIN w1 w ON i.pix = w.pix + JOIN b1 bb ON w.hid = bb.hid + GROUP BY i.sample, w.hid) + SELECT sample, hid, z, tanh(z) AS a FROM z + """, + ) + # Output softmax, then output error delta2 = softmax - onehot(label). + t.materialize( + "delta2", + """ + WITH logit AS ( + SELECT h.sample, w.cls, SUM(h.a * w.w) + MAX(bb.b) AS z + FROM h JOIN w2 w ON h.hid = w.hid + JOIN b2 bb ON w.cls = bb.cls + GROUP BY h.sample, w.cls), + mx AS (SELECT sample, MAX(z) AS m FROM logit GROUP BY sample), + ex AS (SELECT l.sample, l.cls, exp(l.z - mx.m) AS e + FROM logit l JOIN mx ON l.sample = mx.sample), + zsum AS (SELECT sample, SUM(e) AS z FROM ex GROUP BY sample) + SELECT ex.sample, ex.cls, + ex.e / zsum.z + - CASE WHEN ex.cls = lb.label THEN 1.0 ELSE 0.0 END AS d + FROM ex JOIN zsum ON ex.sample = zsum.sample + JOIN labels lb ON ex.sample = lb.sample + """, + ) + # Backprop to the hidden layer: push delta2 back through W2 (join + SUM), + # then multiply by the LOCAL activation derivative grad(tanh(z), z). + t.materialize( + "delta1", + """ + WITH da AS ( + SELECT d.sample, w.hid, SUM(d.d * w.w) AS da + FROM delta2 d JOIN w2 w ON d.cls = w.cls + GROUP BY d.sample, w.hid) + SELECT da.sample, da.hid, da.da * grad(tanh(h.z), h.z) AS d + FROM da JOIN h ON da.sample = h.sample AND da.hid = h.hid + """, + ) + + # Parameter gradients: dW = AVG(input * delta) over the batch. + gW2 = dense_to( + ctx.sql( + f"SELECT h.hid, d.cls, SUM(h.a * d.d) / {N_TRAIN}.0 AS g " + "FROM h JOIN delta2 d ON h.sample = d.sample " + "GROUP BY h.hid, d.cls" + ).to_pandas(), + N_HID, + N_CLS, + "hid", + "cls", + ) + gW1 = dense_to( + ctx.sql( + f"SELECT i.pix, d.hid, SUM(i.val * d.d) / {N_TRAIN}.0 AS g " + "FROM imgs i JOIN delta1 d ON i.sample = d.sample " + "GROUP BY i.pix, d.hid" + ).to_pandas(), + N_PIX, + N_HID, + "pix", + "hid", + ) + gb2 = ctx.sql( + f"SELECT cls, SUM(d) / {N_TRAIN}.0 AS g FROM delta2 GROUP BY cls" + ).to_pandas() + gb1 = ctx.sql( + f"SELECT hid, SUM(d) / {N_TRAIN}.0 AS g FROM delta1 GROUP BY hid" + ).to_pandas() + vb2 = np.zeros(N_CLS) + vb2[gb2["cls"].to_numpy()] = gb2["g"].to_numpy() + vb1 = np.zeros(N_HID) + vb1[gb1["hid"].to_numpy()] = gb1["g"].to_numpy() + + W2 -= lr * gW2 + b2 -= lr * vb2 + W1 -= lr * gW1 + b1 -= lr * vb1 + + def accuracy(X, y) -> float: + a = np.tanh(X @ W1 + b1) + return float(((a @ W2 + b2).argmax(1) == y).mean()) + + print(f"init: test acc {accuracy(Xte, yte):.3f}") + t0 = time.time() + steps = 60 + for s in range(steps): + step(lr=0.5) + if s % 10 == 0 or s == steps - 1: + print( + f"step {s:2d}: train {accuracy(Xtr, ytr):.3f} " + f"test {accuracy(Xte, yte):.3f}" + ) + print( + f"\ntrained an MNIST MLP in SQL: test accuracy " + f"{accuracy(Xte, yte):.3f} in {time.time() - t0:.0f}s" + ) + + +if __name__ == "__main__": + main() From 06fabbc66473e33200ecde05dc8a45af094b2aa8 Mon Sep 17 00:00:00 2001 From: Alex Merose Date: Tue, 30 Jun 2026 17:09:39 +0300 Subject: [PATCH 15/17] demo: train the MNIST MLP as one append-only model table MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Rewrite mnist_mlp.py so the whole model and its entire training history live in a single append-only table model(step, layer, i, j, val): every parameter is a row tagged by generation, and a training step appends the next generation's rows rather than mutating anything. Each step is a single SQL statement (forward, grad(tanh(z),z) backprop, parameter update); evaluation is SQL too (a forward pass with ROW_NUMBER() for the argmax). Python no longer holds the weights or computes any gradients — it only sequences the steps. A 2-layer net can't be one recursive CTE (the recursive relation may be referenced only once, but W1/W2 are used several times per step) and unrolling the steps as non-recursive CTEs blows up exponentially (DataFusion inlines CTEs; no MATERIALIZED). Materialising between steps is therefore host-driven; the thin loop does exactly that. Reaches ~83% test accuracy over 60 steps. Co-Authored-By: Claude Opus 4.8 --- benchmarks/README.md | 40 ++-- benchmarks/mnist_mlp.py | 424 ++++++++++++++++++++++------------------ 2 files changed, 265 insertions(+), 199 deletions(-) diff --git a/benchmarks/README.md b/benchmarks/README.md index f1c6283..e89e8f5 100644 --- a/benchmarks/README.md +++ b/benchmarks/README.md @@ -65,21 +65,37 @@ sidesteps that.) ## `mnist_mlp.py` — train an MNIST MLP classifier in SQL A one-hidden-layer neural network (196 -> 32 tanh -> 10 softmax, on 2x2-pooled -14x14 MNIST) trained by gradient descent where **every gradient is computed in -SQL**; the optimisation loop is plain Python. It is reverse-mode autodiff -expressed as relational algebra: +14x14 MNIST) where **every gradient is computed in SQL** and the whole model — +with its entire training history — lives in a single table. + +The model is one append-only table `model(step, layer, i, j, val)`: every +parameter is a row, tagged by which generation (`step`) it belongs to. **A +training step never mutates anything; it appends the next generation's rows.** +`WHERE step = N` is the model at iteration N, and the full trajectory is the +table. Each step is a *single* SQL statement that reads the current generation +and writes the next — reverse-mode autodiff as relational algebra: - **matmul = join + `GROUP BY SUM`** — a layer's pre-activation is `SUM(input * weight)` grouped by (sample, unit). - **local derivatives = `grad()`** — the hidden activation's Jacobian is `grad(tanh(z), z)`, the autograd feature doing the calculus per (sample, unit). - **cotangent propagation = join**, **parameter gradients = join + `GROUP BY - AVG`**. - -The MNIST images are registered as xarray (the library's core); the model -weights and per-step intermediates are DataFusion in-memory tables (a matmul is -a join over them). The only hand-written gradient is softmax + cross-entropy's -`delta = softmax - onehot` (softmax couples classes through a per-sample -normaliser, an aggregate `grad` does not cross). Reaches ~83% test accuracy in -~20s. Downloads MNIST on first run. - + AVG`**, and the update `w - lr*g` is emitted as the next generation's rows. + +The images are registered as xarray (the library's core); evaluation is SQL too +(a forward pass with `ROW_NUMBER()` for the argmax). The only hand-written +gradient is softmax + cross-entropy's `delta = softmax - onehot` (softmax couples +classes through a per-sample normaliser, which an aggregate `grad` does not +cross). Reaches ~83% test accuracy over 60 steps (~140s on a laptop — the +parameter updates run in SQL and every generation is kept as rows, so it trades +speed for a fully relational, fully inspectable training history). Downloads +MNIST on first run. + +Why is the *outer* loop still Python rather than one recursive query (like +`grad_descent.py`)? A recursive CTE may reference the recursive relation only +once, but a 2-layer net uses the current weights several times per step (W1 and +W2 forward, W2 again in backprop), so it can't be a single recursive statement. +Training is also sequential and reuses each step's result, so steps must be +*materialised* between iterations — which is exactly what the thin loop does +(append a generation, then query it). All the maths stays in SQL; Python only +sequences the steps. diff --git a/benchmarks/mnist_mlp.py b/benchmarks/mnist_mlp.py index 516ee83..4e1d81b 100644 --- a/benchmarks/mnist_mlp.py +++ b/benchmarks/mnist_mlp.py @@ -12,32 +12,45 @@ # /// """Train an MNIST MLP classifier in SQL. -A one-hidden-layer neural network (196->32 tanh->10 softmax, on 2x2-pooled -14x14 = 196-pixel images) trained by gradient descent where **every gradient is -computed in SQL**. The MNIST images are registered as xarray (the library's -core); the model weights and per-step intermediates live in DataFusion -in-memory tables. The optimisation loop is plain Python; all the math is -relational. +A one-hidden-layer network (196->32 tanh->10 softmax, on 2x2-pooled 14x14 = +196-pixel images) trained by gradient descent where **every gradient is computed +in SQL** — and the whole model, with its entire training history, lives in a +single table. -The design is reverse-mode autodiff expressed in relational algebra: +The model is one append-only table ``model(step, layer, i, j, val)``: every +parameter ``W1[i, j]`` / ``b1[i]`` / ``W2`` / ``b2`` is a row, tagged by which +generation (``step``) it belongs to. **A training step never mutates anything — +it appends the next generation's rows.** The full optimisation trajectory is the +table; ``WHERE step = N`` is the model at iteration N. + +Each step is a *single* SQL statement (``STEP`` below) that reads the current +generation and writes the next. It is reverse-mode autodiff as relational +algebra: * **matmul = join + GROUP BY SUM.** A layer's pre-activation is - ``SUM(input * weight)`` grouped by (sample, unit), joining the data table to a - weight table on the shared index. -* **local derivatives = grad().** The hidden activation's Jacobian is - ``grad(tanh(z), z)`` — the engine differentiates the nonlinearity for us, - evaluated per (sample, unit). This is where the autograd feature does its - work; the rest is ordinary SQL. -* **cotangent propagation = join.** The output error is pushed back through the - second weight matrix by another join + SUM, then multiplied by the local - ``grad`` factor to get the hidden-layer error. -* **parameter gradients = join + GROUP BY AVG.** ``dW = AVG(input * delta)`` - grouped by the weight's indices. + ``SUM(input * weight)`` grouped by (sample, unit), joining the data to the + current weight rows. +* **local derivatives = grad().** The hidden Jacobian is ``grad(tanh(z), z)`` — + the autograd feature differentiates the nonlinearity, per (sample, unit). +* **cotangent propagation = join.** The output error is pushed back through W2 by + another join + SUM, then scaled by the local ``grad`` factor. +* **parameter gradients = join + GROUP BY AVG**, and the update is ``w - lr*g``, + emitted as the next generation's rows. The only hand-written gradient is softmax + cross-entropy's ``delta = softmax - -onehot`` (softmax couples classes through a per-sample normaliser, an aggregate -``grad`` does not cross — staying faithful to SQL). Everything else is grad and -joins. +onehot`` (softmax couples classes through a per-sample normaliser, which an +aggregate ``grad`` does not cross — staying faithful to SQL). Everything else is +grad and joins. Evaluation is SQL too: a forward pass with ``ROW_NUMBER()`` for +the argmax. + +Why is the *outer* loop still Python rather than one recursive query (like +``grad_descent.py``)? A recursive CTE may reference the recursive relation only +once, but a 2-layer net uses the current weights several times per step (W1 and +W2 in the forward pass, W2 again in backprop), so it cannot be a single recursive +statement. Training is also inherently sequential and reuses each step's result, +so the steps must be *materialised* between iterations — which is exactly what the +thin Python loop does (append a generation, then query it). All the maths stays +in SQL; Python only sequences the steps. Run standalone (builds the local extension on first use): @@ -64,12 +77,13 @@ # Network dimensions: 14x14 pooled pixels -> 32 hidden (tanh) -> 10 classes. N_TRAIN, N_TEST, N_PIX, N_HID, N_CLS = 1000, 500, 196, 32, 10 +LR, STEPS = 0.5, 60 def _download(url: str, dest: Path, tries: int = 5) -> None: """Fetch a URL to dest, reading the whole body (retries on truncation).""" last = None - for attempt in range(tries): + for _ in range(tries): try: with urllib.request.urlopen(url, timeout=120) as resp: data = resp.read() @@ -116,194 +130,230 @@ def load_mnist(): return pooled[tr], labs[tr], pooled[te], labs[te] -class SqlTables: - """Model parameters and intermediates as DataFusion in-memory tables. +# --- the model as rows -------------------------------------------------------- - The MNIST data stays registered as xarray (the library's core); the model - weights and the per-step intermediate results (hidden activations, errors) - are plain in-memory tables, rebuilt from Arrow each step. Matrices are stored - in long form — a weight ``W[i, j]`` is a row ``(i, j, w)`` — so a matmul is a - join + ``GROUP BY``. - """ +_MODEL_SCHEMA = pa.schema( + [ + ("step", pa.int64()), + ("layer", pa.utf8()), + ("i", pa.int64()), + ("j", pa.int64()), + ("val", pa.float64()), + ] +) - def __init__(self, ctx: xql.XarrayContext): - self.ctx = ctx - - def _replace(self, name: str, batches: list[pa.RecordBatch]) -> None: - if self.ctx.table_exist(name): - self.ctx.deregister_table(name) - self.ctx.register_record_batches(name, [batches]) - - def matrix( - self, name: str, var: str, arr: np.ndarray, di: str, dj: str - ) -> None: - """Register a 2-D array as a long ``(di, dj, var)`` in-memory table.""" - ni, nj = arr.shape - ii, jj = np.meshgrid(np.arange(ni), np.arange(nj), indexing="ij") - batch = pa.RecordBatch.from_pydict( - {di: ii.ravel(), dj: jj.ravel(), var: arr.ravel()} - ) - self._replace(name, [batch]) - def vector(self, name: str, var: str, arr: np.ndarray, d0: str) -> None: - """Register a 1-D array as a ``(d0, var)`` in-memory table.""" - batch = pa.RecordBatch.from_pydict( - {d0: np.arange(len(arr)), var: np.asarray(arr, dtype=np.float64)} - ) - self._replace(name, [batch]) +def _param_rows(step: int, layer: str, arr: np.ndarray) -> dict: + """One layer's parameters as ``(step, layer, i, j, val)`` columns. - def materialize(self, name: str, sql: str) -> None: - """Run a query and register its Arrow result as the next stage's table.""" - self._replace(name, self.ctx.sql(sql).collect()) + A matrix ``W[i, j]`` becomes rows ``(i, j, w)``; a bias vector ``b[i]`` + becomes ``(i, 0, b)``. + """ + if arr.ndim == 2: + ii, jj = np.meshgrid( + np.arange(arr.shape[0]), np.arange(arr.shape[1]), indexing="ij" + ) + ii, jj = ii.ravel(), jj.ravel() + else: + ii, jj = np.arange(arr.size), np.zeros(arr.size, np.int64) + n = arr.size + return { + "step": np.full(n, step, np.int64), + "layer": [layer] * n, + "i": ii.astype(np.int64), + "j": jj.astype(np.int64), + "val": arr.ravel().astype(np.float64), + } -def main() -> None: - Xtr, ytr, Xte, yte = load_mnist() - print( - f"MNIST: train {Xtr.shape}, test {Xte.shape} ({N_PIX} pix, {N_HID} hidden)" +def _generation_batch(step, w1, b1, w2, b2) -> pa.RecordBatch: + """All four layers of one generation as a single RecordBatch.""" + cols: dict[str, list] = {k: [] for k in ("step", "layer", "i", "j", "val")} + for layer, arr in (("w1", w1), ("b1", b1), ("w2", w2), ("b2", b2)): + for k, v in _param_rows(step, layer, arr).items(): + cols[k].extend(list(v)) + return pa.RecordBatch.from_arrays( + [ + pa.array(cols["step"], pa.int64()), + pa.array(cols["layer"], pa.utf8()), + pa.array(cols["i"], pa.int64()), + pa.array(cols["j"], pa.int64()), + pa.array(cols["val"], pa.float64()), + ], + schema=_MODEL_SCHEMA, ) - ctx = xql.XarrayContext() - # The data is registered as xarray (the library's core); model state below - # lives in DataFusion in-memory tables. + +# One training step, as one SQL statement: read the current generation of the +# model table, run the forward + backward pass over the data, and SELECT the next +# generation's parameter rows (which the loop appends to the model table). +STEP = f""" +WITH cur AS (SELECT max(step) AS s FROM model), + w1 AS (SELECT i AS pix, j AS hid, val AS w FROM model, cur + WHERE step = cur.s AND layer = 'w1'), + b1 AS (SELECT i AS hid, val AS b FROM model, cur + WHERE step = cur.s AND layer = 'b1'), + w2 AS (SELECT i AS hid, j AS cls, val AS w FROM model, cur + WHERE step = cur.s AND layer = 'w2'), + b2 AS (SELECT i AS cls, val AS b FROM model, cur + WHERE step = cur.s AND layer = 'b2'), + -- forward: hidden pre-activation z and activation a = tanh(z) + zt AS (SELECT i.sample, w.hid, SUM(i.val * w.w) + MAX(bb.b) AS z + FROM imgs i JOIN w1 w ON i.pix = w.pix JOIN b1 bb ON w.hid = bb.hid + GROUP BY i.sample, w.hid), + h AS (SELECT sample, hid, z, tanh(z) AS a FROM zt), + -- output logits, then a stable softmax + lg AS (SELECT h.sample, w.cls, SUM(h.a * w.w) + MAX(bb.b) AS z + FROM h JOIN w2 w ON h.hid = w.hid JOIN b2 bb ON w.cls = bb.cls + GROUP BY h.sample, w.cls), + mx AS (SELECT sample, MAX(z) AS m FROM lg GROUP BY sample), + ex AS (SELECT l.sample, l.cls, exp(l.z - mx.m) AS e + FROM lg l JOIN mx ON l.sample = mx.sample), + zs AS (SELECT sample, SUM(e) AS z FROM ex GROUP BY sample), + -- output error delta2 = softmax - onehot(label) + d2 AS (SELECT ex.sample, ex.cls, + ex.e / zs.z + - CASE WHEN ex.cls = lb.label THEN 1.0 ELSE 0.0 END AS d + FROM ex JOIN zs ON ex.sample = zs.sample + JOIN labels lb ON lb.sample = ex.sample), + -- backprop to hidden: push delta2 through W2, scale by grad(tanh(z), z) + da AS (SELECT d.sample, w.hid, SUM(d.d * w.w) AS da + FROM d2 d JOIN w2 w ON d.cls = w.cls GROUP BY d.sample, w.hid), + d1 AS (SELECT da.sample, da.hid, da.da * grad(tanh(h.z), h.z) AS d + FROM da JOIN h ON da.sample = h.sample AND da.hid = h.hid), + -- parameter gradients: dW = AVG(input * delta) over the batch + gw1 AS (SELECT i.pix, d.hid, AVG(i.val * d.d) AS g + FROM imgs i JOIN d1 d ON i.sample = d.sample GROUP BY i.pix, d.hid), + gb1 AS (SELECT hid, AVG(d) AS g FROM d1 GROUP BY hid), + gw2 AS (SELECT h.hid, d.cls, AVG(h.a * d.d) AS g + FROM h JOIN d2 d ON h.sample = d.sample GROUP BY h.hid, d.cls), + gb2 AS (SELECT cls, AVG(d) AS g FROM d2 GROUP BY cls) +-- the next generation: w - lr*grad, tagged step+1, as model rows +SELECT (SELECT s FROM cur) + 1 AS step, 'w1' AS layer, + w.pix AS i, w.hid AS j, w.w - {LR} * g.g AS val +FROM w1 w JOIN gw1 g ON w.pix = g.pix AND w.hid = g.hid +UNION ALL +SELECT (SELECT s FROM cur) + 1, 'b1', b.hid, CAST(0 AS BIGINT), b.b - {LR} * g.g +FROM b1 b JOIN gb1 g ON b.hid = g.hid +UNION ALL +SELECT (SELECT s FROM cur) + 1, 'w2', w.hid, w.cls, w.w - {LR} * g.g +FROM w2 w JOIN gw2 g ON w.hid = g.hid AND w.cls = g.cls +UNION ALL +SELECT (SELECT s FROM cur) + 1, 'b2', b.cls, CAST(0 AS BIGINT), b.b - {LR} * g.g +FROM b2 b JOIN gb2 g ON b.cls = g.cls +""" + + +def eval_sql(imgs_table: str, labels_table: str) -> str: + """Accuracy of the latest model on a dataset — a forward pass in SQL. + + ``ROW_NUMBER()`` picks each sample's argmax class; it is compared to the + label. No softmax needed at inference: the argmax of the logits is the + prediction. + """ + return f""" + WITH cur AS (SELECT max(step) AS s FROM model), + w1 AS (SELECT i AS pix, j AS hid, val AS w FROM model, cur + WHERE step = cur.s AND layer = 'w1'), + b1 AS (SELECT i AS hid, val AS b FROM model, cur + WHERE step = cur.s AND layer = 'b1'), + w2 AS (SELECT i AS hid, j AS cls, val AS w FROM model, cur + WHERE step = cur.s AND layer = 'w2'), + b2 AS (SELECT i AS cls, val AS b FROM model, cur + WHERE step = cur.s AND layer = 'b2'), + h AS (SELECT i.sample, w.hid, + tanh(SUM(i.val * w.w) + MAX(bb.b)) AS a + FROM {imgs_table} i JOIN w1 w ON i.pix = w.pix + JOIN b1 bb ON w.hid = bb.hid + GROUP BY i.sample, w.hid), + lg AS (SELECT h.sample, w.cls, SUM(h.a * w.w) + MAX(bb.b) AS z + FROM h JOIN w2 w ON h.hid = w.hid JOIN b2 bb ON w.cls = bb.cls + GROUP BY h.sample, w.cls), + pred AS (SELECT sample, cls, + ROW_NUMBER() OVER (PARTITION BY sample ORDER BY z DESC) AS rk + FROM lg) + SELECT AVG(CASE WHEN p.cls = l.label THEN 1.0 ELSE 0.0 END) AS acc + FROM pred p JOIN {labels_table} l ON p.sample = l.sample + WHERE p.rk = 1 + """ + + +def _register_images(ctx, name, X): ctx.from_dataset( - "imgs", + name, xr.Dataset( - {"val": (("sample", "pix"), Xtr)}, - coords={"sample": np.arange(N_TRAIN), "pix": np.arange(N_PIX)}, + {"val": (("sample", "pix"), X)}, + coords={ + "sample": np.arange(X.shape[0]), + "pix": np.arange(N_PIX), + }, ), - chunks={"sample": N_TRAIN}, + chunks={"sample": X.shape[0]}, ) + + +def _register_labels(ctx, name, y): ctx.from_dataset( - "labels", + name, xr.Dataset( - {"label": (("sample",), ytr.astype(np.float64))}, - coords={"sample": np.arange(N_TRAIN)}, + {"label": (("sample",), y.astype(np.float64))}, + coords={"sample": np.arange(len(y))}, ), - chunks={"sample": N_TRAIN}, + chunks={"sample": len(y)}, + ) + + +def main() -> None: + Xtr, ytr, Xte, yte = load_mnist() + print( + f"MNIST: train {Xtr.shape}, test {Xte.shape} " + f"({N_PIX} pix, {N_HID} hidden, {N_CLS} classes)" ) - t = SqlTables(ctx) + ctx = xql.XarrayContext() + # The data is registered as xarray (the library's core); the model below is + # the one append-only table that holds every layer and every generation. + _register_images(ctx, "imgs", Xtr) + _register_labels(ctx, "labels", ytr) + _register_images(ctx, "imgs_te", Xte) + _register_labels(ctx, "labels_te", yte) + + # Generation 0: small random weights, zero biases. rng = np.random.default_rng(1) - W1 = rng.standard_normal((N_PIX, N_HID)) * 0.1 - b1 = np.zeros(N_HID) - W2 = rng.standard_normal((N_HID, N_CLS)) * 0.1 - b2 = np.zeros(N_CLS) - - def dense_to(df, ni, nj, ci, cj): - g = np.zeros((ni, nj)) - g[df[ci].to_numpy(), df[cj].to_numpy()] = df["g"].to_numpy() - return g - - def step(lr: float) -> None: - nonlocal W1, b1, W2, b2 - t.matrix("w1", "w", W1, "pix", "hid") - t.vector("b1", "b", b1, "hid") - t.matrix("w2", "w", W2, "hid", "cls") - t.vector("b2", "b", b2, "cls") - - # Forward: hidden pre-activation z and activation a = tanh(z). - t.materialize( - "h", - """ - WITH z AS ( - SELECT i.sample, w.hid, SUM(i.val * w.w) + MAX(bb.b) AS z - FROM imgs i JOIN w1 w ON i.pix = w.pix - JOIN b1 bb ON w.hid = bb.hid - GROUP BY i.sample, w.hid) - SELECT sample, hid, z, tanh(z) AS a FROM z - """, - ) - # Output softmax, then output error delta2 = softmax - onehot(label). - t.materialize( - "delta2", - """ - WITH logit AS ( - SELECT h.sample, w.cls, SUM(h.a * w.w) + MAX(bb.b) AS z - FROM h JOIN w2 w ON h.hid = w.hid - JOIN b2 bb ON w.cls = bb.cls - GROUP BY h.sample, w.cls), - mx AS (SELECT sample, MAX(z) AS m FROM logit GROUP BY sample), - ex AS (SELECT l.sample, l.cls, exp(l.z - mx.m) AS e - FROM logit l JOIN mx ON l.sample = mx.sample), - zsum AS (SELECT sample, SUM(e) AS z FROM ex GROUP BY sample) - SELECT ex.sample, ex.cls, - ex.e / zsum.z - - CASE WHEN ex.cls = lb.label THEN 1.0 ELSE 0.0 END AS d - FROM ex JOIN zsum ON ex.sample = zsum.sample - JOIN labels lb ON ex.sample = lb.sample - """, - ) - # Backprop to the hidden layer: push delta2 back through W2 (join + SUM), - # then multiply by the LOCAL activation derivative grad(tanh(z), z). - t.materialize( - "delta1", - """ - WITH da AS ( - SELECT d.sample, w.hid, SUM(d.d * w.w) AS da - FROM delta2 d JOIN w2 w ON d.cls = w.cls - GROUP BY d.sample, w.hid) - SELECT da.sample, da.hid, da.da * grad(tanh(h.z), h.z) AS d - FROM da JOIN h ON da.sample = h.sample AND da.hid = h.hid - """, - ) + gen0 = _generation_batch( + 0, + rng.standard_normal((N_PIX, N_HID)) * 0.1, + np.zeros(N_HID), + rng.standard_normal((N_HID, N_CLS)) * 0.1, + np.zeros(N_CLS), + ) + generations = [gen0] + ctx.register_record_batches("model", [generations]) - # Parameter gradients: dW = AVG(input * delta) over the batch. - gW2 = dense_to( - ctx.sql( - f"SELECT h.hid, d.cls, SUM(h.a * d.d) / {N_TRAIN}.0 AS g " - "FROM h JOIN delta2 d ON h.sample = d.sample " - "GROUP BY h.hid, d.cls" - ).to_pandas(), - N_HID, - N_CLS, - "hid", - "cls", - ) - gW1 = dense_to( - ctx.sql( - f"SELECT i.pix, d.hid, SUM(i.val * d.d) / {N_TRAIN}.0 AS g " - "FROM imgs i JOIN delta1 d ON i.sample = d.sample " - "GROUP BY i.pix, d.hid" - ).to_pandas(), - N_PIX, - N_HID, - "pix", - "hid", + def test_acc() -> float: + return float( + ctx.sql(eval_sql("imgs_te", "labels_te")).to_pandas()["acc"][0] ) - gb2 = ctx.sql( - f"SELECT cls, SUM(d) / {N_TRAIN}.0 AS g FROM delta2 GROUP BY cls" - ).to_pandas() - gb1 = ctx.sql( - f"SELECT hid, SUM(d) / {N_TRAIN}.0 AS g FROM delta1 GROUP BY hid" - ).to_pandas() - vb2 = np.zeros(N_CLS) - vb2[gb2["cls"].to_numpy()] = gb2["g"].to_numpy() - vb1 = np.zeros(N_HID) - vb1[gb1["hid"].to_numpy()] = gb1["g"].to_numpy() - - W2 -= lr * gW2 - b2 -= lr * vb2 - W1 -= lr * gW1 - b1 -= lr * vb1 - - def accuracy(X, y) -> float: - a = np.tanh(X @ W1 + b1) - return float(((a @ W2 + b2).argmax(1) == y).mean()) - - print(f"init: test acc {accuracy(Xte, yte):.3f}") + + print(f"init: test acc {test_acc():.3f}") t0 = time.time() - steps = 60 - for s in range(steps): - step(lr=0.5) - if s % 10 == 0 or s == steps - 1: - print( - f"step {s:2d}: train {accuracy(Xtr, ytr):.3f} " - f"test {accuracy(Xte, yte):.3f}" + for s in range(STEPS): + # One SQL statement computes the next generation; appending its rows to + # the model table *is* the parameter update. + generations.extend(ctx.sql(STEP).collect()) + ctx.deregister_table("model") + ctx.register_record_batches("model", [generations]) + if s % 10 == 0 or s == STEPS - 1: + tr = float( + ctx.sql(eval_sql("imgs", "labels")).to_pandas()["acc"][0] ) + print(f"step {s:2d}: train {tr:.3f} test {test_acc():.3f}") + + n_rows = ctx.sql("SELECT count(*) AS n FROM model").to_pandas()["n"][0] print( - f"\ntrained an MNIST MLP in SQL: test accuracy " - f"{accuracy(Xte, yte):.3f} in {time.time() - t0:.0f}s" + f"\ntrained an MNIST MLP in SQL: test accuracy {test_acc():.3f} " + f"in {time.time() - t0:.0f}s.\nThe model and its entire training " + f"history are one table of {n_rows} rows ({STEPS + 1} generations)." ) From c92447ae8ee1d95ee2be4fca0fb74cf056ff9f3e Mon Sep 17 00:00:00 2001 From: Alex Merose Date: Tue, 30 Jun 2026 17:45:00 +0300 Subject: [PATCH 16/17] demo: data-driven deep MLP with the model and metrics as relations MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Make the architecture itself data. The whole model is one xr.Dataset: each layer's weight is a data_var w{L} over its boundary dims (u{L}, u{L+1}), sharing the dims that connect adjacent layers (the join keys). The dim sizes are the layer widths and the number of weights is the depth, so differing neuron counts are just differing dim sizes — no padding, because the relational long form is naturally ragged. from_dataset splits the one Dataset into a table per weight; changing WIDTHS trains a different network with the same code. One generic contract()-based loop trains a net of any depth: forward contracts each layer, backward is the same contraction transposed (VJP of a contraction is a contraction) with grad(tanh(z), z) for the local derivative. Validated exact against numpy at depth 3. Training metrics are a relation too: each logged step appends a (step, loss, train_acc, test_acc) row to a metrics table rather than a Python list. The trained model, predictions, and metrics all come back out as xarray via to_dataset. ~83% test accuracy in ~13s. Co-Authored-By: Claude Opus 4.8 --- benchmarks/README.md | 91 +++--- benchmarks/mnist_mlp.py | 607 +++++++++++++++++++++++----------------- 2 files changed, 398 insertions(+), 300 deletions(-) diff --git a/benchmarks/README.md b/benchmarks/README.md index e89e8f5..10b4fea 100644 --- a/benchmarks/README.md +++ b/benchmarks/README.md @@ -62,40 +62,57 @@ needs the Substrait round-trip, and Substrait has no recursion — so a `grad` marker can't live inside a recursive CTE. Differentiating once to plain SQL sidesteps that.) -## `mnist_mlp.py` — train an MNIST MLP classifier in SQL - -A one-hidden-layer neural network (196 -> 32 tanh -> 10 softmax, on 2x2-pooled -14x14 MNIST) where **every gradient is computed in SQL** and the whole model — -with its entire training history — lives in a single table. - -The model is one append-only table `model(step, layer, i, j, val)`: every -parameter is a row, tagged by which generation (`step`) it belongs to. **A -training step never mutates anything; it appends the next generation's rows.** -`WHERE step = N` is the model at iteration N, and the full trajectory is the -table. Each step is a *single* SQL statement that reads the current generation -and writes the next — reverse-mode autodiff as relational algebra: - -- **matmul = join + `GROUP BY SUM`** — a layer's pre-activation is - `SUM(input * weight)` grouped by (sample, unit). -- **local derivatives = `grad()`** — the hidden activation's Jacobian is - `grad(tanh(z), z)`, the autograd feature doing the calculus per (sample, unit). -- **cotangent propagation = join**, **parameter gradients = join + `GROUP BY - AVG`**, and the update `w - lr*g` is emitted as the next generation's rows. - -The images are registered as xarray (the library's core); evaluation is SQL too -(a forward pass with `ROW_NUMBER()` for the argmax). The only hand-written -gradient is softmax + cross-entropy's `delta = softmax - onehot` (softmax couples -classes through a per-sample normaliser, which an aggregate `grad` does not -cross). Reaches ~83% test accuracy over 60 steps (~140s on a laptop — the -parameter updates run in SQL and every generation is kept as rows, so it trades -speed for a fully relational, fully inspectable training history). Downloads -MNIST on first run. - -Why is the *outer* loop still Python rather than one recursive query (like -`grad_descent.py`)? A recursive CTE may reference the recursive relation only -once, but a 2-layer net uses the current weights several times per step (W1 and -W2 forward, W2 again in backprop), so it can't be a single recursive statement. -Training is also sequential and reuses each step's result, so steps must be -*materialised* between iterations — which is exactly what the thin loop does -(append a generation, then query it). All the maths stays in SQL; Python only -sequences the steps. +## `mnist_mlp.py` — an MNIST MLP as relational tensor algebra + +An MLP (196 -> 32 tanh -> 10 softmax on 2x2-pooled 14x14 MNIST) built on one +idea: **a neural net is a chain of tensor contractions (einsums), and an einsum +over coordinate-indexed arrays *is* relational algebra.** + +``` +C[i,k] = sum_j A[i,j] * B[j,k] <=> JOIN A, B ON A.j = B.j + GROUP BY i, k -> SUM(A.val * B.val) +``` + +Contracting a shared index is a join on it followed by a grouped `SUM` over the +indices that survive. In xarray-sql an array indexed by named dims is a table +keyed by those dims, so **the dimension names are the join keys**. + +**The architecture is data.** The whole model is *one* `xr.Dataset`: each layer's +weight is a data variable `w{L}` over dims `(u{L}, u{L+1})`, the widths it +connects, sharing the boundary dims (`u1` is layer 0's output and layer 1's +input, so it is the join key between them). The dim sizes *are* the layer widths, +and the number of weights is the depth — differing neuron counts per layer are +just differing dim sizes, no padding, because the relational (long) form is +naturally ragged. `from_dataset` splits that one Dataset into a table per weight +automatically. Change `WIDTHS` (e.g. `196, 64, 32, 10`) and the same code trains +the deeper net. + +A small `contract()` helper turns an einsum spec into one query, and a single +generic loop trains a net of any shape: + +- **forward** contracts the activation with each layer's weight, `+ bias`, + `tanh` (softmax on the last layer). +- **backward is the *same* operator with indices transposed** — the VJP of a + contraction is a contraction — and `grad(tanh(z), z)` supplies the only + genuinely-calculus part. Linear algebra is joins; the derivatives of the + nonlinearities are `grad`. + +Everything stays relational: every stage is an inspectable table (`a1`, `delta2`, +`gw0`, …); the only hand-written gradient is softmax + cross-entropy's `delta = +softmax - onehot`. Even the training metrics are a table — each logged step +appends a `(step, loss, train_acc, test_acc)` row to a `metrics` relation rather +than a Python list (NN training produces a lot of such data; it belongs in +rows). Evaluation is SQL too (a forward pass + `ROW_NUMBER()` argmax), and the +trained model, predictions, and metrics all come **back out as xarray** via +`to_dataset`. Reaches ~83% test accuracy over 60 steps. Downloads MNIST on first +run. + +This is not a numpy replacement — relational matmul carries join overhead a BLAS +inner product doesn't. What it buys is a fully declarative, inspectable pipeline +whose data side is chunked xarray (parallel over the batch, larger-than-memory). +The *outer* training loop stays in Python because the steps must be materialised +between iterations: a multi-layer net can't be one recursive CTE (the recursive +relation may be referenced only once, but the weights are used several times per +step), and unrolling the steps as non-recursive CTEs blows up exponentially +(DataFusion inlines CTEs). The thin loop does exactly that materialisation; all +the maths stays in SQL. diff --git a/benchmarks/mnist_mlp.py b/benchmarks/mnist_mlp.py index 4e1d81b..d7d97aa 100644 --- a/benchmarks/mnist_mlp.py +++ b/benchmarks/mnist_mlp.py @@ -4,53 +4,48 @@ # "xarray_sql", # "xarray", # "numpy", -# "pyarrow", # ] # # [tool.uv.sources] # xarray_sql = { path = "..", editable = true } # /// -"""Train an MNIST MLP classifier in SQL. - -A one-hidden-layer network (196->32 tanh->10 softmax, on 2x2-pooled 14x14 = -196-pixel images) trained by gradient descent where **every gradient is computed -in SQL** — and the whole model, with its entire training history, lives in a -single table. - -The model is one append-only table ``model(step, layer, i, j, val)``: every -parameter ``W1[i, j]`` / ``b1[i]`` / ``W2`` / ``b2`` is a row, tagged by which -generation (``step``) it belongs to. **A training step never mutates anything — -it appends the next generation's rows.** The full optimisation trajectory is the -table; ``WHERE step = N`` is the model at iteration N. - -Each step is a *single* SQL statement (``STEP`` below) that reads the current -generation and writes the next. It is reverse-mode autodiff as relational -algebra: - -* **matmul = join + GROUP BY SUM.** A layer's pre-activation is - ``SUM(input * weight)`` grouped by (sample, unit), joining the data to the - current weight rows. -* **local derivatives = grad().** The hidden Jacobian is ``grad(tanh(z), z)`` — - the autograd feature differentiates the nonlinearity, per (sample, unit). -* **cotangent propagation = join.** The output error is pushed back through W2 by - another join + SUM, then scaled by the local ``grad`` factor. -* **parameter gradients = join + GROUP BY AVG**, and the update is ``w - lr*g``, - emitted as the next generation's rows. - -The only hand-written gradient is softmax + cross-entropy's ``delta = softmax - -onehot`` (softmax couples classes through a per-sample normaliser, which an -aggregate ``grad`` does not cross — staying faithful to SQL). Everything else is -grad and joins. Evaluation is SQL too: a forward pass with ``ROW_NUMBER()`` for -the argmax. - -Why is the *outer* loop still Python rather than one recursive query (like -``grad_descent.py``)? A recursive CTE may reference the recursive relation only -once, but a 2-layer net uses the current weights several times per step (W1 and -W2 in the forward pass, W2 again in backprop), so it cannot be a single recursive -statement. Training is also inherently sequential and reuses each step's result, -so the steps must be *materialised* between iterations — which is exactly what the -thin Python loop does (append a generation, then query it). All the maths stays -in SQL; Python only sequences the steps. +"""Train an MNIST MLP as relational tensor algebra — with the architecture as data. + +A neural network is a chain of **tensor contractions** (einsums), and an einsum +over coordinate-indexed arrays *is* relational algebra: + + C[i,k] = sum_j A[i,j] * B[j,k] <=> JOIN A, B ON A.j = B.j + GROUP BY i, k -> SUM(A.val * B.val) + +Contracting a shared index is a join on it followed by a grouped SUM over the +indices that survive. In xarray-sql an array indexed by named dims is a table +keyed by those dims, so **the dimension names are the join keys**. + +The whole model is **one ``xr.Dataset``**. Each layer's weight is a data variable +whose two dims are the widths it connects — ``w0(u0, u1)``, ``w1(u1, u2)``, … — +sharing the boundary dims (``u1`` is the output of layer 0 and the input of layer +1, so it is the join key between them). **The architecture is therefore data: the +Dataset's dim sizes are the layer widths, and the number of layers is how many +weights it holds.** Differing neuron counts per layer are just differing dim +sizes — no padding, because the relational (long) form is naturally ragged. +``from_dataset`` splits that one Dataset into a table per weight automatically. + +A single ``contract()`` turns an einsum spec into one query, and a single generic +loop trains a net of any depth/width: + +* **forward** — contract the activation with each layer's weight, add bias, tanh + (softmax on the last layer). +* **backward is the same operator transposed** — the VJP of a contraction is a + contraction — with ``grad(tanh(z), z)`` for the one local-derivative step. + Linear algebra is joins; the derivatives of the nonlinearities are ``grad``. + +Every stage is an inspectable relation; the trained model, predictions, and loss +curve come back out as ``xarray`` via ``to_dataset``. Change ``WIDTHS`` and the +same code trains a different network. + +This is not a numpy replacement — relational matmul carries join overhead a BLAS +inner product doesn't. What it buys is a declarative, inspectable pipeline whose +data side is chunked xarray (parallel over the batch, larger-than-memory). Run standalone (builds the local extension on first use): @@ -67,7 +62,6 @@ from pathlib import Path import numpy as np -import pyarrow as pa import xarray as xr import xarray_sql as xql @@ -75,13 +69,252 @@ MIRROR = "https://storage.googleapis.com/cvdf-datasets/mnist" CACHE = Path(tempfile.gettempdir()) / "mnist-xql" -# Network dimensions: 14x14 pooled pixels -> 32 hidden (tanh) -> 10 classes. -N_TRAIN, N_TEST, N_PIX, N_HID, N_CLS = 1000, 500, 196, 32, 10 -LR, STEPS = 0.5, 60 +# The architecture, as data: layer widths. 196 pooled pixels -> 32 tanh -> 10. +# Add an entry (e.g. 196, 64, 32, 10) and the same code trains the deeper net. +WIDTHS = [196, 32, 10] +DEPTH = len(WIDTHS) - 1 # number of weight layers +N_TRAIN, N_TEST = 1000, 500 +LR, STEPS, CHUNK = 0.5, 60, 250 + + +# --- the one idea: a tensor contraction is a relational query ----------------- + + +def contract(spec: str, left: str, right: str) -> str: + """An einsum over two coordinate-indexed tables, as one SQL query. + + ``contract("sample,u0 * u0,u1 -> sample,u1", "x", "w0")`` joins ``x`` and + ``w0`` on their shared dim ``u0``, groups by the output dims, and sums the + product of values — a matmul. Every table has its dims as columns plus a + ``val`` column. Indices in the inputs but not the output are contracted; the + same helper expresses the transposed contractions of backprop. + """ + spec = spec.replace(" ", "") + lhs, out = spec.split("->") + da, db = (operand.split(",") for operand in lhs.split("*")) + out_dims = out.split(",") + shared = [d for d in da if d in db] + join = ( + f"JOIN {right} r ON " + " AND ".join(f"l.{d} = r.{d}" for d in shared) + if shared + else f"CROSS JOIN {right} r" + ) + pick = ", ".join(f"{'l' if d in da else 'r'}.{d} AS {d}" for d in out_dims) + return ( + f"SELECT {pick}, SUM(l.val * r.val) AS val " + f"FROM {left} l {join} GROUP BY {', '.join(out_dims)}" + ) + + +def register_tensor( + ctx: xql.XarrayContext, + name: str, + arr: np.ndarray, + dims: tuple[str, ...], + var: str = "val", + chunk: int | None = None, +) -> None: + """Register a numpy array as a relation, the array-relational way: wrap it as + an ``xr.Dataset`` whose named dims become the table's key columns, then hand + it to ``from_dataset``. A tensor is an array at the edge and a relation + inside; ``from_dataset`` is the bridge, and the dims become the join keys.""" + arr = np.asarray(arr, dtype=np.float64) + ds = xr.Dataset( + {var: (dims, arr)}, + coords={d: np.arange(n) for d, n in zip(dims, arr.shape)}, + ) + ctx.from_dataset(name, ds, chunks={dims[0]: chunk or arr.shape[0]}) + + +class Tensors: + """A step rewrites a handful of relations; ``put`` materialises a query as a + named table (the stages of the forward/backward pass).""" + + def __init__(self, ctx: xql.XarrayContext): + self.ctx = ctx + + def put(self, name: str, sql: str) -> None: + batches = self.ctx.sql(sql).collect() + if self.ctx.table_exist(name): + self.ctx.deregister_table(name) + self.ctx.register_record_batches(name, [batches]) + + +# --- the model as one xarray Dataset ------------------------------------------ + + +def build_model(rng: np.random.Generator) -> xr.Dataset: + """The whole model as one Dataset: weight ``w{L}`` over dims ``(u{L}, u{L+1})`` + and bias ``b{L}`` over ``(u{L+1},)``. The shared boundary dims tie the layers + together; the dim sizes *are* the architecture.""" + data_vars: dict = {} + for layer in range(DEPTH): + n_in, n_out = WIDTHS[layer], WIDTHS[layer + 1] + data_vars[f"w{layer}"] = ( + (f"u{layer}", f"u{layer + 1}"), + rng.standard_normal((n_in, n_out)) * 0.1, + ) + data_vars[f"b{layer}"] = ((f"u{layer + 1}",), np.zeros(n_out)) + coords = {f"u{i}": np.arange(w) for i, w in enumerate(WIDTHS)} + return xr.Dataset(data_vars, coords=coords) + + +def seed_weights(t: Tensors) -> None: + """Unpack the one model Dataset (registered as the ``model`` schema) into + working weight/bias relations with a uniform ``val`` column.""" + for layer in range(DEPTH): + i, o = f"u{layer}", f"u{layer + 1}" + t.put( + f"w{layer}", f"SELECT {i}, {o}, w{layer} AS val FROM model.w{layer}" + ) + t.put(f"b{layer}", f"SELECT {o}, b{layer} AS val FROM model.b{layer}") + + +# --- the network, as contractions (generic over depth) ------------------------ + + +def forward(t: Tensors, inp: str = "x") -> None: + """Forward pass from ``inp``: a contraction + bias + tanh per layer, leaving + the pre-activations ``a{L}.z`` for backprop and the output ``logits``.""" + prev = inp + for layer in range(DEPTH): + i, o = f"u{layer}", f"u{layer + 1}" + zc = contract(f"sample,{i} * {i},{o} -> sample,{o}", prev, f"w{layer}") + if layer < DEPTH - 1: + t.put( + f"a{layer + 1}", + f"""WITH zc AS ({zc}) + SELECT zc.sample, zc.{o}, zc.val + b{layer}.val AS z, + tanh(zc.val + b{layer}.val) AS val + FROM zc JOIN b{layer} ON zc.{o} = b{layer}.{o}""", + ) + prev = f"a{layer + 1}" + else: + t.put( + "logits", + f"""WITH zc AS ({zc}) + SELECT zc.sample, zc.{o}, zc.val + b{layer}.val AS z + FROM zc JOIN b{layer} ON zc.{o} = b{layer}.{o}""", + ) + + +def softmax_delta_sql() -> str: + """Output error delta = softmax(logits) - onehot(label). The one hand-derived + rule: softmax couples classes through a per-sample normaliser an aggregate + grad() does not cross.""" + o = f"u{DEPTH}" + return f""" + WITH m AS (SELECT sample, MAX(z) AS m FROM logits GROUP BY sample), + e AS (SELECT logits.sample, logits.{o}, exp(logits.z - m.m) AS e + FROM logits JOIN m ON logits.sample = m.sample), + s AS (SELECT sample, SUM(e) AS s FROM e GROUP BY sample) + SELECT e.sample, e.{o}, + e.e / s.s - CASE WHEN e.{o} = y.label THEN 1.0 ELSE 0.0 END AS val + FROM e JOIN s ON e.sample = s.sample JOIN y ON y.sample = e.sample""" + + +def train_step(t: Tensors) -> None: + """Forward, backward (the same contraction transposed), SGD update.""" + forward(t) + t.put(f"delta{DEPTH}", softmax_delta_sql()) + # Backward: walk the layers in reverse, the gradients are contractions. + for layer in reversed(range(DEPTH)): + i, o = f"u{layer}", f"u{layer + 1}" + a_in = "x" if layer == 0 else f"a{layer}" + gw = contract( + f"sample,{i} * sample,{o} -> {i},{o}", a_in, f"delta{layer + 1}" + ) + t.put( + f"gw{layer}", f"SELECT {i}, {o}, val / {N_TRAIN} AS val FROM ({gw})" + ) + t.put( + f"gb{layer}", + f"SELECT {o}, AVG(val) AS val FROM delta{layer + 1} GROUP BY {o}", + ) + if layer > 0: # propagate the cotangent, scaled by the local derivative + dc = contract( + f"sample,{o} * {i},{o} -> sample,{i}", + f"delta{layer + 1}", + f"w{layer}", + ) + t.put( + f"delta{layer}", + f"""WITH dh AS ({dc}) + SELECT dh.sample, dh.{i}, dh.val * grad(tanh(a{layer}.z), a{layer}.z) AS val + FROM dh JOIN a{layer} ON dh.sample = a{layer}.sample AND dh.{i} = a{layer}.{i}""", + ) + # SGD: each weight relation becomes w - lr * grad. + for layer in range(DEPTH): + i, o = f"u{layer}", f"u{layer + 1}" + t.put( + f"w{layer}", + f"SELECT w{layer}.{i}, w{layer}.{o}, w{layer}.val - {LR} * gw{layer}.val AS val " + f"FROM w{layer} JOIN gw{layer} ON w{layer}.{i} = gw{layer}.{i} " + f"AND w{layer}.{o} = gw{layer}.{o}", + ) + t.put( + f"b{layer}", + f"SELECT b{layer}.{o}, b{layer}.val - {LR} * gb{layer}.val AS val " + f"FROM b{layer} JOIN gb{layer} ON b{layer}.{o} = gb{layer}.{o}", + ) + + +def accuracy(t: Tensors, inp: str, lab: str) -> float: + """A forward pass over ``inp`` + argmax, compared to ``lab`` — all in SQL.""" + forward(t, inp) + o = f"u{DEPTH}" + return float( + t.ctx.sql( + f"""WITH pred AS ( + SELECT sample, {o}, + ROW_NUMBER() OVER (PARTITION BY sample ORDER BY z DESC) AS rk + FROM logits) + SELECT AVG(CASE WHEN p.{o} = l.label THEN 1.0 ELSE 0.0 END) AS acc + FROM pred p JOIN {lab} l ON p.sample = l.sample WHERE p.rk = 1""" + ).to_pandas()["acc"][0] + ) + + +def record_metrics(t: Tensors, step: int) -> None: + """Append a (step, loss, train_acc, test_acc) row to the ``metrics`` table. + + NN training emits a lot of data — loss curves, per-step accuracies — and like + everything else here it lives as rows in a relation, grown each time, not a + Python list. Read it back at the end as a tidy ``(step,)`` xarray. + """ + o = f"u{DEPTH}" + train = accuracy(t, "x", "y") # leaves the training forward in `logits` + loss = float( + t.ctx.sql( + f"""WITH m AS (SELECT sample, MAX(z) AS m FROM logits GROUP BY sample), + e AS (SELECT logits.sample, logits.{o}, exp(logits.z - m.m) AS e + FROM logits JOIN m ON logits.sample = m.sample), + s AS (SELECT sample, SUM(e) AS s FROM e GROUP BY sample) + SELECT -AVG(ln(e.e / s.s)) AS loss + FROM e JOIN s ON e.sample = s.sample JOIN y ON y.sample = e.sample + WHERE e.{o} = y.label""" + ).to_pandas()["loss"][0] + ) + test = accuracy(t, "x_te", "y_te") + row = ( + f"SELECT CAST({step} AS BIGINT) AS step, CAST({loss} AS DOUBLE) AS loss, " + f"CAST({train} AS DOUBLE) AS train_acc, CAST({test} AS DOUBLE) AS test_acc" + ) + t.put( + "metrics", + f"SELECT * FROM metrics UNION ALL {row}" + if t.ctx.table_exist("metrics") + else row, + ) + print( + f"step {step:2d}: loss {loss:.3f} train {train:.3f} test {test:.3f}" + ) + + +# --- MNIST loading ------------------------------------------------------------ def _download(url: str, dest: Path, tries: int = 5) -> None: - """Fetch a URL to dest, reading the whole body (retries on truncation).""" last = None for _ in range(tries): try: @@ -102,12 +335,11 @@ def _read_idx(path: Path) -> np.ndarray: if magic == 2051: # images n, r, c = struct.unpack(">III", f.read(12)) return np.frombuffer(f.read(), np.uint8).reshape(n, r, c) - (n,) = struct.unpack(">I", f.read(4)) # labels + struct.unpack(">I", f.read(4)) # labels: skip the count return np.frombuffer(f.read(), np.uint8) def load_mnist(): - """Download (and cache) MNIST, 2x2 mean-pool to 14x14, subsample.""" CACHE.mkdir(exist_ok=True) files = { "images": "train-images-idx3-ubyte.gz", @@ -119,241 +351,90 @@ def load_mnist(): if not dest.exists(): _download(f"{MIRROR}/{name}", dest) paths[key] = dest - imgs = _read_idx(paths["images"]).astype(np.float32) / 255.0 labs = _read_idx(paths["labels"]).astype(np.int64) - pooled = imgs.reshape(-1, 14, 2, 14, 2).mean(axis=(2, 4)).reshape(-1, N_PIX) - + side = WIDTHS[0] # pooled pixels per image + pool = int(round((28 * 28 / side) ** 0.5)) # 2 for 196 pixels + k = 28 // pool + pooled = ( + imgs.reshape(-1, k, pool, k, pool).mean(axis=(2, 4)).reshape(-1, side) + ) rng = np.random.default_rng(0) idx = rng.permutation(len(pooled)) tr, te = idx[:N_TRAIN], idx[N_TRAIN : N_TRAIN + N_TEST] return pooled[tr], labs[tr], pooled[te], labs[te] -# --- the model as rows -------------------------------------------------------- - -_MODEL_SCHEMA = pa.schema( - [ - ("step", pa.int64()), - ("layer", pa.utf8()), - ("i", pa.int64()), - ("j", pa.int64()), - ("val", pa.float64()), - ] -) - - -def _param_rows(step: int, layer: str, arr: np.ndarray) -> dict: - """One layer's parameters as ``(step, layer, i, j, val)`` columns. - - A matrix ``W[i, j]`` becomes rows ``(i, j, w)``; a bias vector ``b[i]`` - becomes ``(i, 0, b)``. - """ - if arr.ndim == 2: - ii, jj = np.meshgrid( - np.arange(arr.shape[0]), np.arange(arr.shape[1]), indexing="ij" - ) - ii, jj = ii.ravel(), jj.ravel() - else: - ii, jj = np.arange(arr.size), np.zeros(arr.size, np.int64) - n = arr.size - return { - "step": np.full(n, step, np.int64), - "layer": [layer] * n, - "i": ii.astype(np.int64), - "j": jj.astype(np.int64), - "val": arr.ravel().astype(np.float64), - } - - -def _generation_batch(step, w1, b1, w2, b2) -> pa.RecordBatch: - """All four layers of one generation as a single RecordBatch.""" - cols: dict[str, list] = {k: [] for k in ("step", "layer", "i", "j", "val")} - for layer, arr in (("w1", w1), ("b1", b1), ("w2", w2), ("b2", b2)): - for k, v in _param_rows(step, layer, arr).items(): - cols[k].extend(list(v)) - return pa.RecordBatch.from_arrays( - [ - pa.array(cols["step"], pa.int64()), - pa.array(cols["layer"], pa.utf8()), - pa.array(cols["i"], pa.int64()), - pa.array(cols["j"], pa.int64()), - pa.array(cols["val"], pa.float64()), - ], - schema=_MODEL_SCHEMA, - ) - - -# One training step, as one SQL statement: read the current generation of the -# model table, run the forward + backward pass over the data, and SELECT the next -# generation's parameter rows (which the loop appends to the model table). -STEP = f""" -WITH cur AS (SELECT max(step) AS s FROM model), - w1 AS (SELECT i AS pix, j AS hid, val AS w FROM model, cur - WHERE step = cur.s AND layer = 'w1'), - b1 AS (SELECT i AS hid, val AS b FROM model, cur - WHERE step = cur.s AND layer = 'b1'), - w2 AS (SELECT i AS hid, j AS cls, val AS w FROM model, cur - WHERE step = cur.s AND layer = 'w2'), - b2 AS (SELECT i AS cls, val AS b FROM model, cur - WHERE step = cur.s AND layer = 'b2'), - -- forward: hidden pre-activation z and activation a = tanh(z) - zt AS (SELECT i.sample, w.hid, SUM(i.val * w.w) + MAX(bb.b) AS z - FROM imgs i JOIN w1 w ON i.pix = w.pix JOIN b1 bb ON w.hid = bb.hid - GROUP BY i.sample, w.hid), - h AS (SELECT sample, hid, z, tanh(z) AS a FROM zt), - -- output logits, then a stable softmax - lg AS (SELECT h.sample, w.cls, SUM(h.a * w.w) + MAX(bb.b) AS z - FROM h JOIN w2 w ON h.hid = w.hid JOIN b2 bb ON w.cls = bb.cls - GROUP BY h.sample, w.cls), - mx AS (SELECT sample, MAX(z) AS m FROM lg GROUP BY sample), - ex AS (SELECT l.sample, l.cls, exp(l.z - mx.m) AS e - FROM lg l JOIN mx ON l.sample = mx.sample), - zs AS (SELECT sample, SUM(e) AS z FROM ex GROUP BY sample), - -- output error delta2 = softmax - onehot(label) - d2 AS (SELECT ex.sample, ex.cls, - ex.e / zs.z - - CASE WHEN ex.cls = lb.label THEN 1.0 ELSE 0.0 END AS d - FROM ex JOIN zs ON ex.sample = zs.sample - JOIN labels lb ON lb.sample = ex.sample), - -- backprop to hidden: push delta2 through W2, scale by grad(tanh(z), z) - da AS (SELECT d.sample, w.hid, SUM(d.d * w.w) AS da - FROM d2 d JOIN w2 w ON d.cls = w.cls GROUP BY d.sample, w.hid), - d1 AS (SELECT da.sample, da.hid, da.da * grad(tanh(h.z), h.z) AS d - FROM da JOIN h ON da.sample = h.sample AND da.hid = h.hid), - -- parameter gradients: dW = AVG(input * delta) over the batch - gw1 AS (SELECT i.pix, d.hid, AVG(i.val * d.d) AS g - FROM imgs i JOIN d1 d ON i.sample = d.sample GROUP BY i.pix, d.hid), - gb1 AS (SELECT hid, AVG(d) AS g FROM d1 GROUP BY hid), - gw2 AS (SELECT h.hid, d.cls, AVG(h.a * d.d) AS g - FROM h JOIN d2 d ON h.sample = d.sample GROUP BY h.hid, d.cls), - gb2 AS (SELECT cls, AVG(d) AS g FROM d2 GROUP BY cls) --- the next generation: w - lr*grad, tagged step+1, as model rows -SELECT (SELECT s FROM cur) + 1 AS step, 'w1' AS layer, - w.pix AS i, w.hid AS j, w.w - {LR} * g.g AS val -FROM w1 w JOIN gw1 g ON w.pix = g.pix AND w.hid = g.hid -UNION ALL -SELECT (SELECT s FROM cur) + 1, 'b1', b.hid, CAST(0 AS BIGINT), b.b - {LR} * g.g -FROM b1 b JOIN gb1 g ON b.hid = g.hid -UNION ALL -SELECT (SELECT s FROM cur) + 1, 'w2', w.hid, w.cls, w.w - {LR} * g.g -FROM w2 w JOIN gw2 g ON w.hid = g.hid AND w.cls = g.cls -UNION ALL -SELECT (SELECT s FROM cur) + 1, 'b2', b.cls, CAST(0 AS BIGINT), b.b - {LR} * g.g -FROM b2 b JOIN gb2 g ON b.cls = g.cls -""" - - -def eval_sql(imgs_table: str, labels_table: str) -> str: - """Accuracy of the latest model on a dataset — a forward pass in SQL. - - ``ROW_NUMBER()`` picks each sample's argmax class; it is compared to the - label. No softmax needed at inference: the argmax of the logits is the - prediction. - """ - return f""" - WITH cur AS (SELECT max(step) AS s FROM model), - w1 AS (SELECT i AS pix, j AS hid, val AS w FROM model, cur - WHERE step = cur.s AND layer = 'w1'), - b1 AS (SELECT i AS hid, val AS b FROM model, cur - WHERE step = cur.s AND layer = 'b1'), - w2 AS (SELECT i AS hid, j AS cls, val AS w FROM model, cur - WHERE step = cur.s AND layer = 'w2'), - b2 AS (SELECT i AS cls, val AS b FROM model, cur - WHERE step = cur.s AND layer = 'b2'), - h AS (SELECT i.sample, w.hid, - tanh(SUM(i.val * w.w) + MAX(bb.b)) AS a - FROM {imgs_table} i JOIN w1 w ON i.pix = w.pix - JOIN b1 bb ON w.hid = bb.hid - GROUP BY i.sample, w.hid), - lg AS (SELECT h.sample, w.cls, SUM(h.a * w.w) + MAX(bb.b) AS z - FROM h JOIN w2 w ON h.hid = w.hid JOIN b2 bb ON w.cls = bb.cls - GROUP BY h.sample, w.cls), - pred AS (SELECT sample, cls, - ROW_NUMBER() OVER (PARTITION BY sample ORDER BY z DESC) AS rk - FROM lg) - SELECT AVG(CASE WHEN p.cls = l.label THEN 1.0 ELSE 0.0 END) AS acc - FROM pred p JOIN {labels_table} l ON p.sample = l.sample - WHERE p.rk = 1 - """ - - -def _register_images(ctx, name, X): - ctx.from_dataset( - name, - xr.Dataset( - {"val": (("sample", "pix"), X)}, - coords={ - "sample": np.arange(X.shape[0]), - "pix": np.arange(N_PIX), - }, - ), - chunks={"sample": X.shape[0]}, - ) - - -def _register_labels(ctx, name, y): - ctx.from_dataset( - name, - xr.Dataset( - {"label": (("sample",), y.astype(np.float64))}, - coords={"sample": np.arange(len(y))}, - ), - chunks={"sample": len(y)}, - ) +# --- driver ------------------------------------------------------------------- def main() -> None: Xtr, ytr, Xte, yte = load_mnist() - print( - f"MNIST: train {Xtr.shape}, test {Xte.shape} " - f"({N_PIX} pix, {N_HID} hidden, {N_CLS} classes)" - ) + print(f"MNIST: train {Xtr.shape}, test {Xte.shape} architecture {WIDTHS}") ctx = xql.XarrayContext() - # The data is registered as xarray (the library's core); the model below is - # the one append-only table that holds every layer and every generation. - _register_images(ctx, "imgs", Xtr) - _register_labels(ctx, "labels", ytr) - _register_images(ctx, "imgs_te", Xte) - _register_labels(ctx, "labels_te", yte) - - # Generation 0: small random weights, zero biases. + # The whole model is one Dataset; from_dataset splits it into a table per + # weight (the shared boundary dims become the join keys). rng = np.random.default_rng(1) - gen0 = _generation_batch( - 0, - rng.standard_normal((N_PIX, N_HID)) * 0.1, - np.zeros(N_HID), - rng.standard_normal((N_HID, N_CLS)) * 0.1, - np.zeros(N_CLS), + model = build_model(rng) + ctx.from_dataset( + "model", + model, + table_names={ + (f"u{layer}", f"u{layer + 1}"): f"w{layer}" + for layer in range(DEPTH) + } + | {(f"u{layer + 1}",): f"b{layer}" for layer in range(DEPTH)}, + chunks={f"u{i}": w for i, w in enumerate(WIDTHS)}, ) - generations = [gen0] - ctx.register_record_batches("model", [generations]) + t = Tensors(ctx) + seed_weights(t) - def test_acc() -> float: - return float( - ctx.sql(eval_sql("imgs_te", "labels_te")).to_pandas()["acc"][0] - ) + # Inputs and labels, registered once; the queries read x / x_te by name. + register_tensor(ctx, "x", Xtr, ("sample", "u0"), chunk=CHUNK) + register_tensor(ctx, "y", ytr, ("sample",), var="label") + register_tensor(ctx, "x_te", Xte, ("sample", "u0")) + register_tensor(ctx, "y_te", yte, ("sample",), var="label") + + print(f"init: test acc {accuracy(t, 'x_te', 'y_te'):.3f}") - print(f"init: test acc {test_acc():.3f}") t0 = time.time() - for s in range(STEPS): - # One SQL statement computes the next generation; appending its rows to - # the model table *is* the parameter update. - generations.extend(ctx.sql(STEP).collect()) - ctx.deregister_table("model") - ctx.register_record_batches("model", [generations]) - if s % 10 == 0 or s == STEPS - 1: - tr = float( - ctx.sql(eval_sql("imgs", "labels")).to_pandas()["acc"][0] - ) - print(f"step {s:2d}: train {tr:.3f} test {test_acc():.3f}") + for step in range(STEPS): + train_step(t) + if step % 10 == 0 or step == STEPS - 1: + record_metrics(t, step) + dt = time.time() - t0 + + # The trained model comes back out as one xarray Dataset. + parts = [] + for layer in range(DEPTH): + i, o = f"u{layer}", f"u{layer + 1}" + parts.append( + ctx.sql(f"SELECT {i}, {o}, val FROM w{layer}") + .to_dataset(dims=[i, o]) + .rename({"val": f"w{layer}"}) + ) + parts.append( + ctx.sql(f"SELECT {o}, val FROM b{layer}") + .to_dataset(dims=[o]) + .rename({"val": f"b{layer}"}) + ) + trained = xr.merge(parts) + # The loss curve and accuracies were recorded as rows; read them back as a + # tidy (step,) xarray of training metrics. + metrics = ctx.sql("SELECT * FROM metrics ORDER BY step").to_dataset( + dims=["step"] + ) - n_rows = ctx.sql("SELECT count(*) AS n FROM model").to_pandas()["n"][0] print( - f"\ntrained an MNIST MLP in SQL: test accuracy {test_acc():.3f} " - f"in {time.time() - t0:.0f}s.\nThe model and its entire training " - f"history are one table of {n_rows} rows ({STEPS + 1} generations)." + f"\ntrained a {WIDTHS} MLP as relational tensor algebra in {dt:.0f}s: " + f"test accuracy {accuracy(t, 'x_te', 'y_te'):.3f}." + ) + print( + f"the model is one xarray Dataset again " + f"(vars {list(trained.data_vars)}, dims {dict(trained.sizes)}); " + f"metrics are a table -> xarray {list(metrics.data_vars)} over " + f"{dict(metrics.sizes)}." ) From 3b365f6656f887706f4dfdb8ef589eddbad0a201 Mon Sep 17 00:00:00 2001 From: Alex Merose Date: Tue, 30 Jun 2026 18:15:36 +0300 Subject: [PATCH 17/17] demo: the whole MLP as one weight relation (bias folded, layer dim) Two simplifications collapse the model to a single relation: - Bias folded into the weights (an nn.Linear): each layer's bias is the weight of a constant-1 input, kept as the row inp=width of the same weight array, so a layer is one matrix. - A layer dimension: every layer's weight lives in one weight(layer, inp, out) array, so forward/backward filter on the layer COLUMN instead of referencing a table per layer. The model is one xr.Dataset with a layer dim (NaN-padded for the ragged pyramid, dropped on seed); from_dataset registers it; the update is one query over the whole weight relation. A single contract() and a generic loop train a net of any depth (validated exact against numpy at depth 3). Tensors.put now unifies batch nullability so UNION results register cleanly. Faster too (~6s vs ~13s) at the same ~83% test accuracy; model and metrics still round-trip to xarray. Co-Authored-By: Claude Opus 4.8 --- benchmarks/README.md | 43 +++--- benchmarks/mnist_mlp.py | 297 +++++++++++++++++++++------------------- 2 files changed, 182 insertions(+), 158 deletions(-) diff --git a/benchmarks/README.md b/benchmarks/README.md index 10b4fea..5fa3fcd 100644 --- a/benchmarks/README.md +++ b/benchmarks/README.md @@ -77,31 +77,38 @@ Contracting a shared index is a join on it followed by a grouped `SUM` over the indices that survive. In xarray-sql an array indexed by named dims is a table keyed by those dims, so **the dimension names are the join keys**. -**The architecture is data.** The whole model is *one* `xr.Dataset`: each layer's -weight is a data variable `w{L}` over dims `(u{L}, u{L+1})`, the widths it -connects, sharing the boundary dims (`u1` is layer 0's output and layer 1's -input, so it is the join key between them). The dim sizes *are* the layer widths, -and the number of weights is the depth — differing neuron counts per layer are -just differing dim sizes, no padding, because the relational (long) form is -naturally ragged. `from_dataset` splits that one Dataset into a table per weight -automatically. Change `WIDTHS` (e.g. `196, 64, 32, 10`) and the same code trains -the deeper net. +**The whole network is one relation.** Two moves get there: + +- **Bias folded into the weights (an `nn.Linear`).** Each layer's bias is the + weight of a constant-`1` input, kept as the extra row `inp = width` of the same + weight array — so a layer is a single matrix. +- **A `layer` dimension.** Every layer's weight lives in one + `weight(layer, inp, out)` array, so the forward/backward filter on the `layer` + *column* instead of referencing a table per layer. + +So **the architecture is data**: the whole model is one `xr.Dataset` with a +`layer` dim, registered via `from_dataset`. The dim sizes are the layer widths +and the number of layers is the depth — differing neuron counts are just +differing sizes, NaN-padded in the dense array and dropped on the way in (the +relational form is naturally ragged). Change `WIDTHS` (e.g. `196, 64, 32, 10`) +and the same code trains the deeper net. A small `contract()` helper turns an einsum spec into one query, and a single generic loop trains a net of any shape: -- **forward** contracts the activation with each layer's weight, `+ bias`, - `tanh` (softmax on the last layer). +- **forward** contracts the activation with `weight WHERE layer = L`, adds the + bias row, `tanh` (softmax on the last layer). - **backward is the *same* operator with indices transposed** — the VJP of a - contraction is a contraction — and `grad(tanh(z), z)` supplies the only - genuinely-calculus part. Linear algebra is joins; the derivatives of the - nonlinearities are `grad`. + contraction is a contraction — accumulated into one `gweight` relation, with + `grad(tanh(z), z)` for the only genuinely-calculus part. Even the update is one + query over the whole `weight` relation. Linear algebra is joins; the + derivatives of the nonlinearities are `grad`. Everything stays relational: every stage is an inspectable table (`a1`, `delta2`, -`gw0`, …); the only hand-written gradient is softmax + cross-entropy's `delta = -softmax - onehot`. Even the training metrics are a table — each logged step -appends a `(step, loss, train_acc, test_acc)` row to a `metrics` relation rather -than a Python list (NN training produces a lot of such data; it belongs in +`gweight`, …); the only hand-written gradient is softmax + cross-entropy's +`delta = softmax - onehot`. Even the training metrics are a table — each logged +step appends a `(step, loss, train_acc, test_acc)` row to a `metrics` relation +rather than a Python list (NN training produces a lot of such data; it belongs in rows). Evaluation is SQL too (a forward pass + `ROW_NUMBER()` argmax), and the trained model, predictions, and metrics all come **back out as xarray** via `to_dataset`. Reaches ~83% test accuracy over 60 steps. Downloads MNIST on first diff --git a/benchmarks/mnist_mlp.py b/benchmarks/mnist_mlp.py index d7d97aa..fe31cee 100644 --- a/benchmarks/mnist_mlp.py +++ b/benchmarks/mnist_mlp.py @@ -9,7 +9,7 @@ # [tool.uv.sources] # xarray_sql = { path = "..", editable = true } # /// -"""Train an MNIST MLP as relational tensor algebra — with the architecture as data. +"""Train an MNIST MLP as relational tensor algebra — the whole net is one table. A neural network is a chain of **tensor contractions** (einsums), and an einsum over coordinate-indexed arrays *is* relational algebra: @@ -17,35 +17,41 @@ C[i,k] = sum_j A[i,j] * B[j,k] <=> JOIN A, B ON A.j = B.j GROUP BY i, k -> SUM(A.val * B.val) -Contracting a shared index is a join on it followed by a grouped SUM over the -indices that survive. In xarray-sql an array indexed by named dims is a table -keyed by those dims, so **the dimension names are the join keys**. - -The whole model is **one ``xr.Dataset``**. Each layer's weight is a data variable -whose two dims are the widths it connects — ``w0(u0, u1)``, ``w1(u1, u2)``, … — -sharing the boundary dims (``u1`` is the output of layer 0 and the input of layer -1, so it is the join key between them). **The architecture is therefore data: the -Dataset's dim sizes are the layer widths, and the number of layers is how many -weights it holds.** Differing neuron counts per layer are just differing dim -sizes — no padding, because the relational (long) form is naturally ragged. -``from_dataset`` splits that one Dataset into a table per weight automatically. - -A single ``contract()`` turns an einsum spec into one query, and a single generic -loop trains a net of any depth/width: - -* **forward** — contract the activation with each layer's weight, add bias, tanh - (softmax on the last layer). -* **backward is the same operator transposed** — the VJP of a contraction is a - contraction — with ``grad(tanh(z), z)`` for the one local-derivative step. - Linear algebra is joins; the derivatives of the nonlinearities are ``grad``. - -Every stage is an inspectable relation; the trained model, predictions, and loss -curve come back out as ``xarray`` via ``to_dataset``. Change ``WIDTHS`` and the -same code trains a different network. - -This is not a numpy replacement — relational matmul carries join overhead a BLAS -inner product doesn't. What it buys is a declarative, inspectable pipeline whose -data side is chunked xarray (parallel over the batch, larger-than-memory). +Contracting a shared index is a join on it followed by a grouped SUM. In +xarray-sql an array indexed by named dims is a table keyed by those dims, so the +dim names are the join keys. + +Two simplifications make the whole model **one relation**: + +* **Bias folded into the weights (an ``nn.Linear``).** Each layer's bias is the + weight of a constant-``1`` input, stored as the extra row ``inp = width`` in the + same weight array — so a layer is a single matrix. The forward reads the matmul + rows and that bias row from the one relation (no separate bias table). +* **A ``layer`` dimension.** Every layer's weight lives in one + ``weight(layer, inp, out)`` array, so the forward/backward filter on the + ``layer`` *column* instead of referencing a table per layer. The whole network + is one ``xr.Dataset`` registered with ``from_dataset``; differing layer widths + are NaN-padded in the dense array and dropped on the way in (the relational + form is naturally ragged). The architecture is data — change ``WIDTHS`` and the + same code trains a different net. + +A single ``contract()`` and one generic loop train a net of any depth: forward +contracts the activation with ``weight WHERE layer = L``; backward is the same +contraction transposed (the VJP of a contraction is a contraction), with +``grad(tanh(z), z)`` for the one local-derivative step. Even the weight update is +one query over the whole ``weight`` relation. Linear algebra is joins; the +derivatives of the nonlinearities are ``grad``. + +Everything stays relational and inspectable: activations, errors, gradients, and +the per-step training metrics are all tables; the trained model, predictions, and +metrics come back out as ``xarray`` via ``to_dataset``. + +This is not a numpy replacement — the long form puts one matrix entry per row, so +the matmul-as-join carries overhead a BLAS inner product doesn't. What it buys is +a declarative, inspectable pipeline whose data side is chunked xarray (parallel +over the batch, larger-than-memory). Recovering BLAS speed would mean storing +dense *tiles* per cell and contracting them with a tile-matmul — a future +direction, not done here. Run standalone (builds the local extension on first use): @@ -62,6 +68,7 @@ from pathlib import Path import numpy as np +import pyarrow as pa import xarray as xr import xarray_sql as xql @@ -81,13 +88,13 @@ def contract(spec: str, left: str, right: str) -> str: - """An einsum over two coordinate-indexed tables, as one SQL query. + """An einsum over two coordinate-indexed relations, as one SQL query. - ``contract("sample,u0 * u0,u1 -> sample,u1", "x", "w0")`` joins ``x`` and - ``w0`` on their shared dim ``u0``, groups by the output dims, and sums the - product of values — a matmul. Every table has its dims as columns plus a - ``val`` column. Indices in the inputs but not the output are contracted; the - same helper expresses the transposed contractions of backprop. + ``contract("sample,inp * inp,out -> sample,out", "x", w)`` joins ``x`` and + ``w`` on their shared dim ``inp``, groups by the output dims, and sums the + product of values — a matmul. ``left`` / ``right`` are table names or + parenthesised subqueries; each exposes its dims plus a ``val`` column. + Indices in the inputs but not the output are contracted (summed over). """ spec = spec.replace(" ", "") lhs, out = spec.split("->") @@ -135,66 +142,85 @@ def __init__(self, ctx: xql.XarrayContext): def put(self, name: str, sql: str) -> None: batches = self.ctx.sql(sql).collect() + # UNION branches can yield batches that differ only in field nullability; + # cast them all to one (nullable) schema so registration accepts them. + if batches: + target = pa.schema( + [pa.field(f.name, f.type) for f in batches[0].schema] + ) + batches = [b.cast(target) for b in batches] if self.ctx.table_exist(name): self.ctx.deregister_table(name) self.ctx.register_record_batches(name, [batches]) -# --- the model as one xarray Dataset ------------------------------------------ +# --- the model: one weight relation, bias folded in --------------------------- def build_model(rng: np.random.Generator) -> xr.Dataset: - """The whole model as one Dataset: weight ``w{L}`` over dims ``(u{L}, u{L+1})`` - and bias ``b{L}`` over ``(u{L+1},)``. The shared boundary dims tie the layers - together; the dim sizes *are* the architecture.""" - data_vars: dict = {} + """The whole network as one ``weight(layer, inp, out)`` Dataset. + + Layer ``L`` connects ``WIDTHS[L]`` inputs (plus a constant-1 bias input, index + ``WIDTHS[L]``) to ``WIDTHS[L+1]`` outputs. The dense array is NaN-padded to the + widest layer; the padding is dropped when the relation is seeded, so the live + table is the ragged set of real weights. + """ + max_in = max(WIDTHS[layer] + 1 for layer in range(DEPTH)) + max_out = max(WIDTHS[layer + 1] for layer in range(DEPTH)) + arr = np.full((DEPTH, max_in, max_out), np.nan) for layer in range(DEPTH): n_in, n_out = WIDTHS[layer], WIDTHS[layer + 1] - data_vars[f"w{layer}"] = ( - (f"u{layer}", f"u{layer + 1}"), - rng.standard_normal((n_in, n_out)) * 0.1, + arr[layer, :n_in, :n_out] = rng.standard_normal((n_in, n_out)) * 0.1 + arr[layer, n_in, :n_out] = ( + 0.0 # bias row (weight of the constant input) ) - data_vars[f"b{layer}"] = ((f"u{layer + 1}",), np.zeros(n_out)) - coords = {f"u{i}": np.arange(w) for i, w in enumerate(WIDTHS)} - return xr.Dataset(data_vars, coords=coords) + return xr.Dataset( + {"weight": (("layer", "inp", "out"), arr)}, + coords={ + "layer": np.arange(DEPTH), + "inp": np.arange(max_in), + "out": np.arange(max_out), + }, + ) -def seed_weights(t: Tensors) -> None: - """Unpack the one model Dataset (registered as the ``model`` schema) into - working weight/bias relations with a uniform ``val`` column.""" - for layer in range(DEPTH): - i, o = f"u{layer}", f"u{layer + 1}" - t.put( - f"w{layer}", f"SELECT {i}, {o}, w{layer} AS val FROM model.w{layer}" - ) - t.put(f"b{layer}", f"SELECT {o}, b{layer} AS val FROM model.b{layer}") +def matmul_rows(layer: int) -> str: + """The matmul (non-bias) rows of one layer's weight, as a subquery.""" + return f"(SELECT inp, out, val FROM weight WHERE layer = {layer} AND inp < {WIDTHS[layer]})" + + +def bias_row(layer: int) -> str: + """The bias row (inp = width) of one layer's weight, as a subquery over out.""" + return f"(SELECT out, val FROM weight WHERE layer = {layer} AND inp = {WIDTHS[layer]})" # --- the network, as contractions (generic over depth) ------------------------ def forward(t: Tensors, inp: str = "x") -> None: - """Forward pass from ``inp``: a contraction + bias + tanh per layer, leaving - the pre-activations ``a{L}.z`` for backprop and the output ``logits``.""" + """Forward pass from ``inp``: per layer, contract with the matmul rows and add + the bias row (both from the one weight relation), then tanh on the hidden + layers. Leaves ``a{L}.z`` for backprop and the output ``logits``.""" prev = inp for layer in range(DEPTH): - i, o = f"u{layer}", f"u{layer + 1}" - zc = contract(f"sample,{i} * {i},{o} -> sample,{o}", prev, f"w{layer}") + zc = contract( + "sample,inp * inp,out -> sample,out", prev, matmul_rows(layer) + ) if layer < DEPTH - 1: t.put( f"a{layer + 1}", f"""WITH zc AS ({zc}) - SELECT zc.sample, zc.{o}, zc.val + b{layer}.val AS z, - tanh(zc.val + b{layer}.val) AS val - FROM zc JOIN b{layer} ON zc.{o} = b{layer}.{o}""", + SELECT zc.sample, zc.out AS inp, zc.val + b.val AS z, + tanh(zc.val + b.val) AS val + FROM zc JOIN {bias_row(layer)} b ON zc.out = b.out""", ) prev = f"a{layer + 1}" else: t.put( "logits", f"""WITH zc AS ({zc}) - SELECT zc.sample, zc.{o}, zc.val + b{layer}.val AS z - FROM zc JOIN b{layer} ON zc.{o} = b{layer}.{o}""", + SELECT zc.sample, zc.out, zc.val + b.val AS z + FROM zc JOIN {bias_row(layer)} b ON zc.out = b.out""", ) @@ -202,74 +228,77 @@ def softmax_delta_sql() -> str: """Output error delta = softmax(logits) - onehot(label). The one hand-derived rule: softmax couples classes through a per-sample normaliser an aggregate grad() does not cross.""" - o = f"u{DEPTH}" - return f""" + return """ WITH m AS (SELECT sample, MAX(z) AS m FROM logits GROUP BY sample), - e AS (SELECT logits.sample, logits.{o}, exp(logits.z - m.m) AS e + e AS (SELECT logits.sample, logits.out, exp(logits.z - m.m) AS e FROM logits JOIN m ON logits.sample = m.sample), s AS (SELECT sample, SUM(e) AS s FROM e GROUP BY sample) - SELECT e.sample, e.{o}, - e.e / s.s - CASE WHEN e.{o} = y.label THEN 1.0 ELSE 0.0 END AS val + SELECT e.sample, e.out, + e.e / s.s - CASE WHEN e.out = y.label THEN 1.0 ELSE 0.0 END AS val FROM e JOIN s ON e.sample = s.sample JOIN y ON y.sample = e.sample""" def train_step(t: Tensors) -> None: - """Forward, backward (the same contraction transposed), SGD update.""" + """Forward, backward (the same contraction transposed), one SGD update.""" forward(t) t.put(f"delta{DEPTH}", softmax_delta_sql()) - # Backward: walk the layers in reverse, the gradients are contractions. + # Backward: gradients are contractions over the batch, accumulated into one + # gweight relation tagged by layer. delta{L} is the error at layer L's units. for layer in reversed(range(DEPTH)): - i, o = f"u{layer}", f"u{layer + 1}" a_in = "x" if layer == 0 else f"a{layer}" + # matmul gradient (mean over batch) + bias gradient (mean of delta), + # both tagged with this layer, as rows of one gweight relation. gw = contract( - f"sample,{i} * sample,{o} -> {i},{o}", a_in, f"delta{layer + 1}" + "sample,inp * sample,out -> inp,out", a_in, f"delta{layer + 1}" ) - t.put( - f"gw{layer}", f"SELECT {i}, {o}, val / {N_TRAIN} AS val FROM ({gw})" + rows = ( + f"SELECT CAST({layer} AS BIGINT) AS layer, inp, out, " + f"val / {N_TRAIN} AS val FROM ({gw}) " + f"UNION ALL " + f"SELECT CAST({layer} AS BIGINT) AS layer, " + f"CAST({WIDTHS[layer]} AS BIGINT) AS inp, out, AVG(val) AS val " + f"FROM delta{layer + 1} GROUP BY out" ) t.put( - f"gb{layer}", - f"SELECT {o}, AVG(val) AS val FROM delta{layer + 1} GROUP BY {o}", + "gweight", + f"SELECT * FROM gweight UNION ALL {rows}" + if t.ctx.table_exist("gweight") + else rows, ) if layer > 0: # propagate the cotangent, scaled by the local derivative dc = contract( - f"sample,{o} * {i},{o} -> sample,{i}", + "sample,out * inp,out -> sample,inp", f"delta{layer + 1}", - f"w{layer}", + matmul_rows(layer), ) t.put( f"delta{layer}", - f"""WITH dh AS ({dc}) - SELECT dh.sample, dh.{i}, dh.val * grad(tanh(a{layer}.z), a{layer}.z) AS val - FROM dh JOIN a{layer} ON dh.sample = a{layer}.sample AND dh.{i} = a{layer}.{i}""", + f"""WITH dc AS ({dc}) + SELECT dc.sample, dc.inp AS out, + dc.val * grad(tanh(a{layer}.z), a{layer}.z) AS val + FROM dc JOIN a{layer} + ON dc.sample = a{layer}.sample AND dc.inp = a{layer}.inp""", ) - # SGD: each weight relation becomes w - lr * grad. - for layer in range(DEPTH): - i, o = f"u{layer}", f"u{layer + 1}" - t.put( - f"w{layer}", - f"SELECT w{layer}.{i}, w{layer}.{o}, w{layer}.val - {LR} * gw{layer}.val AS val " - f"FROM w{layer} JOIN gw{layer} ON w{layer}.{i} = gw{layer}.{i} " - f"AND w{layer}.{o} = gw{layer}.{o}", - ) - t.put( - f"b{layer}", - f"SELECT b{layer}.{o}, b{layer}.val - {LR} * gb{layer}.val AS val " - f"FROM b{layer} JOIN gb{layer} ON b{layer}.{o} = gb{layer}.{o}", - ) + # One SGD update for the whole network: weight <- weight - lr * gweight. + t.put( + "weight", + f"""SELECT w.layer, w.inp, w.out, w.val - {LR} * g.val AS val + FROM weight w JOIN gweight g + ON w.layer = g.layer AND w.inp = g.inp AND w.out = g.out""", + ) + t.ctx.deregister_table("gweight") def accuracy(t: Tensors, inp: str, lab: str) -> float: """A forward pass over ``inp`` + argmax, compared to ``lab`` — all in SQL.""" forward(t, inp) - o = f"u{DEPTH}" return float( t.ctx.sql( f"""WITH pred AS ( - SELECT sample, {o}, + SELECT sample, out, ROW_NUMBER() OVER (PARTITION BY sample ORDER BY z DESC) AS rk FROM logits) - SELECT AVG(CASE WHEN p.{o} = l.label THEN 1.0 ELSE 0.0 END) AS acc + SELECT AVG(CASE WHEN p.out = l.label THEN 1.0 ELSE 0.0 END) AS acc FROM pred p JOIN {lab} l ON p.sample = l.sample WHERE p.rk = 1""" ).to_pandas()["acc"][0] ) @@ -282,17 +311,16 @@ def record_metrics(t: Tensors, step: int) -> None: everything else here it lives as rows in a relation, grown each time, not a Python list. Read it back at the end as a tidy ``(step,)`` xarray. """ - o = f"u{DEPTH}" train = accuracy(t, "x", "y") # leaves the training forward in `logits` loss = float( t.ctx.sql( - f"""WITH m AS (SELECT sample, MAX(z) AS m FROM logits GROUP BY sample), - e AS (SELECT logits.sample, logits.{o}, exp(logits.z - m.m) AS e + """WITH m AS (SELECT sample, MAX(z) AS m FROM logits GROUP BY sample), + e AS (SELECT logits.sample, logits.out, exp(logits.z - m.m) AS e FROM logits JOIN m ON logits.sample = m.sample), s AS (SELECT sample, SUM(e) AS s FROM e GROUP BY sample) SELECT -AVG(ln(e.e / s.s)) AS loss FROM e JOIN s ON e.sample = s.sample JOIN y ON y.sample = e.sample - WHERE e.{o} = y.label""" + WHERE e.out = y.label""" ).to_pandas()["loss"][0] ) test = accuracy(t, "x_te", "y_te") @@ -354,7 +382,7 @@ def load_mnist(): imgs = _read_idx(paths["images"]).astype(np.float32) / 255.0 labs = _read_idx(paths["labels"]).astype(np.int64) side = WIDTHS[0] # pooled pixels per image - pool = int(round((28 * 28 / side) ** 0.5)) # 2 for 196 pixels + pool = 28 // int(round(side**0.5)) # 2 for 196 pixels (14x14) k = 28 // pool pooled = ( imgs.reshape(-1, k, pool, k, pool).mean(axis=(2, 4)).reshape(-1, side) @@ -373,31 +401,32 @@ def main() -> None: print(f"MNIST: train {Xtr.shape}, test {Xte.shape} architecture {WIDTHS}") ctx = xql.XarrayContext() - # The whole model is one Dataset; from_dataset splits it into a table per - # weight (the shared boundary dims become the join keys). + # The whole model is one Dataset with a layer dim; from_dataset gives one + # `net` table, and seeding drops the NaN padding to the live `weight` relation. rng = np.random.default_rng(1) model = build_model(rng) ctx.from_dataset( - "model", + "net", model, - table_names={ - (f"u{layer}", f"u{layer + 1}"): f"w{layer}" - for layer in range(DEPTH) - } - | {(f"u{layer + 1}",): f"b{layer}" for layer in range(DEPTH)}, - chunks={f"u{i}": w for i, w in enumerate(WIDTHS)}, + chunks={ + "layer": DEPTH, + "inp": model.sizes["inp"], + "out": model.sizes["out"], + }, ) t = Tensors(ctx) - seed_weights(t) + t.put( + "weight", + "SELECT layer, inp, out, weight AS val FROM net WHERE weight IS NOT NULL", + ) - # Inputs and labels, registered once; the queries read x / x_te by name. - register_tensor(ctx, "x", Xtr, ("sample", "u0"), chunk=CHUNK) + # Inputs and labels (the bias is in the weight relation, so no augmentation). + register_tensor(ctx, "x", Xtr, ("sample", "inp"), chunk=CHUNK) register_tensor(ctx, "y", ytr, ("sample",), var="label") - register_tensor(ctx, "x_te", Xte, ("sample", "u0")) + register_tensor(ctx, "x_te", Xte, ("sample", "inp")) register_tensor(ctx, "y_te", yte, ("sample",), var="label") print(f"init: test acc {accuracy(t, 'x_te', 'y_te'):.3f}") - t0 = time.time() for step in range(STEPS): train_step(t) @@ -405,23 +434,12 @@ def main() -> None: record_metrics(t, step) dt = time.time() - t0 - # The trained model comes back out as one xarray Dataset. - parts = [] - for layer in range(DEPTH): - i, o = f"u{layer}", f"u{layer + 1}" - parts.append( - ctx.sql(f"SELECT {i}, {o}, val FROM w{layer}") - .to_dataset(dims=[i, o]) - .rename({"val": f"w{layer}"}) - ) - parts.append( - ctx.sql(f"SELECT {o}, val FROM b{layer}") - .to_dataset(dims=[o]) - .rename({"val": f"b{layer}"}) - ) - trained = xr.merge(parts) - # The loss curve and accuracies were recorded as rows; read them back as a - # tidy (step,) xarray of training metrics. + # The trained model, predictions, and metrics all come back out as xarray. + weights = ( + ctx.sql("SELECT layer, inp, out, val FROM weight") + .to_dataset(dims=["layer", "inp", "out"]) + .rename({"val": "weight"}) + ) metrics = ctx.sql("SELECT * FROM metrics ORDER BY step").to_dataset( dims=["step"] ) @@ -431,10 +449,9 @@ def main() -> None: f"test accuracy {accuracy(t, 'x_te', 'y_te'):.3f}." ) print( - f"the model is one xarray Dataset again " - f"(vars {list(trained.data_vars)}, dims {dict(trained.sizes)}); " - f"metrics are a table -> xarray {list(metrics.data_vars)} over " - f"{dict(metrics.sizes)}." + f"the whole model is one weight relation -> xarray " + f"{dict(weights.sizes)}; metrics are a table -> xarray " + f"{list(metrics.data_vars)} over {dict(metrics.sizes)}." )