From 62e1357dfcb7b694df5eb8585d19dd73788fc7fa Mon Sep 17 00:00:00 2001 From: Selim Acerbas <91225118+selimacerbas@users.noreply.github.com> Date: Sat, 27 Jun 2026 11:32:33 +0200 Subject: [PATCH] fix(pydantic): support pydantic-ai v2 while keeping v1 (follow-up to #201) restate.ext.pydantic was written against pydantic-ai v1 and crashes on import under pydantic-ai-slim 2.0.0: from restate.ext.pydantic import RestateAgent -> ModuleNotFoundError: No module named 'pydantic_ai.builtin_tools' The v1 -> v2 changes that break it, and how each is handled (all kept v1-compatible): * pydantic_ai.builtin_tools deleted -> pydantic_ai.native_tools; AbstractBuiltinTool -> AbstractNativeTool, tools.BuiltinToolFunc -> tools.NativeToolFunc. Import the v2 names first, fall back to the v1 module. * pydantic_ai.mcp.MCPServer renamed to pydantic_ai.mcp.MCPToolset. The two are *separate* classes (1.x concrete servers subclass MCPServer; 2.0 only has MCPToolset), so detection matches against the tuple of MCP base classes present in the installed version rather than a single aliased name. * The run-time builtin_tools= argument (and its **_deprecated_kwargs shim) was removed from Agent.run in v2; native tools now go via capabilities=[NativeTool(tool), ...]. RestateAgent.run translates builtin_tools -> capabilities when the NativeTool capability is available (1.68+), falling back to forwarding builtin_tools= on the oldest 1.x. The RestateAgent(event_stream_handler=...) public API is unchanged: the handler was only ever passed at run time (still accepted in v2) and stored on RestateModelWrapper, never to a pydantic-ai constructor. Verified: import + RestateAgent construction + event_stream_handler flow + run() kwarg translation + MCP-toolset wrapping all pass on both 1.107.0 and 2.0.0. ruff clean; pyright introduces no new errors. --- python/restate/ext/pydantic/_agent.py | 64 ++++++++++++++++++++----- python/restate/ext/pydantic/_toolset.py | 35 ++++++++++++-- 2 files changed, 82 insertions(+), 17 deletions(-) diff --git a/python/restate/ext/pydantic/_agent.py b/python/restate/ext/pydantic/_agent.py index 29002e9..bdc5e23 100644 --- a/python/restate/ext/pydantic/_agent.py +++ b/python/restate/ext/pydantic/_agent.py @@ -12,7 +12,6 @@ from pydantic_ai._run_context import AgentDepsT from pydantic_ai.agent.abstract import AbstractAgent, AgentMetadata, EventStreamHandler, RunOutputDataT, Instructions from pydantic_ai.agent.wrapper import WrapperAgent -from pydantic_ai.builtin_tools import AbstractBuiltinTool from pydantic_ai.exceptions import UserError from pydantic_ai.messages import AgentStreamEvent, ModelMessage, UserContent from pydantic_ai.models import Model @@ -20,13 +19,37 @@ from pydantic_ai.result import StreamedRunResult from pydantic_ai.run import AgentRunResult from pydantic_ai.settings import ModelSettings -from pydantic_ai.tools import DeferredToolResults, RunContext, BuiltinToolFunc +from pydantic_ai.tools import DeferredToolResults, RunContext + +# pydantic-ai >= 2.0 renamed `builtin_tools` -> `native_tools` (and the matching +# `BuiltinToolFunc` -> `NativeToolFunc`). The new names are also available on the +# pydantic-ai 1.x deprecation shim, so import the v2 names first and fall back to the +# v1 module only when running against an older 1.x release that predates the rename. +try: + from pydantic_ai.native_tools import AbstractNativeTool + from pydantic_ai.tools import NativeToolFunc +except ImportError: # pragma: no cover - pydantic-ai < the native_tools rename + from pydantic_ai.builtin_tools import ( # type: ignore[import-not-found, no-redef] + AbstractBuiltinTool as AbstractNativeTool, + ) + from pydantic_ai.tools import BuiltinToolFunc as NativeToolFunc # type: ignore[attr-defined, no-redef] from pydantic_ai.toolsets.abstract import AbstractToolset from pydantic_ai.toolsets.function import FunctionToolset from pydantic_ai.usage import RunUsage, UsageLimits from ._model import RestateModelWrapper from ._toolset import RestateContextRunToolSet +# pydantic-ai >= 2.0 removed the `builtin_tools=` run-time argument (and its deprecation +# shim): native tools must now be supplied as `capabilities=[NativeTool(tool), ...]`. +# pydantic-ai 1.68+ already exposes the `NativeTool` capability alongside the legacy +# `builtin_tools=` argument, so when `NativeTool` is importable we translate uniformly +# (works on both lines); otherwise we fall back to forwarding `builtin_tools=` for the +# oldest supported 1.x releases that predate the capability. +try: + from pydantic_ai.capabilities import NativeTool +except ImportError: # pragma: no cover - pydantic-ai < the NativeTool capability + NativeTool = None # type: ignore[assignment, misc] + class RestateAgent(WrapperAgent[AgentDepsT, OutputDataT]): """An agent that integrates with Restate framework for building resilient applications. @@ -111,13 +134,15 @@ def set_context(toolset: AbstractToolset[AgentDepsT]) -> AbstractToolset[AgentDe if isinstance(toolset, FunctionToolset) and auto_wrap_tools: return RestateContextRunToolSet(toolset, run_options) try: - from pydantic_ai.mcp import MCPServer - - from ._toolset import RestateMCPServer + # `MCP_TOOLSET_CLASSES` is the tuple of MCP toolset base classes present + # in the installed pydantic-ai (`MCPServer` on 1.x, `MCPToolset` on >= 2.0; + # the concrete servers subclass different bases across the rename), resolved + # once in `_toolset` so the version detection lives in a single place. + from ._toolset import MCP_TOOLSET_CLASSES, RestateMCPServer except ImportError: pass else: - if isinstance(toolset, MCPServer): + if MCP_TOOLSET_CLASSES and isinstance(toolset, MCP_TOOLSET_CLASSES): return RestateMCPServer(toolset, run_options) return toolset @@ -184,7 +209,7 @@ async def run( metadata: AgentMetadata[AgentDepsT] | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, - builtin_tools: Sequence[AbstractBuiltinTool | BuiltinToolFunc[AgentDepsT]] | None = None, + builtin_tools: Sequence[AbstractNativeTool | NativeToolFunc[AgentDepsT]] | None = None, event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, ) -> AgentRunResult[OutputDataT]: ... @@ -205,7 +230,7 @@ async def run( metadata: AgentMetadata[AgentDepsT] | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, - builtin_tools: Sequence[AbstractBuiltinTool | BuiltinToolFunc[AgentDepsT]] | None = None, + builtin_tools: Sequence[AbstractNativeTool | NativeToolFunc[AgentDepsT]] | None = None, event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, ) -> AgentRunResult[RunOutputDataT]: ... @@ -225,7 +250,7 @@ async def run( metadata: AgentMetadata[AgentDepsT] | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, - builtin_tools: Sequence[AbstractBuiltinTool | BuiltinToolFunc[AgentDepsT]] | None = None, + builtin_tools: Sequence[AbstractNativeTool | NativeToolFunc[AgentDepsT]] | None = None, event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, ) -> AgentRunResult[Any]: """Run the agent with a user prompt in async mode. @@ -269,6 +294,19 @@ async def main(): raise TerminalError( "An agent needs to have a `model` in order to be used with Restate, it cannot be set at agent run time." ) + + # pydantic-ai >= 2.0 removed the `builtin_tools=` run-time argument; native tools + # must be supplied through `capabilities=[NativeTool(tool), ...]`. When the + # `NativeTool` capability is available (pydantic-ai 1.68+) translate uniformly so + # the same code path works on both lines; only the oldest 1.x releases that + # predate the capability still take the legacy `builtin_tools=` argument. + forward_kwargs: dict[str, Any] = {} + if NativeTool is not None: + if builtin_tools: + forward_kwargs["capabilities"] = [NativeTool(tool) for tool in builtin_tools] + elif builtin_tools is not None: # pragma: no cover - pydantic-ai < the NativeTool capability + forward_kwargs["builtin_tools"] = builtin_tools + with self._restate_overrides(): return await super(WrapperAgent, self).run( user_prompt=user_prompt, @@ -284,8 +322,8 @@ async def main(): metadata=metadata, infer_name=infer_name, toolsets=toolsets, - builtin_tools=builtin_tools, event_stream_handler=event_stream_handler, + **forward_kwargs, ) @overload @@ -305,7 +343,7 @@ def run_stream( metadata: AgentMetadata[AgentDepsT] | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, - builtin_tools: Sequence[AbstractBuiltinTool | BuiltinToolFunc[AgentDepsT]] | None = None, + builtin_tools: Sequence[AbstractNativeTool | NativeToolFunc[AgentDepsT]] | None = None, event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, ) -> AbstractAsyncContextManager[StreamedRunResult[AgentDepsT, OutputDataT]]: ... @@ -326,7 +364,7 @@ def run_stream( metadata: AgentMetadata[AgentDepsT] | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, - builtin_tools: Sequence[AbstractBuiltinTool | BuiltinToolFunc[AgentDepsT]] | None = None, + builtin_tools: Sequence[AbstractNativeTool | NativeToolFunc[AgentDepsT]] | None = None, event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, ) -> AbstractAsyncContextManager[StreamedRunResult[AgentDepsT, RunOutputDataT]]: ... @@ -347,7 +385,7 @@ async def run_stream( metadata: AgentMetadata[AgentDepsT] | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, - builtin_tools: Sequence[AbstractBuiltinTool | BuiltinToolFunc[AgentDepsT]] | None = None, + builtin_tools: Sequence[AbstractNativeTool | NativeToolFunc[AgentDepsT]] | None = None, event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, **_deprecated_kwargs: Any, ) -> AsyncIterator[StreamedRunResult[AgentDepsT, Any]]: diff --git a/python/restate/ext/pydantic/_toolset.py b/python/restate/ext/pydantic/_toolset.py index 74473bf..79ebf8e 100644 --- a/python/restate/ext/pydantic/_toolset.py +++ b/python/restate/ext/pydantic/_toolset.py @@ -11,7 +11,7 @@ from pydantic_ai import ToolDefinition from pydantic_ai._run_context import AgentDepsT from pydantic_ai.exceptions import ApprovalRequired, CallDeferred, ModelRetry, UserError -from pydantic_ai.mcp import MCPServer, ToolResult +from pydantic_ai.mcp import ToolResult # unchanged across the v1 -> v2 rename from pydantic_ai.tools import RunContext from pydantic_ai.toolsets.abstract import AbstractToolset, ToolsetTool from pydantic_ai.toolsets.wrapper import WrapperToolset @@ -19,6 +19,30 @@ from ._serde import PydanticTypeAdapter from ._utils import current_state +# pydantic-ai renamed the MCP server toolset `pydantic_ai.mcp.MCPServer` -> +# `pydantic_ai.mcp.MCPToolset`. Crucially the two are *separate* classes (both subclass +# `AbstractToolset`, neither inherits the other): on pydantic-ai 1.x the concrete servers +# (`MCPServerStdio`/`MCPServerSSE`/`MCPServerStreamableHTTP`) subclass `MCPServer`, while +# on >= 2.0 only `MCPToolset` remains. We therefore match against *all* of the MCP base +# classes that exist in the installed version (`isinstance` accepts a tuple), so an MCP +# toolset is detected regardless of which base name its concrete class derives from. +_mcp_bases: list[type] = [] +try: + from pydantic_ai.mcp import MCPServer # type: ignore[attr-defined] +except ImportError: # pragma: no cover - pydantic-ai dropped the legacy MCPServer name + pass +else: + _mcp_bases.append(MCPServer) +try: + from pydantic_ai.mcp import MCPToolset +except ImportError: # pragma: no cover - pydantic-ai < the MCPToolset rename + pass +else: + _mcp_bases.append(MCPToolset) + +# Tuple for runtime `isinstance` checks (matches whichever MCP bases are installed). +MCP_TOOLSET_CLASSES: tuple[type, ...] = tuple(_mcp_bases) + @dataclass class RestateContextRunResult: @@ -119,7 +143,7 @@ def visit_and_replace( class RestateMCPServer(WrapperToolset[AgentDepsT]): """A wrapper for MCPServer that integrates with restate.""" - def __init__(self, wrapped: MCPServer, run_options: RunOptions): + def __init__(self, wrapped: AbstractToolset[AgentDepsT], run_options: RunOptions): super().__init__(wrapped) self._wrapped = wrapped self.get_tools_options = replace(run_options, serde=MCP_GET_TOOLS_SERDE) @@ -151,8 +175,11 @@ async def get_tools_in_context() -> RestateMCPGetToolsContextRunResult: raise Exception("Internal error during get_tools call") from e def tool_for_tool_def(self, tool_def: ToolDefinition) -> ToolsetTool[AgentDepsT]: - assert isinstance(self.wrapped, MCPServer) - return self.wrapped.tool_for_tool_def(tool_def) + assert isinstance(self.wrapped, MCP_TOOLSET_CLASSES) + # `MCP_TOOLSET_CLASSES` is a runtime tuple, so it cannot narrow `self.wrapped` + # (typed `AbstractToolset` by `WrapperToolset`) for the type-checker; the assert + # above guarantees the concrete MCP toolset method is present. + return self.wrapped.tool_for_tool_def(tool_def) # type: ignore[attr-defined] async def call_tool( self,