diff --git a/sqlmesh/core/config/connection.py b/sqlmesh/core/config/connection.py index 3c13035220..33f31e7b8d 100644 --- a/sqlmesh/core/config/connection.py +++ b/sqlmesh/core/config/connection.py @@ -12,7 +12,7 @@ from functools import partial import pydantic -from pydantic import Field +from pydantic import Field, computed_field from pydantic_core import from_json from packaging import version from sqlglot import exp @@ -108,7 +108,10 @@ class ConnectionConfig(abc.ABC, BaseConfig): catalog_type_overrides: t.Optional[t.Dict[str, str]] = None # Whether to share a single connection across threads or create a new connection per thread. - shared_connection: t.ClassVar[bool] = False + @computed_field + @property + def shared_connection(self) -> bool: + return False @property @abc.abstractmethod @@ -309,7 +312,10 @@ class BaseDuckDBConnectionConfig(ConnectionConfig): token: t.Optional[str] = None - shared_connection: t.ClassVar[bool] = True + @computed_field + @property + def shared_connection(self) -> bool: + return True _data_file_to_adapter: t.ClassVar[t.Dict[str, EngineAdapter]] = {} @@ -818,11 +824,15 @@ class DatabricksConnectionConfig(ConnectionConfig): DISPLAY_NAME: t.ClassVar[t.Literal["Databricks"]] = "Databricks" DISPLAY_ORDER: t.ClassVar[t.Literal[3]] = 3 - shared_connection: t.ClassVar[bool] = True - _concurrent_tasks_validator = concurrent_tasks_validator _http_headers_validator = http_headers_validator + @computed_field + @property + def shared_connection(self) -> bool: + """The connection should only be shared if U2M OAuth is being used""" + return self.auth_type is not None and self.oauth_client_id is None + @model_validator(mode="before") def _databricks_connect_validator(cls, data: t.Any) -> t.Any: # SQLQueryContextLogger will output any error SQL queries even if they are in a try/except block. diff --git a/tests/core/test_connection_config.py b/tests/core/test_connection_config.py index b0ea640819..d525fb3e1b 100644 --- a/tests/core/test_connection_config.py +++ b/tests/core/test_connection_config.py @@ -1424,18 +1424,19 @@ def test_databricks(make_config): ) -def test_databricks_shared_connection(make_config): - """Databricks should use a shared connection pool to prevent OAuth CSRF races. +def test_databricks__u2m_oauth__shared_connection_pool(make_config): + """Databricks should use a shared connection pool when using OAuth to prevent CSRF races. When concurrent_tasks > 1, ThreadLocalConnectionPool creates one connection per thread. For U2M OAuth, each thread triggers its own browser-based OAuth flow; these race on the CSRF state parameter and cause MismatchingStateError. - Setting shared_connection = True causes ThreadLocalSharedConnectionPool to be - used instead: a single connection is created (behind a lock) and each thread - gets its own cursor, so only one OAuth flow is ever initiated. + For non-U2M OAuth authentication types (e.g. access_token and M2M OAuth) then + ThreadLocalConnectionPool should still be used. - See: https://github.com/tobymao/sqlmesh/issues/5646 + See: + https://github.com/tobymao/sqlmesh/issues/5646 + https://github.com/SQLMesh/sqlmesh/issues/5858 """ from sqlmesh.utils.connection_pool import ThreadLocalSharedConnectionPool @@ -1443,7 +1444,7 @@ def test_databricks_shared_connection(make_config): type="databricks", server_hostname="dbc-test.cloud.databricks.com", http_path="sql/test/foo", - access_token="test-token", + auth_type="databricks-oauth", concurrent_tasks=4, ) assert isinstance(config, DatabricksConnectionConfig) @@ -1453,6 +1454,41 @@ def test_databricks_shared_connection(make_config): assert isinstance(adapter._connection_pool, ThreadLocalSharedConnectionPool) +def test_databricks__m2m_oauth__connection_pool(make_config): + from sqlmesh.utils.connection_pool import ThreadLocalConnectionPool + + config = make_config( + type="databricks", + server_hostname="dbc-test.cloud.databricks.com", + http_path="sql/test/foo", + auth_type="databricks-oauth", + oauth_client_id="oauth_client_id", + concurrent_tasks=4, + ) + assert isinstance(config, DatabricksConnectionConfig) + assert config.shared_connection is False + + adapter = config.create_engine_adapter() + assert isinstance(adapter._connection_pool, ThreadLocalConnectionPool) + + +def test_databricks__access_token__connection_pool(make_config): + from sqlmesh.utils.connection_pool import ThreadLocalConnectionPool + + config = make_config( + type="databricks", + server_hostname="dbc-test.cloud.databricks.com", + http_path="sql/test/foo", + access_token="any-token", + concurrent_tasks=4, + ) + assert isinstance(config, DatabricksConnectionConfig) + assert config.shared_connection is False + + adapter = config.create_engine_adapter() + assert isinstance(adapter._connection_pool, ThreadLocalConnectionPool) + + def test_engine_import_validator(): with pytest.raises( ConfigError,