From d4de8281e18edc46d5763ee955e746423e3ec71e Mon Sep 17 00:00:00 2001 From: "shijinyu.7" Date: Wed, 1 Jul 2026 13:47:40 +0800 Subject: [PATCH 1/4] feat: support TIP auth for sandbox CLI --- .../toolkit/cli/sandbox/agentkit_client.py | 248 ++++++++++++++++++ agentkit/toolkit/cli/sandbox/cli_exec.py | 2 +- agentkit/toolkit/cli/sandbox/cli_file.py | 2 +- agentkit/toolkit/cli/sandbox/cli_get.py | 8 +- agentkit/toolkit/cli/sandbox/cli_mount.py | 2 +- .../toolkit/cli/sandbox/session_create.py | 20 +- agentkit/toolkit/cli/sandbox/session_sync.py | 2 +- agentkit/toolkit/cli/sandbox/tool_resolve.py | 16 +- .../cli/test_cli_sandbox_agentkit_client.py | 182 +++++++++++++ 9 files changed, 464 insertions(+), 18 deletions(-) create mode 100644 agentkit/toolkit/cli/sandbox/agentkit_client.py create mode 100644 tests/toolkit/cli/test_cli_sandbox_agentkit_client.py diff --git a/agentkit/toolkit/cli/sandbox/agentkit_client.py b/agentkit/toolkit/cli/sandbox/agentkit_client.py new file mode 100644 index 0000000..d193d1d --- /dev/null +++ b/agentkit/toolkit/cli/sandbox/agentkit_client.py @@ -0,0 +1,248 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Sandbox-specific AgentKit tools client helpers.""" + +from __future__ import annotations + +import json +import os +from typing import Any, Type, TypeVar +from urllib.parse import parse_qsl, urlencode, urlsplit, urlunsplit + +import requests + +from agentkit.platform.constants import SERVICE_METADATA +from agentkit.sdk.tools.client import AgentkitToolsClient as _OpenapiAgentkitToolsClient +from agentkit.sdk.tools.types import ( + CreateSessionRequest, + CreateSessionResponse, + GetSessionRequest, + GetSessionResponse, + ListSessionsRequest, + ListSessionsResponse, +) + +SANDBOX_APIG_ENDPOINT_ENV = "SANDBOX_APIG_ENDPOINT" +TIP_TOKEN_ENV = "TIP_TOKEN" +_AGENTKIT_API_VERSION = SERVICE_METADATA["agentkit"].default_version + +T = TypeVar("T") + + +def _env_value(name: str) -> str: + return (os.getenv(name) or "").strip() + + +def tip_auth_env_enabled() -> bool: + return bool(_env_value(SANDBOX_APIG_ENDPOINT_ENV) and _env_value(TIP_TOKEN_ENV)) + + +def _with_action_query(endpoint: str, action: str) -> str: + split = urlsplit(endpoint) + query = dict(parse_qsl(split.query, keep_blank_values=True)) + query.setdefault("Action", action) + query.setdefault("Version", _AGENTKIT_API_VERSION) + return urlunsplit( + ( + split.scheme, + split.netloc, + split.path, + urlencode(query), + split.fragment, + ) + ) + + +def _tip_endpoint_url(endpoint: str, action: str) -> str: + if "{Action}" in endpoint or "{action}" in endpoint: + return endpoint.replace("{Action}", action).replace("{action}", action) + return _with_action_query(endpoint, action) + + +def _extract_error_message(payload: object, default: str) -> str: + if not isinstance(payload, dict): + return default + + metadata = payload.get("ResponseMetadata") + if isinstance(metadata, dict): + api_error = metadata.get("Error") + if isinstance(api_error, dict): + message = api_error.get("Message") + if isinstance(message, str) and message: + return message + + for key in ("message", "Message", "error", "Error"): + value = payload.get(key) + if isinstance(value, str) and value: + return value + if isinstance(value, dict): + message = value.get("message") or value.get("Message") + if isinstance(message, str) and message: + return message + return default + + +def _tip_result_payload(payload: object) -> object: + if not isinstance(payload, dict): + return payload + if "Result" in payload: + return payload.get("Result") or {} + data = payload.get("data") + if isinstance(data, dict): + return data + return payload + + +class TipAgentkitToolsClient(_OpenapiAgentkitToolsClient): + """AgentKit tools client with optional TIP bearer-token session routing. + + When both SANDBOX_APIG_ENDPOINT and TIP_TOKEN are present, session APIs are + sent directly to the APIG endpoint. Otherwise this behaves exactly like the + generated AgentkitToolsClient. + """ + + def __init__( + self, + access_key: str = "", + secret_key: str = "", + region: str = "", + session_token: str = "", + timeout: int = 30, + ) -> None: + self._tip_endpoint = _env_value(SANDBOX_APIG_ENDPOINT_ENV).rstrip("/") + self._tip_token = _env_value(TIP_TOKEN_ENV) + self._tip_timeout = timeout + self._tip_session: requests.Session | None = None + if self.uses_tip_auth: + self._tip_session = requests.Session() + return + + super().__init__( + access_key=access_key, + secret_key=secret_key, + region=region, + session_token=session_token, + ) + + @property + def uses_tip_auth(self) -> bool: + return bool(self._tip_endpoint and self._tip_token) + + def _invoke_tip_api( + self, + api_action: str, + request: Any, + response_type: Type[T], + ) -> T: + if not self.uses_tip_auth or self._tip_session is None: + return self._invoke_api( + api_action=api_action, + request=request, + response_type=response_type, + ) + + url = _tip_endpoint_url(self._tip_endpoint, api_action) + body = request.model_dump(by_alias=True, exclude_none=True) + try: + response = self._tip_session.post( + url, + json=body, + headers={ + "Accept": "application/json", + "Content-Type": "application/json", + "Authorization": f"Bearer {self._tip_token}", + }, + timeout=self._tip_timeout, + ) + except requests.RequestException as exc: + raise Exception(f"Failed to {api_action}: {exc}") from exc + + try: + payload = response.json() + except ValueError as exc: + raise Exception( + f"Failed to {api_action}: invalid JSON response: {response.text}" + ) from exc + + if response.status_code >= 400: + raise Exception( + f"Failed to {api_action}: " + f"{_extract_error_message(payload, response.text)}" + ) + + if isinstance(payload, dict): + metadata = payload.get("ResponseMetadata") + if isinstance(metadata, dict) and metadata.get("Error"): + raise Exception( + f"Failed to {api_action}: " + f"{_extract_error_message(payload, json.dumps(payload))}" + ) + + result = _tip_result_payload(payload) + if not isinstance(result, dict): + raise Exception(f"Failed to {api_action}: invalid response payload") + return response_type(**result) + + def create_session(self, request: CreateSessionRequest) -> CreateSessionResponse: + return self._invoke_tip_api( + api_action="CreateSession", + request=request, + response_type=CreateSessionResponse, + ) + + def get_session(self, request: GetSessionRequest) -> GetSessionResponse: + return self._invoke_tip_api( + api_action="GetSession", + request=request, + response_type=GetSessionResponse, + ) + + def list_sessions(self, request: ListSessionsRequest) -> ListSessionsResponse: + return self._invoke_tip_api( + api_action="ListSessions", + request=request, + response_type=ListSessionsResponse, + ) + + def _raise_tip_unsupported(self, api_action: str) -> None: + raise Exception( + f"{api_action} is not available with TIP sandbox auth. " + f"Unset {SANDBOX_APIG_ENDPOINT_ENV}/{TIP_TOKEN_ENV} to use the " + "standard AgentKit OpenAPI client." + ) + + def create_tool(self, request: Any) -> Any: + if self.uses_tip_auth: + self._raise_tip_unsupported("CreateTool") + return super().create_tool(request) + + def get_tool(self, request: Any) -> Any: + if self.uses_tip_auth: + self._raise_tip_unsupported("GetTool") + return super().get_tool(request) + + def list_tools(self, request: Any) -> Any: + if self.uses_tip_auth: + self._raise_tip_unsupported("ListTools") + return super().list_tools(request) + + +def is_tip_agentkit_client(client: object) -> bool: + return bool(getattr(client, "uses_tip_auth", False)) + + +# Keep sandbox modules/tests able to patch a local AgentkitToolsClient symbol, +# while routing construction through the sandbox-aware implementation. +AgentkitToolsClient = TipAgentkitToolsClient diff --git a/agentkit/toolkit/cli/sandbox/cli_exec.py b/agentkit/toolkit/cli/sandbox/cli_exec.py index e9669ca..db92654 100644 --- a/agentkit/toolkit/cli/sandbox/cli_exec.py +++ b/agentkit/toolkit/cli/sandbox/cli_exec.py @@ -32,7 +32,7 @@ import typer -from agentkit.sdk.tools.client import AgentkitToolsClient +from agentkit.toolkit.cli.sandbox.agentkit_client import AgentkitToolsClient from agentkit.toolkit.cli.sandbox.cli_file import ( _build_remote_extract_command, _create_sources_upload_archive, diff --git a/agentkit/toolkit/cli/sandbox/cli_file.py b/agentkit/toolkit/cli/sandbox/cli_file.py index 2e0b800..c639705 100644 --- a/agentkit/toolkit/cli/sandbox/cli_file.py +++ b/agentkit/toolkit/cli/sandbox/cli_file.py @@ -29,7 +29,7 @@ import requests import typer -from agentkit.sdk.tools.client import AgentkitToolsClient +from agentkit.toolkit.cli.sandbox.agentkit_client import AgentkitToolsClient from agentkit.toolkit.cli.sandbox.session_create import SANDBOX_TOOL_ID_ENV from agentkit.toolkit.cli.sandbox.session_sync import sync_remote_sessions from agentkit.toolkit.cli.sandbox.tool_resolve import SandboxToolType diff --git a/agentkit/toolkit/cli/sandbox/cli_get.py b/agentkit/toolkit/cli/sandbox/cli_get.py index ef3a2dc..7f18cb2 100644 --- a/agentkit/toolkit/cli/sandbox/cli_get.py +++ b/agentkit/toolkit/cli/sandbox/cli_get.py @@ -20,7 +20,7 @@ import typer -from agentkit.sdk.tools.client import AgentkitToolsClient +from agentkit.toolkit.cli.sandbox.agentkit_client import AgentkitToolsClient from agentkit.toolkit.cli.sandbox.session_create import SANDBOX_TOOL_ID_ENV from agentkit.toolkit.cli.sandbox.session_sync import sync_remote_sessions from agentkit.toolkit.cli.sandbox.tool_resolve import SandboxToolType @@ -92,11 +92,7 @@ def get_command( ) raise typer.Exit(1) - if ( - session_id - and resolved_tool_id - and result.get("tool_id") != resolved_tool_id - ): + if session_id and resolved_tool_id and result.get("tool_id") != resolved_tool_id: echo_json( _session_not_found_result( session_id=session_id, diff --git a/agentkit/toolkit/cli/sandbox/cli_mount.py b/agentkit/toolkit/cli/sandbox/cli_mount.py index 276f0ff..c4a7e1d 100644 --- a/agentkit/toolkit/cli/sandbox/cli_mount.py +++ b/agentkit/toolkit/cli/sandbox/cli_mount.py @@ -26,9 +26,9 @@ import typer -from agentkit.sdk.tools.client import AgentkitToolsClient from agentkit.utils.http_defaults import http_timeout from agentkit.sdk.tools import types as tools_types +from agentkit.toolkit.cli.sandbox.agentkit_client import AgentkitToolsClient from agentkit.toolkit.cli.sandbox.session_create import SANDBOX_TOOL_ID_ENV from agentkit.toolkit.cli.sandbox.session_sync import sync_remote_sessions from agentkit.toolkit.cli.sandbox.sandbox_client import ( diff --git a/agentkit/toolkit/cli/sandbox/session_create.py b/agentkit/toolkit/cli/sandbox/session_create.py index df56ea2..060844d 100644 --- a/agentkit/toolkit/cli/sandbox/session_create.py +++ b/agentkit/toolkit/cli/sandbox/session_create.py @@ -21,8 +21,11 @@ import uuid from typing import Optional -from agentkit.sdk.tools.client import AgentkitToolsClient from agentkit.sdk.tools import types as tools_types +from agentkit.toolkit.cli.sandbox.agentkit_client import ( + AgentkitToolsClient, + is_tip_agentkit_client, +) from agentkit.toolkit.cli.sandbox.model_config import ( ANTHROPIC_BASE_URL_ENV_KEYS, CODEX_CONFIG_TOML_ENV, @@ -305,18 +308,21 @@ def _create_session( ttl: int, envs: Optional[list[tools_types.EnvsItemForCreateSession]] = None, ) -> dict[str, object]: - tool = client.get_tool(tools_types.GetToolRequest(tool_id=tool_id)) + tos_mount_points = None + if not is_tip_agentkit_client(client): + tool = client.get_tool(tools_types.GetToolRequest(tool_id=tool_id)) + tos_mount_points = build_session_tos_mount_points( + tool, + tool_id=tool_id, + session_id=session_id, + ) request = tools_types.CreateSessionRequest( tool_id=tool_id, ttl=ttl, ttl_unit="second", user_session_id=session_id, envs=envs, - tos_mount_points=build_session_tos_mount_points( - tool, - tool_id=tool_id, - session_id=session_id, - ), + tos_mount_points=tos_mount_points, ) try: response = client.create_session(request) diff --git a/agentkit/toolkit/cli/sandbox/session_sync.py b/agentkit/toolkit/cli/sandbox/session_sync.py index 993a23e..1037980 100644 --- a/agentkit/toolkit/cli/sandbox/session_sync.py +++ b/agentkit/toolkit/cli/sandbox/session_sync.py @@ -18,8 +18,8 @@ from typing import Optional -from agentkit.sdk.tools.client import AgentkitToolsClient from agentkit.sdk.tools import types as tools_types +from agentkit.toolkit.cli.sandbox.agentkit_client import AgentkitToolsClient from agentkit.toolkit.cli.sandbox.tool_resolve import ( SandboxToolType, resolve_existing_sandbox_tool_id, diff --git a/agentkit/toolkit/cli/sandbox/tool_resolve.py b/agentkit/toolkit/cli/sandbox/tool_resolve.py index 66a689b..8b14069 100644 --- a/agentkit/toolkit/cli/sandbox/tool_resolve.py +++ b/agentkit/toolkit/cli/sandbox/tool_resolve.py @@ -22,8 +22,11 @@ from pathlib import Path from typing import Optional -from agentkit.sdk.tools.client import AgentkitToolsClient from agentkit.sdk.tools import types as tools_types +from agentkit.toolkit.cli.sandbox.agentkit_client import ( + AgentkitToolsClient, + is_tip_agentkit_client, +) from agentkit.toolkit.cli.sandbox.model_config import ( ANTHROPIC_BASE_URL_ENV_KEYS, MODEL_BASE_URL_ENV_KEYS, @@ -222,6 +225,9 @@ def _validate_existing_tool_id( tool_type: str | SandboxToolType | None, save_result: bool = False, ) -> str: + if is_tip_agentkit_client(client): + return tool_id + try: response = client.get_tool(tools_types.GetToolRequest(tool_id=tool_id)) except Exception as exc: @@ -492,6 +498,11 @@ def resolve_sandbox_tool_id( return resolved_tool_id resolved_tool_type = normalize_tool_type(tool_type) + if is_tip_agentkit_client(client): + error( + "TIP sandbox auth requires an existing sandbox tool ID. " + f"Set --tool-id or {env_var_name}." + ) return _create_tool(resolved_tool_type) @@ -536,6 +547,9 @@ def resolve_existing_sandbox_tool_id( save_result=True, ) + if is_tip_agentkit_client(client): + return None + listed_tool_id = _list_first_tool(client, resolved_tool_type) if listed_tool_id: return listed_tool_id diff --git a/tests/toolkit/cli/test_cli_sandbox_agentkit_client.py b/tests/toolkit/cli/test_cli_sandbox_agentkit_client.py new file mode 100644 index 0000000..fcef95c --- /dev/null +++ b/tests/toolkit/cli/test_cli_sandbox_agentkit_client.py @@ -0,0 +1,182 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from urllib.parse import parse_qs, urlsplit + +import pytest +import typer + +from agentkit.sdk.tools import types as tools_types + + +class _FakeTipResponse: + status_code = 200 + text = "{}" + + def __init__(self, payload): + self._payload = payload + + def json(self): + return self._payload + + +class _FakeTipSession: + def __init__(self): + self.calls = [] + + def post(self, url, *, json, headers, timeout): + self.calls.append( + { + "url": url, + "json": json, + "headers": headers, + "timeout": timeout, + } + ) + return _FakeTipResponse( + { + "Result": { + "SessionId": "instance-1", + "UserSessionId": "user-1", + "Endpoint": "https://sandbox.example.com", + } + } + ) + + +def test_tip_client_create_session_uses_apig_endpoint_and_bearer_token( + monkeypatch, +): + import agentkit.toolkit.cli.sandbox.agentkit_client as agentkit_client + + fake_session = _FakeTipSession() + monkeypatch.setenv("SANDBOX_APIG_ENDPOINT", "https://apig.example.com/sandbox") + monkeypatch.setenv("TIP_TOKEN", "tip-token") + monkeypatch.setattr( + agentkit_client.requests, + "Session", + lambda: fake_session, + ) + + client = agentkit_client.TipAgentkitToolsClient() + response = client.create_session( + tools_types.CreateSessionRequest( + tool_id="tool-1", + ttl=60, + ttl_unit="second", + user_session_id="user-1", + ) + ) + + assert response.session_id == "instance-1" + assert response.user_session_id == "user-1" + assert response.endpoint == "https://sandbox.example.com" + assert len(fake_session.calls) == 1 + + call = fake_session.calls[0] + parsed = urlsplit(call["url"]) + query = parse_qs(parsed.query) + assert parsed.scheme == "https" + assert parsed.netloc == "apig.example.com" + assert parsed.path == "/sandbox" + assert query["Action"] == ["CreateSession"] + assert query["Version"] == ["2025-10-30"] + assert call["headers"]["Authorization"] == "Bearer tip-token" + assert call["json"] == { + "ToolId": "tool-1", + "Ttl": 60, + "TtlUnit": "second", + "UserSessionId": "user-1", + } + + +def test_tip_create_session_skips_get_tool_for_tos_mount_resolution(): + import agentkit.toolkit.cli.sandbox.session_create as session_create + + class FakeTipClient: + uses_tip_auth = True + get_tool_called = False + last_request = None + + def get_tool(self, _request): + self.get_tool_called = True + raise AssertionError("get_tool should not be called in TIP mode") + + def create_session(self, request): + self.last_request = request + return tools_types.CreateSessionResponse( + SessionId="instance-1", + UserSessionId="user-1", + Endpoint="https://sandbox.example.com", + ) + + client = FakeTipClient() + result = session_create._create_session( + client, + session_id="user-1", + tool_id="tool-1", + ttl=60, + ) + + assert result == { + "session_id": "user-1", + "tool_id": "tool-1", + "instance_id": "instance-1", + "endpoint": "https://sandbox.example.com", + } + assert client.get_tool_called is False + assert client.last_request.tos_mount_points is None + + +def test_tip_tool_resolution_trusts_explicit_tool_id(): + import agentkit.toolkit.cli.sandbox.tool_resolve as tool_resolve + + class FakeTipClient: + uses_tip_auth = True + + def get_tool(self, _request): + raise AssertionError("get_tool should not be called in TIP mode") + + assert ( + tool_resolve.resolve_existing_sandbox_tool_id( + tool_id="tool-1", + tool_type=tool_resolve.SandboxToolType.CODE_ENV, + client=FakeTipClient(), + env_var_name="AGENTKIT_SANDBOX_TOOL_ID", + ) + == "tool-1" + ) + + +def test_tip_tool_resolution_requires_existing_tool_id(monkeypatch, tmp_path): + import agentkit.toolkit.cli.sandbox.tool_resolve as tool_resolve + + class FakeTipClient: + uses_tip_auth = True + + store_path = tmp_path / "tools.json" + monkeypatch.setattr(tool_resolve, "_get_tool_store_path", lambda: store_path) + monkeypatch.delenv("AGENTKIT_SANDBOX_TOOL_ID", raising=False) + + with pytest.raises(typer.Exit) as exc: + tool_resolve.resolve_sandbox_tool_id( + tool_id=None, + tool_type=tool_resolve.SandboxToolType.CODE_ENV, + client=FakeTipClient(), + env_var_name="AGENTKIT_SANDBOX_TOOL_ID", + ) + + assert exc.value.exit_code == 1 From 5ca1edaa3fa4cb7e1608a5c264a08105685ebae8 Mon Sep 17 00:00:00 2001 From: "shijinyu.7" Date: Thu, 2 Jul 2026 16:45:08 +0800 Subject: [PATCH 2/4] feat: add sandbox invoke command --- agentkit/toolkit/cli/sandbox/a2a_client.py | 407 ++++++++++ agentkit/toolkit/cli/sandbox/cli.py | 5 + agentkit/toolkit/cli/sandbox/cli_invoke.py | 722 ++++++++++++++++++ .../toolkit/cli/sandbox/session_create.py | 101 ++- tests/toolkit/cli/test_cli_sandbox.py | 522 +++++++++++++ 5 files changed, 1731 insertions(+), 26 deletions(-) create mode 100644 agentkit/toolkit/cli/sandbox/a2a_client.py create mode 100644 agentkit/toolkit/cli/sandbox/cli_invoke.py diff --git a/agentkit/toolkit/cli/sandbox/a2a_client.py b/agentkit/toolkit/cli/sandbox/a2a_client.py new file mode 100644 index 0000000..6530a4e --- /dev/null +++ b/agentkit/toolkit/cli/sandbox/a2a_client.py @@ -0,0 +1,407 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A2A JSON-RPC helpers for sandbox CLI commands.""" + +from __future__ import annotations + +from dataclasses import dataclass +import json +import sys +import time +from typing import Any +from urllib.parse import urlsplit, urlunsplit +import uuid + +import requests + +DEFAULT_A2A_PATH = "/a2a" +DEFAULT_A2A_TIMEOUT_SECONDS = 1200 +DEFAULT_A2A_HISTORY_LENGTH = 20 +DEFAULT_A2A_POLL_INTERVAL_SECONDS = 2.0 +DEFAULT_READY_RETRIES = 12 +DEFAULT_READY_RETRY_DELAY = 5.0 +RETRYABLE_A2A_STATUS_CODES = {502, 503, 504} +TERMINAL_STATES = { + "completed", + "failed", + "canceled", + "rejected", + "input-required", + "auth-required", +} + + +class A2AApiError(RuntimeError): + """Raised when a sandbox A2A request fails.""" + + def __init__( + self, + operation: str, + message: str, + *, + status_code: int | None = None, + response_text: str | None = None, + response_json: Any = None, + ) -> None: + super().__init__(message) + self.operation = operation + self.status_code = status_code + self.response_text = response_text + self.response_json = response_json + + +@dataclass(frozen=True) +class A2ATaskStart: + task: dict[str, Any] + task_id: str + context_id: str | None + + +def send_message_nonblocking( + *, + endpoint: object, + prompt: str, + a2a_path: str = DEFAULT_A2A_PATH, + context_id: str | None = None, + request_metadata: dict[str, str] | None = None, + history_length: int | None = DEFAULT_A2A_HISTORY_LENGTH, + timeout: int = 60, + readiness_retries: int = DEFAULT_READY_RETRIES, + readiness_retry_delay: float = DEFAULT_READY_RETRY_DELAY, +) -> A2ATaskStart: + message: dict[str, Any] = { + "kind": "message", + "messageId": str(uuid.uuid4()), + "role": "user", + "parts": [{"kind": "text", "text": prompt}], + } + if context_id: + message["contextId"] = context_id + + configuration: dict[str, Any] = {"blocking": False} + if history_length is not None: + configuration["historyLength"] = history_length + + params: dict[str, Any] = { + "message": message, + "configuration": configuration, + } + if request_metadata: + params["metadata"] = request_metadata + + response = _post_jsonrpc( + endpoint=endpoint, + a2a_path=a2a_path, + payload={ + "jsonrpc": "2.0", + "id": str(uuid.uuid4()), + "method": "message/send", + "params": params, + }, + timeout=timeout, + operation="A2ASendMessage", + readiness_retries=readiness_retries, + readiness_retry_delay=readiness_retry_delay, + ) + task = _jsonrpc_result_task("A2ASendMessage", response) + task_id = _task_id(task) + if not task_id: + raise A2AApiError( + "A2ASendMessage", + "response task does not contain id", + response_json=response, + ) + return A2ATaskStart( + task=task, + task_id=task_id, + context_id=task_context_id(task), + ) + + +def get_task( + *, + endpoint: object, + task_id: str, + a2a_path: str = DEFAULT_A2A_PATH, + history_length: int | None = DEFAULT_A2A_HISTORY_LENGTH, + timeout: int = 60, + readiness_retries: int = DEFAULT_READY_RETRIES, + readiness_retry_delay: float = DEFAULT_READY_RETRY_DELAY, +) -> dict[str, Any]: + params: dict[str, Any] = {"id": task_id} + if history_length is not None: + params["historyLength"] = history_length + + response = _post_jsonrpc( + endpoint=endpoint, + a2a_path=a2a_path, + payload={ + "jsonrpc": "2.0", + "id": str(uuid.uuid4()), + "method": "tasks/get", + "params": params, + }, + timeout=timeout, + operation="A2AGetTask", + readiness_retries=readiness_retries, + readiness_retry_delay=readiness_retry_delay, + ) + return _jsonrpc_result_task("A2AGetTask", response) + + +def poll_task_until_terminal( + *, + endpoint: object, + task_id: str, + a2a_path: str = DEFAULT_A2A_PATH, + history_length: int | None = DEFAULT_A2A_HISTORY_LENGTH, + timeout: int = DEFAULT_A2A_TIMEOUT_SECONDS, + interval: float = DEFAULT_A2A_POLL_INTERVAL_SECONDS, + print_events: bool = False, +) -> dict[str, Any]: + deadline = time.monotonic() + timeout + latest_task = get_task( + endpoint=endpoint, + task_id=task_id, + a2a_path=a2a_path, + history_length=history_length, + timeout=min(60, timeout), + ) + while task_state(latest_task) not in TERMINAL_STATES: + if print_events: + print(json.dumps(latest_task, ensure_ascii=False), file=sys.stderr) + if time.monotonic() >= deadline: + raise TimeoutError(f"Timed out while waiting for A2A task {task_id}") + time.sleep(interval) + latest_task = get_task( + endpoint=endpoint, + task_id=task_id, + a2a_path=a2a_path, + history_length=history_length, + timeout=min(60, timeout), + ) + if print_events: + print(json.dumps(latest_task, ensure_ascii=False), file=sys.stderr) + return latest_task + + +def task_result_parts(task: dict[str, Any] | None) -> list[dict[str, str]]: + if not task: + return [] + for source in ( + _artifact_parts(task), + _message_parts(_get_nested(task, ("status", "message"))), + _history_agent_parts(task), + ): + parts = _text_parts(source) + if parts: + return parts + return [] + + +def task_result_text(task: dict[str, Any] | None) -> str: + return "\n".join(part["text"] for part in task_result_parts(task)) + + +def task_state(task: dict[str, Any] | None) -> str | None: + value = _get_nested(task, ("status", "state")) + return value if isinstance(value, str) else None + + +def task_context_id(task: dict[str, Any] | None) -> str | None: + if not isinstance(task, dict): + return None + value = task.get("contextId") or task.get("context_id") + return value if isinstance(value, str) and value else None + + +def build_a2a_url(endpoint: object, a2a_path: str = DEFAULT_A2A_PATH) -> str: + if not isinstance(endpoint, str) or not endpoint.strip(): + raise A2AApiError("A2ARequest", "Sandbox session endpoint is missing") + + parts = urlsplit(endpoint.strip()) + base_path = parts.path.rstrip("/") + if not a2a_path or a2a_path == "/": + resolved_path = f"{base_path}/" if base_path else "/" + else: + suffix = "/" + a2a_path.strip("/") + resolved_path = f"{base_path}{suffix}" if base_path else suffix + return urlunsplit( + (parts.scheme, parts.netloc, resolved_path, parts.query, parts.fragment) + ) + + +def _post_jsonrpc( + *, + endpoint: object, + a2a_path: str, + payload: dict[str, Any], + timeout: int, + operation: str, + readiness_retries: int, + readiness_retry_delay: float, +) -> dict[str, Any]: + url = build_a2a_url(endpoint, a2a_path) + response: requests.Response | None = None + for attempt in range(readiness_retries + 1): + try: + response = requests.post(url, json=payload, timeout=timeout) + except requests.RequestException as exc: + if attempt >= readiness_retries: + raise A2AApiError(operation, str(exc)) from exc + time.sleep(readiness_retry_delay) + continue + + if not _is_retryable_a2a_response(response) or attempt >= readiness_retries: + break + time.sleep(readiness_retry_delay) + + if response is None: + raise A2AApiError(operation, "request was not sent") + + body = response.text + if response.status_code < 200 or response.status_code >= 300: + raise A2AApiError( + operation, + _failure_hint(response.status_code, body), + status_code=response.status_code, + response_text=body, + ) + + try: + parsed = response.json() + except ValueError as exc: + raise A2AApiError( + operation, + "response is not valid JSON", + status_code=response.status_code, + response_text=body, + ) from exc + if not isinstance(parsed, dict): + raise A2AApiError( + operation, + "response JSON is not an object", + status_code=response.status_code, + response_json={"response": parsed}, + ) + if parsed.get("error") is not None: + raise A2AApiError( + operation, + "A2A JSON-RPC returned error", + status_code=response.status_code, + response_json=parsed, + ) + return parsed + + +def _is_retryable_a2a_response(response: requests.Response) -> bool: + if response.status_code in RETRYABLE_A2A_STATUS_CODES: + return True + if response.status_code != 500: + return False + + body = response.text.lower() + return "function_proxy_error" in body and "connection refused" in body + + +def _jsonrpc_result_task(operation: str, response: dict[str, Any]) -> dict[str, Any]: + result = response.get("result") + if not isinstance(result, dict): + raise A2AApiError( + operation, + "response does not contain result task", + response_json=response, + ) + if result.get("kind") != "task" and "status" not in result: + raise A2AApiError( + operation, + "response result is not an A2A task", + response_json=response, + ) + return result + + +def _task_id(task: dict[str, Any] | None) -> str | None: + if not isinstance(task, dict): + return None + value = task.get("id") + return value if isinstance(value, str) and value else None + + +def _artifact_parts(task: dict[str, Any]) -> list[Any]: + artifacts = task.get("artifacts") + if not isinstance(artifacts, list): + return [] + parts: list[Any] = [] + for artifact in artifacts: + if not isinstance(artifact, dict): + continue + artifact_parts = artifact.get("parts") + if isinstance(artifact_parts, list): + parts.extend(artifact_parts) + return parts + + +def _message_parts(message: Any) -> list[Any]: + if not isinstance(message, dict): + return [] + parts = message.get("parts") + return parts if isinstance(parts, list) else [] + + +def _history_agent_parts(task: dict[str, Any]) -> list[Any]: + history = task.get("history") + if not isinstance(history, list): + return [] + parts: list[Any] = [] + for message in reversed(history): + if not isinstance(message, dict) or message.get("role") != "agent": + continue + message_parts = message.get("parts") + if isinstance(message_parts, list): + text_parts = _text_parts(message_parts) + if text_parts: + parts.extend(message_parts) + break + return parts + + +def _text_parts(parts: list[Any]) -> list[dict[str, str]]: + return [ + {"text": part["text"]} + for part in parts + if isinstance(part, dict) + and part.get("kind") == "text" + and isinstance(part.get("text"), str) + and part["text"] + ] + + +def _get_nested(value: Any, path: tuple[str, ...]) -> Any: + current = value + for key in path: + if not isinstance(current, dict): + return None + current = current.get(key) + return current + + +def _failure_hint(status_code: int, body: str) -> str: + lower_body = body.lower() + if status_code in (404, 410) or "expired" in lower_body or "deleted" in lower_body: + return "Sandbox session or A2A task may have expired or been deleted" + if status_code in (401, 403): + return "Sandbox access credentials may have expired or become invalid" + return "A2A request returned non-2xx status" diff --git a/agentkit/toolkit/cli/sandbox/cli.py b/agentkit/toolkit/cli/sandbox/cli.py index 0af02f2..2160b09 100644 --- a/agentkit/toolkit/cli/sandbox/cli.py +++ b/agentkit/toolkit/cli/sandbox/cli.py @@ -22,6 +22,7 @@ from agentkit.toolkit.cli.sandbox.cli_exec import exec_command from agentkit.toolkit.cli.sandbox.cli_file import file_command from agentkit.toolkit.cli.sandbox.cli_get import get_command +from agentkit.toolkit.cli.sandbox.cli_invoke import invoke_command from agentkit.toolkit.cli.sandbox.cli_model_login import codex_login_command from agentkit.toolkit.cli.sandbox.cli_mount import mount_command from agentkit.toolkit.cli.sandbox.cli_run import run_command @@ -44,6 +45,10 @@ name="exec", context_settings={"allow_extra_args": True}, )(exec_command) +sandbox_app.command( + name="invoke", + context_settings={"allow_extra_args": True}, +)(invoke_command) sandbox_app.command(name="run")(run_command) sandbox_app.command( name="shell", diff --git a/agentkit/toolkit/cli/sandbox/cli_invoke.py b/agentkit/toolkit/cli/sandbox/cli_invoke.py new file mode 100644 index 0000000..6294744 --- /dev/null +++ b/agentkit/toolkit/cli/sandbox/cli_invoke.py @@ -0,0 +1,722 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A2A invoke command for sandbox CLI.""" + +from __future__ import annotations + +import json +import os +from pathlib import Path +from typing import Any, Optional + +import typer + +from agentkit.sdk.tools import types as tools_types +from agentkit.toolkit.cli.sandbox.a2a_client import ( + DEFAULT_A2A_HISTORY_LENGTH, + DEFAULT_A2A_PATH, + DEFAULT_A2A_POLL_INTERVAL_SECONDS, + DEFAULT_A2A_TIMEOUT_SECONDS, + A2AApiError, + poll_task_until_terminal, + send_message_nonblocking, + task_context_id, + task_result_text, + task_state, +) +from agentkit.toolkit.cli.sandbox.session_create import ( + SANDBOX_TOOL_ID_ENV, + ensure_sandbox_session, +) +from agentkit.toolkit.cli.sandbox.sandbox_client import echo_json, error +from agentkit.toolkit.cli.sandbox.tool_resolve import SandboxToolType + +MODEL_AGENT_ENV_KEYS = ( + "MODEL_AGENT_API_BASE", + "MODEL_AGENT_API_KEY", + "MODEL_AGENT_PROVIDER", + "MODEL_AGENT_NAME", + "MODEL_AGENT_EXTRA_HEADERS", +) +REQUIRED_MODEL_AGENT_ENV_KEYS = ( + "MODEL_AGENT_API_BASE", + "MODEL_AGENT_API_KEY", + "MODEL_AGENT_PROVIDER", + "MODEL_AGENT_NAME", +) +OPENCLAW_CONFIG_FILE = Path("/root/.openclaw/openclaw.json") +OPENCLAW_MODEL_CONFIG_ROOTS = ( + ("models",), + ("model",), + ("modelProviders",), + ("model_providers",), + ("providers",), + ("agents", "defaults", "model"), +) +OPENCLAW_API_BASE_KEYS = ( + "api_base_url", + "apiBaseUrl", + "api_base", + "apiBase", + "base_url", + "baseURL", + "baseUrl", + "baseurl", +) +OPENCLAW_API_KEY_KEYS = ( + "api_key", + "apiKey", + "apikey", + "key", + "model_key", + "modelKey", +) +OPENCLAW_PROVIDER_API_KEYS = ( + "api", + "model_api", + "modelApi", + "api_type", + "apiType", +) +OPENCLAW_MODEL_HEADER_KEYS = ( + "headers", + "extra_headers", + "extraHeaders", +) +OPENCLAW_PROVIDER_API_ROUTES = { + "openai-completions": "openai", + "openai-responses": "openai/responses", + "openai-codex-responses": "openai/responses", + "anthropic-messages": "anthropic", + "google-generative-ai": "gemini", + "github-copilot": "github_copilot", + "bedrock-converse-stream": "bedrock/converse", + "ollama": "ollama_chat", + "azure-openai-responses": "azure/responses", +} + + +def _resolve_invoke_tool_id( + *, + tool_id: Optional[str], + tool_type: SandboxToolType, +) -> str: + explicit_tool_id = (tool_id or "").strip() + if explicit_tool_id: + return explicit_tool_id + + env_tool_id = (os.getenv(SANDBOX_TOOL_ID_ENV) or "").strip() + if env_tool_id: + return env_tool_id + + return tool_type.value + + +def build_invoke_model_agent_envs( + *, + model_name: Optional[str] = None, + model_provider: Optional[str] = None, + model_base_url: Optional[str] = None, + model_api_key: Optional[str] = None, + openclaw_config_file: Path = OPENCLAW_CONFIG_FILE, +) -> list[tools_types.EnvsItemForCreateSession]: + cli_values = { + "MODEL_AGENT_API_BASE": (model_base_url or "").strip(), + "MODEL_AGENT_API_KEY": (model_api_key or "").strip(), + "MODEL_AGENT_PROVIDER": (model_provider or "").strip(), + "MODEL_AGENT_NAME": (model_name or "").strip(), + } + env_values = _collect_model_agent_envs_from_env() + openclaw_values = _collect_openclaw_model_agent_envs(openclaw_config_file) + + values: dict[str, str] = {} + for key in MODEL_AGENT_ENV_KEYS: + values[key] = ( + cli_values.get(key) or env_values.get(key) or openclaw_values.get(key) or "" + ) + + return [ + tools_types.EnvsItemForCreateSession(key=key, value=values[key]) + for key in MODEL_AGENT_ENV_KEYS + if key in REQUIRED_MODEL_AGENT_ENV_KEYS or values[key] + ] + + +def _collect_model_agent_envs_from_env() -> dict[str, str]: + values: dict[str, str] = {} + for key in MODEL_AGENT_ENV_KEYS: + value = os.getenv(key, "").strip() + if value: + values[key] = value + return values + + +def _collect_openclaw_model_agent_envs(path: Path) -> dict[str, str]: + try: + data = json.loads(path.read_text(encoding="utf-8")) + except (FileNotFoundError, OSError, json.JSONDecodeError): + return {} + if not isinstance(data, dict): + return {} + + primary = _get_nested(data, ("agents", "defaults", "model", "primary")) + if not isinstance(primary, str): + return {} + + provider, model_name = _parse_openclaw_primary(primary) + if not provider or not model_name: + return {} + + model_config = _find_openclaw_model_config(data, provider, model_name) + if not model_config: + return {} + + api_base = _pick_openclaw_text(model_config, OPENCLAW_API_BASE_KEYS) + api_key = _pick_openclaw_text(model_config, OPENCLAW_API_KEY_KEYS) + provider_api = _pick_openclaw_text(model_config, OPENCLAW_PROVIDER_API_KEYS) + if not api_base or not api_key or not provider_api: + return {} + + litellm_provider, model_agent_name = _resolve_openclaw_provider_api_route( + provider_api, + model_name, + ) + if not litellm_provider or not model_agent_name: + return {} + + values = { + "MODEL_AGENT_API_BASE": api_base, + "MODEL_AGENT_API_KEY": api_key, + "MODEL_AGENT_PROVIDER": litellm_provider, + "MODEL_AGENT_NAME": model_agent_name, + } + extra_headers = _pick_openclaw_headers_json(model_config) + if extra_headers: + values["MODEL_AGENT_EXTRA_HEADERS"] = extra_headers + return values + + +def _resolve_openclaw_provider_api_route( + provider_api: str, + model_name: str, +) -> tuple[str, str]: + litellm_provider = OPENCLAW_PROVIDER_API_ROUTES.get(provider_api) + if not litellm_provider: + return "", "" + return litellm_provider, model_name + + +def _parse_openclaw_primary(primary: str) -> tuple[str, str]: + provider, separator, model_name = primary.strip().partition("/") + if not separator: + return "", "" + return provider.strip(), model_name.strip() + + +def _find_openclaw_model_config( + data: dict[str, Any], + provider: str, + model_name: str, +) -> dict[str, Any] | None: + for path in OPENCLAW_MODEL_CONFIG_ROOTS: + root = _get_nested(data, path) + match = _find_openclaw_model_config_in(root, provider, model_name) + if match: + return match + return _find_openclaw_model_config_in(data, provider, model_name) + + +def _find_openclaw_model_config_in( + value: Any, + provider: str, + model_name: str, +) -> dict[str, Any] | None: + if isinstance(value, dict): + direct = _openclaw_direct_model_config(value, provider, model_name) + if direct: + return direct + + if _openclaw_model_config_matches(value, provider, model_name): + return value + + for child in value.values(): + match = _find_openclaw_model_config_in(child, provider, model_name) + if match: + return match + elif isinstance(value, list): + for item in value: + match = _find_openclaw_model_config_in(item, provider, model_name) + if match: + return match + return None + + +def _openclaw_direct_model_config( + value: dict[str, Any], + provider: str, + model_name: str, +) -> dict[str, Any] | None: + direct_keys = ( + f"{provider}/{model_name}", + model_name, + ) + for key in direct_keys: + candidate = value.get(key) + if isinstance(candidate, dict): + return candidate + + provider_config = value.get(provider) + if isinstance(provider_config, dict): + candidate = provider_config.get(model_name) + if isinstance(candidate, dict): + return _merge_openclaw_model_config(provider_config, candidate) + if _openclaw_provider_config_has_model(provider_config, model_name): + model_item = _find_openclaw_model_item(provider_config, model_name) + return _merge_openclaw_model_config(provider_config, model_item) + return None + + +def _merge_openclaw_model_config( + provider_config: dict[str, Any], + model_config: dict[str, Any] | None, +) -> dict[str, Any]: + merged = {key: value for key, value in provider_config.items() if key != "models"} + if model_config: + merged.update(model_config) + provider_headers = _pick_openclaw_headers(provider_config) + model_headers = _pick_openclaw_headers(model_config) + if provider_headers or model_headers: + headers = {} + headers.update(provider_headers) + headers.update(model_headers) + merged["headers"] = headers + return merged + + +def _openclaw_provider_config_has_model( + provider_config: dict[str, Any], + model_name: str, +) -> bool: + return _find_openclaw_model_item(provider_config, model_name) is not None + + +def _find_openclaw_model_item( + provider_config: dict[str, Any], + model_name: str, +) -> dict[str, Any] | None: + models = provider_config.get("models") + if isinstance(models, dict): + candidate = models.get(model_name) + if isinstance(candidate, dict): + return candidate + for item in models.values(): + if isinstance(item, dict) and _openclaw_model_config_matches_model_name( + item, model_name + ): + return item + if isinstance(models, list): + for item in models: + if isinstance(item, dict) and _openclaw_model_config_matches_model_name( + item, model_name + ): + return item + return None + + +def _openclaw_model_config_matches_model_name(value: Any, model_name: str) -> bool: + if isinstance(value, str): + return value == model_name + if not isinstance(value, dict): + return False + + name_value = _pick_openclaw_text( + value, + ("id", "name", "model", "model_name", "modelName"), + ) + return name_value == model_name + + +def _openclaw_model_config_matches( + value: dict[str, Any], + provider: str, + model_name: str, +) -> bool: + provider_value = _pick_openclaw_text( + value, + ("provider", "provider_name", "providerName", "type"), + ) + name_value = _pick_openclaw_text( + value, + ("name", "model", "model_name", "modelName", "id"), + ) + return provider_value == provider and name_value in { + model_name, + f"{provider}/{model_name}", + } + + +def _get_nested(value: Any, path: tuple[str, ...]) -> Any: + current = value + for key in path: + if not isinstance(current, dict): + return None + current = current.get(key) + return current + + +def _pick_openclaw_text(value: dict[str, Any], keys: tuple[str, ...]) -> str | None: + for key in keys: + item = value.get(key) + if isinstance(item, str) and item.strip(): + return item.strip() + return None + + +def _pick_openclaw_headers_json(value: dict[str, Any]) -> str | None: + headers = _pick_openclaw_headers(value) + if not headers: + return None + return json.dumps(headers, ensure_ascii=False, sort_keys=True) + + +def _pick_openclaw_headers(value: dict[str, Any]) -> dict[str, str]: + for key in OPENCLAW_MODEL_HEADER_KEYS: + item = value.get(key) + if not isinstance(item, dict): + continue + headers: dict[str, str] = {} + for header_key, header_value in item.items(): + if ( + isinstance(header_key, str) + and isinstance(header_value, str) + and header_key.strip() + and header_value.strip() + ): + headers[header_key.strip()] = header_value.strip() + if headers: + return headers + return {} + + +def _task_failure_error(task: dict[str, Any]) -> dict[str, str]: + state = task_state(task) or "unknown" + message = _status_message_text(task) or f"Sandbox task ended with state: {state}" + return { + "type": "SandboxTaskFailed", + "message": message, + } + + +def _status_message_text(task: dict[str, Any]) -> str: + status = task.get("status") + if not isinstance(status, dict): + return "" + message = status.get("message") + if not isinstance(message, dict): + return "" + parts = message.get("parts") + if not isinstance(parts, list): + return "" + texts = [ + part["text"] + for part in parts + if isinstance(part, dict) + and part.get("kind") == "text" + and isinstance(part.get("text"), str) + and part["text"] + ] + return "\n".join(texts) + + +def _task_output( + *, + task: dict[str, Any], + session: dict[str, object], + source: str, +) -> dict[str, Any]: + state = task_state(task) + output: dict[str, Any] = { + "ok": state == "completed", + "task_state": state, + "error": None, + "task_id": task.get("id"), + "context_id": task_context_id(task), + "final_result": task_result_text(task), + "session_id": session.get("session_id"), + "tool_id": session.get("tool_id"), + "sandbox": { + "available": True, + "endpoint": session.get("endpoint"), + }, + "source": source, + } + if state != "completed": + output["error"] = _task_failure_error(task) + if not output["final_result"]: + output["final_result"] = None + return output + + +def _task_created_output( + *, + task: dict[str, Any], + task_id: str, + context_id: str | None, + session: dict[str, object], +) -> dict[str, Any]: + return { + "ok": True, + "status": "success", + "task_id": task_id, + "task_state": task_state(task), + "context_id": context_id, + "session_id": session.get("session_id"), + "tool_id": session.get("tool_id"), + "sandbox": { + "available": True, + "endpoint": session.get("endpoint"), + }, + "source": "sandbox-invoke", + } + + +def _error_payload(exc: Exception) -> dict[str, Any]: + error_payload: dict[str, Any] = { + "type": type(exc).__name__, + "message": str(exc), + } + if isinstance(exc, A2AApiError): + error_payload = { + "type": type(exc).__name__, + "operation": exc.operation, + "message": str(exc), + "status_code": exc.status_code, + } + if exc.response_json is not None: + error_payload["response"] = exc.response_json + elif exc.response_text is not None: + error_payload["response"] = exc.response_text + return { + "ok": False, + "status": "error", + "error": error_payload, + "task_id": None, + "sandbox": {"available": False, "endpoint": None}, + "source": "sandbox-invoke", + } + + +def _normalize_timeout(timeout: int) -> int: + if timeout <= 0: + error("--timeout must be greater than 0") + return timeout + + +def _normalize_interval(interval: float) -> float: + if interval <= 0: + error("--interval must be greater than 0") + return interval + + +def _resolve_async_mode(ctx: typer.Context, async_mode: bool) -> bool: + args = list(ctx.args) + if not args: + return async_mode + if len(args) == 1 and async_mode: + value = args[0].strip().lower() + if value in {"true", "1", "yes", "y", "on"}: + return True + if value in {"false", "0", "no", "n", "off"}: + return False + error(f"Unexpected argument: {args[0]}") + + +def invoke_command( + ctx: typer.Context, + session_id: Optional[str] = typer.Option( + None, + "--session-id", + "--sid", + "-s", + help=( + "Sandbox session ID. Defaults to a generated UUID and creates " + "a sandbox session when needed." + ), + ), + prompt: Optional[str] = typer.Option( + None, + "--prompt", + help="Prompt to send to the sandbox A2A agent.", + ), + tool_id: Optional[str] = typer.Option( + None, + "--tool-id", + help=( + f"Sandbox tool ID. Defaults to {SANDBOX_TOOL_ID_ENV}; when unset, " + "--tool-type is used as the tool ID." + ), + ), + tool_type: SandboxToolType = typer.Option( + SandboxToolType.SKILL_ENV, + "--tool-type", + help="Sandbox tool type used as the fallback tool ID.", + ), + async_mode: bool = typer.Option( + False, + "--async", + help="Return immediately after creating the A2A task.", + ), + task_id: Optional[str] = typer.Option( + None, + "--task-id", + help="Poll an existing A2A task ID instead of creating a new task.", + ), + ttl: Optional[int] = typer.Option( + None, + "--ttl", + help=( + "Sandbox session TTL in seconds. Defaults to AGENTKIT_SANDBOX_TTL " + "or the exec command default." + ), + ), + model_name: Optional[str] = typer.Option( + None, + "--model-name", + help="Model name to inject as MODEL_AGENT_NAME when creating a session.", + ), + model_provider: Optional[str] = typer.Option( + None, + "--model-provider", + help="Model provider to inject as MODEL_AGENT_PROVIDER.", + ), + model_base_url: Optional[str] = typer.Option( + None, + "--model-base-url", + help="Model API base URL to inject as MODEL_AGENT_API_BASE.", + ), + model_api_key: Optional[str] = typer.Option( + None, + "--model-api-key", + help="Model API key to inject as MODEL_AGENT_API_KEY.", + ), + timeout: int = typer.Option( + DEFAULT_A2A_TIMEOUT_SECONDS, + "--timeout", + help="Maximum seconds to wait for synchronous invoke or task polling.", + ), + interval: float = typer.Option( + DEFAULT_A2A_POLL_INTERVAL_SECONDS, + "--interval", + help="Polling interval in seconds.", + ), + history_length: int = typer.Option( + DEFAULT_A2A_HISTORY_LENGTH, + "--history-length", + help="A2A task history length to request.", + ), + a2a_path: str = typer.Option( + DEFAULT_A2A_PATH, + "--a2a-path", + help="A2A JSON-RPC path on the sandbox endpoint.", + hidden=True, + ), +) -> None: + """Invoke a sandbox A2A agent.""" + resolved_task_id = (task_id or "").strip() + resolved_prompt = (prompt or "").strip() + resolved_async_mode = _resolve_async_mode(ctx, async_mode) + if not resolved_task_id and not resolved_prompt: + error("--prompt is required unless --task-id is provided") + if ttl is not None and ttl <= 0: + error("--ttl must be greater than 0") + if history_length < 0: + error("--history-length must be non-negative") + + resolved_timeout = _normalize_timeout(timeout) + resolved_interval = _normalize_interval(interval) + resolved_tool_id = _resolve_invoke_tool_id( + tool_id=tool_id, + tool_type=tool_type, + ) + + try: + session = ensure_sandbox_session( + session_id=session_id, + tool_id=resolved_tool_id, + tool_type=tool_type.value, + ttl=ttl, + envs=build_invoke_model_agent_envs( + model_name=model_name, + model_provider=model_provider, + model_base_url=model_base_url, + model_api_key=model_api_key, + ), + resolve_tool=False, + include_tos_mount_points=False, + ) + except typer.Exit: + raise + except Exception as exc: + error(str(exc)) + + try: + if resolved_task_id: + task = poll_task_until_terminal( + endpoint=session.get("endpoint"), + task_id=resolved_task_id, + a2a_path=a2a_path, + history_length=history_length, + timeout=resolved_timeout, + interval=resolved_interval, + ) + echo_json(_task_output(task=task, session=session, source="sandbox-invoke")) + return + + task_start = send_message_nonblocking( + endpoint=session.get("endpoint"), + prompt=resolved_prompt, + a2a_path=a2a_path, + request_metadata={ + "session_id": str(session.get("session_id") or ""), + "user_id": "agentkit-sandbox-invoke", + }, + history_length=history_length, + timeout=min(60, resolved_timeout), + ) + if resolved_async_mode: + echo_json( + _task_created_output( + task=task_start.task, + task_id=task_start.task_id, + context_id=task_start.context_id, + session=session, + ) + ) + return + + task = poll_task_until_terminal( + endpoint=session.get("endpoint"), + task_id=task_start.task_id, + a2a_path=a2a_path, + history_length=history_length, + timeout=resolved_timeout, + interval=resolved_interval, + ) + echo_json(_task_output(task=task, session=session, source="sandbox-invoke")) + except typer.Exit: + raise + except Exception as exc: + echo_json(_error_payload(exc)) + raise typer.Exit(1) from exc diff --git a/agentkit/toolkit/cli/sandbox/session_create.py b/agentkit/toolkit/cli/sandbox/session_create.py index 060844d..cdc0cf5 100644 --- a/agentkit/toolkit/cli/sandbox/session_create.py +++ b/agentkit/toolkit/cli/sandbox/session_create.py @@ -307,9 +307,10 @@ def _create_session( tool_id: str, ttl: int, envs: Optional[list[tools_types.EnvsItemForCreateSession]] = None, + include_tos_mount_points: bool = True, ) -> dict[str, object]: tos_mount_points = None - if not is_tip_agentkit_client(client): + if include_tos_mount_points and not is_tip_agentkit_client(client): tool = client.get_tool(tools_types.GetToolRequest(tool_id=tool_id)) tos_mount_points = build_session_tos_mount_points( tool, @@ -374,18 +375,45 @@ def _confirm_session_after_create_start_fail( return None +def _get_remote_session_by_user_session_id( + client: AgentkitToolsClient, + *, + session_id: str, + tool_id: str, +) -> dict[str, object] | None: + response = client.list_sessions( + tools_types.ListSessionsRequest( + tool_id=tool_id, + max_results=10, + filters=[ + tools_types.FiltersItemForListSessions( + name="UserSessionId", + values=[session_id], + ) + ], + ) + ) + for session in response.session_infos or []: + result = session_info_to_result(session, tool_id) + if result and result.get("session_id") == session_id: + return result + return None + + def ensure_sandbox_session_with_status( session_id: Optional[str] = None, tool_id: Optional[str] = None, tool_type: str = DEFAULT_SANDBOX_TOOL_TYPE, ttl: Optional[int] = None, envs: Optional[list[tools_types.EnvsItemForCreateSession]] = None, + resolve_tool: bool = True, + include_tos_mount_points: bool = True, ) -> tuple[dict[str, object], bool]: resolved_session_id = session_id or str(uuid.uuid4()) existing = find_session_result(resolved_session_id) if session_id else None client = AgentkitToolsClient() synced_tool_id = None - if session_id and not existing: + if resolve_tool and session_id and not existing: synced_tool_id = sync_remote_sessions( session_id=resolved_session_id, tool_id=tool_id, @@ -397,13 +425,18 @@ def ensure_sandbox_session_with_status( resolved_tool_id = synced_tool_id if not resolved_tool_id: - resolved_tool_id = resolve_sandbox_tool_id( - tool_id=tool_id, - tool_type=tool_type, - default_tool_id=existing.get("tool_id") if existing else None, - client=client, - env_var_name=SANDBOX_TOOL_ID_ENV, - ) + if resolve_tool: + resolved_tool_id = resolve_sandbox_tool_id( + tool_id=tool_id, + tool_type=tool_type, + default_tool_id=existing.get("tool_id") if existing else None, + client=client, + env_var_name=SANDBOX_TOOL_ID_ENV, + ) + else: + resolved_tool_id = (tool_id or "").strip() + if not resolved_tool_id: + error("Sandbox tool ID is required") if existing: result = _get_existing_remote_session( @@ -416,26 +449,37 @@ def ensure_sandbox_session_with_status( save_session_result(result) return result, False - synced_tool_id = sync_remote_sessions( + if resolve_tool: + synced_tool_id = sync_remote_sessions( + session_id=resolved_session_id, + tool_id=resolved_tool_id, + tool_type=tool_type, + client=client, + env_var_name=SANDBOX_TOOL_ID_ENV, + ) + if synced_tool_id: + resolved_tool_id = synced_tool_id + existing = find_session_result(resolved_session_id) + if existing: + result = _get_existing_remote_session( + client, + existing, + resolved_session_id, + resolved_tool_id, + ) + if result: + save_session_result(result) + return result, False + + if session_id and not resolve_tool: + result = _get_remote_session_by_user_session_id( + client, session_id=resolved_session_id, tool_id=resolved_tool_id, - tool_type=tool_type, - client=client, - env_var_name=SANDBOX_TOOL_ID_ENV, ) - if synced_tool_id: - resolved_tool_id = synced_tool_id - existing = find_session_result(resolved_session_id) - if existing: - result = _get_existing_remote_session( - client, - existing, - resolved_session_id, - resolved_tool_id, - ) - if result: - save_session_result(result) - return result, False + if result: + save_session_result(result) + return result, False session_envs = envs result = _create_session( @@ -444,6 +488,7 @@ def ensure_sandbox_session_with_status( resolved_tool_id, _resolve_ttl(ttl), envs=session_envs, + include_tos_mount_points=include_tos_mount_points, ) save_session_result(result) return result, True @@ -455,6 +500,8 @@ def ensure_sandbox_session( tool_type: str = DEFAULT_SANDBOX_TOOL_TYPE, ttl: Optional[int] = None, envs: Optional[list[tools_types.EnvsItemForCreateSession]] = None, + resolve_tool: bool = True, + include_tos_mount_points: bool = True, ) -> dict[str, object]: result, _is_new = ensure_sandbox_session_with_status( session_id=session_id, @@ -462,5 +509,7 @@ def ensure_sandbox_session( tool_type=tool_type, ttl=ttl, envs=envs, + resolve_tool=resolve_tool, + include_tos_mount_points=include_tos_mount_points, ) return result diff --git a/tests/toolkit/cli/test_cli_sandbox.py b/tests/toolkit/cli/test_cli_sandbox.py index 5fe6e3d..795200e 100644 --- a/tests/toolkit/cli/test_cli_sandbox.py +++ b/tests/toolkit/cli/test_cli_sandbox.py @@ -248,6 +248,38 @@ def fake_ensure_sandbox_session(session_id=None, tool_id=None, **_kwargs): ) +class _FakeA2AResponse: + def __init__(self, payload, status_code=200, text=None): + self._payload = payload + self.status_code = status_code + self.text = text if text is not None else json.dumps(payload) + + def json(self): + return self._payload + + +def _patch_invoke_session(monkeypatch, cli_invoke, session, capture=None): + def fake_ensure_sandbox_session(session_id=None, tool_id=None, **kwargs): + if capture is not None: + capture["session_id"] = session_id + capture["tool_id"] = tool_id + capture.update(kwargs) + result = dict(session) + result.setdefault("tool_id", tool_id) + return result + + monkeypatch.setattr( + cli_invoke, + "ensure_sandbox_session", + fake_ensure_sandbox_session, + ) + + +def _clear_model_agent_envs(monkeypatch, cli_invoke): + for key in cli_invoke.MODEL_AGENT_ENV_KEYS: + monkeypatch.delenv(key, raising=False) + + def test_ensure_sandbox_session_uses_env_defaults(monkeypatch, tmp_path) -> None: import agentkit.toolkit.cli.sandbox.session_create as session_create @@ -316,6 +348,73 @@ def test_ensure_sandbox_session_uses_cached_tool_by_type( assert _FakeToolsClient.last_get_tool_request.tool_id == "tool-from-cache" +def test_ensure_sandbox_session_can_skip_tool_lookup_for_resolved_tool_id( + monkeypatch, + tmp_path, +) -> None: + import agentkit.toolkit.cli.sandbox.session_create as session_create + + monkeypatch.setattr( + session_create, + "AgentkitToolsClient", + lambda: _FakeToolsClient(), + ) + _patch_store_path(monkeypatch, tmp_path) + + session_create.ensure_sandbox_session( + tool_id="SkillEnv", + tool_type="SkillEnv", + resolve_tool=False, + include_tos_mount_points=False, + ) + + assert _FakeToolsClient.get_tool_call_count == 0 + assert _FakeToolsClient.last_request.tool_id == "SkillEnv" + + +def test_ensure_sandbox_session_skip_tool_lookup_syncs_named_remote_session( + monkeypatch, + tmp_path, +) -> None: + import agentkit.toolkit.cli.sandbox.session_create as session_create + + monkeypatch.setattr( + session_create, + "AgentkitToolsClient", + lambda: _FakeToolsClient(), + ) + _patch_store_path(monkeypatch, tmp_path) + _FakeToolsClient.list_sessions_responses = [ + _FakeListSessionsResponse( + [ + _FakeSessionInfo( + user_session_id="user-1", + session_id="instance-1", + endpoint="https://sandbox.example.com/a2a", + ) + ] + ) + ] + + result = session_create.ensure_sandbox_session( + session_id="user-1", + tool_id="SkillEnv", + tool_type="SkillEnv", + resolve_tool=False, + include_tos_mount_points=False, + ) + + assert _FakeToolsClient.get_tool_call_count == 0 + assert _FakeToolsClient.create_call_count == 0 + assert _FakeToolsClient.list_sessions_call_count == 1 + assert result == { + "session_id": "user-1", + "tool_id": "SkillEnv", + "instance_id": "instance-1", + "endpoint": "https://sandbox.example.com/a2a", + } + + def test_ensure_sandbox_session_rejects_unavailable_cached_tool( monkeypatch, tmp_path, @@ -1207,6 +1306,7 @@ def test_sandbox_command_group_is_registered() -> None: assert "create" in result.output assert "exec" in result.output assert "get" in result.output + assert "invoke" in result.output assert "mount" in result.output assert "shell" in result.output assert "web" in result.output @@ -1219,6 +1319,7 @@ def test_sandbox_command_group_is_registered() -> None: ["sandbox", "shell", "--help"], ["sandbox", "web", "--help"], ["sandbox", "exec", "--help"], + ["sandbox", "invoke", "--help"], ["sandbox", "mount", "--help"], ], ) @@ -1246,6 +1347,427 @@ def test_sandbox_shell_id_option_is_disabled(args) -> None: assert "--shell-id" not in result.output +def test_build_invoke_model_agent_envs_uses_cli_values_first( + monkeypatch, +) -> None: + import agentkit.toolkit.cli.sandbox.cli_invoke as cli_invoke + + monkeypatch.setenv("MODEL_AGENT_NAME", "env-model") + monkeypatch.setenv("MODEL_AGENT_PROVIDER", "env-provider") + monkeypatch.setenv("MODEL_AGENT_API_BASE", "https://env.example.com") + monkeypatch.setenv("MODEL_AGENT_API_KEY", "env-key") + monkeypatch.setenv("MODEL_AGENT_EXTRA_HEADERS", '{"X-Env":"1"}') + + envs = cli_invoke.build_invoke_model_agent_envs( + model_name="cli-model", + model_provider="cli-provider", + model_base_url="https://cli.example.com", + model_api_key="cli-key", + ) + + assert {item.key: item.value for item in envs} == { + "MODEL_AGENT_API_BASE": "https://cli.example.com", + "MODEL_AGENT_API_KEY": "cli-key", + "MODEL_AGENT_PROVIDER": "cli-provider", + "MODEL_AGENT_NAME": "cli-model", + "MODEL_AGENT_EXTRA_HEADERS": '{"X-Env":"1"}', + } + + +def test_build_invoke_model_agent_envs_uses_env_values( + monkeypatch, + tmp_path, +) -> None: + import agentkit.toolkit.cli.sandbox.cli_invoke as cli_invoke + + monkeypatch.setenv("MODEL_AGENT_NAME", "env-model") + monkeypatch.setenv("MODEL_AGENT_PROVIDER", "env-provider") + monkeypatch.setenv("MODEL_AGENT_API_BASE", "https://env.example.com") + monkeypatch.setenv("MODEL_AGENT_API_KEY", "env-key") + monkeypatch.setenv("MODEL_AGENT_EXTRA_HEADERS", '{"X-Env":"1"}') + + envs = cli_invoke.build_invoke_model_agent_envs( + openclaw_config_file=tmp_path / "missing-openclaw.json", + ) + + assert {item.key: item.value for item in envs} == { + "MODEL_AGENT_API_BASE": "https://env.example.com", + "MODEL_AGENT_API_KEY": "env-key", + "MODEL_AGENT_PROVIDER": "env-provider", + "MODEL_AGENT_NAME": "env-model", + "MODEL_AGENT_EXTRA_HEADERS": '{"X-Env":"1"}', + } + + +def test_build_invoke_model_agent_envs_uses_openclaw_config( + monkeypatch, + tmp_path, +) -> None: + import agentkit.toolkit.cli.sandbox.cli_invoke as cli_invoke + + _clear_model_agent_envs(monkeypatch, cli_invoke) + openclaw_path = tmp_path / "openclaw.json" + openclaw_path.write_text( + json.dumps( + { + "agents": {"defaults": {"model": {"primary": "provider-a/model-a"}}}, + "models": { + "provider-a": { + "api_base": "https://openclaw.example.com", + "api_key": "openclaw-key", + "api": "openai-responses", + "headers": {"X-Provider": "provider"}, + "models": { + "model-a": { + "headers": {"X-Model": "model"}, + } + }, + } + }, + } + ), + encoding="utf-8", + ) + + envs = cli_invoke.build_invoke_model_agent_envs( + openclaw_config_file=openclaw_path, + ) + + assert {item.key: item.value for item in envs} == { + "MODEL_AGENT_API_BASE": "https://openclaw.example.com", + "MODEL_AGENT_API_KEY": "openclaw-key", + "MODEL_AGENT_PROVIDER": "openai/responses", + "MODEL_AGENT_NAME": "model-a", + "MODEL_AGENT_EXTRA_HEADERS": '{"X-Model": "model", "X-Provider": "provider"}', + } + + +def test_build_invoke_model_agent_envs_uses_empty_required_values( + monkeypatch, + tmp_path, +) -> None: + import agentkit.toolkit.cli.sandbox.cli_invoke as cli_invoke + + _clear_model_agent_envs(monkeypatch, cli_invoke) + + envs = cli_invoke.build_invoke_model_agent_envs( + openclaw_config_file=tmp_path / "missing-openclaw.json", + ) + + assert [(item.key, item.value) for item in envs] == [ + ("MODEL_AGENT_API_BASE", ""), + ("MODEL_AGENT_API_KEY", ""), + ("MODEL_AGENT_PROVIDER", ""), + ("MODEL_AGENT_NAME", ""), + ] + + +def test_cli_invoke_async_uses_env_tool_id_and_sends_a2a( + monkeypatch, +) -> None: + from agentkit.toolkit.cli.cli import app + import agentkit.toolkit.cli.sandbox.a2a_client as a2a_client + import agentkit.toolkit.cli.sandbox.cli_invoke as cli_invoke + + _clear_model_agent_envs(monkeypatch, cli_invoke) + monkeypatch.setenv("AGENTKIT_SANDBOX_TOOL_ID", "tool-env") + capture = {} + _patch_invoke_session( + monkeypatch, + cli_invoke, + { + "session_id": "session-cli", + "instance_id": "instance-cli", + "endpoint": "https://sandbox.example.com/base?Authorization=token", + }, + capture, + ) + calls = [] + + def fake_post(url, json=None, timeout=None): + calls.append({"url": url, "json": json, "timeout": timeout}) + return _FakeA2AResponse( + { + "jsonrpc": "2.0", + "id": "rpc-1", + "result": { + "kind": "task", + "id": "task-1", + "contextId": "context-1", + "status": {"state": "working"}, + }, + } + ) + + monkeypatch.setattr(a2a_client.requests, "post", fake_post) + + result = runner.invoke( + app, + [ + "sandbox", + "invoke", + "--session-id", + "session-cli", + "--prompt", + "hello", + "--async", + "true", + ], + ) + + assert result.exit_code == 0 + assert capture["session_id"] == "session-cli" + assert capture["tool_id"] == "tool-env" + assert capture["tool_type"] == "SkillEnv" + assert capture["ttl"] is None + assert capture["resolve_tool"] is False + assert capture["include_tos_mount_points"] is False + assert [(item.key, item.value) for item in capture["envs"]] == [ + ("MODEL_AGENT_API_BASE", ""), + ("MODEL_AGENT_API_KEY", ""), + ("MODEL_AGENT_PROVIDER", ""), + ("MODEL_AGENT_NAME", ""), + ] + assert len(calls) == 1 + call = calls[0] + assert call["url"] == "https://sandbox.example.com/base/a2a?Authorization=token" + assert call["json"]["method"] == "message/send" + assert call["json"]["params"]["message"]["parts"] == [ + {"kind": "text", "text": "hello"} + ] + assert call["json"]["params"]["metadata"] == { + "session_id": "session-cli", + "user_id": "agentkit-sandbox-invoke", + } + payload = json.loads(result.output) + assert payload["ok"] is True + assert payload["task_id"] == "task-1" + assert payload["context_id"] == "context-1" + assert payload["session_id"] == "session-cli" + assert payload["tool_id"] == "tool-env" + + +def test_cli_invoke_sync_falls_back_to_tool_type_and_polls( + monkeypatch, +) -> None: + from agentkit.toolkit.cli.cli import app + import agentkit.toolkit.cli.sandbox.a2a_client as a2a_client + import agentkit.toolkit.cli.sandbox.cli_invoke as cli_invoke + + _clear_model_agent_envs(monkeypatch, cli_invoke) + monkeypatch.delenv("AGENTKIT_SANDBOX_TOOL_ID", raising=False) + capture = {} + _patch_invoke_session( + monkeypatch, + cli_invoke, + { + "session_id": "generated-session", + "instance_id": "instance-cli", + "endpoint": "https://sandbox.example.com", + }, + capture, + ) + calls = [] + + def fake_post(url, json=None, timeout=None): + calls.append({"url": url, "json": json, "timeout": timeout}) + method = json["method"] + if method == "message/send": + return _FakeA2AResponse( + { + "jsonrpc": "2.0", + "id": "rpc-send", + "result": { + "kind": "task", + "id": "task-sync", + "contextId": "context-sync", + "status": {"state": "working"}, + }, + } + ) + return _FakeA2AResponse( + { + "jsonrpc": "2.0", + "id": "rpc-get", + "result": { + "kind": "task", + "id": "task-sync", + "contextId": "context-sync", + "status": {"state": "completed"}, + "artifacts": [ + {"parts": [{"kind": "text", "text": "done"}]}, + ], + }, + } + ) + + monkeypatch.setattr(a2a_client.requests, "post", fake_post) + + result = runner.invoke( + app, + ["sandbox", "invoke", "--prompt", "run it", "--ttl", "123"], + ) + + assert result.exit_code == 0 + assert capture["session_id"] is None + assert capture["tool_id"] == "SkillEnv" + assert capture["tool_type"] == "SkillEnv" + assert capture["ttl"] == 123 + assert capture["resolve_tool"] is False + assert capture["include_tos_mount_points"] is False + assert [(item.key, item.value) for item in capture["envs"]] == [ + ("MODEL_AGENT_API_BASE", ""), + ("MODEL_AGENT_API_KEY", ""), + ("MODEL_AGENT_PROVIDER", ""), + ("MODEL_AGENT_NAME", ""), + ] + assert [call["json"]["method"] for call in calls] == [ + "message/send", + "tasks/get", + ] + assert calls[1]["json"]["params"]["id"] == "task-sync" + payload = json.loads(result.output) + assert payload["ok"] is True + assert payload["task_state"] == "completed" + assert payload["final_result"] == "done" + assert payload["task_id"] == "task-sync" + + +def test_cli_invoke_passes_model_agent_envs_from_options( + monkeypatch, +) -> None: + from agentkit.toolkit.cli.cli import app + import agentkit.toolkit.cli.sandbox.a2a_client as a2a_client + import agentkit.toolkit.cli.sandbox.cli_invoke as cli_invoke + + _clear_model_agent_envs(monkeypatch, cli_invoke) + capture = {} + _patch_invoke_session( + monkeypatch, + cli_invoke, + { + "session_id": "session-cli", + "instance_id": "instance-cli", + "endpoint": "https://sandbox.example.com", + }, + capture, + ) + + monkeypatch.setattr( + a2a_client.requests, + "post", + lambda *_args, **_kwargs: _FakeA2AResponse( + { + "jsonrpc": "2.0", + "id": "rpc-1", + "result": { + "kind": "task", + "id": "task-1", + "status": {"state": "working"}, + }, + } + ), + ) + + result = runner.invoke( + app, + [ + "sandbox", + "invoke", + "--prompt", + "hello", + "--async", + "--model-name", + "model-cli", + "--model-provider", + "provider-cli", + "--model-base-url", + "https://models.example.com", + "--model-api-key", + "key-cli", + ], + ) + + assert result.exit_code == 0 + assert {item.key: item.value for item in capture["envs"]} == { + "MODEL_AGENT_API_BASE": "https://models.example.com", + "MODEL_AGENT_API_KEY": "key-cli", + "MODEL_AGENT_PROVIDER": "provider-cli", + "MODEL_AGENT_NAME": "model-cli", + } + + +def test_cli_invoke_task_id_polls_without_prompt( + monkeypatch, +) -> None: + from agentkit.toolkit.cli.cli import app + import agentkit.toolkit.cli.sandbox.a2a_client as a2a_client + import agentkit.toolkit.cli.sandbox.cli_invoke as cli_invoke + + capture = {} + _patch_invoke_session( + monkeypatch, + cli_invoke, + { + "session_id": "session-cli", + "instance_id": "instance-cli", + "endpoint": "https://sandbox.example.com", + }, + capture, + ) + calls = [] + + def fake_post(url, json=None, timeout=None): + calls.append({"url": url, "json": json, "timeout": timeout}) + return _FakeA2AResponse( + { + "jsonrpc": "2.0", + "id": "rpc-get", + "result": { + "kind": "task", + "id": "task-existing", + "status": {"state": "completed"}, + "artifacts": [ + {"parts": [{"kind": "text", "text": "existing result"}]}, + ], + }, + } + ) + + monkeypatch.setattr(a2a_client.requests, "post", fake_post) + + result = runner.invoke( + app, + [ + "sandbox", + "invoke", + "--sid", + "session-cli", + "--task-id", + "task-existing", + ], + ) + + assert result.exit_code == 0 + assert capture["session_id"] == "session-cli" + assert len(calls) == 1 + assert calls[0]["json"]["method"] == "tasks/get" + assert calls[0]["json"]["params"]["id"] == "task-existing" + payload = json.loads(result.output) + assert payload["ok"] is True + assert payload["task_id"] == "task-existing" + assert payload["final_result"] == "existing result" + + +def test_cli_invoke_requires_prompt_without_task_id() -> None: + from agentkit.toolkit.cli.cli import app + + result = runner.invoke(app, ["sandbox", "invoke"]) + + assert result.exit_code == 1 + assert "--prompt is required unless --task-id is provided" in result.output + + def test_sandbox_exec_tos_mount_option_is_disabled() -> None: from agentkit.toolkit.cli.cli import app From 3a1056ce494845bd34e0b10b6144e21c1d21e551 Mon Sep 17 00:00:00 2001 From: "shijinyu.7" Date: Thu, 2 Jul 2026 19:04:40 +0800 Subject: [PATCH 3/4] refactor: centralize sandbox env profiles --- agentkit/toolkit/cli/sandbox/cli_create.py | 169 +---- agentkit/toolkit/cli/sandbox/cli_invoke.py | 357 +--------- agentkit/toolkit/cli/sandbox/env_config.py | 609 ++++++++++++++++++ .../toolkit/cli/sandbox/session_create.py | 143 +--- 4 files changed, 626 insertions(+), 652 deletions(-) create mode 100644 agentkit/toolkit/cli/sandbox/env_config.py diff --git a/agentkit/toolkit/cli/sandbox/cli_create.py b/agentkit/toolkit/cli/sandbox/cli_create.py index 484c059..82a6287 100644 --- a/agentkit/toolkit/cli/sandbox/cli_create.py +++ b/agentkit/toolkit/cli/sandbox/cli_create.py @@ -26,28 +26,15 @@ from agentkit.platform import VolcConfiguration from agentkit.sdk.tools.client import AgentkitToolsClient from agentkit.sdk.tools import types as tools_types +from agentkit.toolkit.cli.sandbox.env_config import ( + DEFAULT_CREATE_TOOL_TYPE, + build_create_tool_envs, +) from agentkit.toolkit.cli.sandbox.model_config import ( - ANTHROPIC_BASE_URL_ENV_KEYS, - CODE_ENV_CODEX_HOME, - CODE_ENV_HOME, - CODEX_CONFIG_TOML_ENV, - CODEX_MODEL_CATALOG_JSON_ENV, ModelProviderType, - MODEL_API_KEY_ENV, - MODEL_API_KEY_ENV_KEYS, - MODEL_BASE_URL_ENV_KEYS, - MODEL_NAME_ENV_KEYS, - MODEL_PROVIDER_ENV, infer_model_provider_from_base_url, normalize_model_base_url, normalize_model_provider, - resolve_model_base_urls, - resolve_model_name, - should_emit_codex_model_catalog, - should_emit_codex_model_config, - validate_model_provider_base_url, - build_codex_config_toml as _shared_build_codex_config_toml, - build_codex_model_catalog_json as _shared_build_codex_model_catalog_json, ) from agentkit.toolkit.cli.sandbox.tool_resolve import save_tool_result from agentkit.toolkit.cli.sandbox.tos_config import ( @@ -63,21 +50,9 @@ SANDBOX_REGION_ENV = "AGENTKIT_SANDBOX_REGION" SANDBOX_TOS_REGION_ENV = "AGENTKIT_SANDBOX_TOS_REGION" -DEFAULT_CREATE_TOOL_TYPE = "CodeEnv" DEFAULT_CPU = 4 VALID_CPU_VALUES = (2, 4, 8, 16) MEMORY_MB_PER_CPU = 2048 -DISABLED_SERVICE_ENV_KEYS = ( - "DISABLE_JUPYTER", - "DISABLE_CODE_SERVER", - "DISABLE_NODEJS_REPL", -) -BROWSER_EXTRA_ARGS_ENV = "BROWSER_EXTRA_ARGS" -DEFAULT_BROWSER_EXTRA_ARGS = ( - "--enable-unsafe-swiftshader --use-gl=angle " - "--use-angle=swiftshader-webgl --ignore-gpu-blocklist" -) -WEB_SEARCH_API_KEY_ENV = "WEB_SEARCH_API_KEY" SKILL_ROLE_NAME_OPTION = "--skill-role-name" TOOL_READY_STATUS = "Ready" TOOL_FAILED_STATUSES = {"Error", "Failed", "CreateFailed", "Deleting", "Deleted"} @@ -111,140 +86,6 @@ def _cpu_to_resource_shape(cpu: int) -> tuple[int, int]: return resolved_cpu * 1000, resolved_cpu * MEMORY_MB_PER_CPU -def _append_tool_envs( - envs: list[tools_types.EnvsItemForCreateTool], - keys: tuple[str, ...], - value: Optional[str], -) -> None: - resolved = (value or "").strip() - if not resolved: - return - - envs.extend( - tools_types.EnvsItemForCreateTool(Key=key, Value=resolved) for key in keys - ) - - -def _build_codex_config_toml( - model_name: str, - model_provider: str | ModelProviderType | None = None, - model_base_url: Optional[str] = None, -) -> str: - return _shared_build_codex_config_toml(model_name, model_provider, model_base_url) - - -def _build_codex_model_catalog_json( - model_name: str, - model_provider: str | ModelProviderType | None = None, -) -> str: - return _shared_build_codex_model_catalog_json(model_name, model_provider) - - -def _append_code_env_tool_envs( - envs: list[tools_types.EnvsItemForCreateTool], - model_name: str, - model_provider: str | ModelProviderType | None, - model_base_url: Optional[str], - *, - include_codex_model_config: bool = True, -) -> None: - code_envs = [ - tools_types.EnvsItemForCreateTool( - Key="OPENCODE_DISABLE_AUTOUPDATE", - Value="1", - ), - tools_types.EnvsItemForCreateTool( - Key="HOME", - Value=CODE_ENV_HOME, - ), - tools_types.EnvsItemForCreateTool( - Key="CODEX_HOME", - Value=CODE_ENV_CODEX_HOME, - ), - ] - if include_codex_model_config: - code_envs.append( - tools_types.EnvsItemForCreateTool( - Key=CODEX_CONFIG_TOML_ENV, - Value=_build_codex_config_toml( - model_name, - model_provider, - model_base_url, - ), - ) - ) - if should_emit_codex_model_catalog(model_provider): - code_envs.append( - tools_types.EnvsItemForCreateTool( - Key=CODEX_MODEL_CATALOG_JSON_ENV, - Value=_build_codex_model_catalog_json(model_name, model_provider), - ) - ) - envs.extend(code_envs) - - -def _build_tool_model_envs( - *, - tool_type: str, - model_name: Optional[str] = None, - model_api_key: Optional[str] = None, - model_provider: str | ModelProviderType | None = None, - model_base_url: Optional[str] = None, - model_provider_was_provided: Optional[bool] = None, - model_base_url_was_provided: Optional[bool] = None, - websearch_apikey: Optional[str] = None, -) -> list[tools_types.EnvsItemForCreateTool] | None: - envs: list[tools_types.EnvsItemForCreateTool] = [] - validate_model_provider_base_url( - model_provider=model_provider, - model_base_url=model_base_url, - model_provider_was_provided=model_provider_was_provided, - model_base_url_was_provided=model_base_url_was_provided, - ) - resolved_model_base_url = normalize_model_base_url(model_base_url) - effective_model_provider = model_provider or infer_model_provider_from_base_url( - resolved_model_base_url - ) - resolved_model_provider = normalize_model_provider(effective_model_provider) - resolved_model_name = resolve_model_name(model_name, resolved_model_provider) - resolved_base_url, resolved_anthropic_base_url = resolve_model_base_urls( - model_provider=resolved_model_provider, - model_base_url=resolved_model_base_url, - ) - resolved_model_api_key = model_api_key or os.getenv(MODEL_API_KEY_ENV) - _append_tool_envs(envs, (MODEL_PROVIDER_ENV,), resolved_model_provider) - _append_tool_envs(envs, MODEL_NAME_ENV_KEYS, resolved_model_name) - _append_tool_envs(envs, MODEL_API_KEY_ENV_KEYS, resolved_model_api_key) - _append_tool_envs( - envs, - MODEL_BASE_URL_ENV_KEYS, - resolved_base_url, - ) - _append_tool_envs( - envs, - ANTHROPIC_BASE_URL_ENV_KEYS, - resolved_anthropic_base_url, - ) - _append_tool_envs(envs, DISABLED_SERVICE_ENV_KEYS, "true") - _append_tool_envs(envs, (BROWSER_EXTRA_ARGS_ENV,), DEFAULT_BROWSER_EXTRA_ARGS) - _append_tool_envs(envs, (WEB_SEARCH_API_KEY_ENV,), websearch_apikey) - if tool_type.strip() == DEFAULT_CREATE_TOOL_TYPE: - _append_code_env_tool_envs( - envs, - resolved_model_name, - resolved_model_provider, - resolved_model_base_url, - include_codex_model_config=( - bool(resolved_model_name) - and should_emit_codex_model_config( - model_provider=resolved_model_provider, - model_base_url=resolved_model_base_url, - ) - ), - ) - return envs or None - - def _build_create_tool_request( *, tool_type: str, @@ -290,7 +131,7 @@ def _build_create_tool_request( EnablePrivateNetwork=False, ), TosMountConfig=tos_mount_config, - Envs=_build_tool_model_envs( + Envs=build_create_tool_envs( tool_type=resolved_tool_type, model_name=model_name, model_api_key=model_api_key, diff --git a/agentkit/toolkit/cli/sandbox/cli_invoke.py b/agentkit/toolkit/cli/sandbox/cli_invoke.py index 6294744..8c15d59 100644 --- a/agentkit/toolkit/cli/sandbox/cli_invoke.py +++ b/agentkit/toolkit/cli/sandbox/cli_invoke.py @@ -16,14 +16,11 @@ from __future__ import annotations -import json import os -from pathlib import Path from typing import Any, Optional import typer -from agentkit.sdk.tools import types as tools_types from agentkit.toolkit.cli.sandbox.a2a_client import ( DEFAULT_A2A_HISTORY_LENGTH, DEFAULT_A2A_PATH, @@ -36,6 +33,10 @@ task_result_text, task_state, ) +from agentkit.toolkit.cli.sandbox.env_config import ( + MODEL_AGENT_ENV_KEYS as _MODEL_AGENT_ENV_KEYS, + build_invoke_session_envs as build_invoke_model_agent_envs, +) from agentkit.toolkit.cli.sandbox.session_create import ( SANDBOX_TOOL_ID_ENV, ensure_sandbox_session, @@ -43,69 +44,7 @@ from agentkit.toolkit.cli.sandbox.sandbox_client import echo_json, error from agentkit.toolkit.cli.sandbox.tool_resolve import SandboxToolType -MODEL_AGENT_ENV_KEYS = ( - "MODEL_AGENT_API_BASE", - "MODEL_AGENT_API_KEY", - "MODEL_AGENT_PROVIDER", - "MODEL_AGENT_NAME", - "MODEL_AGENT_EXTRA_HEADERS", -) -REQUIRED_MODEL_AGENT_ENV_KEYS = ( - "MODEL_AGENT_API_BASE", - "MODEL_AGENT_API_KEY", - "MODEL_AGENT_PROVIDER", - "MODEL_AGENT_NAME", -) -OPENCLAW_CONFIG_FILE = Path("/root/.openclaw/openclaw.json") -OPENCLAW_MODEL_CONFIG_ROOTS = ( - ("models",), - ("model",), - ("modelProviders",), - ("model_providers",), - ("providers",), - ("agents", "defaults", "model"), -) -OPENCLAW_API_BASE_KEYS = ( - "api_base_url", - "apiBaseUrl", - "api_base", - "apiBase", - "base_url", - "baseURL", - "baseUrl", - "baseurl", -) -OPENCLAW_API_KEY_KEYS = ( - "api_key", - "apiKey", - "apikey", - "key", - "model_key", - "modelKey", -) -OPENCLAW_PROVIDER_API_KEYS = ( - "api", - "model_api", - "modelApi", - "api_type", - "apiType", -) -OPENCLAW_MODEL_HEADER_KEYS = ( - "headers", - "extra_headers", - "extraHeaders", -) -OPENCLAW_PROVIDER_API_ROUTES = { - "openai-completions": "openai", - "openai-responses": "openai/responses", - "openai-codex-responses": "openai/responses", - "anthropic-messages": "anthropic", - "google-generative-ai": "gemini", - "github-copilot": "github_copilot", - "bedrock-converse-stream": "bedrock/converse", - "ollama": "ollama_chat", - "azure-openai-responses": "azure/responses", -} +MODEL_AGENT_ENV_KEYS = _MODEL_AGENT_ENV_KEYS def _resolve_invoke_tool_id( @@ -124,292 +63,6 @@ def _resolve_invoke_tool_id( return tool_type.value -def build_invoke_model_agent_envs( - *, - model_name: Optional[str] = None, - model_provider: Optional[str] = None, - model_base_url: Optional[str] = None, - model_api_key: Optional[str] = None, - openclaw_config_file: Path = OPENCLAW_CONFIG_FILE, -) -> list[tools_types.EnvsItemForCreateSession]: - cli_values = { - "MODEL_AGENT_API_BASE": (model_base_url or "").strip(), - "MODEL_AGENT_API_KEY": (model_api_key or "").strip(), - "MODEL_AGENT_PROVIDER": (model_provider or "").strip(), - "MODEL_AGENT_NAME": (model_name or "").strip(), - } - env_values = _collect_model_agent_envs_from_env() - openclaw_values = _collect_openclaw_model_agent_envs(openclaw_config_file) - - values: dict[str, str] = {} - for key in MODEL_AGENT_ENV_KEYS: - values[key] = ( - cli_values.get(key) or env_values.get(key) or openclaw_values.get(key) or "" - ) - - return [ - tools_types.EnvsItemForCreateSession(key=key, value=values[key]) - for key in MODEL_AGENT_ENV_KEYS - if key in REQUIRED_MODEL_AGENT_ENV_KEYS or values[key] - ] - - -def _collect_model_agent_envs_from_env() -> dict[str, str]: - values: dict[str, str] = {} - for key in MODEL_AGENT_ENV_KEYS: - value = os.getenv(key, "").strip() - if value: - values[key] = value - return values - - -def _collect_openclaw_model_agent_envs(path: Path) -> dict[str, str]: - try: - data = json.loads(path.read_text(encoding="utf-8")) - except (FileNotFoundError, OSError, json.JSONDecodeError): - return {} - if not isinstance(data, dict): - return {} - - primary = _get_nested(data, ("agents", "defaults", "model", "primary")) - if not isinstance(primary, str): - return {} - - provider, model_name = _parse_openclaw_primary(primary) - if not provider or not model_name: - return {} - - model_config = _find_openclaw_model_config(data, provider, model_name) - if not model_config: - return {} - - api_base = _pick_openclaw_text(model_config, OPENCLAW_API_BASE_KEYS) - api_key = _pick_openclaw_text(model_config, OPENCLAW_API_KEY_KEYS) - provider_api = _pick_openclaw_text(model_config, OPENCLAW_PROVIDER_API_KEYS) - if not api_base or not api_key or not provider_api: - return {} - - litellm_provider, model_agent_name = _resolve_openclaw_provider_api_route( - provider_api, - model_name, - ) - if not litellm_provider or not model_agent_name: - return {} - - values = { - "MODEL_AGENT_API_BASE": api_base, - "MODEL_AGENT_API_KEY": api_key, - "MODEL_AGENT_PROVIDER": litellm_provider, - "MODEL_AGENT_NAME": model_agent_name, - } - extra_headers = _pick_openclaw_headers_json(model_config) - if extra_headers: - values["MODEL_AGENT_EXTRA_HEADERS"] = extra_headers - return values - - -def _resolve_openclaw_provider_api_route( - provider_api: str, - model_name: str, -) -> tuple[str, str]: - litellm_provider = OPENCLAW_PROVIDER_API_ROUTES.get(provider_api) - if not litellm_provider: - return "", "" - return litellm_provider, model_name - - -def _parse_openclaw_primary(primary: str) -> tuple[str, str]: - provider, separator, model_name = primary.strip().partition("/") - if not separator: - return "", "" - return provider.strip(), model_name.strip() - - -def _find_openclaw_model_config( - data: dict[str, Any], - provider: str, - model_name: str, -) -> dict[str, Any] | None: - for path in OPENCLAW_MODEL_CONFIG_ROOTS: - root = _get_nested(data, path) - match = _find_openclaw_model_config_in(root, provider, model_name) - if match: - return match - return _find_openclaw_model_config_in(data, provider, model_name) - - -def _find_openclaw_model_config_in( - value: Any, - provider: str, - model_name: str, -) -> dict[str, Any] | None: - if isinstance(value, dict): - direct = _openclaw_direct_model_config(value, provider, model_name) - if direct: - return direct - - if _openclaw_model_config_matches(value, provider, model_name): - return value - - for child in value.values(): - match = _find_openclaw_model_config_in(child, provider, model_name) - if match: - return match - elif isinstance(value, list): - for item in value: - match = _find_openclaw_model_config_in(item, provider, model_name) - if match: - return match - return None - - -def _openclaw_direct_model_config( - value: dict[str, Any], - provider: str, - model_name: str, -) -> dict[str, Any] | None: - direct_keys = ( - f"{provider}/{model_name}", - model_name, - ) - for key in direct_keys: - candidate = value.get(key) - if isinstance(candidate, dict): - return candidate - - provider_config = value.get(provider) - if isinstance(provider_config, dict): - candidate = provider_config.get(model_name) - if isinstance(candidate, dict): - return _merge_openclaw_model_config(provider_config, candidate) - if _openclaw_provider_config_has_model(provider_config, model_name): - model_item = _find_openclaw_model_item(provider_config, model_name) - return _merge_openclaw_model_config(provider_config, model_item) - return None - - -def _merge_openclaw_model_config( - provider_config: dict[str, Any], - model_config: dict[str, Any] | None, -) -> dict[str, Any]: - merged = {key: value for key, value in provider_config.items() if key != "models"} - if model_config: - merged.update(model_config) - provider_headers = _pick_openclaw_headers(provider_config) - model_headers = _pick_openclaw_headers(model_config) - if provider_headers or model_headers: - headers = {} - headers.update(provider_headers) - headers.update(model_headers) - merged["headers"] = headers - return merged - - -def _openclaw_provider_config_has_model( - provider_config: dict[str, Any], - model_name: str, -) -> bool: - return _find_openclaw_model_item(provider_config, model_name) is not None - - -def _find_openclaw_model_item( - provider_config: dict[str, Any], - model_name: str, -) -> dict[str, Any] | None: - models = provider_config.get("models") - if isinstance(models, dict): - candidate = models.get(model_name) - if isinstance(candidate, dict): - return candidate - for item in models.values(): - if isinstance(item, dict) and _openclaw_model_config_matches_model_name( - item, model_name - ): - return item - if isinstance(models, list): - for item in models: - if isinstance(item, dict) and _openclaw_model_config_matches_model_name( - item, model_name - ): - return item - return None - - -def _openclaw_model_config_matches_model_name(value: Any, model_name: str) -> bool: - if isinstance(value, str): - return value == model_name - if not isinstance(value, dict): - return False - - name_value = _pick_openclaw_text( - value, - ("id", "name", "model", "model_name", "modelName"), - ) - return name_value == model_name - - -def _openclaw_model_config_matches( - value: dict[str, Any], - provider: str, - model_name: str, -) -> bool: - provider_value = _pick_openclaw_text( - value, - ("provider", "provider_name", "providerName", "type"), - ) - name_value = _pick_openclaw_text( - value, - ("name", "model", "model_name", "modelName", "id"), - ) - return provider_value == provider and name_value in { - model_name, - f"{provider}/{model_name}", - } - - -def _get_nested(value: Any, path: tuple[str, ...]) -> Any: - current = value - for key in path: - if not isinstance(current, dict): - return None - current = current.get(key) - return current - - -def _pick_openclaw_text(value: dict[str, Any], keys: tuple[str, ...]) -> str | None: - for key in keys: - item = value.get(key) - if isinstance(item, str) and item.strip(): - return item.strip() - return None - - -def _pick_openclaw_headers_json(value: dict[str, Any]) -> str | None: - headers = _pick_openclaw_headers(value) - if not headers: - return None - return json.dumps(headers, ensure_ascii=False, sort_keys=True) - - -def _pick_openclaw_headers(value: dict[str, Any]) -> dict[str, str]: - for key in OPENCLAW_MODEL_HEADER_KEYS: - item = value.get(key) - if not isinstance(item, dict): - continue - headers: dict[str, str] = {} - for header_key, header_value in item.items(): - if ( - isinstance(header_key, str) - and isinstance(header_value, str) - and header_key.strip() - and header_value.strip() - ): - headers[header_key.strip()] = header_value.strip() - if headers: - return headers - return {} - - def _task_failure_error(task: dict[str, Any]) -> dict[str, str]: state = task_state(task) or "unknown" message = _status_message_text(task) or f"Sandbox task ended with state: {state}" diff --git a/agentkit/toolkit/cli/sandbox/env_config.py b/agentkit/toolkit/cli/sandbox/env_config.py new file mode 100644 index 0000000..0ba2608 --- /dev/null +++ b/agentkit/toolkit/cli/sandbox/env_config.py @@ -0,0 +1,609 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Environment variable profiles for sandbox tool and session creation.""" + +from __future__ import annotations + +from collections import OrderedDict +import json +import os +from pathlib import Path +from typing import Any, Optional + +from agentkit.sdk.tools import types as tools_types +from agentkit.toolkit.cli.sandbox.model_config import ( + ANTHROPIC_BASE_URL_ENV_KEYS, + CODE_ENV_CODEX_HOME, + CODE_ENV_HOME, + CODEX_CONFIG_TOML_ENV, + CODEX_MODEL_CATALOG_JSON_ENV, + MODEL_API_KEY_ENV, + MODEL_API_KEY_ENV_KEYS, + MODEL_BASE_URL_ENV_KEYS, + MODEL_NAME_ENV_KEYS, + MODEL_PROVIDER_ENV, + ModelProviderType, + build_codex_config_toml, + build_codex_model_catalog_json, + infer_model_provider_from_base_url, + normalize_model_base_url, + normalize_model_provider, + normalize_optional_model_provider, + resolve_model_base_urls, + resolve_model_name, + should_emit_codex_model_catalog, + should_emit_codex_model_config, + validate_model_provider_base_url, +) + +DEFAULT_CREATE_TOOL_TYPE = "CodeEnv" +DISABLED_SERVICE_ENV_KEYS = ( + "DISABLE_JUPYTER", + "DISABLE_CODE_SERVER", + "DISABLE_NODEJS_REPL", +) +BROWSER_EXTRA_ARGS_ENV = "BROWSER_EXTRA_ARGS" +DEFAULT_BROWSER_EXTRA_ARGS = ( + "--enable-unsafe-swiftshader --use-gl=angle " + "--use-angle=swiftshader-webgl --ignore-gpu-blocklist" +) +WEB_SEARCH_API_KEY_ENV = "WEB_SEARCH_API_KEY" + +MODEL_AGENT_ENV_KEYS = ( + "MODEL_AGENT_API_BASE", + "MODEL_AGENT_API_KEY", + "MODEL_AGENT_PROVIDER", + "MODEL_AGENT_NAME", + "MODEL_AGENT_EXTRA_HEADERS", +) +REQUIRED_MODEL_AGENT_ENV_KEYS = ( + "MODEL_AGENT_API_BASE", + "MODEL_AGENT_API_KEY", + "MODEL_AGENT_PROVIDER", + "MODEL_AGENT_NAME", +) +OPENCLAW_CONFIG_FILE = Path("/root/.openclaw/openclaw.json") +OPENCLAW_MODEL_CONFIG_ROOTS = ( + ("models",), + ("model",), + ("modelProviders",), + ("model_providers",), + ("providers",), + ("agents", "defaults", "model"), +) +OPENCLAW_API_BASE_KEYS = ( + "api_base_url", + "apiBaseUrl", + "api_base", + "apiBase", + "base_url", + "baseURL", + "baseUrl", + "baseurl", +) +OPENCLAW_API_KEY_KEYS = ( + "api_key", + "apiKey", + "apikey", + "key", + "model_key", + "modelKey", +) +OPENCLAW_PROVIDER_API_KEYS = ( + "api", + "model_api", + "modelApi", + "api_type", + "apiType", +) +OPENCLAW_MODEL_HEADER_KEYS = ( + "headers", + "extra_headers", + "extraHeaders", +) +OPENCLAW_PROVIDER_API_ROUTES = { + "openai-completions": "openai", + "openai-responses": "openai/responses", + "openai-codex-responses": "openai/responses", + "anthropic-messages": "anthropic", + "google-generative-ai": "gemini", + "github-copilot": "github_copilot", + "bedrock-converse-stream": "bedrock/converse", + "ollama": "ollama_chat", + "azure-openai-responses": "azure/responses", +} + + +class EnvBundle: + """Ordered environment map with small adapters for SDK request types.""" + + def __init__(self) -> None: + self._values: OrderedDict[str, str] = OrderedDict() + + def add( + self, key: str, value: Optional[str], *, include_empty: bool = False + ) -> None: + resolved = (value or "").strip() + if not resolved and not include_empty: + return + self._values[key] = resolved + + def add_many( + self, + keys: tuple[str, ...], + value: Optional[str], + *, + include_empty: bool = False, + ) -> None: + for key in keys: + self.add(key, value, include_empty=include_empty) + + def to_create_tool_envs(self) -> list[tools_types.EnvsItemForCreateTool] | None: + if not self._values: + return None + return [ + tools_types.EnvsItemForCreateTool(Key=key, Value=value) + for key, value in self._values.items() + ] + + def to_create_session_envs( + self, + ) -> list[tools_types.EnvsItemForCreateSession] | None: + if not self._values: + return None + return [ + tools_types.EnvsItemForCreateSession(key=key, value=value) + for key, value in self._values.items() + ] + + def to_required_create_session_envs( + self, + ) -> list[tools_types.EnvsItemForCreateSession]: + return [ + tools_types.EnvsItemForCreateSession(key=key, value=value) + for key, value in self._values.items() + ] + + +def build_create_tool_envs( + *, + tool_type: str, + model_name: Optional[str] = None, + model_api_key: Optional[str] = None, + model_provider: str | ModelProviderType | None = None, + model_base_url: Optional[str] = None, + model_provider_was_provided: Optional[bool] = None, + model_base_url_was_provided: Optional[bool] = None, + websearch_apikey: Optional[str] = None, +) -> list[tools_types.EnvsItemForCreateTool] | None: + """Build CreateTool.Envs for the sandbox create profile.""" + + bundle = EnvBundle() + validate_model_provider_base_url( + model_provider=model_provider, + model_base_url=model_base_url, + model_provider_was_provided=model_provider_was_provided, + model_base_url_was_provided=model_base_url_was_provided, + ) + resolved_model_base_url = normalize_model_base_url(model_base_url) + effective_model_provider = model_provider or infer_model_provider_from_base_url( + resolved_model_base_url + ) + resolved_model_provider = normalize_model_provider(effective_model_provider) + resolved_model_name = resolve_model_name(model_name, resolved_model_provider) + resolved_base_url, resolved_anthropic_base_url = resolve_model_base_urls( + model_provider=resolved_model_provider, + model_base_url=resolved_model_base_url, + ) + resolved_model_api_key = model_api_key or os.getenv(MODEL_API_KEY_ENV) + + bundle.add_many((MODEL_PROVIDER_ENV,), resolved_model_provider) + bundle.add_many(MODEL_NAME_ENV_KEYS, resolved_model_name) + bundle.add_many(MODEL_API_KEY_ENV_KEYS, resolved_model_api_key) + bundle.add_many(MODEL_BASE_URL_ENV_KEYS, resolved_base_url) + bundle.add_many(ANTHROPIC_BASE_URL_ENV_KEYS, resolved_anthropic_base_url) + bundle.add_many(DISABLED_SERVICE_ENV_KEYS, "true") + bundle.add(BROWSER_EXTRA_ARGS_ENV, DEFAULT_BROWSER_EXTRA_ARGS) + bundle.add(WEB_SEARCH_API_KEY_ENV, websearch_apikey) + + if tool_type.strip() == DEFAULT_CREATE_TOOL_TYPE: + bundle.add("OPENCODE_DISABLE_AUTOUPDATE", "1") + bundle.add("HOME", CODE_ENV_HOME) + bundle.add("CODEX_HOME", CODE_ENV_CODEX_HOME) + if resolved_model_name and should_emit_codex_model_config( + model_provider=resolved_model_provider, + model_base_url=resolved_model_base_url, + ): + bundle.add( + CODEX_CONFIG_TOML_ENV, + build_codex_config_toml( + resolved_model_name, + resolved_model_provider, + resolved_model_base_url, + ), + ) + if should_emit_codex_model_catalog(resolved_model_provider): + bundle.add( + CODEX_MODEL_CATALOG_JSON_ENV, + build_codex_model_catalog_json( + resolved_model_name, + resolved_model_provider, + ), + ) + return bundle.to_create_tool_envs() + + +def build_exec_session_envs( + *, + model_name: Optional[str] = None, + model_api_key: Optional[str] = None, + model_provider: str | ModelProviderType | None = None, + model_base_url: Optional[str] = None, + model_provider_was_provided: Optional[bool] = None, + model_base_url_was_provided: Optional[bool] = None, + include_codex_config: bool = False, + disable_websearch_apikey: bool = False, +) -> list[tools_types.EnvsItemForCreateSession] | None: + """Build CreateSession.Envs for the exec/shell CodeEnv profile.""" + + bundle = EnvBundle() + validate_model_provider_base_url( + model_provider=model_provider, + model_base_url=model_base_url, + model_provider_was_provided=model_provider_was_provided, + model_base_url_was_provided=model_base_url_was_provided, + ) + resolved_model_base_url = normalize_model_base_url(model_base_url) + effective_model_provider = model_provider or infer_model_provider_from_base_url( + resolved_model_base_url + ) + resolved_model_provider = normalize_optional_model_provider( + effective_model_provider + ) + resolved_model_name = ( + resolve_model_name(model_name, resolved_model_provider) + if resolved_model_provider + else (model_name or "").strip() + ) + resolved_base_url, resolved_anthropic_base_url = ( + resolve_model_base_urls( + model_provider=resolved_model_provider, + model_base_url=resolved_model_base_url, + ) + if resolved_model_provider or resolved_model_base_url + else (None, None) + ) + resolved_model_api_key = model_api_key or os.getenv(MODEL_API_KEY_ENV) + + bundle.add_many((MODEL_PROVIDER_ENV,), resolved_model_provider) + bundle.add_many(MODEL_NAME_ENV_KEYS, resolved_model_name) + if resolved_base_url: + bundle.add_many(MODEL_BASE_URL_ENV_KEYS, resolved_base_url) + if resolved_anthropic_base_url: + bundle.add_many(ANTHROPIC_BASE_URL_ENV_KEYS, resolved_anthropic_base_url) + if ( + include_codex_config + and resolved_model_name + and should_emit_codex_model_config( + model_provider=resolved_model_provider, + model_base_url=resolved_model_base_url, + ) + ): + bundle.add( + CODEX_CONFIG_TOML_ENV, + build_codex_config_toml( + resolved_model_name, + resolved_model_provider, + resolved_model_base_url, + ), + ) + if should_emit_codex_model_catalog(resolved_model_provider): + bundle.add( + CODEX_MODEL_CATALOG_JSON_ENV, + build_codex_model_catalog_json( + resolved_model_name, + resolved_model_provider, + ), + ) + bundle.add_many(MODEL_API_KEY_ENV_KEYS, resolved_model_api_key) + if disable_websearch_apikey: + bundle.add(WEB_SEARCH_API_KEY_ENV, "", include_empty=True) + return bundle.to_create_session_envs() + + +def build_invoke_session_envs( + *, + model_name: Optional[str] = None, + model_provider: Optional[str] = None, + model_base_url: Optional[str] = None, + model_api_key: Optional[str] = None, + openclaw_config_file: Path = OPENCLAW_CONFIG_FILE, +) -> list[tools_types.EnvsItemForCreateSession]: + """Build CreateSession.Envs for the A2A invoke SkillEnv profile.""" + + cli_values = { + "MODEL_AGENT_API_BASE": (model_base_url or "").strip(), + "MODEL_AGENT_API_KEY": (model_api_key or "").strip(), + "MODEL_AGENT_PROVIDER": (model_provider or "").strip(), + "MODEL_AGENT_NAME": (model_name or "").strip(), + } + env_values = _collect_model_agent_envs_from_env() + openclaw_values = _collect_openclaw_model_agent_envs(openclaw_config_file) + + bundle = EnvBundle() + for key in MODEL_AGENT_ENV_KEYS: + include_empty = key in REQUIRED_MODEL_AGENT_ENV_KEYS + bundle.add( + key, + cli_values.get(key) or env_values.get(key) or openclaw_values.get(key), + include_empty=include_empty, + ) + return bundle.to_required_create_session_envs() + + +def _collect_model_agent_envs_from_env() -> dict[str, str]: + values: dict[str, str] = {} + for key in MODEL_AGENT_ENV_KEYS: + value = os.getenv(key, "").strip() + if value: + values[key] = value + return values + + +def _collect_openclaw_model_agent_envs(path: Path) -> dict[str, str]: + try: + data = json.loads(path.read_text(encoding="utf-8")) + except (FileNotFoundError, OSError, json.JSONDecodeError): + return {} + if not isinstance(data, dict): + return {} + + primary = _get_nested(data, ("agents", "defaults", "model", "primary")) + if not isinstance(primary, str): + return {} + + provider, model_name = _parse_openclaw_primary(primary) + if not provider or not model_name: + return {} + + model_config = _find_openclaw_model_config(data, provider, model_name) + if not model_config: + return {} + + api_base = _pick_openclaw_text(model_config, OPENCLAW_API_BASE_KEYS) + api_key = _pick_openclaw_text(model_config, OPENCLAW_API_KEY_KEYS) + provider_api = _pick_openclaw_text(model_config, OPENCLAW_PROVIDER_API_KEYS) + if not api_base or not api_key or not provider_api: + return {} + + litellm_provider, model_agent_name = _resolve_openclaw_provider_api_route( + provider_api, + model_name, + ) + if not litellm_provider or not model_agent_name: + return {} + + values = { + "MODEL_AGENT_API_BASE": api_base, + "MODEL_AGENT_API_KEY": api_key, + "MODEL_AGENT_PROVIDER": litellm_provider, + "MODEL_AGENT_NAME": model_agent_name, + } + extra_headers = _pick_openclaw_headers_json(model_config) + if extra_headers: + values["MODEL_AGENT_EXTRA_HEADERS"] = extra_headers + return values + + +def _resolve_openclaw_provider_api_route( + provider_api: str, + model_name: str, +) -> tuple[str, str]: + litellm_provider = OPENCLAW_PROVIDER_API_ROUTES.get(provider_api) + if not litellm_provider: + return "", "" + return litellm_provider, model_name + + +def _parse_openclaw_primary(primary: str) -> tuple[str, str]: + provider, separator, model_name = primary.strip().partition("/") + if not separator: + return "", "" + return provider.strip(), model_name.strip() + + +def _find_openclaw_model_config( + data: dict[str, Any], + provider: str, + model_name: str, +) -> dict[str, Any] | None: + for path in OPENCLAW_MODEL_CONFIG_ROOTS: + root = _get_nested(data, path) + match = _find_openclaw_model_config_in(root, provider, model_name) + if match: + return match + return _find_openclaw_model_config_in(data, provider, model_name) + + +def _find_openclaw_model_config_in( + value: Any, + provider: str, + model_name: str, +) -> dict[str, Any] | None: + if isinstance(value, dict): + direct = _openclaw_direct_model_config(value, provider, model_name) + if direct: + return direct + + if _openclaw_model_config_matches(value, provider, model_name): + return value + + for child in value.values(): + match = _find_openclaw_model_config_in(child, provider, model_name) + if match: + return match + elif isinstance(value, list): + for item in value: + match = _find_openclaw_model_config_in(item, provider, model_name) + if match: + return match + return None + + +def _openclaw_direct_model_config( + value: dict[str, Any], + provider: str, + model_name: str, +) -> dict[str, Any] | None: + direct_keys = ( + f"{provider}/{model_name}", + model_name, + ) + for key in direct_keys: + candidate = value.get(key) + if isinstance(candidate, dict): + return candidate + + provider_config = value.get(provider) + if isinstance(provider_config, dict): + candidate = provider_config.get(model_name) + if isinstance(candidate, dict): + return _merge_openclaw_model_config(provider_config, candidate) + if _openclaw_provider_config_has_model(provider_config, model_name): + model_item = _find_openclaw_model_item(provider_config, model_name) + return _merge_openclaw_model_config(provider_config, model_item) + return None + + +def _merge_openclaw_model_config( + provider_config: dict[str, Any], + model_config: dict[str, Any] | None, +) -> dict[str, Any]: + merged = {key: value for key, value in provider_config.items() if key != "models"} + if model_config: + merged.update(model_config) + provider_headers = _pick_openclaw_headers(provider_config) + model_headers = _pick_openclaw_headers(model_config) + if provider_headers or model_headers: + headers = {} + headers.update(provider_headers) + headers.update(model_headers) + merged["headers"] = headers + return merged + + +def _openclaw_provider_config_has_model( + provider_config: dict[str, Any], + model_name: str, +) -> bool: + return _find_openclaw_model_item(provider_config, model_name) is not None + + +def _find_openclaw_model_item( + provider_config: dict[str, Any], + model_name: str, +) -> dict[str, Any] | None: + models = provider_config.get("models") + if isinstance(models, dict): + candidate = models.get(model_name) + if isinstance(candidate, dict): + return candidate + for item in models.values(): + if isinstance(item, dict) and _openclaw_model_config_matches_model_name( + item, model_name + ): + return item + if isinstance(models, list): + for item in models: + if isinstance(item, dict) and _openclaw_model_config_matches_model_name( + item, model_name + ): + return item + return None + + +def _openclaw_model_config_matches_model_name(value: Any, model_name: str) -> bool: + if isinstance(value, str): + return value == model_name + if not isinstance(value, dict): + return False + + name_value = _pick_openclaw_text( + value, + ("id", "name", "model", "model_name", "modelName"), + ) + return name_value == model_name + + +def _openclaw_model_config_matches( + value: dict[str, Any], + provider: str, + model_name: str, +) -> bool: + provider_value = _pick_openclaw_text( + value, + ("provider", "provider_name", "providerName", "type"), + ) + name_value = _pick_openclaw_text( + value, + ("name", "model", "model_name", "modelName", "id"), + ) + return provider_value == provider and name_value in { + model_name, + f"{provider}/{model_name}", + } + + +def _get_nested(value: Any, path: tuple[str, ...]) -> Any: + current = value + for key in path: + if not isinstance(current, dict): + return None + current = current.get(key) + return current + + +def _pick_openclaw_text(value: dict[str, Any], keys: tuple[str, ...]) -> str | None: + for key in keys: + item = value.get(key) + if isinstance(item, str) and item.strip(): + return item.strip() + return None + + +def _pick_openclaw_headers_json(value: dict[str, Any]) -> str | None: + headers = _pick_openclaw_headers(value) + if not headers: + return None + return json.dumps(headers, ensure_ascii=False, sort_keys=True) + + +def _pick_openclaw_headers(value: dict[str, Any]) -> dict[str, str]: + for key in OPENCLAW_MODEL_HEADER_KEYS: + item = value.get(key) + if not isinstance(item, dict): + continue + headers: dict[str, str] = {} + for header_key, header_value in item.items(): + if ( + isinstance(header_key, str) + and isinstance(header_value, str) + and header_key.strip() + and header_value.strip() + ): + headers[header_key.strip()] = header_value.strip() + if headers: + return headers + return {} diff --git a/agentkit/toolkit/cli/sandbox/session_create.py b/agentkit/toolkit/cli/sandbox/session_create.py index cdc0cf5..154cac5 100644 --- a/agentkit/toolkit/cli/sandbox/session_create.py +++ b/agentkit/toolkit/cli/sandbox/session_create.py @@ -26,26 +26,9 @@ AgentkitToolsClient, is_tip_agentkit_client, ) -from agentkit.toolkit.cli.sandbox.model_config import ( - ANTHROPIC_BASE_URL_ENV_KEYS, - CODEX_CONFIG_TOML_ENV, - CODEX_MODEL_CATALOG_JSON_ENV, - MODEL_API_KEY_ENV, - MODEL_API_KEY_ENV_KEYS, - MODEL_BASE_URL_ENV_KEYS, - MODEL_NAME_ENV_KEYS, - MODEL_PROVIDER_ENV, - ModelProviderType, - build_codex_config_toml, - build_codex_model_catalog_json, - infer_model_provider_from_base_url, - normalize_model_base_url, - normalize_optional_model_provider, - resolve_model_base_urls, - resolve_model_name, - should_emit_codex_model_catalog, - should_emit_codex_model_config, - validate_model_provider_base_url, +from agentkit.toolkit.cli.sandbox.env_config import ( + WEB_SEARCH_API_KEY_ENV as _WEB_SEARCH_API_KEY_ENV, + build_exec_session_envs, ) from agentkit.toolkit.cli.sandbox.session_sync import ( session_info_to_result, @@ -65,128 +48,16 @@ DEFAULT_SANDBOX_TTL = 28800 SANDBOX_TOOL_ID_ENV = "AGENTKIT_SANDBOX_TOOL_ID" SANDBOX_TTL_ENV = "AGENTKIT_SANDBOX_TTL" -WEB_SEARCH_API_KEY_ENV = "WEB_SEARCH_API_KEY" +WEB_SEARCH_API_KEY_ENV = _WEB_SEARCH_API_KEY_ENV CREATE_SESSION_START_FAIL_CODE = "ErrCreateSessionFail" CREATE_SESSION_CONFIRM_ATTEMPTS = 6 CREATE_SESSION_CONFIRM_INTERVAL_SECONDS = 5 CREATE_SESSION_READY_STATUS = "ready" -def _append_envs( - envs: list[tools_types.EnvsItemForCreateSession], - keys: tuple[str, ...], - value: Optional[str], -) -> None: - resolved = (value or "").strip() - if not resolved: - return - - envs.extend( - tools_types.EnvsItemForCreateSession(key=key, value=resolved) for key in keys - ) - - -def _append_codex_config_envs( - envs: list[tools_types.EnvsItemForCreateSession], - model_name: Optional[str], - model_provider: str | ModelProviderType | None, - model_base_url: Optional[str], -) -> None: - resolved_model_name = (model_name or "").strip() - if not resolved_model_name: - return - - envs.append( - tools_types.EnvsItemForCreateSession( - key=CODEX_CONFIG_TOML_ENV, - value=build_codex_config_toml( - resolved_model_name, - model_provider, - model_base_url, - ), - ) - ) - if should_emit_codex_model_catalog(model_provider): - envs.append( - tools_types.EnvsItemForCreateSession( - key=CODEX_MODEL_CATALOG_JSON_ENV, - value=build_codex_model_catalog_json( - resolved_model_name, - model_provider, - ), - ) - ) - - -def build_model_envs( - *, - model_name: Optional[str] = None, - model_api_key: Optional[str] = None, - model_provider: str | ModelProviderType | None = None, - model_base_url: Optional[str] = None, - model_provider_was_provided: Optional[bool] = None, - model_base_url_was_provided: Optional[bool] = None, - include_codex_config: bool = False, - disable_websearch_apikey: bool = False, -) -> list[tools_types.EnvsItemForCreateSession] | None: - envs: list[tools_types.EnvsItemForCreateSession] = [] - validate_model_provider_base_url( - model_provider=model_provider, - model_base_url=model_base_url, - model_provider_was_provided=model_provider_was_provided, - model_base_url_was_provided=model_base_url_was_provided, - ) - resolved_model_base_url = normalize_model_base_url(model_base_url) - effective_model_provider = model_provider or infer_model_provider_from_base_url( - resolved_model_base_url - ) - resolved_model_provider = normalize_optional_model_provider( - effective_model_provider - ) - resolved_model_name = ( - resolve_model_name(model_name, resolved_model_provider) - if resolved_model_provider - else (model_name or "").strip() - ) - resolved_base_url, resolved_anthropic_base_url = ( - resolve_model_base_urls( - model_provider=resolved_model_provider, - model_base_url=resolved_model_base_url, - ) - if resolved_model_provider or resolved_model_base_url - else (None, None) - ) - resolved_model_api_key = model_api_key or os.getenv(MODEL_API_KEY_ENV) - _append_envs(envs, (MODEL_PROVIDER_ENV,), resolved_model_provider) - _append_envs(envs, MODEL_NAME_ENV_KEYS, resolved_model_name) - if resolved_base_url: - _append_envs(envs, MODEL_BASE_URL_ENV_KEYS, resolved_base_url) - if resolved_anthropic_base_url: - _append_envs( - envs, - ANTHROPIC_BASE_URL_ENV_KEYS, - resolved_anthropic_base_url, - ) - if ( - include_codex_config - and resolved_model_name - and should_emit_codex_model_config( - model_provider=resolved_model_provider, - model_base_url=resolved_model_base_url, - ) - ): - _append_codex_config_envs( - envs, - resolved_model_name, - resolved_model_provider, - resolved_model_base_url, - ) - _append_envs(envs, MODEL_API_KEY_ENV_KEYS, resolved_model_api_key) - if disable_websearch_apikey: - envs.append( - tools_types.EnvsItemForCreateSession(key=WEB_SEARCH_API_KEY_ENV, value="") - ) - return envs or None +def build_model_envs(**kwargs): + """Backward-compatible wrapper for the exec session env profile.""" + return build_exec_session_envs(**kwargs) def _resolve_ttl(ttl: Optional[int]) -> int: From ebcfcc8a5431e9371316050c493cd3cfb8a79882 Mon Sep 17 00:00:00 2001 From: "shijinyu.7" Date: Thu, 2 Jul 2026 19:31:45 +0800 Subject: [PATCH 4/4] refactor: reuse sandbox model env resolution --- agentkit/toolkit/cli/sandbox/env_config.py | 183 ++++++++++++++------- 1 file changed, 119 insertions(+), 64 deletions(-) diff --git a/agentkit/toolkit/cli/sandbox/env_config.py b/agentkit/toolkit/cli/sandbox/env_config.py index 0ba2608..a415a58 100644 --- a/agentkit/toolkit/cli/sandbox/env_config.py +++ b/agentkit/toolkit/cli/sandbox/env_config.py @@ -17,6 +17,7 @@ from __future__ import annotations from collections import OrderedDict +from dataclasses import dataclass import json import os from pathlib import Path @@ -177,6 +178,86 @@ def to_required_create_session_envs( ] +@dataclass(frozen=True) +class ResolvedSandboxModelEnv: + provider: str | None + model_name: str | None + base_url: str | None + anthropic_base_url: str | None + api_key: str | None + model_base_url: str | None + + +def _resolve_sandbox_model_env( + *, + model_name: Optional[str], + model_api_key: Optional[str], + model_provider: str | ModelProviderType | None, + model_base_url: Optional[str], + model_provider_was_provided: Optional[bool], + model_base_url_was_provided: Optional[bool], + require_provider: bool, +) -> ResolvedSandboxModelEnv: + validate_model_provider_base_url( + model_provider=model_provider, + model_base_url=model_base_url, + model_provider_was_provided=model_provider_was_provided, + model_base_url_was_provided=model_base_url_was_provided, + ) + resolved_model_base_url = normalize_model_base_url(model_base_url) + effective_model_provider = model_provider or infer_model_provider_from_base_url( + resolved_model_base_url + ) + resolved_model_provider = ( + normalize_model_provider(effective_model_provider) + if require_provider + else normalize_optional_model_provider(effective_model_provider) + ) + resolved_model_name = ( + resolve_model_name(model_name, resolved_model_provider) + if resolved_model_provider + else (model_name or "").strip() + ) + should_resolve_urls = bool( + require_provider or resolved_model_provider or resolved_model_base_url + ) + resolved_base_url, resolved_anthropic_base_url = ( + resolve_model_base_urls( + model_provider=resolved_model_provider, + model_base_url=resolved_model_base_url, + ) + if should_resolve_urls + else (None, None) + ) + return ResolvedSandboxModelEnv( + provider=resolved_model_provider, + model_name=resolved_model_name, + base_url=resolved_base_url, + anthropic_base_url=resolved_anthropic_base_url, + api_key=model_api_key or os.getenv(MODEL_API_KEY_ENV), + model_base_url=resolved_model_base_url, + ) + + +def _append_standard_model_envs( + bundle: EnvBundle, + resolved: ResolvedSandboxModelEnv, + *, + include_base_urls: bool, + include_api_key: bool, + api_key_before_base_urls: bool = False, +) -> None: + bundle.add_many((MODEL_PROVIDER_ENV,), resolved.provider) + bundle.add_many(MODEL_NAME_ENV_KEYS, resolved.model_name) + if include_api_key and api_key_before_base_urls: + bundle.add_many(MODEL_API_KEY_ENV_KEYS, resolved.api_key) + if include_base_urls: + bundle.add_many(MODEL_BASE_URL_ENV_KEYS, resolved.base_url) + bundle.add_many(ANTHROPIC_BASE_URL_ENV_KEYS, resolved.anthropic_base_url) + if include_api_key and not api_key_before_base_urls: + bundle.add_many(MODEL_API_KEY_ENV_KEYS, resolved.api_key) + + def build_create_tool_envs( *, tool_type: str, @@ -191,29 +272,22 @@ def build_create_tool_envs( """Build CreateTool.Envs for the sandbox create profile.""" bundle = EnvBundle() - validate_model_provider_base_url( + resolved = _resolve_sandbox_model_env( + model_name=model_name, + model_api_key=model_api_key, model_provider=model_provider, model_base_url=model_base_url, model_provider_was_provided=model_provider_was_provided, model_base_url_was_provided=model_base_url_was_provided, + require_provider=True, ) - resolved_model_base_url = normalize_model_base_url(model_base_url) - effective_model_provider = model_provider or infer_model_provider_from_base_url( - resolved_model_base_url - ) - resolved_model_provider = normalize_model_provider(effective_model_provider) - resolved_model_name = resolve_model_name(model_name, resolved_model_provider) - resolved_base_url, resolved_anthropic_base_url = resolve_model_base_urls( - model_provider=resolved_model_provider, - model_base_url=resolved_model_base_url, + _append_standard_model_envs( + bundle, + resolved, + include_base_urls=True, + include_api_key=True, + api_key_before_base_urls=True, ) - resolved_model_api_key = model_api_key or os.getenv(MODEL_API_KEY_ENV) - - bundle.add_many((MODEL_PROVIDER_ENV,), resolved_model_provider) - bundle.add_many(MODEL_NAME_ENV_KEYS, resolved_model_name) - bundle.add_many(MODEL_API_KEY_ENV_KEYS, resolved_model_api_key) - bundle.add_many(MODEL_BASE_URL_ENV_KEYS, resolved_base_url) - bundle.add_many(ANTHROPIC_BASE_URL_ENV_KEYS, resolved_anthropic_base_url) bundle.add_many(DISABLED_SERVICE_ENV_KEYS, "true") bundle.add(BROWSER_EXTRA_ARGS_ENV, DEFAULT_BROWSER_EXTRA_ARGS) bundle.add(WEB_SEARCH_API_KEY_ENV, websearch_apikey) @@ -222,24 +296,24 @@ def build_create_tool_envs( bundle.add("OPENCODE_DISABLE_AUTOUPDATE", "1") bundle.add("HOME", CODE_ENV_HOME) bundle.add("CODEX_HOME", CODE_ENV_CODEX_HOME) - if resolved_model_name and should_emit_codex_model_config( - model_provider=resolved_model_provider, - model_base_url=resolved_model_base_url, + if resolved.model_name and should_emit_codex_model_config( + model_provider=resolved.provider, + model_base_url=resolved.model_base_url, ): bundle.add( CODEX_CONFIG_TOML_ENV, build_codex_config_toml( - resolved_model_name, - resolved_model_provider, - resolved_model_base_url, + resolved.model_name, + resolved.provider, + resolved.model_base_url, ), ) - if should_emit_codex_model_catalog(resolved_model_provider): + if should_emit_codex_model_catalog(resolved.provider): bundle.add( CODEX_MODEL_CATALOG_JSON_ENV, build_codex_model_catalog_json( - resolved_model_name, - resolved_model_provider, + resolved.model_name, + resolved.provider, ), ) return bundle.to_create_tool_envs() @@ -259,65 +333,46 @@ def build_exec_session_envs( """Build CreateSession.Envs for the exec/shell CodeEnv profile.""" bundle = EnvBundle() - validate_model_provider_base_url( + resolved = _resolve_sandbox_model_env( + model_name=model_name, + model_api_key=model_api_key, model_provider=model_provider, model_base_url=model_base_url, model_provider_was_provided=model_provider_was_provided, model_base_url_was_provided=model_base_url_was_provided, + require_provider=False, ) - resolved_model_base_url = normalize_model_base_url(model_base_url) - effective_model_provider = model_provider or infer_model_provider_from_base_url( - resolved_model_base_url + _append_standard_model_envs( + bundle, + resolved, + include_base_urls=bool(resolved.base_url or resolved.anthropic_base_url), + include_api_key=False, ) - resolved_model_provider = normalize_optional_model_provider( - effective_model_provider - ) - resolved_model_name = ( - resolve_model_name(model_name, resolved_model_provider) - if resolved_model_provider - else (model_name or "").strip() - ) - resolved_base_url, resolved_anthropic_base_url = ( - resolve_model_base_urls( - model_provider=resolved_model_provider, - model_base_url=resolved_model_base_url, - ) - if resolved_model_provider or resolved_model_base_url - else (None, None) - ) - resolved_model_api_key = model_api_key or os.getenv(MODEL_API_KEY_ENV) - - bundle.add_many((MODEL_PROVIDER_ENV,), resolved_model_provider) - bundle.add_many(MODEL_NAME_ENV_KEYS, resolved_model_name) - if resolved_base_url: - bundle.add_many(MODEL_BASE_URL_ENV_KEYS, resolved_base_url) - if resolved_anthropic_base_url: - bundle.add_many(ANTHROPIC_BASE_URL_ENV_KEYS, resolved_anthropic_base_url) if ( include_codex_config - and resolved_model_name + and resolved.model_name and should_emit_codex_model_config( - model_provider=resolved_model_provider, - model_base_url=resolved_model_base_url, + model_provider=resolved.provider, + model_base_url=resolved.model_base_url, ) ): bundle.add( CODEX_CONFIG_TOML_ENV, build_codex_config_toml( - resolved_model_name, - resolved_model_provider, - resolved_model_base_url, + resolved.model_name, + resolved.provider, + resolved.model_base_url, ), ) - if should_emit_codex_model_catalog(resolved_model_provider): + if should_emit_codex_model_catalog(resolved.provider): bundle.add( CODEX_MODEL_CATALOG_JSON_ENV, build_codex_model_catalog_json( - resolved_model_name, - resolved_model_provider, + resolved.model_name, + resolved.provider, ), ) - bundle.add_many(MODEL_API_KEY_ENV_KEYS, resolved_model_api_key) + bundle.add_many(MODEL_API_KEY_ENV_KEYS, resolved.api_key) if disable_websearch_apikey: bundle.add(WEB_SEARCH_API_KEY_ENV, "", include_empty=True) return bundle.to_create_session_envs()