From 53a685d5901cb3eceec706f53bf199da82e8ec34 Mon Sep 17 00:00:00 2001 From: Gabriel Musat Mestre Date: Fri, 26 Jun 2026 15:07:40 +0200 Subject: [PATCH] Integrate `datafusion-distributed` with `datafusion-python` --- Cargo.lock | 321 ++++++++++++++++++ Cargo.toml | 2 + crates/core/Cargo.toml | 2 + crates/core/src/context.rs | 34 +- crates/core/src/distributed_worker.rs | 206 +++++++++++ .../core/src/distributed_worker_resolver.rs | 84 +++++ crates/core/src/lib.rs | 4 + crates/core/src/physical_plan.rs | 36 ++ examples/README.md | 14 + examples/distributed-localhost-worker.py | 37 ++ examples/distributed-run.py | 79 +++++ python/datafusion/__init__.py | 8 +- python/datafusion/context.py | 39 +++ python/datafusion/distributed.py | 175 ++++++++++ python/datafusion/plan.py | 25 +- python/tests/test_context.py | 49 +++ python/tests/test_distributed.py | 48 +++ python/tests/test_imports.py | 11 + python/tests/test_plans.py | 18 + 19 files changed, 1186 insertions(+), 6 deletions(-) create mode 100644 crates/core/src/distributed_worker.rs create mode 100644 crates/core/src/distributed_worker_resolver.rs create mode 100644 examples/distributed-localhost-worker.py create mode 100644 examples/distributed-run.py create mode 100644 python/datafusion/distributed.py create mode 100644 python/tests/test_distributed.py diff --git a/Cargo.lock b/Cargo.lock index fc9b74cdb..20d3fba1b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -238,6 +238,26 @@ dependencies = [ "num-traits", ] +[[package]] +name = "arrow-flight" +version = "58.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "28abfe8bf9f124e5fc83b334af4fa58f8d0323ad25312ccb2d1da50178415704" +dependencies = [ + "arrow-array", + "arrow-buffer", + "arrow-cast", + "arrow-ipc", + "arrow-schema", + "base64", + "bytes", + "futures", + "prost", + "prost-types", + "tonic", + "tonic-prost", +] + [[package]] name = "arrow-ipc" version = "58.3.0" @@ -377,6 +397,17 @@ version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f4de21c0feef7e5a556e51af767c953f0501f7f300ba785cc99c47bdc8081a50" +[[package]] +name = "async-lock" +version = "3.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "290f7f2596bd5b78a9fec8088ccd89180d7f9f55b94b0576823bbbdc72ee8311" +dependencies = [ + "event-listener", + "event-listener-strategy", + "pin-project-lite", +] + [[package]] name = "async-recursion" version = "1.1.1" @@ -420,6 +451,49 @@ version = "1.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f2032f911046de80f0a198e0901378627c33f59ea0ac00e363d481118bd70a53" +[[package]] +name = "axum" +version = "0.8.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "31b698c5f9a010f6573133b09e0de5408834d0c82f8d7475a89fc1867a71cd90" +dependencies = [ + "axum-core", + "bytes", + "futures-util", + "http", + "http-body", + "http-body-util", + "itoa", + "matchit", + "memchr", + "mime", + "percent-encoding", + "pin-project-lite", + "serde_core", + "sync_wrapper", + "tower", + "tower-layer", + "tower-service", +] + +[[package]] +name = "axum-core" +version = "0.5.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08c78f31d7b1291f7ee735c1c6780ccde7785daae9a9206026862dab7d8792d1" +dependencies = [ + "bytes", + "futures-core", + "http", + "http-body", + "http-body-util", + "mime", + "pin-project-lite", + "sync_wrapper", + "tower-layer", + "tower-service", +] + [[package]] name = "base64" version = "0.22.1" @@ -439,6 +513,15 @@ dependencies = [ "num-traits", ] +[[package]] +name = "bincode" +version = "1.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b1f45e9417d87227c7a56d22e471c6206462cba514c7590c09aff4cf6d1ddcad" +dependencies = [ + "serde", +] + [[package]] name = "bitflags" version = "2.13.0" @@ -576,8 +659,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1aa79e62e7697b8e29b513a68abacf485adcd1fe8284a4316c5ae868e6633327" dependencies = [ "iana-time-zone", + "js-sys", "num-traits", "serde", + "wasm-bindgen", "windows-link", ] @@ -631,6 +716,15 @@ version = "0.4.32" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cc14f565cf027a105f7a44ccf9e5b424348421a1d8952a8fc9d499d313107789" +[[package]] +name = "concurrent-queue" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ca0197aee26d1ae37445ee532fefce43251d24cc7c166799f4d46817f1d3973" +dependencies = [ + "crossbeam-utils", +] + [[package]] name = "const-oid" version = "0.10.2" @@ -712,6 +806,33 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "crossbeam-channel" +version = "0.5.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "82b8f8f868b36967f9606790d1903570de9ceaf870a7bf9fbbd3016d636a2cb2" +dependencies = [ + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-epoch" +version = "0.9.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e" +dependencies = [ + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-queue" +version = "0.3.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0f58bbc28f91df819d0aa2a2c00cd19754769c2fad90579b3592b1c9ba7a3115" +dependencies = [ + "crossbeam-utils", +] + [[package]] name = "crossbeam-utils" version = "0.8.21" @@ -1083,6 +1204,43 @@ dependencies = [ "tokio", ] +[[package]] +name = "datafusion-distributed" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5467bab3520c7641427f723f6df70ebad684aef934e66d0a55a08aeabf95b9f7" +dependencies = [ + "arrow-flight", + "arrow-ipc", + "arrow-select", + "async-trait", + "bincode", + "bytes", + "chrono", + "crossbeam-queue", + "dashmap", + "datafusion", + "datafusion-proto", + "delegate", + "futures", + "http", + "itertools", + "moka", + "object_store", + "pin-project", + "prost", + "rand 0.9.4", + "sketches-ddsketch", + "tokio", + "tokio-stream", + "tokio-util", + "tonic", + "tonic-prost", + "tower", + "url", + "uuid", +] + [[package]] name = "datafusion-doc" version = "54.0.0" @@ -1535,6 +1693,7 @@ dependencies = [ "chrono", "cstr", "datafusion", + "datafusion-distributed", "datafusion-ffi", "datafusion-proto", "datafusion-python-util", @@ -1553,6 +1712,7 @@ dependencies = [ "pyo3-log", "serde_json", "tokio", + "tonic", "url", "uuid", ] @@ -1652,6 +1812,17 @@ dependencies = [ "url", ] +[[package]] +name = "delegate" +version = "0.13.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "780eb241654bf097afb00fc5f054a09b687dad862e485fdcf8399bb056565370" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "digest" version = "0.10.7" @@ -1713,6 +1884,27 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "event-listener" +version = "5.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e13b66accf52311f30a0db42147dadea9850cb48cd070028831ae5f5d4b856ab" +dependencies = [ + "concurrent-queue", + "parking", + "pin-project-lite", +] + +[[package]] +name = "event-listener-strategy" +version = "0.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8be9f3dfaaffdae2972880079a491a1a8bb7cbed0b8dd7a347f668b4150a3b93" +dependencies = [ + "event-listener", + "pin-project-lite", +] + [[package]] name = "fastrand" version = "2.4.1" @@ -2041,6 +2233,12 @@ version = "1.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87" +[[package]] +name = "httpdate" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9" + [[package]] name = "humantime" version = "2.3.0" @@ -2070,6 +2268,7 @@ dependencies = [ "http", "http-body", "httparse", + "httpdate", "itoa", "pin-project-lite", "smallvec", @@ -2093,6 +2292,19 @@ dependencies = [ "tower-service", ] +[[package]] +name = "hyper-timeout" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b90d566bffbce6a75bd8b09a05aa8c2cb1fabb6cb348f8840c9e4c90a0d83b0" +dependencies = [ + "hyper", + "hyper-util", + "pin-project-lite", + "tokio", + "tower-service", +] + [[package]] name = "hyper-util" version = "0.1.20" @@ -2459,6 +2671,12 @@ dependencies = [ "twox-hash", ] +[[package]] +name = "matchit" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47e1ffaa40ddd1f3ed91f717a33c8c0ee23fff369e3aa8772b9605cc1d22f4c3" + [[package]] name = "md-5" version = "0.10.6" @@ -2494,6 +2712,12 @@ dependencies = [ "libmimalloc-sys", ] +[[package]] +name = "mime" +version = "0.3.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" + [[package]] name = "miniz_oxide" version = "0.8.9" @@ -2515,6 +2739,26 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "moka" +version = "0.12.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "957228ad12042ee839f93c8f257b62b4c0ab5eaae1d4fa60de53b27c9d7c5046" +dependencies = [ + "async-lock", + "crossbeam-channel", + "crossbeam-epoch", + "crossbeam-utils", + "equivalent", + "event-listener", + "futures-util", + "parking_lot", + "portable-atomic", + "smallvec", + "tagptr", + "uuid", +] + [[package]] name = "multimap" version = "0.10.1" @@ -2629,6 +2873,12 @@ dependencies = [ "num-traits", ] +[[package]] +name = "parking" +version = "2.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f38d5652c16fde515bb1ecef450ab0f6a219d619a7274976324d5e377f7dceba" + [[package]] name = "parking_lot" version = "0.12.5" @@ -3570,6 +3820,16 @@ version = "2.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f8fadd59c855ef2080decdef8ff161eb6661b86933c9d82e5ba29dc602a55aba" +[[package]] +name = "signal-hook-registry" +version = "1.4.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c4db69cba1110affc0e9f7bcd48bbf87b3f4fc7c61fc9155afd4c469eb3d6c1b" +dependencies = [ + "errno", + "libc", +] + [[package]] name = "simd-adler32" version = "0.3.9" @@ -3588,6 +3848,15 @@ version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8ee5873ec9cce0195efcb7a4e9507a04cd49aec9c83d0389df45b1ef7ba2e649" +[[package]] +name = "sketches-ddsketch" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c6f73aeb92d671e0cc4dca167e59b2deb6387c375391bc99ee743f326994a2b" +dependencies = [ + "serde", +] + [[package]] name = "slab" version = "0.4.12" @@ -3767,6 +4036,12 @@ dependencies = [ "syn", ] +[[package]] +name = "tagptr" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b2093cf4c8eb1e67749a6762251bc9cd836b6fc171623bd0a9d324d37af2417" + [[package]] name = "target-lexicon" version = "0.13.5" @@ -3860,7 +4135,9 @@ dependencies = [ "bytes", "libc", "mio", + "parking_lot", "pin-project-lite", + "signal-hook-registry", "socket2", "tokio-macros", "windows-sys 0.61.2", @@ -3942,6 +4219,46 @@ dependencies = [ "winnow", ] +[[package]] +name = "tonic" +version = "0.14.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac2a5518c70fa84342385732db33fb3f44bc4cc748936eb5833d2df34d6445ef" +dependencies = [ + "async-trait", + "axum", + "base64", + "bytes", + "h2", + "http", + "http-body", + "http-body-util", + "hyper", + "hyper-timeout", + "hyper-util", + "percent-encoding", + "pin-project", + "socket2", + "sync_wrapper", + "tokio", + "tokio-stream", + "tower", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "tonic-prost" +version = "0.14.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "50849f68853be452acf590cde0b146665b8d507b3b8af17261df47e02c209ea0" +dependencies = [ + "bytes", + "prost", + "tonic", +] + [[package]] name = "tower" version = "0.5.3" @@ -3950,11 +4267,15 @@ checksum = "ebe5ef63511595f1344e2d5cfa636d973292adc0eec1f0ad45fae9f0851ab1d4" dependencies = [ "futures-core", "futures-util", + "indexmap", "pin-project-lite", + "slab", "sync_wrapper", "tokio", + "tokio-util", "tower-layer", "tower-service", + "tracing", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 0f3236ecf..564b69e74 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -32,6 +32,7 @@ resolver = "3" [workspace.dependencies] tokio = { version = "1.52" } +tonic = { version = "0.14", features = ["transport"] } pyo3 = { version = "0.28" } pyo3-async-runtimes = { version = "0.28" } pyo3-log = "0.13.3" @@ -50,6 +51,7 @@ datafusion-functions-aggregate = { version = "54" } datafusion-functions-window = { version = "54" } datafusion-spark = { version = "54" } datafusion-expr = { version = "54" } +datafusion-distributed = { version = "2" } prost = "0.14.3" serde_json = "1" uuid = { version = "1.23" } diff --git a/crates/core/Cargo.toml b/crates/core/Cargo.toml index e2e922a82..17fbe7412 100644 --- a/crates/core/Cargo.toml +++ b/crates/core/Cargo.toml @@ -40,6 +40,7 @@ tokio = { workspace = true, features = [ "rt-multi-thread", "sync", ] } +tonic = { workspace = true } pyo3 = { workspace = true, features = [ "extension-module", "generate-import-lib", @@ -54,6 +55,7 @@ datafusion-substrait = { workspace = true, optional = true } datafusion-proto = { workspace = true } datafusion-ffi = { workspace = true } datafusion-spark = { workspace = true } +datafusion-distributed = { workspace = true } prost = { workspace = true } # keep in line with `datafusion-substrait` serde_json = { workspace = true } uuid = { workspace = true, features = ["v4"] } diff --git a/crates/core/src/context.rs b/crates/core/src/context.rs index 0613a96dc..96df5f9e5 100644 --- a/crates/core/src/context.rs +++ b/crates/core/src/context.rs @@ -43,10 +43,11 @@ use datafusion::execution::memory_pool::{FairSpillPool, GreedyMemoryPool, Unboun use datafusion::execution::options::{ArrowReadOptions, ReadOptions}; use datafusion::execution::runtime_env::RuntimeEnvBuilder; use datafusion::execution::session_state::SessionStateBuilder; -use datafusion::execution::{FunctionRegistry, TaskContextProvider}; +use datafusion::execution::{FunctionRegistry, SessionState, TaskContextProvider}; use datafusion::prelude::{ AvroReadOptions, CsvReadOptions, DataFrame, JsonReadOptions, ParquetReadOptions, }; +use datafusion_distributed::{DistributedConfig, DistributedExt, SessionStateBuilderExt}; use datafusion_ffi::catalog_provider::FFI_CatalogProvider; use datafusion_ffi::catalog_provider_list::FFI_CatalogProviderList; use datafusion_ffi::config::extension_options::FFI_ExtensionOptions; @@ -78,6 +79,7 @@ use crate::common::data_type::PyScalarValue; use crate::common::df_schema::PyDFSchema; use crate::dataframe::PyDataFrame; use crate::dataset::Dataset; +use crate::distributed_worker_resolver::PyWorkerResolver; use crate::errors::{ PyDataFusionError, PyDataFusionResult, from_datafusion_error, py_datafusion_err, }; @@ -219,6 +221,15 @@ impl PySessionConfig { Ok(Self::from(config)) } + + #[pyo3(signature = (worker_resolver))] + fn with_distributed(&self, worker_resolver: PyWorkerResolver) -> Self { + let config = self + .config + .clone() + .with_distributed_worker_resolver(worker_resolver); + Self::from(config) + } } /// Runtime options for a SessionContext @@ -392,13 +403,20 @@ impl PySessionContext { } else { RuntimeEnvBuilder::default() }; + let distributed = DistributedConfig::from_config_options(config.options()).is_ok(); + let runtime = Arc::new(runtime_env_builder.build()?); - let session_state = SessionStateBuilder::new() + let mut builder = SessionStateBuilder::new() .with_config(config) .with_runtime_env(runtime) .with_default_features() - .with_analyzer_rule(Arc::new(crate::analyzer::ResolveLambdaVariables::new())) - .build(); + .with_analyzer_rule(Arc::new(crate::analyzer::ResolveLambdaVariables::new())); + + if distributed { + builder = builder.with_distributed_planner(); + } + + let session_state = builder.build(); let ctx = Arc::new(SessionContext::new_with_state(session_state)); Ok(PySessionContext { ctx, @@ -1430,6 +1448,14 @@ impl PySessionContext { } impl PySessionContext { + pub(crate) fn from_session_state(session_state: SessionState) -> Self { + Self { + ctx: Arc::new(SessionContext::new_with_state(session_state)), + logical_codec: Arc::new(PythonLogicalCodec::default()), + physical_codec: Arc::new(PythonPhysicalCodec::default()), + } + } + async fn _table(&self, name: &str) -> datafusion::common::Result { self.ctx.table(name).await } diff --git a/crates/core/src/distributed_worker.rs b/crates/core/src/distributed_worker.rs new file mode 100644 index 000000000..431bed6d3 --- /dev/null +++ b/crates/core/src/distributed_worker.rs @@ -0,0 +1,206 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::collections::HashMap; +use std::net::SocketAddr; + +use async_trait::async_trait; +use datafusion::common::{DataFusionError, Result as DataFusionResult}; +use datafusion::execution::{SessionState, SessionStateBuilder}; +use datafusion_distributed::{Worker, WorkerQueryContext, WorkerSessionBuilder}; +use datafusion_python_util::wait_for_future; +use pyo3::Borrowed; +use pyo3::exceptions::{PyRuntimeError, PyTypeError}; +use pyo3::prelude::*; +use tonic::transport::Server; + +use crate::context::PySessionContext; +use crate::errors::{PyDataFusionError, PyDataFusionResult}; + +#[pyclass( + from_py_object, + frozen, + name = "Worker", + module = "datafusion", + subclass +)] +#[derive(Clone)] +pub struct PyWorker { + worker: Worker, +} + +#[pymethods] +impl PyWorker { + #[new] + fn new() -> Self { + Self { + worker: Worker::default(), + } + } + + #[staticmethod] + fn from_session_builder(session_builder: PyWorkerSessionBuilder) -> Self { + Self { + worker: Worker::from_session_builder(session_builder), + } + } + + fn with_version(&self, version: String) -> Self { + Self { + worker: self.worker.clone().with_version(version), + } + } + + fn with_max_message_size(&self, size: usize) -> Self { + Self { + worker: self.worker.clone().with_max_message_size(size), + } + } + + #[pyo3(signature = (host = "127.0.0.1", port = 50051))] + fn serve(&self, py: Python<'_>, host: &str, port: u16) -> PyDataFusionResult<()> { + let addr = parse_socket_addr(host, port)?; + let worker = self.worker.clone(); + wait_for_future(py, serve_worker(worker, addr))?.map_err(PyDataFusionError::from) + } + + #[pyo3(signature = (host = "127.0.0.1", port = 50051))] + fn serve_async<'py>( + &self, + py: Python<'py>, + host: &str, + port: u16, + ) -> PyResult> { + let addr = parse_socket_addr(host, port)?; + let worker = self.worker.clone(); + pyo3_async_runtimes::tokio::future_into_py(py, async move { + serve_worker(worker, addr) + .await + .map_err(PyDataFusionError::from)?; + Ok(()) + }) + } +} + +#[pyclass(name = "WorkerQueryContext", module = "datafusion", subclass)] +pub struct PyWorkerQueryContext { + builder: Option, + headers: HashMap, +} + +impl PyWorkerQueryContext { + fn new(ctx: WorkerQueryContext) -> Self { + let headers = ctx + .headers + .iter() + .map(|(name, value)| { + ( + name.as_str().to_owned(), + value.to_str().unwrap_or_default().to_owned(), + ) + }) + .collect(); + + Self { + builder: Some(ctx.builder), + headers, + } + } +} + +#[pymethods] +impl PyWorkerQueryContext { + fn session_context(mut slf: PyRefMut<'_, Self>) -> PyResult { + let builder = slf.builder.take().ok_or_else(|| { + PyRuntimeError::new_err("WorkerQueryContext.session_context() can only be called once") + })?; + Ok(PySessionContext::from_session_state(builder.build())) + } + + #[getter] + fn headers(&self) -> HashMap { + self.headers.clone() + } +} + +pub(crate) struct PyWorkerSessionBuilder { + callback: Py, +} + +impl FromPyObject<'_, '_> for PyWorkerSessionBuilder { + type Error = PyErr; + + fn extract(obj: Borrowed<'_, '_, PyAny>) -> Result { + if !obj.is_callable() { + return Err(PyTypeError::new_err( + "Expected worker session builder to be callable", + )); + } + + Ok(Self { + callback: obj.to_owned().unbind(), + }) + } +} + +#[async_trait] +impl WorkerSessionBuilder for PyWorkerSessionBuilder { + async fn build_session_state( + &self, + ctx: WorkerQueryContext, + ) -> Result { + Python::attach(|py| -> PyResult { + let ctx = Py::new(py, PyWorkerQueryContext::new(ctx))?; + let result = self.callback.call1(py, (ctx,))?; + let session_context = extract_session_context(result.bind(py))?; + Ok(session_context.ctx.state()) + }) + .map_err(|error| DataFusionError::External(Box::new(error))) + } +} + +fn extract_session_context(obj: &Bound<'_, PyAny>) -> PyResult { + if let Ok(session_context) = obj.extract::() { + return Ok(session_context); + } + + if let Ok(ctx_attr) = obj.getattr("ctx") + && let Ok(session_context) = ctx_attr.extract::() + { + return Ok(session_context); + } + + Err(PyTypeError::new_err( + "WorkerSessionBuilder.build_session_state() must return a datafusion.SessionContext", + )) +} + +fn parse_socket_addr(host: &str, port: u16) -> PyDataFusionResult { + format!("{host}:{port}").parse().map_err(|error| { + PyDataFusionError::Common(format!( + "invalid worker bind address {host}:{port}: {error}" + )) + }) +} + +async fn serve_worker(worker: Worker, addr: SocketAddr) -> DataFusionResult<()> { + Server::builder() + .add_service(worker.into_worker_server()) + .serve(addr) + .await + .map_err(|error| DataFusionError::External(Box::new(error))) +} diff --git a/crates/core/src/distributed_worker_resolver.rs b/crates/core/src/distributed_worker_resolver.rs new file mode 100644 index 000000000..a155bae64 --- /dev/null +++ b/crates/core/src/distributed_worker_resolver.rs @@ -0,0 +1,84 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion::common::DataFusionError; +use datafusion_distributed::WorkerResolver; +use pyo3::Borrowed; +use pyo3::exceptions::{PyTypeError, PyValueError}; +use pyo3::prelude::*; +use pyo3::types::PyString; +use url::Url; + +pub(crate) struct PyWorkerResolver { + get_urls: Py, +} + +impl FromPyObject<'_, '_> for PyWorkerResolver { + type Error = PyErr; + + fn extract(obj: Borrowed<'_, '_, PyAny>) -> Result { + let get_urls = obj.getattr("get_urls")?; + if !get_urls.is_callable() { + return Err(PyTypeError::new_err( + "Expected worker_resolver.get_urls to be callable", + )); + } + + Ok(Self { + get_urls: get_urls.unbind(), + }) + } +} + +struct WorkerUrls(Vec); + +impl FromPyObject<'_, '_> for WorkerUrls { + type Error = PyErr; + + fn extract(obj: Borrowed<'_, '_, PyAny>) -> Result { + if obj.is_instance_of::() { + return Err(PyTypeError::new_err( + "WorkerResolver.get_urls() must return an iterable of URL strings, not a string", + )); + } + + let mut parsed_urls = Vec::new(); + for url in obj.try_iter()? { + let url = url?; + let url = url.extract::()?; + let parsed_url = Url::parse(&url).map_err(|error| { + PyValueError::new_err(format!( + "WorkerResolver.get_urls() returned invalid URL {url:?}: {error}" + )) + })?; + parsed_urls.push(parsed_url); + } + + Ok(Self(parsed_urls)) + } +} + +impl WorkerResolver for PyWorkerResolver { + fn get_urls(&self) -> Result, DataFusionError> { + Python::attach(|py| -> PyResult> { + let urls = self.get_urls.call0(py)?; + let urls = urls.extract::(py)?; + Ok(urls.0) + }) + .map_err(|error| DataFusionError::External(Box::new(error))) + } +} diff --git a/crates/core/src/lib.rs b/crates/core/src/lib.rs index 7f0f9cb39..dd184068c 100644 --- a/crates/core/src/lib.rs +++ b/crates/core/src/lib.rs @@ -51,6 +51,8 @@ pub mod table; pub mod unparser; mod array; +mod distributed_worker; +mod distributed_worker_resolver; #[cfg(feature = "substrait")] pub mod substrait; mod udaf; @@ -96,6 +98,8 @@ fn _internal(py: Python, m: Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; + m.add_class::()?; let catalog = PyModule::new(py, "catalog")?; catalog::init_module(&catalog)?; diff --git a/crates/core/src/physical_plan.rs b/crates/core/src/physical_plan.rs index 594655a60..b25561391 100644 --- a/crates/core/src/physical_plan.rs +++ b/crates/core/src/physical_plan.rs @@ -18,7 +18,11 @@ use std::sync::Arc; use datafusion::physical_plan::{ExecutionPlan, ExecutionPlanProperties, displayable}; +use datafusion_distributed::{ + DistributedMetricsFormat, display_plan_ascii, rewrite_distributed_plan_with_metrics, +}; use datafusion_proto::physical_plan::AsExecutionPlan; +use datafusion_python_util::wait_for_future; use prost::Message; use pyo3::exceptions::PyRuntimeError; use pyo3::prelude::*; @@ -69,6 +73,25 @@ impl PyExecutionPlan { format!("{}", d.indent(false)) } + #[pyo3(signature = (metrics_format="none"))] + pub fn display_distributed( + &self, + py: Python<'_>, + metrics_format: &str, + ) -> PyDataFusionResult { + let metrics_format = parse_distributed_metrics_format(metrics_format)?; + let show_metrics = metrics_format.is_some(); + let plan = match metrics_format { + Some(metrics_format) => wait_for_future( + py, + rewrite_distributed_plan_with_metrics(self.plan.clone(), metrics_format), + )??, + None => self.plan.clone(), + }; + + Ok(display_plan_ascii(plan.as_ref(), show_metrics)) + } + #[pyo3(signature = (ctx=None))] pub fn to_bytes<'py>( &'py self, @@ -128,6 +151,19 @@ impl PyExecutionPlan { } } +fn parse_distributed_metrics_format( + format: &str, +) -> PyDataFusionResult> { + match format { + "none" => Ok(None), + "aggregated" => Ok(Some(DistributedMetricsFormat::Aggregated)), + "per_task" => Ok(Some(DistributedMetricsFormat::PerTask)), + _ => Err(crate::errors::PyDataFusionError::Common(format!( + "invalid distributed metrics format {format:?}; expected 'none', 'aggregated', or 'per_task'" + ))), + } +} + impl From for Arc { fn from(plan: PyExecutionPlan) -> Arc { plan.plan.clone() diff --git a/examples/README.md b/examples/README.md index e0e3056d9..c524aae85 100644 --- a/examples/README.md +++ b/examples/README.md @@ -49,6 +49,20 @@ Here is a direct link to the file used in the examples: - [Fan out distinct expressions to a multiprocessing pool](./multiprocessing_pickle_expr.py) - [Distribute expression evaluation across Ray actors](./ray_pickle_expr.py) +### Distributed Query Execution + +- [Start a localhost datafusion-distributed worker](./distributed-localhost-worker.py) +- [Run a distributed query against localhost workers](./distributed-run.py) + +Example: + +```bash +curl -LO https://d37ci6vzurychx.cloudfront.net/trip-data/yellow_tripdata_2021-01.parquet +python examples/distributed-localhost-worker.py 50051 +python examples/distributed-localhost-worker.py 50052 +WORKERS=50051,50052 python examples/distributed-run.py --plan yellow_tripdata_2021-01.parquet +``` + ### Substrait Support - [Serialize query plans using Substrait](./substrait.py) diff --git a/examples/distributed-localhost-worker.py b/examples/distributed-localhost-worker.py new file mode 100644 index 000000000..56ccdbeba --- /dev/null +++ b/examples/distributed-localhost-worker.py @@ -0,0 +1,37 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import sys + +from datafusion import SessionContext, Worker, WorkerQueryContext + + +def build_worker_session(context: WorkerQueryContext) -> SessionContext: + return context.session_context() + + +def main() -> None: + port = int(sys.argv[1]) if len(sys.argv) > 1 else 50051 + worker = Worker.from_session_builder(build_worker_session).with_version( + "python-example" + ) + print(f"Starting datafusion-distributed worker on http://127.0.0.1:{port}") + worker.serve("127.0.0.1", port) + + +if __name__ == "__main__": + main() diff --git a/examples/distributed-run.py b/examples/distributed-run.py new file mode 100644 index 000000000..1e24226d1 --- /dev/null +++ b/examples/distributed-run.py @@ -0,0 +1,79 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import os +from argparse import ArgumentParser +from pathlib import Path + +from datafusion import SessionConfig, SessionContext + +DEFAULT_PARQUET_PATH = "yellow_tripdata_2021-01.parquet" +PARQUET_DOWNLOAD_URL = ( + "https://d37ci6vzurychx.cloudfront.net/trip-data/" + "yellow_tripdata_2021-01.parquet" +) + + +class LocalhostWorkerResolver: + def __init__(self, ports: list[str]) -> None: + self.ports = ports + + def get_urls(self) -> list[str]: + return [f"http://127.0.0.1:{port}" for port in self.ports] + + +def worker_ports_from_env() -> list[str]: + workers = os.environ.get("WORKERS", "") + ports = [port.strip() for port in workers.split(",") if port.strip()] + if not ports: + msg = "Set WORKERS to a comma-separated list of localhost worker ports" + raise RuntimeError(msg) + return ports + + +parser = ArgumentParser() +parser.add_argument("parquet_path", nargs="?", default=DEFAULT_PARQUET_PATH) +parser.add_argument( + "--plan", + action="store_true", + help="print the distributed physical plan instead of running the query", +) +args = parser.parse_args() +parquet_path = args.parquet_path +if "://" not in parquet_path: + local_parquet_path = Path(parquet_path).expanduser() + if not local_parquet_path.exists(): + parser.error( + f"Parquet file {parquet_path!r} was not found. Download the example " + f"data with:\n curl -LO {PARQUET_DOWNLOAD_URL}\n" + "or pass the path to an existing parquet file." + ) + parquet_path = str(local_parquet_path) + +config = SessionConfig().with_distributed( + LocalhostWorkerResolver(worker_ports_from_env()), +) +ctx = SessionContext(config) +ctx.sql("SET distributed.file_scan_config_bytes_per_partition = 1") +ctx.register_parquet("taxi", parquet_path) +df = ctx.sql( + "select passenger_count, count(*) from taxi where passenger_count is not null group by passenger_count order by passenger_count" +) +if args.plan: + print(df.execution_plan().display_distributed()) +else: + df.show() diff --git a/python/datafusion/__init__.py b/python/datafusion/__init__.py index 9c55f446c..c4e2fc604 100644 --- a/python/datafusion/__init__.py +++ b/python/datafusion/__init__.py @@ -65,7 +65,7 @@ import importlib_metadata # type: ignore[import] # Public submodules -from . import functions, ipc, object_store, substrait, unparser +from . import distributed, functions, ipc, object_store, substrait, unparser # The following imports are okay to remain as opaque to the user. from .catalog import ( @@ -81,6 +81,7 @@ SessionConfig, SessionContext, SQLOptions, + WorkerResolver, ) from .dataframe import ( DataFrame, @@ -91,6 +92,7 @@ ParquetWriterOptions, ) from .dataframe_formatter import configure_formatter +from .distributed import Worker, WorkerQueryContext from .expr import Expr, WindowFrame from .io import read_avro, read_csv, read_json, read_parquet from .options import CsvReadOptions @@ -140,11 +142,15 @@ "TableProviderFactoryExportable", "WindowFrame", "WindowUDF", + "Worker", + "WorkerQueryContext", + "WorkerResolver", "catalog", "col", "column", "common", "configure_formatter", + "distributed", "expr", "functions", "ipc", diff --git a/python/datafusion/context.py b/python/datafusion/context.py index 0bfc59bfe..ca3770be8 100644 --- a/python/datafusion/context.py +++ b/python/datafusion/context.py @@ -143,6 +143,22 @@ class PhysicalOptimizerRuleExportable(Protocol): def __datafusion_physical_optimizer_rule__(self) -> object: ... # noqa: D105 +class WorkerResolver(Protocol): + """Type hint for datafusion-distributed worker discovery objects.""" + + def get_urls(self) -> Iterable[str]: + """Return worker URLs available for distributed query execution. + + Examples: + >>> class StaticWorkerResolver: + ... def get_urls(self) -> list[str]: + ... return ["http://127.0.0.1:50051"] + >>> StaticWorkerResolver().get_urls() + ['http://127.0.0.1:50051'] + """ + ... + + class SessionConfig: """Session configuration options.""" @@ -348,6 +364,29 @@ def with_extension(self, extension: Any) -> SessionConfig: self.config_internal = self.config_internal.with_extension(extension) return self + def with_distributed( + self, + worker_resolver: WorkerResolver, + ) -> SessionConfig: + """Enable datafusion-distributed planning with a worker resolver. + + Args: + worker_resolver: Object whose ``get_urls()`` method returns worker + URL strings. + + Returns: + A new :py:class:`SessionConfig` object with distributed planning enabled. + + Examples: + >>> from datafusion import SessionConfig + >>> class StaticWorkerResolver: + ... def get_urls(self) -> list[str]: + ... return ["http://127.0.0.1:50051"] + >>> config = SessionConfig().with_distributed(StaticWorkerResolver()) + """ + self.config_internal = self.config_internal.with_distributed(worker_resolver) + return self + class RuntimeEnvBuilder: """Runtime configuration options.""" diff --git a/python/datafusion/distributed.py b/python/datafusion/distributed.py new file mode 100644 index 000000000..a8c4a12b0 --- /dev/null +++ b/python/datafusion/distributed.py @@ -0,0 +1,175 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Bindings for datafusion-distributed workers.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from datafusion.context import SessionContext + +from ._internal import Worker as WorkerInternal +from ._internal import WorkerQueryContext as WorkerQueryContextInternal + +if TYPE_CHECKING: + from collections.abc import Awaitable, Callable + + +class WorkerQueryContext: + """Context passed to :class:`WorkerSessionBuilder` callbacks.""" + + def __init__(self, context: WorkerQueryContextInternal) -> None: + """Wrap the internal worker query context. + + This is created by DataFusion when a worker receives a query; user code + normally only sees it as the argument to ``build_session_state``. + """ + self._raw = context + + @property + def headers(self) -> dict[str, str]: + """Return incoming gRPC request headers for the worker query.""" + return dict(self._raw.headers) + + def session_context(self) -> SessionContext: + """Build a public :class:`SessionContext` from the upstream worker builder. + + The upstream builder is consumed, so this method can only be called once + for each ``WorkerQueryContext``. + + Examples: + >>> from datafusion import SessionContext + >>> from datafusion.distributed import WorkerQueryContext + >>> def build_session_state( + ... context: WorkerQueryContext, + ... ) -> SessionContext: + ... return context.session_context() + """ + wrapper = SessionContext.__new__(SessionContext) + wrapper.ctx = self._raw.session_context() + return wrapper + + +def _wrap_worker_session_builder( + callback: Callable[[WorkerQueryContext], SessionContext], +) -> Callable[[WorkerQueryContextInternal], SessionContext]: + def adapter(context: WorkerQueryContextInternal) -> SessionContext: + wrapped_context = WorkerQueryContext(context) + return callback(wrapped_context) + + return adapter + + +class Worker: + """A datafusion-distributed worker service.""" + + def __init__( + self, + session_builder: Callable[[WorkerQueryContext], SessionContext] | None = None, + ) -> None: + """Create a worker. + + Args: + session_builder: Optional custom session builder callback or object. + + Examples: + >>> from datafusion import Worker + >>> worker = Worker() + >>> isinstance(worker, Worker) + True + """ + if session_builder is None: + self._raw = WorkerInternal() + else: + if not callable(session_builder): + msg = "Expected session_builder to be callable" + raise TypeError(msg) + adapter = _wrap_worker_session_builder(session_builder) + self._raw = WorkerInternal.from_session_builder(adapter) + + @classmethod + def from_session_builder( + cls, + session_builder: Callable[[WorkerQueryContext], SessionContext], + ) -> Worker: + """Create a worker with a custom session builder. + + Examples: + >>> from datafusion import SessionContext, Worker + >>> from datafusion.distributed import WorkerQueryContext + >>> def build_session_state( + ... context: WorkerQueryContext, + ... ) -> SessionContext: + ... return context.session_context() + >>> worker = Worker.from_session_builder(build_session_state) + >>> isinstance(worker, Worker) + True + """ + return cls(session_builder=session_builder) + + def with_version(self, version: str) -> Worker: + """Set the worker version string returned by the worker service. + + Examples: + >>> from datafusion import Worker + >>> Worker().with_version("local").__class__ is Worker + True + """ + self._raw = self._raw.with_version(version) + return self + + def with_max_message_size(self, size: int) -> Worker: + """Set the maximum FlightData chunk size for this worker. + + Examples: + >>> from datafusion import Worker + >>> Worker().with_max_message_size(1024).__class__ is Worker + True + """ + self._raw = self._raw.with_max_message_size(size) + return self + + def serve(self, host: str = "127.0.0.1", port: int = 50051) -> None: + """Run the worker service on a tonic server until it is stopped. + + Examples: + >>> from datafusion import Worker + >>> Worker().serve("127.0.0.1", 50051) # doctest: +SKIP + """ + self._raw.serve(host, port) + + def serve_async( + self, host: str = "127.0.0.1", port: int = 50051 + ) -> Awaitable[None]: + """Return an awaitable that serves this worker. + + Examples: + >>> import asyncio + >>> from datafusion import Worker + >>> async def main() -> None: + ... task = asyncio.create_task(Worker().serve_async()) + ... task.cancel() + >>> asyncio.run(main()) # doctest: +SKIP + """ + return self._raw.serve_async(host, port) + + +__all__ = [ + "Worker", + "WorkerQueryContext", +] diff --git a/python/datafusion/plan.py b/python/datafusion/plan.py index b2c6eab3e..7245e99c6 100644 --- a/python/datafusion/plan.py +++ b/python/datafusion/plan.py @@ -20,7 +20,7 @@ from __future__ import annotations import warnings -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Literal import datafusion._internal as df_internal @@ -160,6 +160,29 @@ def display_indent(self) -> str: """Print an indented form of the physical plan.""" return self._raw_plan.display_indent() + def display_distributed( + self, + metrics_format: Literal["none", "aggregated", "per_task"] = "none", + ) -> str: + """Print the physical plan with datafusion-distributed formatting. + + Args: + metrics_format: ``"none"`` prints the plan without metrics. + ``"aggregated"`` and ``"per_task"`` include execution metrics. + For distributed plans, metrics are first collected from workers. + The plan must have already been executed when metrics are requested. + + Examples: + >>> from datafusion import SessionContext + >>> ctx = SessionContext() + >>> plan = ctx.sql("SELECT 1").execution_plan() + >>> isinstance(plan.display_distributed(), str) + True + >>> isinstance(plan.display_distributed(metrics_format="aggregated"), str) + True + """ + return self._raw_plan.display_distributed(metrics_format) + def __repr__(self) -> str: """Print a string representation of the physical plan.""" return self._raw_plan.__repr__() diff --git a/python/tests/test_context.py b/python/tests/test_context.py index 112a6fd7b..58f3536d4 100644 --- a/python/tests/test_context.py +++ b/python/tests/test_context.py @@ -29,6 +29,7 @@ SessionContext, SQLOptions, Table, + WorkerResolver, column, literal, udf, @@ -99,6 +100,54 @@ def test_create_context_with_all_valid_args(): ctx.catalog("datafusion") +def test_create_context_with_distributed_worker_resolver(): + class StaticWorkerResolver: + def __init__(self) -> None: + self.calls = 0 + + def get_urls(self) -> list[str]: + self.calls += 1 + return ["http://localhost:50051", "http://localhost:50052"] + + resolver: WorkerResolver = StaticWorkerResolver() + config = SessionConfig().with_distributed(resolver) + ctx = SessionContext(config) + + ctx.sql("SELECT 1").execution_plan() + + assert isinstance(resolver, StaticWorkerResolver) + assert resolver.calls > 0 + + +def test_distributed_worker_resolver_does_not_accept_file_scan_config_argument(): + class StaticWorkerResolver: + def get_urls(self) -> list[str]: + return ["http://localhost:50051"] + + with pytest.raises(TypeError): + SessionConfig().with_distributed(StaticWorkerResolver(), 1) # type: ignore[call-arg] + + +def test_distributed_worker_resolver_requires_callable_get_urls(): + class InvalidWorkerResolver: + get_urls = "not-callable" + + with pytest.raises(TypeError, match="get_urls"): + SessionConfig().with_distributed(InvalidWorkerResolver()) + + +def test_distributed_worker_resolver_requires_valid_urls(): + class InvalidWorkerResolver: + def get_urls(self) -> list[str]: + return ["not-a-url"] + + config = SessionConfig().with_distributed(InvalidWorkerResolver()) + ctx = SessionContext(config) + + with pytest.raises(Exception, match="invalid URL"): + ctx.sql("SELECT 1").execution_plan() + + def test_register_record_batches(ctx): # create a RecordBatch and register it as memtable batch = pa.RecordBatch.from_arrays( diff --git a/python/tests/test_distributed.py b/python/tests/test_distributed.py new file mode 100644 index 000000000..8935cd56d --- /dev/null +++ b/python/tests/test_distributed.py @@ -0,0 +1,48 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest +from datafusion import SessionContext, Worker, WorkerQueryContext + + +def test_create_worker(): + worker = Worker() + + assert worker.with_version("local") is worker + assert worker.with_max_message_size(1024) is worker + + +def test_create_worker_with_callable_session_builder(): + def build_session_state(context: WorkerQueryContext) -> SessionContext: + return context.session_context() + + worker = Worker.from_session_builder(build_session_state) + + assert isinstance(worker, Worker) + + +def test_worker_session_builder_requires_callable(): + class InvalidWorkerSessionBuilder: + build_session_state = "not-callable" + + with pytest.raises(TypeError, match="callable"): + Worker.from_session_builder(InvalidWorkerSessionBuilder()) + + +def test_worker_serve_rejects_invalid_bind_address(): + with pytest.raises(Exception, match="invalid worker bind address"): + Worker().serve("not-an-address", 50051) diff --git a/python/tests/test_imports.py b/python/tests/test_imports.py index fea4cc91f..333ffa830 100644 --- a/python/tests/test_imports.py +++ b/python/tests/test_imports.py @@ -24,6 +24,9 @@ SessionContext, TableProviderFactory, TableProviderFactoryExportable, + Worker, + WorkerQueryContext, + WorkerResolver, functions, ) from datafusion.common import ( @@ -98,9 +101,17 @@ def test_class_module_is_datafusion(): # context for klass in [ SessionContext, + WorkerResolver, ]: assert klass.__module__ == "datafusion.context" + # distributed + for klass in [ + Worker, + WorkerQueryContext, + ]: + assert klass.__module__ == "datafusion.distributed" + # dataframe for klass in [ DataFrame, diff --git a/python/tests/test_plans.py b/python/tests/test_plans.py index 11e709f6b..9527ee565 100644 --- a/python/tests/test_plans.py +++ b/python/tests/test_plans.py @@ -54,6 +54,24 @@ def test_execution_plan_to_bytes_roundtrip(ctx, df) -> None: assert str(original_execution_plan) == str(execution_plan) +@pytest.mark.parametrize("metrics_format", ["none", "aggregated", "per_task"]) +def test_execution_plan_display_distributed(ctx, metrics_format) -> None: + text = ( + ctx.sql("SELECT 1") + .execution_plan() + .display_distributed(metrics_format=metrics_format) + ) + + assert text + + +def test_execution_plan_display_distributed_validates_metrics_format(ctx) -> None: + plan = ctx.sql("SELECT 1").execution_plan() + + with pytest.raises(Exception, match="metrics format"): + plan.display_distributed(metrics_format="invalid") # type: ignore[arg-type] + + def test_logical_plan_to_proto_is_deprecated(ctx, df) -> None: """to_proto / from_proto still work but emit DeprecationWarning.""" plan = df.logical_plan()