diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 0000000..fc419a1 --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,37 @@ +# AGENTS.md + +Guidance for contributors (including AI assistants) working on `xarray-sql`. It +summarizes recurring maintainer review feedback so changes land clean. + +## Documentation and comments + +- Keep docstrings and comments self-contained. Do **not** put GitHub issue or PR + numbers in docstrings or code comments; a reader should not need the issue + tracker to understand the code. Issue references belong in the commit message + and PR description (e.g. `Closes #189`), not in the source. +- Do not reference the review conversation, chat, or "the reporter" in comments. + Describe the behavior, not how it came up. + +## API surface + +- Mark internal helpers private with a leading underscore when they are not part + of the public API. + +## Tests + +- Test the public contract (values, dims, coords, attrs), not internal call + counts or private classes, so the suite survives refactors. +- Avoid redundant tests: if a public-path test already covers a behavior, do not + add a second lower-level test for the same thing. +- Make query results deterministic with `ORDER BY` so assertions do not have to + re-sort the output. +- Do not pass `dims=` to `to_dataset()` when inference already resolves them. + Reserve explicit `dims=` / `template=` for genuinely ambiguous cases (multiple + registered Datasets, or a test that is specifically exercising those + arguments). + +## Imports + +- Keep imports at the top of the file. Assume transitive dependencies are safe + to import non-locally, rather than deferring imports into functions to avoid + a dependency. diff --git a/README.md b/README.md index 0356c6a..a254d29 100644 --- a/README.md +++ b/README.md @@ -47,9 +47,8 @@ clim = ctx.sql(''' ORDER BY month ''') -# Write the SQL result back to an Xarray Dataset. `month` is a derived -# column, so name it as the dimension; the variable's units are recovered -# from the registered table. The result is one value per month: air(month). +# Round-trip the result back to Xarray. `month` is a derived column, so name +# it as the dimension. clim_ds = clim.to_dataset(dims=["month"]) # Plot the annual cycle as a time series. @@ -138,7 +137,9 @@ ctx.sql(''' AND TIMESTAMP '2020-01-01 05:00:00' GROUP BY latitude, longitude ORDER BY latitude DESC, longitude -''').to_dataset(dims=['latitude', 'longitude'], template=ds) +# `latitude`/`longitude` are inferred from the registered table's surviving +# dims; `template` is kept only to recover metadata (attrs, encoding). +''').to_dataset(template=ds) # Size: 8MB # Dimensions: (latitude: 721, longitude: 1440) # Coordinates: diff --git a/docs/examples.md b/docs/examples.md index 01b9956..f80c5b5 100644 --- a/docs/examples.md +++ b/docs/examples.md @@ -34,8 +34,10 @@ clim = ctx.sql(''' clim.to_pandas().head() # Option 2: round-trip back to an Xarray Dataset and plot the annual cycle as -# a time series. `month` is a derived column, so name it as the dimension; the -# variable's units are recovered from the registered table. +# a time series. `to_dataset()` infers dimensions from the registered table's +# surviving dims, so a GROUP BY on a real dimension needs no `dims=`. Here +# `month` is a derived column, not a registered dim, so name it explicitly; +# the variable's units are recovered from the registered table. clim_ds = clim.to_dataset(dims=["month"]) clim_ds["air"].plot() ``` diff --git a/tests/test_ds.py b/tests/test_ds.py index aa3deb2..3682dac 100644 --- a/tests/test_ds.py +++ b/tests/test_ds.py @@ -125,7 +125,7 @@ def test_aggregation_drops_dim(air_dataset_small): ctx.from_dataset("air", air_dataset_small) out = ctx.sql( "SELECT lat, lon, AVG(air) AS air_avg FROM air GROUP BY lat, lon" - ).to_dataset(dims=["lat", "lon"]) + ).to_dataset() assert set(out.dims) == {"lat", "lon"} assert "air_avg" in out.data_vars assert "air" not in out.data_vars @@ -139,6 +139,23 @@ def test_aggregation_drops_dim(air_dataset_small): np.testing.assert_allclose(actual, expected) +def test_aggregation_infers_dims(air_dataset_small): + """to_dataset() infers the surviving GROUP BY dim when dims is omitted.""" + ctx = XarrayContext() + ctx.from_dataset("air", air_dataset_small) + + # Grouping by the time coordinate keeps time as the sole dimension; the + # ORDER BY makes the result order deterministic so no sort is needed below. + out = ctx.sql( + 'SELECT "time", AVG("air") AS air FROM "air" ' + 'GROUP BY "time" ORDER BY "time"' + ).to_dataset() + assert set(out.dims) == {"time"} + assert "air" in out.data_vars + expected = air_dataset_small.compute().mean(dim=["lat", "lon"])["air"] + np.testing.assert_allclose(out["air"].values, expected.values) + + def test_barrier_query_scans_source_once(air_dataset_small): """A barrier plan (aggregation) executes the source exactly once. @@ -166,7 +183,7 @@ def test_barrier_query_scans_source_once(air_dataset_small): out = ctx.sql( "SELECT lat, lon, AVG(air) AS air_avg FROM air GROUP BY lat, lon" - ).to_dataset(dims=["lat", "lon"]) + ).to_dataset() reads_after_construct = len(reads) out.compute() reads_after_compute = len(reads) @@ -188,7 +205,7 @@ def test_order_by_direction_sets_dim_order(air_dataset_small): ctx.from_dataset("air", air_dataset_small) out = ctx.sql( "SELECT lat, AVG(air) AS air_avg FROM air GROUP BY lat ORDER BY lat DESC" - ).to_dataset(dims=["lat"]) + ).to_dataset() lat = out["lat"].values assert (np.diff(lat) < 0).all(), f"expected descending lat, got {lat}" @@ -289,7 +306,7 @@ def test_fast_path_uses_scanned_tables_coords_not_user_template( def test_round_trip_preserves_descending_lat_on_lazy_path(air_dataset_small): - """Lazy round-trip preserves source dim order (xarray-sql#171). + """Lazy round-trip preserves source dim order. NCEP ``air_temperature`` ships descending lat (75.0 -> 15.0). The discovery path's ``.distinct().sort()`` previously flipped lat to @@ -383,14 +400,12 @@ def test_to_dataset_multi_registered_requires_explicit_template( assert set(out.dims) == {"time", "lat", "lon"} -def test_to_dataset_infer_fails_when_no_template_fits(air_dataset_small): - """If no registered Dataset's dims fit the result -> clear error.""" +def test_to_dataset_infer_fails_when_no_dim_survives(air_dataset_small): + """A global aggregation leaves no registered dim in the result -> clear error.""" ctx = XarrayContext() ctx.from_dataset("air", air_dataset_small) with pytest.raises(ValueError, match="dims cannot be inferred"): - ctx.sql( - "SELECT lat, lon, AVG(air) AS air_avg FROM air GROUP BY lat, lon" - ).to_dataset() + ctx.sql("SELECT AVG(air) AS air_avg FROM air").to_dataset() def test_template_accepts_name_or_dataset(air_dataset_small): @@ -447,7 +462,7 @@ def test_template_aggregation_alias_no_attrs(air_dataset_small): ctx.from_dataset("air", ds) out = ctx.sql( "SELECT lat, lon, AVG(air) AS air_avg FROM air GROUP BY lat, lon" - ).to_dataset(dims=["lat", "lon"]) + ).to_dataset() assert "air_avg" in out.data_vars assert out["air_avg"].attrs == {} diff --git a/xarray_sql/ds.py b/xarray_sql/ds.py index 5dfdf42..cd1db39 100644 --- a/xarray_sql/ds.py +++ b/xarray_sql/ds.py @@ -757,10 +757,12 @@ def to_dataset( Args: dims: Result columns to use as Dataset dimensions. When - ``None``, defaults to the dims of the registered Dataset - referenced by the SQL ``FROM`` clause (if exactly one - matches), or any single registered Dataset whose dims are - all present in the result columns. + ``None``, defaults to a registered Dataset's dimensions that + survive into the result columns, so an aggregation that drops + dims (e.g. ``GROUP BY time`` over a ``(time, lat, lon)`` grid) + round-trips on the remaining dim. Raises when no dimension + survives, or when several registered Datasets imply different + dims (pass ``dims`` explicitly then). template: Source to recover metadata (attrs, encoding, non-dim coordinates, dim-coord dtype) from. Either an ``xr.Dataset`` used directly, or the name of a registered table (e.g. @@ -879,33 +881,37 @@ def _infer_dimension_columns( ) -> list[str]: """Pick a default ``dimension_columns`` from the registry, or raise. - Uses the data variable's dim order (via :func:`_ds_var_dims`) so - the round-trip preserves the original axis order. + A registered Dataset's dims that survive into the result columns + become the dimensions, so aggregations that drop dims (e.g. + ``GROUP BY time`` over a ``(time, lat, lon)`` grid) round-trip on the + surviving dim(s). Uses the data variable's dim order (via + :func:`_ds_var_dims`) so the original axis order is preserved. """ result_cols = set(self._result_columns()) - if ( - preferred_template is not None - and set(preferred_template.dims) <= result_cols - ): - return _ds_var_dims(preferred_template) + + def surviving(template: xr.Dataset) -> list[str]: + # Template dims still present in the result, in var axis order. + return [d for d in _ds_var_dims(template) if d in result_cols] + + if preferred_template is not None: + preferred = surviving(preferred_template) + if preferred: + return preferred if not self._templates: raise ValueError( "dims cannot be inferred (no registered " "Dataset on this result); pass dims=[...] " "explicitly." ) - candidates = [ - _ds_var_dims(t) - for t in self._templates.values() - if set(t.dims) <= result_cols - ] + candidates = {tuple(surviving(t)) for t in self._templates.values()} + candidates.discard(()) # templates with no surviving dim if len(candidates) == 1: - return candidates[0] + return list(next(iter(candidates))) if not candidates: raise ValueError( - "dims cannot be inferred: no registered " - "Dataset has all of its dims present in the result " - "columns. Pass dims=[...] explicitly." + "dims cannot be inferred: no registered Dataset " + "dimension survives in the result columns. Pass " + "dims=[...] explicitly." ) raise ValueError( "dims cannot be inferred unambiguously: multiple "