Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 51 additions & 13 deletions python/restate/ext/pydantic/_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,44 @@
from pydantic_ai._run_context import AgentDepsT
from pydantic_ai.agent.abstract import AbstractAgent, AgentMetadata, EventStreamHandler, RunOutputDataT, Instructions
from pydantic_ai.agent.wrapper import WrapperAgent
from pydantic_ai.builtin_tools import AbstractBuiltinTool
from pydantic_ai.exceptions import UserError
from pydantic_ai.messages import AgentStreamEvent, ModelMessage, UserContent
from pydantic_ai.models import Model
from pydantic_ai.output import OutputDataT, OutputSpec
from pydantic_ai.result import StreamedRunResult
from pydantic_ai.run import AgentRunResult
from pydantic_ai.settings import ModelSettings
from pydantic_ai.tools import DeferredToolResults, RunContext, BuiltinToolFunc
from pydantic_ai.tools import DeferredToolResults, RunContext

# pydantic-ai >= 2.0 renamed `builtin_tools` -> `native_tools` (and the matching
# `BuiltinToolFunc` -> `NativeToolFunc`). The new names are also available on the
# pydantic-ai 1.x deprecation shim, so import the v2 names first and fall back to the
# v1 module only when running against an older 1.x release that predates the rename.
try:
from pydantic_ai.native_tools import AbstractNativeTool
from pydantic_ai.tools import NativeToolFunc
except ImportError: # pragma: no cover - pydantic-ai < the native_tools rename
from pydantic_ai.builtin_tools import ( # type: ignore[import-not-found, no-redef]
AbstractBuiltinTool as AbstractNativeTool,
)
from pydantic_ai.tools import BuiltinToolFunc as NativeToolFunc # type: ignore[attr-defined, no-redef]
from pydantic_ai.toolsets.abstract import AbstractToolset
from pydantic_ai.toolsets.function import FunctionToolset
from pydantic_ai.usage import RunUsage, UsageLimits
from ._model import RestateModelWrapper
from ._toolset import RestateContextRunToolSet

# pydantic-ai >= 2.0 removed the `builtin_tools=` run-time argument (and its deprecation
# shim): native tools must now be supplied as `capabilities=[NativeTool(tool), ...]`.
# pydantic-ai 1.68+ already exposes the `NativeTool` capability alongside the legacy
# `builtin_tools=` argument, so when `NativeTool` is importable we translate uniformly
# (works on both lines); otherwise we fall back to forwarding `builtin_tools=` for the
# oldest supported 1.x releases that predate the capability.
try:
from pydantic_ai.capabilities import NativeTool
except ImportError: # pragma: no cover - pydantic-ai < the NativeTool capability
NativeTool = None # type: ignore[assignment, misc]


class RestateAgent(WrapperAgent[AgentDepsT, OutputDataT]):
"""An agent that integrates with Restate framework for building resilient applications.
Expand Down Expand Up @@ -111,13 +134,15 @@ def set_context(toolset: AbstractToolset[AgentDepsT]) -> AbstractToolset[AgentDe
if isinstance(toolset, FunctionToolset) and auto_wrap_tools:
return RestateContextRunToolSet(toolset, run_options)
try:
from pydantic_ai.mcp import MCPServer

from ._toolset import RestateMCPServer
# `MCP_TOOLSET_CLASSES` is the tuple of MCP toolset base classes present
# in the installed pydantic-ai (`MCPServer` on 1.x, `MCPToolset` on >= 2.0;
# the concrete servers subclass different bases across the rename), resolved
# once in `_toolset` so the version detection lives in a single place.
from ._toolset import MCP_TOOLSET_CLASSES, RestateMCPServer
except ImportError:
pass
else:
if isinstance(toolset, MCPServer):
if MCP_TOOLSET_CLASSES and isinstance(toolset, MCP_TOOLSET_CLASSES):
return RestateMCPServer(toolset, run_options)

return toolset
Expand Down Expand Up @@ -184,7 +209,7 @@ async def run(
metadata: AgentMetadata[AgentDepsT] | None = None,
infer_name: bool = True,
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
builtin_tools: Sequence[AbstractBuiltinTool | BuiltinToolFunc[AgentDepsT]] | None = None,
builtin_tools: Sequence[AbstractNativeTool | NativeToolFunc[AgentDepsT]] | None = None,
event_stream_handler: EventStreamHandler[AgentDepsT] | None = None,
) -> AgentRunResult[OutputDataT]: ...

Expand All @@ -205,7 +230,7 @@ async def run(
metadata: AgentMetadata[AgentDepsT] | None = None,
infer_name: bool = True,
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
builtin_tools: Sequence[AbstractBuiltinTool | BuiltinToolFunc[AgentDepsT]] | None = None,
builtin_tools: Sequence[AbstractNativeTool | NativeToolFunc[AgentDepsT]] | None = None,
event_stream_handler: EventStreamHandler[AgentDepsT] | None = None,
) -> AgentRunResult[RunOutputDataT]: ...

Expand All @@ -225,7 +250,7 @@ async def run(
metadata: AgentMetadata[AgentDepsT] | None = None,
infer_name: bool = True,
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
builtin_tools: Sequence[AbstractBuiltinTool | BuiltinToolFunc[AgentDepsT]] | None = None,
builtin_tools: Sequence[AbstractNativeTool | NativeToolFunc[AgentDepsT]] | None = None,
event_stream_handler: EventStreamHandler[AgentDepsT] | None = None,
) -> AgentRunResult[Any]:
"""Run the agent with a user prompt in async mode.
Expand Down Expand Up @@ -269,6 +294,19 @@ async def main():
raise TerminalError(
"An agent needs to have a `model` in order to be used with Restate, it cannot be set at agent run time."
)

# pydantic-ai >= 2.0 removed the `builtin_tools=` run-time argument; native tools
# must be supplied through `capabilities=[NativeTool(tool), ...]`. When the
# `NativeTool` capability is available (pydantic-ai 1.68+) translate uniformly so
# the same code path works on both lines; only the oldest 1.x releases that
# predate the capability still take the legacy `builtin_tools=` argument.
forward_kwargs: dict[str, Any] = {}
if NativeTool is not None:
if builtin_tools:
forward_kwargs["capabilities"] = [NativeTool(tool) for tool in builtin_tools]
elif builtin_tools is not None: # pragma: no cover - pydantic-ai < the NativeTool capability
forward_kwargs["builtin_tools"] = builtin_tools

with self._restate_overrides():
return await super(WrapperAgent, self).run(
user_prompt=user_prompt,
Expand All @@ -284,8 +322,8 @@ async def main():
metadata=metadata,
infer_name=infer_name,
toolsets=toolsets,
builtin_tools=builtin_tools,
event_stream_handler=event_stream_handler,
**forward_kwargs,
)

@overload
Expand All @@ -305,7 +343,7 @@ def run_stream(
metadata: AgentMetadata[AgentDepsT] | None = None,
infer_name: bool = True,
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
builtin_tools: Sequence[AbstractBuiltinTool | BuiltinToolFunc[AgentDepsT]] | None = None,
builtin_tools: Sequence[AbstractNativeTool | NativeToolFunc[AgentDepsT]] | None = None,
event_stream_handler: EventStreamHandler[AgentDepsT] | None = None,
) -> AbstractAsyncContextManager[StreamedRunResult[AgentDepsT, OutputDataT]]: ...

Expand All @@ -326,7 +364,7 @@ def run_stream(
metadata: AgentMetadata[AgentDepsT] | None = None,
infer_name: bool = True,
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
builtin_tools: Sequence[AbstractBuiltinTool | BuiltinToolFunc[AgentDepsT]] | None = None,
builtin_tools: Sequence[AbstractNativeTool | NativeToolFunc[AgentDepsT]] | None = None,
event_stream_handler: EventStreamHandler[AgentDepsT] | None = None,
) -> AbstractAsyncContextManager[StreamedRunResult[AgentDepsT, RunOutputDataT]]: ...

Expand All @@ -347,7 +385,7 @@ async def run_stream(
metadata: AgentMetadata[AgentDepsT] | None = None,
infer_name: bool = True,
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
builtin_tools: Sequence[AbstractBuiltinTool | BuiltinToolFunc[AgentDepsT]] | None = None,
builtin_tools: Sequence[AbstractNativeTool | NativeToolFunc[AgentDepsT]] | None = None,
event_stream_handler: EventStreamHandler[AgentDepsT] | None = None,
**_deprecated_kwargs: Any,
) -> AsyncIterator[StreamedRunResult[AgentDepsT, Any]]:
Expand Down
35 changes: 31 additions & 4 deletions python/restate/ext/pydantic/_toolset.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,38 @@
from pydantic_ai import ToolDefinition
from pydantic_ai._run_context import AgentDepsT
from pydantic_ai.exceptions import ApprovalRequired, CallDeferred, ModelRetry, UserError
from pydantic_ai.mcp import MCPServer, ToolResult
from pydantic_ai.mcp import ToolResult # unchanged across the v1 -> v2 rename
from pydantic_ai.tools import RunContext
from pydantic_ai.toolsets.abstract import AbstractToolset, ToolsetTool
from pydantic_ai.toolsets.wrapper import WrapperToolset

from ._serde import PydanticTypeAdapter
from ._utils import current_state

# pydantic-ai renamed the MCP server toolset `pydantic_ai.mcp.MCPServer` ->
# `pydantic_ai.mcp.MCPToolset`. Crucially the two are *separate* classes (both subclass
# `AbstractToolset`, neither inherits the other): on pydantic-ai 1.x the concrete servers
# (`MCPServerStdio`/`MCPServerSSE`/`MCPServerStreamableHTTP`) subclass `MCPServer`, while
# on >= 2.0 only `MCPToolset` remains. We therefore match against *all* of the MCP base
# classes that exist in the installed version (`isinstance` accepts a tuple), so an MCP
# toolset is detected regardless of which base name its concrete class derives from.
_mcp_bases: list[type] = []
try:
from pydantic_ai.mcp import MCPServer # type: ignore[attr-defined]
except ImportError: # pragma: no cover - pydantic-ai dropped the legacy MCPServer name
pass
else:
_mcp_bases.append(MCPServer)
try:
from pydantic_ai.mcp import MCPToolset
except ImportError: # pragma: no cover - pydantic-ai < the MCPToolset rename
pass
else:
_mcp_bases.append(MCPToolset)

# Tuple for runtime `isinstance` checks (matches whichever MCP bases are installed).
MCP_TOOLSET_CLASSES: tuple[type, ...] = tuple(_mcp_bases)


@dataclass
class RestateContextRunResult:
Expand Down Expand Up @@ -119,7 +143,7 @@ def visit_and_replace(
class RestateMCPServer(WrapperToolset[AgentDepsT]):
"""A wrapper for MCPServer that integrates with restate."""

def __init__(self, wrapped: MCPServer, run_options: RunOptions):
def __init__(self, wrapped: AbstractToolset[AgentDepsT], run_options: RunOptions):
super().__init__(wrapped)
self._wrapped = wrapped
self.get_tools_options = replace(run_options, serde=MCP_GET_TOOLS_SERDE)
Expand Down Expand Up @@ -151,8 +175,11 @@ async def get_tools_in_context() -> RestateMCPGetToolsContextRunResult:
raise Exception("Internal error during get_tools call") from e

def tool_for_tool_def(self, tool_def: ToolDefinition) -> ToolsetTool[AgentDepsT]:
assert isinstance(self.wrapped, MCPServer)
return self.wrapped.tool_for_tool_def(tool_def)
assert isinstance(self.wrapped, MCP_TOOLSET_CLASSES)
# `MCP_TOOLSET_CLASSES` is a runtime tuple, so it cannot narrow `self.wrapped`
# (typed `AbstractToolset` by `WrapperToolset`) for the type-checker; the assert
# above guarantees the concrete MCP toolset method is present.
return self.wrapped.tool_for_tool_def(tool_def) # type: ignore[attr-defined]

async def call_tool(
self,
Expand Down
Loading