From 27169e11e1f43be2e06974ab37d472daae804f64 Mon Sep 17 00:00:00 2001 From: "liyi.ly" Date: Fri, 3 Jul 2026 15:15:02 +0800 Subject: [PATCH 1/3] fix: address NFR review findings across security/observability/reliability/perf MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Observability & routing: - agent_server_app: remove dead duplicate /run_sse registration (leftover debug copy with print(); the telemetry-instrumented handler registered first and moved to routes[0] is the live one) - simple_app telemetry: redact credential headers (authorization/token) before recording gen_ai.request.headers on spans — they bypassed the logging RedactionFilter; mirrors agent_server middleware _EXCLUDED_HEADERS Security/concurrency & performance (ve_sign): - thread signing scope (service/version/region/host/content_type/scheme) through as parameters instead of mutating module globals — concurrent calls for different services/regions could sign with each other's scope - reuse a shared requests.Session for pooled keep-alive connections - add golden-vector tests for the signing algorithm (canonical request, HMAC chain, Authorization header) and per-call-scope isolation Reliability: - utils.misc.retry: exponential backoff with cap, explicit retry_on parameter, document the idempotency contract - ve_agentkit runtime polling: mild backoff (3s -> 10s cap) Performance (memory): - code_pipeline step-log download and ve_agentkit failure-log download now stream to disk chunk-wise; failure-log read-back bounded to displayed lines Test hygiene: - tests/platform clean_env: snapshot/restore full environ so raw os.environ writes cannot leak across tests (order-dependent failures) - global_config_io: leave debug traces on best-effort fallbacks (chmod 0600 / mtime cache), add module logger Not addressed on purpose: CORS default allow_origins=["*"] and missing default endpoint auth are behavior/product decisions (breaking for existing deployments) — flagged in the NFR report for a separate decision. Pre-existing failures (also fail on the base commit, standalone): test_cli_config_interactive_provider_resolution (2) and test_executor_platform_context_provider — order-dependent platform-context tests previously masked by env leakage; left for a follow-up. Suite: 857 passed, 3 pre-existing failures (see above). --- .../apps/agent_server_app/agent_server_app.py | 72 -------- agentkit/apps/simple_app/telemetry.py | 11 +- agentkit/toolkit/runners/ve_agentkit.py | 32 ++-- agentkit/toolkit/volcengine/code_pipeline.py | 24 ++- agentkit/utils/global_config_io.py | 12 +- agentkit/utils/misc.py | 15 +- agentkit/utils/ve_sign.py | 86 ++++++--- tests/platform/conftest.py | 11 +- .../runners/test_ve_agentkit_lifecycle.py | 19 +- tests/utils/test_engineering_standards.py | 6 +- tests/utils/test_ve_sign_signing.py | 167 ++++++++++++++++++ 11 files changed, 322 insertions(+), 133 deletions(-) create mode 100644 tests/utils/test_ve_sign_signing.py diff --git a/agentkit/apps/agent_server_app/agent_server_app.py b/agentkit/apps/agent_server_app/agent_server_app.py index 34d8a26..f342c63 100644 --- a/agentkit/apps/agent_server_app/agent_server_app.py +++ b/agentkit/apps/agent_server_app/agent_server_app.py @@ -280,78 +280,6 @@ async def event_generator(): routes.insert(0, routes.pop(i)) break - @self.app.post("/run_sse") - async def run_agent_sse(req: RunAgentRequest) -> StreamingResponse: - print("my run sse !!!") - # SSE endpoint - session = await self.server.session_service.get_session( - app_name=req.app_name, - user_id=req.user_id, - session_id=req.session_id, - ) - if not session: - raise HTTPException(status_code=404, detail="Session not found") - - # Convert the events to properly formatted SSE - async def event_generator(): - try: - stream_mode = ( - StreamingMode.SSE - if req.streaming - else StreamingMode.NONE - ) - runner = await self.server.get_runner_async(req.app_name) - async with Aclosing( - runner.run_async( - user_id=req.user_id, - session_id=req.session_id, - new_message=req.new_message, - state_delta=req.state_delta, - run_config=RunConfig(streaming_mode=stream_mode), - invocation_id=req.invocation_id, - ) - ) as agen: - async for event in agen: - # ADK Web renders artifacts from `actions.artifactDelta` - # during part processing *and* during action processing - # 1) the original event with `artifactDelta` cleared (content) - # 2) a content-less "action-only" event carrying `artifactDelta` - events_to_stream = [event] - if ( - event.actions.artifact_delta - and event.content - and event.content.parts - ): - content_event = event.model_copy(deep=True) - content_event.actions.artifact_delta = {} - artifact_event = event.model_copy(deep=True) - artifact_event.content = None - events_to_stream = [ - content_event, - artifact_event, - ] - - for event_to_stream in events_to_stream: - sse_event = event_to_stream.model_dump_json( - exclude_none=True, - by_alias=True, - ) - logger.debug( - "Generated event in agent run streaming: %s", - sse_event, - ) - yield f"data: {sse_event}\n\n" - except Exception as e: - logger.exception("Error in event_generator: %s", e) - yield f"data: {json.dumps({'error': str(e)})}\n\n" - - # Returns a streaming response with the proper media type for SSE - - return StreamingResponse( - event_generator(), - media_type="text/event-stream", - ) - # Attach ASGI middleware for unified telemetry across all routes self.app.add_middleware(AgentkitTelemetryHTTPMiddleware) diff --git a/agentkit/apps/simple_app/telemetry.py b/agentkit/apps/simple_app/telemetry.py index b6eb652..9dea82a 100644 --- a/agentkit/apps/simple_app/telemetry.py +++ b/agentkit/apps/simple_app/telemetry.py @@ -46,6 +46,14 @@ logger = logging.getLogger("agentkit." + __name__) +# Keep in sync with agent_server_app.middleware._EXCLUDED_HEADERS: credential +# headers must never be recorded on spans (they bypass the logging redaction). +_EXCLUDED_HEADERS = {"authorization", "token"} + + +def _redact_headers(headers: dict) -> dict: + return {k: v for k, v in headers.items() if k.lower() not in _EXCLUDED_HEADERS} + def dont_throw(func): """ @@ -105,7 +113,8 @@ def trace_agent( span.set_attribute(key="gen_ai.func_name", value=func.__name__) span.set_attribute( - key="gen_ai.request.headers", value=safe_serialize_to_json_string(headers) + key="gen_ai.request.headers", + value=safe_serialize_to_json_string(_redact_headers(headers)), ) session_id = headers.get("session_id") if session_id: diff --git a/agentkit/toolkit/runners/ve_agentkit.py b/agentkit/toolkit/runners/ve_agentkit.py index 7fdb353..fe4d0c6 100644 --- a/agentkit/toolkit/runners/ve_agentkit.py +++ b/agentkit/toolkit/runners/ve_agentkit.py @@ -903,22 +903,29 @@ def _download_and_show_runtime_failed_logs( self.reporter.info("Downloading failure logs...") try: - log_response = requests.get(runtime.failed_log_file_url, timeout=30) - log_response.raise_for_status() - # Create logs directory with timestamp-based filename for uniqueness log_dir = os.path.join(os.getcwd(), ".agentkit", "logs") os.makedirs(log_dir, exist_ok=True) log_filename = f"runtime_failed_{runtime_id}_{int(time.time())}.log" log_filepath = os.path.join(log_dir, log_filename) - # Save raw log content first - with open(log_filepath, "wb") as f: - f.write(log_response.content) - - # Read back with error handling for encoding issues + # Stream raw log content to disk so large logs don't spike memory + with requests.get( + runtime.failed_log_file_url, timeout=30, stream=True + ) as log_response: + log_response.raise_for_status() + with open(log_filepath, "wb") as f: + for chunk in log_response.iter_content(chunk_size=65536): + if chunk: + f.write(chunk) + + # Read back only what gets displayed, with encoding error handling + lines = [] with open(log_filepath, "r", encoding="utf-8", errors="ignore") as f: - lines = f.readlines() + for line in f: + lines.append(line) + if len(lines) >= 50: + break self.reporter.show_logs( title="Runtime Failure Logs (First 50 lines)", lines=lines, max_lines=50 @@ -992,6 +999,10 @@ def _wait_for_runtime_status_multiple( # Use reporter.long_task() for progress tracking client = self._get_runtime_client(region) + # Poll with mild backoff (3s -> 10s cap) to avoid hammering the + # control plane at a constant rate on long-running deploys. + poll_interval = 3.0 + with self.reporter.long_task(task_description, total=total_time) as task: while True: runtime = retry( @@ -1031,7 +1042,8 @@ def _wait_for_runtime_status_multiple( ) ) - time.sleep(3) + time.sleep(poll_interval) + poll_interval = min(poll_interval * 1.5, 10.0) def _needs_runtime_update( self, runtime: runtime_types.GetRuntimeResponse, config: VeAgentkitRunnerConfig diff --git a/agentkit/toolkit/volcengine/code_pipeline.py b/agentkit/toolkit/volcengine/code_pipeline.py index c81ed8f..3801d34 100644 --- a/agentkit/toolkit/volcengine/code_pipeline.py +++ b/agentkit/toolkit/volcengine/code_pipeline.py @@ -914,14 +914,22 @@ def download_and_merge_pipeline_logs( step_name=step_name, ) - # Download the log content - response = requests.get(log_url, timeout=30) - response.raise_for_status() - - log_content = response.text - out_file.write(log_content) - - if not log_content.endswith("\n"): + # Download the log content, streaming chunks to + # disk so peak memory stays flat for large logs. + with requests.get( + log_url, timeout=30, stream=True + ) as response: + response.raise_for_status() + response.encoding = response.encoding or "utf-8" + last_chunk = "" + for chunk in response.iter_content( + chunk_size=65536, decode_unicode=True + ): + if chunk: + out_file.write(chunk) + last_chunk = chunk + + if not last_chunk.endswith("\n"): out_file.write("\n") successful_downloads += 1 diff --git a/agentkit/utils/global_config_io.py b/agentkit/utils/global_config_io.py index 6cd08ea..c236ac2 100644 --- a/agentkit/utils/global_config_io.py +++ b/agentkit/utils/global_config_io.py @@ -14,9 +14,11 @@ from __future__ import annotations +import logging from pathlib import Path from typing import Any, Optional, Tuple +logger = logging.getLogger("agentkit." + __name__) _cache: Tuple[Optional[float], dict] = (None, {}) @@ -81,14 +83,16 @@ def write_global_config_dict( try: path.chmod(0o600) - except Exception: - pass + except Exception as e: + # Security-relevant fallback: the config may hold credentials, so + # leave a trace when tightening permissions fails. + logger.debug("chmod 0600 on %s failed, continuing: %s", path, e) try: mtime = path.stat().st_mtime _cache = (mtime, data) - except Exception: - pass + except Exception as e: + logger.debug("mtime cache refresh for %s failed: %s", path, e) def get_path_value(data: Any, *keys: str) -> Any: diff --git a/agentkit/utils/misc.py b/agentkit/utils/misc.py index 9d921c3..e23833e 100644 --- a/agentkit/utils/misc.py +++ b/agentkit/utils/misc.py @@ -112,11 +112,22 @@ def retry( func: Callable[[], T], retries: int = 3, delay: float = 1.0, + max_delay: float = 10.0, + retry_on: tuple = (Exception,), ) -> T: + """Retry ``func`` with exponential backoff. + + Only wrap idempotent operations (reads): a retried non-idempotent call + (Create*/Run*) would be re-executed on transient failure. Prefer passing + a narrow ``retry_on`` (connection/timeout error types) so non-retryable + failures such as auth or validation errors surface immediately instead of + burning the full retry budget. + """ for attempt in range(retries): try: return func() - except Exception: # noqa: BLE001 + except retry_on: # noqa: BLE001 if attempt == retries - 1: raise - time.sleep(delay) + time.sleep(min(delay * (2**attempt), max_delay)) + raise RuntimeError("retry: retries must be >= 1") diff --git a/agentkit/utils/ve_sign.py b/agentkit/utils/ve_sign.py index 25947ee..e03c783 100644 --- a/agentkit/utils/ve_sign.py +++ b/agentkit/utils/ve_sign.py @@ -41,6 +41,11 @@ logger = get_logger(__name__) +# Legacy module-level defaults. ``ve_request``/``request`` now thread these +# values through as function parameters (module globals were mutated per call, +# which is not thread-safe: concurrent calls for different services/regions +# could sign with each other's scope). Kept only as fallbacks for any direct +# ``request()`` callers relying on the old behavior. Service = "" Version = "" Region = "" @@ -48,6 +53,11 @@ ContentType = "" Scheme = "https" +# Shared session so signed OpenAPI calls reuse pooled keep-alive connections +# instead of re-doing a TCP+TLS handshake per request. ``Session.request`` is +# thread-safe for concurrent use. +_session = requests.Session() + MAX_X_CUSTOM_SOURCE_LENGTH = 256 @@ -90,7 +100,7 @@ def _signed_request(method, url, headers, params, data) -> requests.Response: resp: requests.Response | None = None for attempt in range(retries + 1): try: - resp = requests.request( + resp = _session.request( method=method, url=url, headers=headers, @@ -224,25 +234,47 @@ def hash_sha256(content: str): # 第二步:签名请求函数 -def request(method, date, query, header, ak, sk, action, body): +def request( + method, + date, + query, + header, + ak, + sk, + action, + body, + service=None, + version=None, + region=None, + host=None, + content_type=None, + scheme=None, +): + # 签名 scope 参数直传(None 时回退到模块级默认,兼容旧直接调用方式) + service = Service if service is None else service + version = Version if version is None else version + region = Region if region is None else region + host = Host if host is None else host + content_type = ContentType if content_type is None else content_type + scheme = Scheme if scheme is None else scheme # 第三步:创建身份证明。其中的 Service 和 Region 字段是固定的。ak 和 sk 分别代表 # AccessKeyID 和 SecretAccessKey。同时需要初始化签名结构体。一些签名计算时需要的属性也在这里处理。 # 初始化身份证明结构体 credential = { "access_key_id": ak, "secret_access_key": sk, - "service": Service, - "region": Region, + "service": service, + "region": region, } # 初始化签名结构体 request_param = { "body": body, - "host": Host, + "host": host, "path": "/", "method": method, - "content_type": ContentType, + "content_type": content_type, "date": date, - "query": {"Action": action, "Version": Version, **query}, + "query": {"Action": action, "Version": version, **query}, } if body is None: request_param["body"] = "" @@ -315,7 +347,7 @@ def request(method, date, query, header, ak, sk, action, body): # 第六步:将 Signature 签名写入 HTTP Header 中,并发送 HTTP 请求。 r = _signed_request( method=method, - url=f"{Scheme}://{request_param['host']}{request_param['path']}", + url=f"{scheme}://{request_param['host']}{request_param['path']}", headers=header, params=request_param["query"], data=request_param["body"], @@ -336,33 +368,29 @@ def ve_request( content_type: str = "application/json", scheme: str = "https", ): - # response_body = request("Get", datetime.datetime.utcnow(), {}, {}, AK, SK, "ListUsers", None) - # print(response_body) - # 以下参数视服务不同而不同,一个服务内通常是一致的 - global Service - Service = service - global Version - Version = version - global Region - Region = region - global Host - Host = host - global ContentType - ContentType = content_type - global Scheme - Scheme = scheme or "https" - - AK = ak - SK = sk - + # 以下参数视服务不同而不同,一个服务内通常是一致的。 + # 注意:签名 scope 以参数直传,不再写模块级全局(并发下不同 service/region + # 的调用互不串用)。 now = datetime.datetime.utcnow() # Body的格式需要配合Content-Type,API使用的类型请阅读具体的官方文档,如:json格式需要json.dumps(obj) - # response_body = request("GET", now, {"Limit": "2"}, {}, AK, SK, "ListUsers", None) import json response_body = request( - "POST", now, {}, header or {}, AK, SK, action, json.dumps(request_body) + "POST", + now, + {}, + header or {}, + ak, + sk, + action, + json.dumps(request_body), + service=service, + version=version, + region=region, + host=host, + content_type=content_type, + scheme=scheme or "https", ) check_error(response_body) return response_body diff --git a/tests/platform/conftest.py b/tests/platform/conftest.py index 3c19d5f..9c8a65d 100644 --- a/tests/platform/conftest.py +++ b/tests/platform/conftest.py @@ -18,7 +18,13 @@ @pytest.fixture def clean_env(monkeypatch): - """Clean up environment variables that may affect platform resolution.""" + """Clean up environment variables that may affect platform resolution. + + Snapshots the full environment and restores it on teardown, so tests that + set variables via raw ``os.environ[...]`` (instead of monkeypatch.setenv) + cannot leak into later tests and create order-dependent failures. + """ + snapshot = os.environ.copy() for key in list(os.environ.keys()): if ( key.startswith("VOLC") @@ -30,6 +36,9 @@ def clean_env(monkeypatch): } ): monkeypatch.delenv(key) + yield + os.environ.clear() + os.environ.update(snapshot) @pytest.fixture diff --git a/tests/toolkit/runners/test_ve_agentkit_lifecycle.py b/tests/toolkit/runners/test_ve_agentkit_lifecycle.py index a7d0eea..807b677 100644 --- a/tests/toolkit/runners/test_ve_agentkit_lifecycle.py +++ b/tests/toolkit/runners/test_ve_agentkit_lifecycle.py @@ -173,6 +173,16 @@ def __init__(self, content=b"log line 1\nlog line 2\n"): def raise_for_status(self): pass + # The failure-log download streams via a context manager + iter_content. + def __enter__(self): + return self + + def __exit__(self, *exc): + return False + + def iter_content(self, chunk_size=None): + yield self.content + # --------------------------------------------------------------------------- # Shared fixtures @@ -327,8 +337,8 @@ def test_create_new_runtime_init_failure_downloads_logs_and_cleans_up_when_confi get_calls = [] - def _fake_get(url, timeout=None): - get_calls.append((url, timeout)) + def _fake_get(url, timeout=None, stream=None): + get_calls.append((url, timeout, stream)) return _FakeResponse() monkeypatch.setattr(mod.requests, "get", _fake_get) @@ -341,8 +351,9 @@ def _fake_get(url, timeout=None): assert result.error_code == ErrorCode.RUNTIME_NOT_READY assert result.service_id == "rt-bad" - # Failure-log download was attempted against the failed_log_file_url. - assert get_calls == [("https://logs.example/failed.log", 30)] + # Failure-log download was attempted against the failed_log_file_url, + # streamed to disk (stream=True) with a bounded timeout. + assert get_calls == [("https://logs.example/failed.log", 30, True)] assert reporter.show_logs_calls # logs were shown # A cleanup confirmation was requested, defaulting to False. assert reporter.confirm_calls and reporter.confirm_calls[0][1] is False diff --git a/tests/utils/test_engineering_standards.py b/tests/utils/test_engineering_standards.py index a97bce4..2f90642 100644 --- a/tests/utils/test_engineering_standards.py +++ b/tests/utils/test_engineering_standards.py @@ -52,14 +52,16 @@ def no_sleep(monkeypatch): def _count_request(monkeypatch, exc_factory): - """Patch ``requests.request`` to always raise; return an attempt counter.""" + """Patch the shared session's ``request`` to always raise; return an + attempt counter. (``_signed_request`` goes through ``ve_sign._session`` + for connection pooling, so that is the seam to patch.)""" counter = {"attempts": 0} def _fake_request(**_kwargs): counter["attempts"] += 1 raise exc_factory() - monkeypatch.setattr(ve_sign.requests, "request", _fake_request) + monkeypatch.setattr(ve_sign._session, "request", _fake_request) return counter diff --git a/tests/utils/test_ve_sign_signing.py b/tests/utils/test_ve_sign_signing.py new file mode 100644 index 0000000..6675352 --- /dev/null +++ b/tests/utils/test_ve_sign_signing.py @@ -0,0 +1,167 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Golden-vector tests for the ve_sign signing algorithm. + +The signing path (canonical request -> string-to-sign -> HMAC chain -> +Authorization header) previously had zero direct coverage: a signing bug is a +silent auth failure against the OpenAPI. These tests pin the algorithm with a +fixed timestamp and fixed credentials so any change to canonicalization, +scope, or the HMAC chain shows up as a diff against known-good vectors. + +They also pin the thread-safety contract of the parameterized refactor: two +calls with different service/region scopes must not bleed into each other +(the old implementation passed scope via mutable module globals). + +No network is performed: ``ve_sign._signed_request`` is stubbed out. +""" + +import datetime +import json + +import pytest + +from agentkit.utils import ve_sign + + +FIXED_DATE = datetime.datetime(2026, 1, 2, 3, 4, 5) +SCOPE = dict( + service="test-svc", + version="2024-01-01", + region="cn-test", + host="open.test.example.com", + content_type="application/json", + scheme="https", +) + + +@pytest.fixture +def capture_signed_request(monkeypatch): + """Stub the transport; capture what the signer hands to it.""" + captured = {} + + def _fake(method, url, headers, params, data): + captured.update( + method=method, url=url, headers=headers, params=params, data=data + ) + + class _Resp: + @staticmethod + def json(): + return {"ResponseMetadata": {"Action": "ListFoo"}} + + return _Resp() + + monkeypatch.setattr(ve_sign, "_signed_request", _fake) + return captured + + +# --------------------------------------------------------------------------- # +# norm_query +# --------------------------------------------------------------------------- # + + +def test_norm_query_sorts_keys_encodes_and_expands_lists(): + # Keys sorted; list values expanded in given order; space -> %20, + -> %2B. + assert ( + ve_sign.norm_query({"b": "1", "a": ["y", "x"], "sp ace": "v+1"}) + == "a=y&a=x&b=1&sp%20ace=v%2B1" + ) + + +# --------------------------------------------------------------------------- # +# request(): golden signing vector +# --------------------------------------------------------------------------- # + + +def test_request_produces_golden_authorization_header(capture_signed_request): + ve_sign.request( + "POST", + FIXED_DATE, + {"Limit": "2"}, + {}, + "AKTEST", + "SKTEST", + "ListFoo", + json.dumps({"a": 1}), + **SCOPE, + ) + headers = capture_signed_request["headers"] + + assert capture_signed_request["url"] == "https://open.test.example.com/" + assert headers["X-Date"] == "20260102T030405Z" + assert headers["Host"] == "open.test.example.com" + assert headers["Content-Type"] == "application/json" + assert ( + headers["X-Content-Sha256"] + == "f9d86028c6e0d64e225186f96acb69338b2c59764df79162107f5c4bb34d1310" + ) + # Golden vector: any change to canonicalization/scope/HMAC chain breaks this. + assert headers["Authorization"] == ( + "HMAC-SHA256 Credential=AKTEST/20260102/cn-test/test-svc/request, " + "SignedHeaders=content-type;host;x-content-sha256;x-date, " + "Signature=1191b7baccab57749590b7da8aef8af04894e52b208da0f0fd6733f3ef25c8db" + ) + # Action/Version merged into the query ahead of caller params. + assert capture_signed_request["params"] == { + "Action": "ListFoo", + "Version": "2024-01-01", + "Limit": "2", + } + + +def test_request_scope_is_per_call_not_global(capture_signed_request): + """Two calls with different scopes must not bleed into each other.""" + ve_sign.request( + "POST", FIXED_DATE, {}, {}, "AK", "SK", "ActA", "", **SCOPE + ) + first_auth = capture_signed_request["headers"]["Authorization"] + + other = dict(SCOPE, service="other-svc", region="cn-other") + ve_sign.request( + "POST", FIXED_DATE, {}, {}, "AK", "SK", "ActA", "", **other + ) + second_auth = capture_signed_request["headers"]["Authorization"] + + assert "/cn-test/test-svc/" in first_auth + assert "/cn-other/other-svc/" in second_auth + # Module-level legacy defaults were not mutated by parameterized calls. + assert ve_sign.Service == "" + assert ve_sign.Region == "" + + +# --------------------------------------------------------------------------- # +# check_error +# --------------------------------------------------------------------------- # + + +def test_check_error_raises_on_top_level_error(): + with pytest.raises(ValueError, match="Error in response"): + ve_sign.check_error({"Error": {"Code": "X"}}) + + +def test_check_error_raises_on_response_metadata_error(): + with pytest.raises(ValueError, match="AccessDenied"): + ve_sign.check_error( + { + "ResponseMetadata": { + "Action": "ListFoo", + "Error": {"Code": "AccessDenied", "Message": "nope"}, + } + } + ) + + +def test_check_error_passes_clean_response(): + ve_sign.check_error({"ResponseMetadata": {"Action": "ListFoo"}, "Result": {}}) From a50edd6fe9dcf2ddb7f39b2099589bf151d38c63 Mon Sep 17 00:00:00 2001 From: "liyi.ly" Date: Fri, 3 Jul 2026 16:41:01 +0800 Subject: [PATCH 2/3] fix: harden error paths and close protocol/contract test gaps (NFR round 2) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Test fixes (unblocks CI): - cli_config_interactive / executor_platform_context tests: also patch global_config_exists() and clear provider env vars — they previously "passed" only when earlier tests leaked CLOUD_PROVIDER into the process environment, and never validated the global-config path they claimed to Security/robustness: - /run_sse SSE error events no longer echo internal exception details to clients (error_type only; full detail stays in server logs) - cap server-provided Retry-After at 30s so a misconfigured/hostile server cannot stall signed-request retries for arbitrary durations (+ tests) - a2a executor wrapper: initialize result before try — an executor failure previously masked the original exception with UnboundLocalError in the finally-block telemetry call (mcp_app already had this fix) Performance: - simple_app: run sync entrypoints via asyncio.to_thread so blocking user code no longer stalls the event loop for concurrent requests - code_pipeline log export: download step logs concurrently (bounded pool, streamed to temp files) and stitch in original order — export time no longer grows linearly with pipeline size New test coverage (+49 tests, suite now 910 passed / 0 failed): - agent_server route-table assembly on a real FastAPI app: /run_sse registered once and first (regression guard for the removed duplicate), /invoke not shadowed by the A2A mount, telemetry middleware attached - agentkit.apps public export surface (__all__ lazy __getattr__) - A2A protocol round-trip against the real A2AStarletteApplication: agent-card discovery + JSON-RPC message/send validated with a2a.types - MCP round-trip via fastmcp in-memory client: tools/list schema + calls - offline unit tests for the 5 previously untested SDK service clients (knowledge/mcp/memory/skills/tools): wire payload, response parsing, error mapping Known follow-ups (reported, not addressed here): base_service_client response validation can leak raw pydantic.ValidationError outside the ApiError hierarchy; a2a telemetry imports a private google-adk symbol that is not declared in dependencies. --- agentkit/apps/a2a_app/a2a_app.py | 4 + .../apps/agent_server_app/agent_server_app.py | 14 +- .../apps/simple_app/simple_app_handlers.py | 5 +- agentkit/toolkit/volcengine/code_pipeline.py | 101 +++++--- agentkit/utils/ve_sign.py | 6 +- tests/apps/test_a2a_protocol_roundtrip.py | 232 +++++++++++++++++ .../apps/test_agent_server_route_assembly.py | 238 ++++++++++++++++++ tests/apps/test_apps_public_exports.py | 63 +++++ tests/apps/test_mcp_protocol_roundtrip.py | 166 ++++++++++++ tests/sdk/test_knowledge_client.py | 190 ++++++++++++++ tests/sdk/test_mcp_client.py | 197 +++++++++++++++ tests/sdk/test_memory_client.py | 205 +++++++++++++++ tests/sdk/test_skills_client.py | 186 ++++++++++++++ tests/sdk/test_tools_client.py | 196 +++++++++++++++ ..._config_interactive_provider_resolution.py | 14 ++ ...test_executor_platform_context_provider.py | 7 + tests/utils/test_engineering_standards.py | 16 ++ 17 files changed, 1809 insertions(+), 31 deletions(-) create mode 100644 tests/apps/test_a2a_protocol_roundtrip.py create mode 100644 tests/apps/test_agent_server_route_assembly.py create mode 100644 tests/apps/test_apps_public_exports.py create mode 100644 tests/apps/test_mcp_protocol_roundtrip.py create mode 100644 tests/sdk/test_knowledge_client.py create mode 100644 tests/sdk/test_mcp_client.py create mode 100644 tests/sdk/test_memory_client.py create mode 100644 tests/sdk/test_skills_client.py create mode 100644 tests/sdk/test_tools_client.py diff --git a/agentkit/apps/a2a_app/a2a_app.py b/agentkit/apps/a2a_app/a2a_app.py index 37e9c71..dea2bf8 100644 --- a/agentkit/apps/a2a_app/a2a_app.py +++ b/agentkit/apps/a2a_app/a2a_app.py @@ -46,6 +46,10 @@ async def wrapper(*args, **kwargs): with telemetry.tracer.start_as_current_span(name="a2a_invocation") as span: exception = None + # Initialize before try: if execute_func raises, the finally block + # below would otherwise hit UnboundLocalError and mask the original + # exception (mcp_app's equivalent wrapper already does this). + result = None try: result = await execute_func( executor_instance, context=context, event_queue=event_queue diff --git a/agentkit/apps/agent_server_app/agent_server_app.py b/agentkit/apps/agent_server_app/agent_server_app.py index f342c63..ee41b65 100644 --- a/agentkit/apps/agent_server_app/agent_server_app.py +++ b/agentkit/apps/agent_server_app/agent_server_app.py @@ -261,7 +261,19 @@ async def event_generator(): telemetry.trace_agent_server_finish( path="/run_sse", func_result="", exception=e ) - yield f"data: {json.dumps({'error': str(e)})}\n\n" + # Do not echo internal exception details (paths, backend + # errors) to the client; full detail stays in server logs. + yield ( + "data: " + + json.dumps( + { + "error": "internal error while running agent; " + "see server logs", + "error_type": type(e).__name__, + } + ) + + "\n\n" + ) # Returns a streaming response with the proper media type for SSE return StreamingResponse( diff --git a/agentkit/apps/simple_app/simple_app_handlers.py b/agentkit/apps/simple_app/simple_app_handlers.py index 390fdd8..4c1340d 100644 --- a/agentkit/apps/simple_app/simple_app_handlers.py +++ b/agentkit/apps/simple_app/simple_app_handlers.py @@ -182,7 +182,10 @@ async def _process_invoke(self, request: Request) -> tuple[dict, dict, Any]: if asyncio.iscoroutinefunction(self.func): return payload, headers, await self.func(*args) else: - return payload, headers, self.func(*args) + # Run sync entrypoints in a worker thread: executing them inline + # would block the event loop and stall all concurrent requests if + # the entrypoint does blocking IO or heavy computation. + return payload, headers, await asyncio.to_thread(self.func, *args) def _convert_to_sse(self, obj) -> bytes: """Convert object to Server-Sent Events format using safe serialization. diff --git a/agentkit/toolkit/volcengine/code_pipeline.py b/agentkit/toolkit/volcengine/code_pipeline.py index 3801d34..391c870 100644 --- a/agentkit/toolkit/volcengine/code_pipeline.py +++ b/agentkit/toolkit/volcengine/code_pipeline.py @@ -15,6 +15,10 @@ import requests import logging +import os +import shutil +import tempfile +from concurrent.futures import ThreadPoolExecutor from agentkit.utils.ve_sign import ve_request from agentkit.platform import ( @@ -849,6 +853,61 @@ def download_and_merge_pipeline_logs( successful_downloads = 0 failed_downloads = 0 + # Download all step logs concurrently (bounded pool) into temp + # files first — serially fetching sign-URI + log per step made + # export time grow linearly with pipeline size. Streaming to temp + # files keeps peak memory flat; stitching below preserves the + # original output order. (get_task_run_log_download_uri -> + # ve_request signs with per-call scope, so it is thread-safe.) + def _fetch_step_log(task_run_id, task_id, step_name): + log_url = self.get_task_run_log_download_uri( + workspace_id=workspace_id, + pipeline_id=pipeline_id, + pipeline_run_id=pipeline_run_id, + task_run_id=task_run_id, + task_id=task_id, + step_name=step_name, + ) + tmp = tempfile.NamedTemporaryFile( + mode="w", encoding="utf-8", suffix=".steplog", delete=False + ) + try: + with tmp, requests.get( + log_url, timeout=30, stream=True + ) as response: + response.raise_for_status() + response.encoding = response.encoding or "utf-8" + last_chunk = "" + for chunk in response.iter_content( + chunk_size=65536, decode_unicode=True + ): + if chunk: + tmp.write(chunk) + last_chunk = chunk + if not last_chunk.endswith("\n"): + tmp.write("\n") + except Exception: + try: + os.unlink(tmp.name) + except OSError: + pass + raise + return tmp.name + + log_futures = {} + # Exiting the pool context blocks until all downloads finished; + # futures stay consumable afterwards. + with ThreadPoolExecutor(max_workers=4) as pool: + for s_i, stage in enumerate(stages): + for t_i, task in enumerate(stage.get("Tasks", [])): + for p_i, step in enumerate(task.get("Steps", [])): + log_futures[(s_i, t_i, p_i)] = pool.submit( + _fetch_step_log, + task.get("TaskRunID", "unknown"), + task.get("Id", "unknown"), + step.get("Name", "unknown"), + ) + for stage_idx, stage in enumerate(stages, 1): stage_id = stage.get("Id", "unknown") stage_name = stage.get("Name", "unknown") @@ -902,35 +961,21 @@ def download_and_merge_pipeline_logs( out_file.write(f"Finish Time: {step_finish_time}\n") out_file.write(f"{'*' * 60}\n\n") - # Try to download the step log + # Stitch in the concurrently downloaded step log try: - # Get log download URI - log_url = self.get_task_run_log_download_uri( - workspace_id=workspace_id, - pipeline_id=pipeline_id, - pipeline_run_id=pipeline_run_id, - task_run_id=task_run_id, - task_id=task_id, - step_name=step_name, - ) - - # Download the log content, streaming chunks to - # disk so peak memory stays flat for large logs. - with requests.get( - log_url, timeout=30, stream=True - ) as response: - response.raise_for_status() - response.encoding = response.encoding or "utf-8" - last_chunk = "" - for chunk in response.iter_content( - chunk_size=65536, decode_unicode=True - ): - if chunk: - out_file.write(chunk) - last_chunk = chunk - - if not last_chunk.endswith("\n"): - out_file.write("\n") + tmp_path = log_futures[ + (stage_idx - 1, task_idx - 1, step_idx - 1) + ].result() + try: + with open( + tmp_path, "r", encoding="utf-8" + ) as tmp_f: + shutil.copyfileobj(tmp_f, out_file, 65536) + finally: + try: + os.unlink(tmp_path) + except OSError: + pass successful_downloads += 1 logger.info( diff --git a/agentkit/utils/ve_sign.py b/agentkit/utils/ve_sign.py index e03c783..16464d7 100644 --- a/agentkit/utils/ve_sign.py +++ b/agentkit/utils/ve_sign.py @@ -69,6 +69,10 @@ # Tunable via env; AGENTKIT_HTTP_RETRIES=0 disables retries. _RETRYABLE_STATUS = frozenset({429, 503}) +# Cap server-provided Retry-After: without it a hostile/misconfigured server +# saying "Retry-After: 3600" would stall the client for an hour per attempt. +_RETRY_AFTER_CAP = 30.0 + def _backoff_seconds(attempt: int) -> float: return min(8.0, 0.5 * (2**attempt)) @@ -79,7 +83,7 @@ def _retry_after_seconds(resp: requests.Response) -> float | None: if not raw: return None try: - return max(0.0, float(raw)) + return min(max(0.0, float(raw)), _RETRY_AFTER_CAP) except ValueError: return None diff --git a/tests/apps/test_a2a_protocol_roundtrip.py b/tests/apps/test_a2a_protocol_roundtrip.py new file mode 100644 index 0000000..db3f404 --- /dev/null +++ b/tests/apps/test_a2a_protocol_roundtrip.py @@ -0,0 +1,232 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Offline protocol-level round-trip tests for ``AgentkitA2aApp``. + +Unlike ``test_a2a_app.py`` (which unit-tests decorator guards and never builds +the real server), these tests exercise the FULL production build path of +``AgentkitA2aApp.run()``: the real ``A2AStarletteApplication`` is assembled with +a ``DefaultRequestHandler`` + ``InMemoryTaskStore`` and ``.build()`` into a real +Starlette app. Only the final ``uvicorn.run`` socket bind is stubbed out; the +captured ASGI app is then driven in-process via ``starlette.testclient.TestClient`` +(no network). + +Contract assertions are anchored on the OFFICIAL ``a2a.types`` pydantic models +(``AgentCard.model_validate`` / ``SendMessageResponse.model_validate``), so a +wire-format regression in a2a-sdk (field renames, alias changes, envelope +shape) fails these tests even if agentkit's own code is untouched. Assertions +about agentkit-specific behavior (extra routes, fixed reply text) are kept +separate from the schema-validation assertions. + +Covered protocol surface (a2a-sdk 0.3.7): + * Agent card discovery: GET /.well-known/agent-card.json (canonical) and the + deprecated /.well-known/agent.json alias, both validated against AgentCard. + * JSON-RPC 2.0 ``message/send``: request serialized from a real + ``SendMessageRequest`` (by_alias camelCase wire format), response parsed + into ``SendMessageResponse`` -> ``SendMessageSuccessResponse`` -> Message. +""" + +from __future__ import annotations + +import uuid + +import pytest +from a2a.server.agent_execution import AgentExecutor +from a2a.types import ( + AgentCapabilities, + AgentCard, + AgentSkill, + Message, + MessageSendParams, + Part, + Role, + SendMessageRequest, + SendMessageResponse, + SendMessageSuccessResponse, + TextPart, +) +from a2a.utils import new_agent_text_message +from starlette.applications import Starlette +from starlette.testclient import TestClient + +from agentkit.apps.a2a_app import a2a_app as a2a_app_module +from agentkit.apps.a2a_app.a2a_app import AgentkitA2aApp + +AGENT_CARD_PATH = "/.well-known/agent-card.json" +DEPRECATED_AGENT_CARD_PATH = "/.well-known/agent.json" +RPC_PATH = "/" + +FIXED_REPLY_TEXT = "fixed-reply-from-stub-executor" + + +class _RecordingUvicorn: + """Stands in for the ``uvicorn`` module inside ``a2a_app``; captures the + built Starlette app instead of binding a socket.""" + + def __init__(self) -> None: + self.app: Starlette | None = None + self.kwargs: dict | None = None + + def run(self, app, **kwargs) -> None: + self.app = app + self.kwargs = kwargs + + +def _minimal_agent_card() -> AgentCard: + """The smallest AgentCard that satisfies a2a.types' required fields.""" + return AgentCard( + name="stub-agent", + description="Protocol round-trip test agent", + url="http://testserver/", + version="0.0.1", + capabilities=AgentCapabilities(streaming=False), + default_input_modes=["text"], + default_output_modes=["text"], + skills=[ + AgentSkill( + id="echo", + name="Echo", + description="Returns a fixed reply.", + tags=["test"], + ) + ], + ) + + +@pytest.fixture +def built_app(monkeypatch) -> Starlette: + """Runs the real AgentkitA2aApp.run() build path and returns the Starlette + app that would have been served by uvicorn. + + A fresh executor class is defined per invocation because the + ``agent_executor`` decorator mutates ``cls.execute`` in place (a shared + class would get double-wrapped across tests). + """ + recorder = _RecordingUvicorn() + monkeypatch.setattr(a2a_app_module, "uvicorn", recorder) + + app = AgentkitA2aApp() + + @app.agent_executor() + class _FixedReplyExecutor(AgentExecutor): + async def execute(self, context, event_queue): + # A Message event is terminal for message/send: the handler + # returns it as the JSON-RPC result. + await event_queue.enqueue_event( + new_agent_text_message( + FIXED_REPLY_TEXT, + context_id=context.context_id, + task_id=context.task_id, + ) + ) + + async def cancel(self, context, event_queue): # pragma: no cover + raise NotImplementedError + + app.run(_minimal_agent_card(), host="127.0.0.1", port=0) + + assert recorder.app is not None, "run() never reached uvicorn.run" + return recorder.app + + +# --------------------------------------------------------------------------- +# Agent card discovery (GET well-known path) -- schema contract +# --------------------------------------------------------------------------- + + +def test_agent_card_endpoint_returns_schema_valid_agent_card(built_app): + with TestClient(built_app) as client: + response = client.get(AGENT_CARD_PATH) + + assert response.status_code == 200 + # Contract assertion: the served card must round-trip through the official + # a2a.types model (camelCase wire aliases included). + card = AgentCard.model_validate(response.json()) + assert card.name == "stub-agent" + assert card.version == "0.0.1" + assert [skill.id for skill in card.skills] == ["echo"] + + +def test_deprecated_agent_json_path_serves_the_same_schema_valid_card(built_app): + # a2a-sdk 0.3.7 still serves the pre-rename path for backward compat. + with TestClient(built_app) as client: + response = client.get(DEPRECATED_AGENT_CARD_PATH) + + assert response.status_code == 200 + card = AgentCard.model_validate(response.json()) + assert card.name == "stub-agent" + + +# --------------------------------------------------------------------------- +# JSON-RPC message/send round trip -- schema contract +# --------------------------------------------------------------------------- + + +def _send_message_payload(text: str) -> dict: + """Build a message/send request through the official request model so the + outgoing wire format is also produced (and thus checked) by a2a.types.""" + request = SendMessageRequest( + id=str(uuid.uuid4()), + params=MessageSendParams( + message=Message( + message_id=str(uuid.uuid4()), + role=Role.user, + parts=[Part(root=TextPart(text=text))], + ) + ), + ) + return request.model_dump(mode="json", by_alias=True, exclude_none=True) + + +def test_message_send_roundtrip_returns_schema_valid_success_response(built_app): + payload = _send_message_payload("hello over the real protocol stack") + + with TestClient(built_app) as client: + response = client.post(RPC_PATH, json=payload) + + assert response.status_code == 200 + + # Contract assertions: parse with the official union response model. + parsed = SendMessageResponse.model_validate(response.json()) + assert isinstance(parsed.root, SendMessageSuccessResponse) + assert parsed.root.jsonrpc == "2.0" + assert parsed.root.id == payload["id"] + + result = parsed.root.result + assert isinstance(result, Message) + assert result.role == Role.agent + + # Behavior assertion (agentkit + stub executor): the fixed reply text made + # it through executor -> event queue -> request handler -> JSON-RPC result. + text_parts = [ + part.root.text for part in result.parts if isinstance(part.root, TextPart) + ] + assert text_parts == [FIXED_REPLY_TEXT] + + +# --------------------------------------------------------------------------- +# agentkit-specific routes registered by run() on top of the protocol app +# --------------------------------------------------------------------------- + + +def test_run_build_path_registers_agentkit_extra_routes_alongside_protocol_routes( + built_app, +): + # These are implementation details of AgentkitA2aApp.run(), asserted + # separately from the protocol contract above. + route_paths = {route.path for route in built_app.routes} + assert RPC_PATH in route_paths + assert AGENT_CARD_PATH in route_paths + assert "/ping" in route_paths + assert "/env" in route_paths diff --git a/tests/apps/test_agent_server_route_assembly.py b/tests/apps/test_agent_server_route_assembly.py new file mode 100644 index 0000000..0bcf58e --- /dev/null +++ b/tests/apps/test_agent_server_route_assembly.py @@ -0,0 +1,238 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Offline route-table assembly guards for ``AgentkitAgentServerApp``. + +Target: ``AgentkitAgentServerApp.__init__`` +(``agentkit/apps/agent_server_app/agent_server_app.py``). The constructor +assembles the serving surface on the FastAPI app returned by ADK's +``AdkWebServer.get_fast_api_app``: it registers the custom ``POST /run_sse`` +override and moves it to the *front* of the route table for priority matching +(without deleting the ADK default route), attaches +``AgentkitTelemetryHTTPMiddleware``, adds the ``POST /invoke`` compatibility +route, and finally mounts the A2A server app at ``/`` -- deliberately last so +the catch-all mount cannot shadow API routes. + +None of this was previously covered by a test that assembles a *real* FastAPI +app and inspects ``router.routes``; a duplicated ``POST /run_sse`` +registration (dead code, since removed) shipped unnoticed because of that +gap. These tests close it. + +Seam: ``get_fast_api_app`` is a *bound method* of the ``AdkWebServer`` +instance created inside ``__init__`` (there is no module-level symbol to +patch), so we monkeypatch the module's ``AdkWebServer`` reference with a +subclass that keeps the real (trivial, assignment-only) constructor and +overrides only ``get_fast_api_app`` to return a plain ``fastapi.FastAPI``. +``__init__`` then runs its full assembly logic against that real app -- real +route objects, real middleware stack -- with no ADK server, no sockets, no +network. The two heavy collaborators that cannot be built offline from a +plain ``BaseAgent`` are stubbed at module level: ``Runner`` (veadk; reaches +for ``agent.short_term_memory``) and ``to_a2a`` (builds a full A2A app); the +stubbed A2A app is a real ``starlette`` application so the ``mount()`` and +lifespan surfaces stay genuine. +""" + +from __future__ import annotations + +from fastapi import FastAPI +from fastapi.routing import APIRoute +from google.adk.agents.base_agent import BaseAgent +from google.adk.sessions.in_memory_session_service import InMemorySessionService +from starlette.applications import Starlette +from starlette.routing import Mount + +import agentkit.apps.agent_server_app.agent_server_app as mod +from agentkit.apps.agent_server_app.middleware import ( + AgentkitTelemetryHTTPMiddleware, +) + +# Env vars read by resolve_agentkit_allow_origins(); cleared per-assembly so +# CORS resolution is deterministic regardless of the host environment. +_ORIGIN_ENV_VARS = ( + "AGENTKIT_ALLOW_ORIGINS", + "ADK_ALLOW_ORIGINS", + "AGENTKIT_ALLOW_ORIGIN_REGEX", + "ADK_ALLOW_ORIGIN_REGEX", + "AGENTKIT_DISABLE_DEFAULT_ALLOW_ORIGINS", +) + + +# --------------------------------------------------------------------------- +# Assembly helper: run the real __init__ against a real FastAPI base app. +# --------------------------------------------------------------------------- + + +def _assemble(monkeypatch, *, with_adk_default_run_sse: bool = False): + """Construct ``AgentkitAgentServerApp`` on a real FastAPI base app. + + ``with_adk_default_run_sse`` simulates the ADK default ``POST /run_sse`` + route already being present on the app returned by ``get_fast_api_app``, + which is what the priority-move logic exists for in production. + + Returns ``(server_app, records)`` where ``records`` captures the base + app, the stubbed A2A app instance, and the veadk Runner kwargs. + """ + for name in _ORIGIN_ENV_VARS: + monkeypatch.delenv(name, raising=False) + + records: dict = {"base_app": None, "a2a_app": None, "runner_kwargs": None} + + class _StubAdkWebServer(mod.AdkWebServer): + """Real AdkWebServer wiring, minus the ADK FastAPI factory.""" + + def get_fast_api_app(self, lifespan=None, allow_origins=None, **kwargs): + del allow_origins, kwargs + base_app = FastAPI(lifespan=lifespan) + if with_adk_default_run_sse: + # Stand-in for the route the real ADK factory registers. + async def _adk_default_run_sse(): + return {"source": "adk-default"} + + base_app.post("/run_sse")(_adk_default_run_sse) + records["base_app"] = base_app + return base_app + + class _StubRunner: + """veadk Runner stand-in: record kwargs, build nothing.""" + + def __init__(self, agent=None, short_term_memory=None, **kwargs): + records["runner_kwargs"] = { + "agent": agent, + "short_term_memory": short_term_memory, + **kwargs, + } + + def _stub_to_a2a(agent=None, runner=None, **kwargs): + del agent, runner, kwargs + # A real starlette app: mountable ASGI callable with a genuine + # ``router.on_startup`` for the lifespan hook in __init__. + records["a2a_app"] = Starlette() + return records["a2a_app"] + + monkeypatch.setattr(mod, "AdkWebServer", _StubAdkWebServer) + monkeypatch.setattr(mod, "Runner", _StubRunner) + monkeypatch.setattr(mod, "to_a2a", _stub_to_a2a) + + server_app = mod.AgentkitAgentServerApp( + agent=BaseAgent(name="route_assembly_agent"), + short_term_memory=InMemorySessionService(), + ) + return server_app, records + + +def _post_routes(app, path): + return [ + r + for r in app.router.routes + if getattr(r, "path", None) == path + and "POST" in getattr(r, "methods", set()) + ] + + +def _endpoint_name(route): + return getattr(getattr(route, "endpoint", None), "__name__", None) + + +# =========================================================================== +# POST /run_sse: exactly one custom registration, at the front of the table +# =========================================================================== + + +def test_run_sse_registered_exactly_once(monkeypatch): + # Regression guard for the removed dead code that registered the custom + # POST /run_sse override twice. On a plain base app the route table must + # contain exactly one POST /run_sse: the custom override. + server_app, _records = _assemble(monkeypatch) + + run_sse_routes = _post_routes(server_app.app, "/run_sse") + + assert len(run_sse_routes) == 1 + assert _endpoint_name(run_sse_routes[0]) == "run_agent_sse" + + +def test_run_sse_is_the_first_route_for_priority_matching(monkeypatch): + # __init__ pops the custom route and insert(0)s it so it wins matching + # over anything else (FastAPI puts /openapi.json, /docs, ... first by + # default). Assert it really landed at index 0. + server_app, _records = _assemble(monkeypatch) + + routes = server_app.app.router.routes + + first = routes[0] + assert isinstance(first, APIRoute) + assert first.path == "/run_sse" + assert "POST" in first.methods + assert _endpoint_name(first) == "run_agent_sse" + # And it is the only POST /run_sse anywhere else in the table. + assert _post_routes(server_app.app, "/run_sse") == [first] + + +def test_run_sse_priority_move_keeps_adk_default_route(monkeypatch): + # When the base app already carries the ADK default POST /run_sse (the + # production case), the custom override must be moved ahead of it while + # the default stays in the table -- moved, not deleted, and no duplicate + # of the custom endpoint. + server_app, _records = _assemble(monkeypatch, with_adk_default_run_sse=True) + + run_sse_routes = _post_routes(server_app.app, "/run_sse") + names = [_endpoint_name(r) for r in run_sse_routes] + + assert names == ["run_agent_sse", "_adk_default_run_sse"] + assert server_app.app.router.routes[0] is run_sse_routes[0] + + +# =========================================================================== +# POST /invoke: present and not shadowed by the A2A root mount +# =========================================================================== + + +def test_invoke_route_registered_exactly_once(monkeypatch): + server_app, _records = _assemble(monkeypatch) + + invoke_routes = _post_routes(server_app.app, "/invoke") + + assert len(invoke_routes) == 1 + assert _endpoint_name(invoke_routes[0]) == "_invoke_compat" + + +def test_root_mount_is_last_and_does_not_shadow_invoke(monkeypatch): + # The A2A app is mounted at "/" (a catch-all): if it preceded /invoke in + # the route table it would swallow the request. __init__ mounts it last + # on purpose; pin both the relative order and the mounted app identity. + server_app, records = _assemble(monkeypatch) + + routes = server_app.app.router.routes + mounts = [r for r in routes if isinstance(r, Mount)] + + assert len(mounts) == 1 + mount = mounts[0] + assert mount.app is records["a2a_app"] + + [invoke_route] = _post_routes(server_app.app, "/invoke") + assert routes.index(invoke_route) < routes.index(mount) + # Nothing is registered after the catch-all mount. + assert routes[-1] is mount + + +# =========================================================================== +# Middleware: unified telemetry attached at the app level +# =========================================================================== + + +def test_telemetry_http_middleware_is_attached(monkeypatch): + server_app, _records = _assemble(monkeypatch) + + middleware_classes = [m.cls for m in server_app.app.user_middleware] + + assert AgentkitTelemetryHTTPMiddleware in middleware_classes diff --git a/tests/apps/test_apps_public_exports.py b/tests/apps/test_apps_public_exports.py new file mode 100644 index 0000000..e277ebe --- /dev/null +++ b/tests/apps/test_apps_public_exports.py @@ -0,0 +1,63 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Public export guards for the ``agentkit.apps`` package. + +Target: ``agentkit/apps/__init__.py``. The package exports its four app +classes through a *lazy* module-level ``__getattr__`` (one ``if`` branch per +name) rather than eager imports, so a typo in a branch -- or a branch/` +``__all__`` mismatch -- would not fail at import time and could ship +silently. These tests resolve every advertised name through the real lazy +path and pin ``__all__`` to the documented public surface. +""" + +from __future__ import annotations + +import inspect + +import pytest + +import agentkit.apps + +# The four app classes the package documents as its public surface. +_DOCUMENTED_APP_CLASSES = { + "AgentkitA2aApp", + "AgentkitMCPApp", + "AgentkitSimpleApp", + "AgentkitAgentServerApp", +} + + +def test_all_matches_the_documented_app_classes(): + assert set(agentkit.apps.__all__) == _DOCUMENTED_APP_CLASSES + # No duplicate entries hiding in the list form. + assert len(agentkit.apps.__all__) == len(set(agentkit.apps.__all__)) + + +@pytest.mark.parametrize("name", sorted(_DOCUMENTED_APP_CLASSES)) +def test_every_export_resolves_through_the_lazy_getattr_to_a_class(name): + # getattr() on the module exercises the real lazy __getattr__ branch. + obj = getattr(agentkit.apps, name) + + assert inspect.isclass(obj), f"{name} did not resolve to a class" + # The branch must return the class it advertises, not a lookalike. + assert obj.__name__ == name + assert obj.__module__.startswith("agentkit.apps.") + + +def test_unknown_attribute_raises_attribute_error(): + # The lazy __getattr__ fallthrough must raise AttributeError (not return + # None or raise something else), so hasattr()/import-from behave normally. + with pytest.raises(AttributeError, match="has no attribute 'NotAnApp'"): + getattr(agentkit.apps, "NotAnApp") diff --git a/tests/apps/test_mcp_protocol_roundtrip.py b/tests/apps/test_mcp_protocol_roundtrip.py new file mode 100644 index 0000000..0c4ec73 --- /dev/null +++ b/tests/apps/test_mcp_protocol_roundtrip.py @@ -0,0 +1,166 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Offline protocol-level round-trip tests for ``AgentkitMCPApp``. + +Unlike ``test_mcp_app.py`` (which swaps the FastMCP server for a capturing fake +and asserts on telemetry call shapes), these tests keep the REAL ``FastMCP`` +instance that ``AgentkitMCPApp`` builds and drive it through fastmcp's +in-memory transport: ``fastmcp.Client(server)`` connects directly to the server +object, so a full MCP session (initialize, tools/list, tools/call) runs with +real protocol serialization and zero network. + +Contract assertions are anchored on the official ``mcp.types`` models that the +client returns (``mcp.types.Tool`` with its JSON-Schema ``inputSchema``, +content blocks from ``tools/call``), so a serialization regression in +fastmcp/mcp surfaces here even when agentkit's own wrapper code is untouched. + +Covered protocol surface (fastmcp 2.12.3 / mcp 1.26.0): + * tools/list: registered sync + async tools and the built-in env-detect tool + are advertised with name, description, and complete input schemas derived + from the ORIGINAL function signatures (the telemetry wrapper must stay + transparent via functools.wraps/__wrapped__). + * tools/call: real invocations of a sync tool, an async tool, and the + env-detect tool, asserting both unstructured text content and structured + output. + +Tests follow this repo's convention of driving coroutines with ``asyncio.run`` +from synchronous test functions (no pytest-asyncio markers). +""" + +from __future__ import annotations + +import asyncio + +import mcp.types +import pytest +from fastmcp import Client, FastMCP + +from agentkit.apps.mcp_app.mcp_app import AgentkitMCPApp + + +@pytest.fixture +def app() -> AgentkitMCPApp: + """A real AgentkitMCPApp with one sync tool, one async tool, and the + env-detect tool registered through the production decorators.""" + instance = AgentkitMCPApp() + + @instance.tool + def add(a: int, b: int) -> int: + """Add two integers.""" + return a + b + + @instance.tool + async def greet(name: str) -> str: + """Greet someone by name.""" + return f"Hello, {name}!" + + instance.add_env_detect_tool() + return instance + + +def _run_session(server: FastMCP, scenario): + """Open an in-memory MCP session against the real server and run the + given async scenario inside it.""" + + async def runner(): + async with Client(server) as client: + return await scenario(client) + + return asyncio.run(runner()) + + +# --------------------------------------------------------------------------- +# tools/list -- schema contract +# --------------------------------------------------------------------------- + + +def test_tools_list_advertises_all_registered_tools(app): + tools = _run_session(app._mcp_server, lambda client: client.list_tools()) + + # Contract: the client hands back official mcp.types.Tool models. + assert all(isinstance(tool, mcp.types.Tool) for tool in tools) + assert {tool.name for tool in tools} == {"add", "greet", "get_env"} + + +def test_tools_list_exposes_complete_input_schema_for_sync_tool(app): + tools = _run_session(app._mcp_server, lambda client: client.list_tools()) + add_tool = next(tool for tool in tools if tool.name == "add") + + assert add_tool.description == "Add two integers." + # The schema must reflect the ORIGINAL signature, not the (*args, **kwargs) + # telemetry wrapper -- this pins that @wraps keeps the wrapper transparent + # to FastMCP's schema derivation. + schema = add_tool.inputSchema + assert schema["type"] == "object" + assert set(schema["required"]) == {"a", "b"} + assert schema["properties"]["a"]["type"] == "integer" + assert schema["properties"]["b"]["type"] == "integer" + + +def test_tools_list_exposes_complete_input_schema_for_async_tool(app): + tools = _run_session(app._mcp_server, lambda client: client.list_tools()) + greet_tool = next(tool for tool in tools if tool.name == "greet") + + assert greet_tool.description == "Greet someone by name." + schema = greet_tool.inputSchema + assert schema["type"] == "object" + assert schema["required"] == ["name"] + assert schema["properties"]["name"]["type"] == "string" + + +# --------------------------------------------------------------------------- +# tools/call -- real invocations through the protocol +# --------------------------------------------------------------------------- + + +def test_tools_call_sync_tool_returns_text_and_structured_content(app): + result = _run_session( + app._mcp_server, + lambda client: client.call_tool("add", {"a": 2, "b": 3}), + ) + + assert result.is_error is False + # Contract: unstructured content is a list of official content blocks. + assert isinstance(result.content[0], mcp.types.TextContent) + assert result.content[0].text == "5" + # Contract: fastmcp also returns structured output for typed returns. + assert result.structured_content == {"result": 5} + assert result.data == 5 + + +def test_tools_call_async_tool_executes_and_returns_content(app): + result = _run_session( + app._mcp_server, + lambda client: client.call_tool("greet", {"name": "world"}), + ) + + assert result.is_error is False + assert isinstance(result.content[0], mcp.types.TextContent) + assert result.content[0].text == "Hello, world!" + assert result.data == "Hello, world!" + + +def test_tools_call_env_detect_tool_reports_runtime_over_the_protocol( + app, monkeypatch +): + monkeypatch.setenv("RUNTIME_IAM_ROLE_TRN", "trn:some:role") + + result = _run_session( + app._mcp_server, + lambda client: client.call_tool("get_env", {}), + ) + + assert result.is_error is False + assert result.data == {"env": "agentkit"} diff --git a/tests/sdk/test_knowledge_client.py b/tests/sdk/test_knowledge_client.py new file mode 100644 index 0000000..958112d --- /dev/null +++ b/tests/sdk/test_knowledge_client.py @@ -0,0 +1,190 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Offline tests for ``AgentkitKnowledgeClient``. + +The client is exercised through its public methods only; the volcengine +transport (``Service.json``) is stubbed at the same seam used by +``tests/client/test_base_service_client_errors.py``, so no network is +performed. Covers request construction (action + payload shape), response +parsing (alias -> snake_case field mapping) and error mapping. + +Explicit credentials are supplied so the vefaas auto-refresh path is never +taken and construction needs neither a credential file nor network. +""" + +from __future__ import annotations + +import json +import types + +# Import the toolkit package first to fully initialise the import graph before +# touching ``agentkit.client`` (the package wiring is order-sensitive). +import agentkit.toolkit # noqa: F401 + +import pytest +import requests + +from agentkit.auth.errors import NetworkError +from agentkit.sdk.knowledge.client import AgentkitKnowledgeClient +from agentkit.sdk.knowledge.types import ( + AddKnowledgeBaseRequest, + GetKnowledgeBaseRequest, + KnowledgeBasesItemForAddKnowledgeBase, + ListKnowledgeBasesRequest, +) +from agentkit.toolkit.errors import ApiError + + +@pytest.fixture +def client(): + return AgentkitKnowledgeClient( + access_key="AK_LOCAL_TEST_ONLY", + secret_key="SK_LOCAL_TEST_ONLY", + region="cn-beijing", + ) + + +def _stub_transport(client, result): + """Replace the transport with a stub returning a successful envelope. + + Returns the captured calls; each entry records the api action, query + params and the decoded JSON body handed to the transport layer. + """ + calls = [] + + def _json(self, api, params, body): + calls.append({"api": api, "params": params, "body": json.loads(body)}) + return json.dumps( + {"ResponseMetadata": {"RequestId": "req-test"}, "Result": result} + ) + + client.json = types.MethodType(_json, client) + return calls + + +def _stub_transport_raising(client, exc): + def _json(self, api, params, body): + raise exc + + client.json = types.MethodType(_json, client) + + +def test_get_knowledge_base_sends_action_and_payload(client): + calls = _stub_transport(client, {}) + + client.get_knowledge_base(GetKnowledgeBaseRequest(knowledge_id="kb-123")) + + assert calls == [ + {"api": "GetKnowledgeBase", "params": {}, "body": {"KnowledgeId": "kb-123"}} + ] + # The action is wired as a plain POST to / with Action/Version query params. + info = client.api_info["GetKnowledgeBase"] + assert info.method == "POST" + assert info.path == "/" + assert info.query == {"Action": "GetKnowledgeBase", "Version": client.api_version} + + +def test_add_knowledge_base_serializes_nested_payload_by_alias(client): + calls = _stub_transport(client, {}) + + client.add_knowledge_base( + AddKnowledgeBaseRequest( + project_name="default", + knowledge_bases=[ + KnowledgeBasesItemForAddKnowledgeBase( + name="kb-a", + provider_knowledge_id="pkb-1", + provider_type="viking", + ) + ], + ) + ) + + assert calls[0]["api"] == "AddKnowledgeBase" + # PascalCase aliases on the wire; unset optional fields are excluded. + assert calls[0]["body"] == { + "ProjectName": "default", + "KnowledgeBases": [ + { + "Name": "kb-a", + "ProviderKnowledgeId": "pkb-1", + "ProviderType": "viking", + } + ], + } + + +def test_list_knowledge_bases_parses_result_fields(client): + _stub_transport( + client, + { + "NextToken": "tok-2", + "KnowledgeBases": [ + { + "KnowledgeId": "kb-1", + "Name": "first", + "Status": "Ready", + "AssociatedRuntimes": [{"Id": "rt-1", "Name": "runtime-1"}], + } + ], + }, + ) + + resp = client.list_knowledge_bases(ListKnowledgeBasesRequest()) + + assert resp.next_token == "tok-2" + assert len(resp.knowledge_bases) == 1 + kb = resp.knowledge_bases[0] + assert kb.knowledge_id == "kb-1" + assert kb.name == "first" + assert kb.status == "Ready" + assert kb.associated_runtimes[0].id == "rt-1" + assert kb.associated_runtimes[0].name == "runtime-1" + + +def test_backend_error_metadata_raises_apierror_with_code(client): + def _json(self, api, params, body): + return json.dumps( + { + "ResponseMetadata": { + "Error": { + "Code": "InvalidKnowledgeBase.NotFound", + "Message": "no such knowledge base", + } + } + } + ) + + client.json = types.MethodType(_json, client) + + with pytest.raises(ApiError) as excinfo: + client.get_knowledge_base(GetKnowledgeBaseRequest(knowledge_id="kb-404")) + + assert excinfo.value.error_code == "InvalidKnowledgeBase.NotFound" + assert "GetKnowledgeBase" in str(excinfo.value) + assert "no such knowledge base" in str(excinfo.value) + + +def test_transport_failure_raises_networkerror(client): + _stub_transport_raising( + client, requests.exceptions.ConnectionError("socket reset") + ) + + with pytest.raises(NetworkError) as excinfo: + client.get_knowledge_base(GetKnowledgeBaseRequest(knowledge_id="kb-123")) + + assert isinstance(excinfo.value.__cause__, requests.exceptions.ConnectionError) + # Transport detail must not leak into the domain message. + assert "socket reset" not in str(excinfo.value) diff --git a/tests/sdk/test_mcp_client.py b/tests/sdk/test_mcp_client.py new file mode 100644 index 0000000..4c10364 --- /dev/null +++ b/tests/sdk/test_mcp_client.py @@ -0,0 +1,197 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Offline tests for ``AgentkitMCPClient``. + +The client is exercised through its public methods only; the volcengine +transport (``Service.json``) is stubbed at the same seam used by +``tests/client/test_base_service_client_errors.py``, so no network is +performed. Covers request construction (action + payload shape), response +parsing (alias -> snake_case field mapping) and error mapping for the core +create/get/delete service methods. + +Explicit credentials are supplied so the vefaas auto-refresh path is never +taken and construction needs neither a credential file nor network. +""" + +from __future__ import annotations + +import json +import types + +# Import the toolkit package first to fully initialise the import graph before +# touching ``agentkit.client`` (the package wiring is order-sensitive). +import agentkit.toolkit # noqa: F401 + +import pytest +import requests + +from agentkit.auth.errors import NetworkError +from agentkit.sdk.mcp.client import AgentkitMCPClient +from agentkit.sdk.mcp.types import ( + CreateMCPServiceRequest, + DeleteMCPServiceRequest, + GetMCPServiceRequest, +) +from agentkit.toolkit.errors import ApiError + + +@pytest.fixture +def client(): + return AgentkitMCPClient( + access_key="AK_LOCAL_TEST_ONLY", + secret_key="SK_LOCAL_TEST_ONLY", + region="cn-beijing", + ) + + +def _stub_transport(client, result): + """Replace the transport with a stub returning a successful envelope. + + Returns the captured calls; each entry records the api action, query + params and the decoded JSON body handed to the transport layer. + """ + calls = [] + + def _json(self, api, params, body): + calls.append({"api": api, "params": params, "body": json.loads(body)}) + return json.dumps( + {"ResponseMetadata": {"RequestId": "req-test"}, "Result": result} + ) + + client.json = types.MethodType(_json, client) + return calls + + +def _stub_transport_raising(client, exc): + def _json(self, api, params, body): + raise exc + + client.json = types.MethodType(_json, client) + + +def test_create_mcp_service_sends_action_and_payload(client): + calls = _stub_transport(client, {}) + + client.create_mcp_service( + CreateMCPServiceRequest( + name="my-mcp", + path="/mcp", + backend_type="VeFaaS", + protocol_type="SSE", + ) + ) + + assert calls == [ + { + "api": "CreateMCPService", + "params": {}, + # PascalCase aliases on the wire; unset optional fields excluded. + "body": { + "Name": "my-mcp", + "Path": "/mcp", + "BackendType": "VeFaaS", + "ProtocolType": "SSE", + }, + } + ] + info = client.api_info["CreateMCPService"] + assert info.method == "POST" + assert info.path == "/" + assert info.query == {"Action": "CreateMCPService", "Version": client.api_version} + + +def test_create_mcp_service_parses_service_id(client): + _stub_transport(client, {"MCPServiceId": "mcp-svc-1"}) + + resp = client.create_mcp_service( + CreateMCPServiceRequest( + name="my-mcp", path="/mcp", backend_type="VeFaaS", protocol_type="SSE" + ) + ) + + assert resp.mcp_service_id == "mcp-svc-1" + + +def test_get_mcp_service_parses_nested_service(client): + _stub_transport( + client, + { + "MCPService": { + "MCPServiceId": "mcp-svc-1", + "Name": "my-mcp", + "Path": "/mcp", + "Status": "Running", + "ProtocolType": "SSE", + } + }, + ) + + resp = client.get_mcp_service(GetMCPServiceRequest(mcp_service_id="mcp-svc-1")) + + assert resp.mcp_service.mcp_service_id == "mcp-svc-1" + assert resp.mcp_service.name == "my-mcp" + assert resp.mcp_service.path == "/mcp" + assert resp.mcp_service.status == "Running" + + +def test_delete_mcp_service_sends_id_and_parses_echo(client): + calls = _stub_transport(client, {"MCPServiceId": "mcp-svc-1"}) + + resp = client.delete_mcp_service( + DeleteMCPServiceRequest(mcp_service_id="mcp-svc-1") + ) + + assert calls == [ + { + "api": "DeleteMCPService", + "params": {}, + "body": {"MCPServiceId": "mcp-svc-1"}, + } + ] + assert resp.mcp_service_id == "mcp-svc-1" + + +def test_backend_error_metadata_raises_apierror_with_code(client): + def _json(self, api, params, body): + return json.dumps( + { + "ResponseMetadata": { + "Error": { + "Code": "InvalidMCPService.NotFound", + "Message": "no such mcp service", + } + } + } + ) + + client.json = types.MethodType(_json, client) + + with pytest.raises(ApiError) as excinfo: + client.get_mcp_service(GetMCPServiceRequest(mcp_service_id="mcp-404")) + + assert excinfo.value.error_code == "InvalidMCPService.NotFound" + assert "GetMCPService" in str(excinfo.value) + assert "no such mcp service" in str(excinfo.value) + + +def test_transport_failure_raises_networkerror(client): + _stub_transport_raising(client, requests.exceptions.Timeout("read timed out")) + + with pytest.raises(NetworkError) as excinfo: + client.get_mcp_service(GetMCPServiceRequest(mcp_service_id="mcp-svc-1")) + + assert isinstance(excinfo.value.__cause__, requests.exceptions.Timeout) + # Transport detail must not leak into the domain message. + assert "read timed out" not in str(excinfo.value) diff --git a/tests/sdk/test_memory_client.py b/tests/sdk/test_memory_client.py new file mode 100644 index 0000000..fe39db9 --- /dev/null +++ b/tests/sdk/test_memory_client.py @@ -0,0 +1,205 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Offline tests for ``AgentkitMemoryClient``. + +The client is exercised through its public methods only; the volcengine +transport (``Service.json``) is stubbed at the same seam used by +``tests/client/test_base_service_client_errors.py``, so no network is +performed. Covers request construction (action + payload shape), response +parsing (alias -> snake_case field mapping) and error mapping for the core +create/get/delete collection methods. + +Explicit credentials are supplied so the vefaas auto-refresh path is never +taken and construction needs neither a credential file nor network. +""" + +from __future__ import annotations + +import json +import types + +# Import the toolkit package first to fully initialise the import graph before +# touching ``agentkit.client`` (the package wiring is order-sensitive). +import agentkit.toolkit # noqa: F401 + +import pytest +import requests + +from agentkit.auth.errors import NetworkError +from agentkit.sdk.memory.client import AgentkitMemoryClient +from agentkit.sdk.memory.types import ( + CreateMemoryCollectionRequest, + DeleteMemoryCollectionRequest, + GetMemoryCollectionRequest, +) +from agentkit.toolkit.errors import ApiError + + +@pytest.fixture +def client(): + return AgentkitMemoryClient( + access_key="AK_LOCAL_TEST_ONLY", + secret_key="SK_LOCAL_TEST_ONLY", + region="cn-beijing", + ) + + +def _stub_transport(client, result): + """Replace the transport with a stub returning a successful envelope. + + Returns the captured calls; each entry records the api action, query + params and the decoded JSON body handed to the transport layer. + """ + calls = [] + + def _json(self, api, params, body): + calls.append({"api": api, "params": params, "body": json.loads(body)}) + return json.dumps( + {"ResponseMetadata": {"RequestId": "req-test"}, "Result": result} + ) + + client.json = types.MethodType(_json, client) + return calls + + +def _stub_transport_raising(client, exc): + def _json(self, api, params, body): + raise exc + + client.json = types.MethodType(_json, client) + + +def test_create_memory_collection_sends_action_and_payload(client): + calls = _stub_transport(client, {}) + + client.create_memory_collection( + CreateMemoryCollectionRequest( + name="mem-a", + description="test collection", + provider_type="viking", + ) + ) + + assert calls == [ + { + "api": "CreateMemoryCollection", + "params": {}, + # PascalCase aliases on the wire; unset optional fields excluded. + "body": { + "Name": "mem-a", + "Description": "test collection", + "ProviderType": "viking", + }, + } + ] + info = client.api_info["CreateMemoryCollection"] + assert info.method == "POST" + assert info.path == "/" + assert info.query == { + "Action": "CreateMemoryCollection", + "Version": client.api_version, + } + + +def test_create_memory_collection_parses_result(client): + _stub_transport( + client, + { + "MemoryId": "mem-1", + "ProviderCollectionId": "pc-1", + "ProviderType": "viking", + "Status": "Creating", + }, + ) + + resp = client.create_memory_collection( + CreateMemoryCollectionRequest(name="mem-a") + ) + + assert resp.memory_id == "mem-1" + assert resp.provider_collection_id == "pc-1" + assert resp.provider_type == "viking" + assert resp.status == "Creating" + + +def test_get_memory_collection_parses_fields(client): + _stub_transport( + client, + { + "MemoryId": "mem-1", + "Name": "mem-a", + "Managed": True, + "Status": "Ready", + "CreateTime": "2026-01-01T00:00:00Z", + }, + ) + + resp = client.get_memory_collection(GetMemoryCollectionRequest(memory_id="mem-1")) + + assert resp.memory_id == "mem-1" + assert resp.name == "mem-a" + assert resp.managed is True + assert resp.status == "Ready" + assert resp.create_time == "2026-01-01T00:00:00Z" + + +def test_delete_memory_collection_sends_id(client): + calls = _stub_transport(client, {"MemoryId": "mem-1", "Status": "Deleting"}) + + resp = client.delete_memory_collection( + DeleteMemoryCollectionRequest(memory_id="mem-1") + ) + + assert calls == [ + {"api": "DeleteMemoryCollection", "params": {}, "body": {"MemoryId": "mem-1"}} + ] + assert resp.memory_id == "mem-1" + assert resp.status == "Deleting" + + +def test_backend_error_metadata_raises_apierror_with_code(client): + def _json(self, api, params, body): + return json.dumps( + { + "ResponseMetadata": { + "Error": { + "Code": "InvalidMemoryCollection.NotFound", + "Message": "no such memory collection", + } + } + } + ) + + client.json = types.MethodType(_json, client) + + with pytest.raises(ApiError) as excinfo: + client.get_memory_collection(GetMemoryCollectionRequest(memory_id="mem-404")) + + assert excinfo.value.error_code == "InvalidMemoryCollection.NotFound" + assert "GetMemoryCollection" in str(excinfo.value) + assert "no such memory collection" in str(excinfo.value) + + +def test_transport_failure_raises_networkerror(client): + _stub_transport_raising( + client, requests.exceptions.ConnectionError("connection refused") + ) + + with pytest.raises(NetworkError) as excinfo: + client.get_memory_collection(GetMemoryCollectionRequest(memory_id="mem-1")) + + assert isinstance(excinfo.value.__cause__, requests.exceptions.ConnectionError) + # Transport detail must not leak into the domain message. + assert "connection refused" not in str(excinfo.value) diff --git a/tests/sdk/test_skills_client.py b/tests/sdk/test_skills_client.py new file mode 100644 index 0000000..3dadbc8 --- /dev/null +++ b/tests/sdk/test_skills_client.py @@ -0,0 +1,186 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Offline tests for ``AgentkitSkillsClient``. + +The client is exercised through its public methods only; the volcengine +transport (``Service.json``) is stubbed at the same seam used by +``tests/client/test_base_service_client_errors.py``, so no network is +performed. The client exposes ~20 methods; the core create/delete/list +triple is covered here since every method funnels through the same +``_invoke_api`` path. Covers request construction, response parsing and +error mapping. + +Explicit credentials are supplied so the vefaas auto-refresh path is never +taken and construction needs neither a credential file nor network. +""" + +from __future__ import annotations + +import json +import types + +# Import the toolkit package first to fully initialise the import graph before +# touching ``agentkit.client`` (the package wiring is order-sensitive). +import agentkit.toolkit # noqa: F401 + +import pytest +import requests + +from agentkit.auth.errors import NetworkError +from agentkit.sdk.skills.client import AgentkitSkillsClient +from agentkit.sdk.skills.types import ( + CreateSkillRequest, + DeleteSkillRequest, + DeleteSkillResponse, + GetSkillRequest, + ListSkillsRequest, +) +from agentkit.toolkit.errors import ApiError + + +@pytest.fixture +def client(): + return AgentkitSkillsClient( + access_key="AK_LOCAL_TEST_ONLY", + secret_key="SK_LOCAL_TEST_ONLY", + region="cn-beijing", + ) + + +def _stub_transport(client, result): + """Replace the transport with a stub returning a successful envelope. + + Returns the captured calls; each entry records the api action, query + params and the decoded JSON body handed to the transport layer. + """ + calls = [] + + def _json(self, api, params, body): + calls.append({"api": api, "params": params, "body": json.loads(body)}) + return json.dumps( + {"ResponseMetadata": {"RequestId": "req-test"}, "Result": result} + ) + + client.json = types.MethodType(_json, client) + return calls + + +def _stub_transport_raising(client, exc): + def _json(self, api, params, body): + raise exc + + client.json = types.MethodType(_json, client) + + +def test_create_skill_sends_action_and_payload(client): + calls = _stub_transport(client, {"Id": "skill-1"}) + + resp = client.create_skill( + CreateSkillRequest(name="my-skill", tos_url="tos://bucket/skill.zip") + ) + + assert calls == [ + { + "api": "CreateSkill", + "params": {}, + # PascalCase aliases on the wire; unset optional fields excluded. + "body": {"Name": "my-skill", "TosUrl": "tos://bucket/skill.zip"}, + } + ] + assert resp.id == "skill-1" + info = client.api_info["CreateSkill"] + assert info.method == "POST" + assert info.path == "/" + assert info.query == {"Action": "CreateSkill", "Version": client.api_version} + + +def test_delete_skill_sends_id(client): + calls = _stub_transport(client, {}) + + resp = client.delete_skill(DeleteSkillRequest(id="skill-1")) + + assert calls == [ + {"api": "DeleteSkill", "params": {}, "body": {"Id": "skill-1"}} + ] + # DeleteSkillResponse carries no fields; a typed empty object is returned. + assert isinstance(resp, DeleteSkillResponse) + + +def test_list_skills_parses_items_and_total_count(client): + _stub_transport( + client, + { + "TotalCount": 1, + "Items": [ + { + "Id": "skill-1", + "Name": "my-skill", + "Status": "Ready", + "Description": "demo", + "CreateTimeStamp": "1700000000", + "UpdateTimeStamp": "1700000001", + "Versions": ["v1", "v2"], + "ProjectName": "default", + } + ], + }, + ) + + resp = client.list_skills(ListSkillsRequest(page_number=1, page_size=10)) + + assert resp.total_count == 1 + assert len(resp.items) == 1 + skill = resp.items[0] + assert skill.id == "skill-1" + assert skill.name == "my-skill" + assert skill.status == "Ready" + assert skill.versions == ["v1", "v2"] + assert skill.project_name == "default" + + +def test_backend_error_metadata_raises_apierror_with_code(client): + def _json(self, api, params, body): + return json.dumps( + { + "ResponseMetadata": { + "Error": { + "Code": "InvalidSkill.NotFound", + "Message": "no such skill", + } + } + } + ) + + client.json = types.MethodType(_json, client) + + with pytest.raises(ApiError) as excinfo: + client.get_skill(GetSkillRequest(id="skill-404")) + + assert excinfo.value.error_code == "InvalidSkill.NotFound" + assert "GetSkill" in str(excinfo.value) + assert "no such skill" in str(excinfo.value) + + +def test_transport_failure_raises_networkerror(client): + _stub_transport_raising( + client, requests.exceptions.ConnectionError("socket reset") + ) + + with pytest.raises(NetworkError) as excinfo: + client.get_skill(GetSkillRequest(id="skill-1")) + + assert isinstance(excinfo.value.__cause__, requests.exceptions.ConnectionError) + # Transport detail must not leak into the domain message. + assert "socket reset" not in str(excinfo.value) diff --git a/tests/sdk/test_tools_client.py b/tests/sdk/test_tools_client.py new file mode 100644 index 0000000..5d73431 --- /dev/null +++ b/tests/sdk/test_tools_client.py @@ -0,0 +1,196 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Offline tests for ``AgentkitToolsClient``. + +The client is exercised through its public methods only; the volcengine +transport (``Service.json``) is stubbed at the same seam used by +``tests/client/test_base_service_client_errors.py``, so no network is +performed. Covers request construction (action + payload shape), response +parsing (alias -> snake_case field mapping) and error mapping for the core +create/get/list tool methods. + +Explicit credentials are supplied so the vefaas auto-refresh path is never +taken and construction needs neither a credential file nor network. +""" + +from __future__ import annotations + +import json +import types + +# Import the toolkit package first to fully initialise the import graph before +# touching ``agentkit.client`` (the package wiring is order-sensitive). +import agentkit.toolkit # noqa: F401 + +import pytest +import requests + +from agentkit.auth.errors import NetworkError +from agentkit.sdk.tools.client import AgentkitToolsClient +from agentkit.sdk.tools.types import ( + CreateToolRequest, + GetToolRequest, + ListToolsRequest, +) +from agentkit.toolkit.errors import ApiError + + +@pytest.fixture +def client(): + return AgentkitToolsClient( + access_key="AK_LOCAL_TEST_ONLY", + secret_key="SK_LOCAL_TEST_ONLY", + region="cn-beijing", + ) + + +def _stub_transport(client, result): + """Replace the transport with a stub returning a successful envelope. + + Returns the captured calls; each entry records the api action, query + params and the decoded JSON body handed to the transport layer. + """ + calls = [] + + def _json(self, api, params, body): + calls.append({"api": api, "params": params, "body": json.loads(body)}) + return json.dumps( + {"ResponseMetadata": {"RequestId": "req-test"}, "Result": result} + ) + + client.json = types.MethodType(_json, client) + return calls + + +def _stub_transport_raising(client, exc): + def _json(self, api, params, body): + raise exc + + client.json = types.MethodType(_json, client) + + +def test_create_tool_sends_action_and_payload(client): + calls = _stub_transport(client, {"ToolId": "tool-1"}) + + resp = client.create_tool( + CreateToolRequest( + name="sandbox-a", + tool_type="Sandbox", + port=8080, + ) + ) + + assert calls == [ + { + "api": "CreateTool", + "params": {}, + # PascalCase aliases on the wire; unset optional fields excluded. + "body": {"Name": "sandbox-a", "ToolType": "Sandbox", "Port": 8080}, + } + ] + assert resp.tool_id == "tool-1" + info = client.api_info["CreateTool"] + assert info.method == "POST" + assert info.path == "/" + assert info.query == {"Action": "CreateTool", "Version": client.api_version} + + +def test_get_tool_sends_id_and_parses_fields(client): + calls = _stub_transport( + client, + { + "ToolId": "tool-1", + "Name": "sandbox-a", + "ToolType": "Sandbox", + "Status": "Ready", + "Port": 8080, + "ImageUrl": "registry.example.com/sandbox:latest", + }, + ) + + resp = client.get_tool(GetToolRequest(tool_id="tool-1")) + + assert calls == [ + {"api": "GetTool", "params": {}, "body": {"ToolId": "tool-1"}} + ] + assert resp.tool_id == "tool-1" + assert resp.name == "sandbox-a" + assert resp.tool_type == "Sandbox" + assert resp.status == "Ready" + assert resp.port == 8080 + assert resp.image_url == "registry.example.com/sandbox:latest" + + +def test_list_tools_parses_tools_and_next_token(client): + _stub_transport( + client, + { + "NextToken": "tok-9", + "Tools": [ + { + "ToolId": "tool-1", + "Name": "sandbox-a", + "Status": "Ready", + "Port": 8080, + } + ], + }, + ) + + resp = client.list_tools(ListToolsRequest(max_results=10)) + + assert resp.next_token == "tok-9" + assert len(resp.tools) == 1 + tool = resp.tools[0] + assert tool.tool_id == "tool-1" + assert tool.name == "sandbox-a" + assert tool.status == "Ready" + assert tool.port == 8080 + + +def test_backend_error_metadata_raises_apierror_with_code(client): + def _json(self, api, params, body): + return json.dumps( + { + "ResponseMetadata": { + "Error": { + "Code": "InvalidTool.NotFound", + "Message": "no such tool", + } + } + } + ) + + client.json = types.MethodType(_json, client) + + with pytest.raises(ApiError) as excinfo: + client.get_tool(GetToolRequest(tool_id="tool-404")) + + assert excinfo.value.error_code == "InvalidTool.NotFound" + assert "GetTool" in str(excinfo.value) + assert "no such tool" in str(excinfo.value) + + +def test_transport_failure_raises_networkerror(client): + _stub_transport_raising( + client, requests.exceptions.ConnectionError("socket reset") + ) + + with pytest.raises(NetworkError) as excinfo: + client.get_tool(GetToolRequest(tool_id="tool-1")) + + assert isinstance(excinfo.value.__cause__, requests.exceptions.ConnectionError) + # Transport detail must not leak into the domain message. + assert "socket reset" not in str(excinfo.value) diff --git a/tests/toolkit/cli/test_cli_config_interactive_provider_resolution.py b/tests/toolkit/cli/test_cli_config_interactive_provider_resolution.py index 9a29794..2609015 100644 --- a/tests/toolkit/cli/test_cli_config_interactive_provider_resolution.py +++ b/tests/toolkit/cli/test_cli_config_interactive_provider_resolution.py @@ -50,6 +50,13 @@ def test_interactive_config_strategy_context_uses_resolved_provider( global_cfg = GlobalConfig() global_cfg.defaults.cloud_provider = "byteplus" monkeypatch.setattr(global_cfg_mod, "get_global_config", lambda: global_cfg) + # The resolver checks global_config_exists() before consulting + # get_global_config(), and env vars outrank global config — patch/clear + # both so the test is hermetic (it previously "passed" only when earlier + # tests leaked CLOUD_PROVIDER into the process environment). + monkeypatch.setattr(global_cfg_mod, "global_config_exists", lambda: True) + monkeypatch.delenv("AGENTKIT_CLOUD_PROVIDER", raising=False) + monkeypatch.delenv("CLOUD_PROVIDER", raising=False) def fake_create_common_config_interactively(existing_config): return CommonConfig.from_dict(existing_config or {}) @@ -107,6 +114,13 @@ def test_interactive_config_common_input_uses_raw_yaml_common( global_cfg = GlobalConfig() global_cfg.defaults.cloud_provider = "byteplus" monkeypatch.setattr(global_cfg_mod, "get_global_config", lambda: global_cfg) + # The resolver checks global_config_exists() before consulting + # get_global_config(), and env vars outrank global config — patch/clear + # both so the test is hermetic (it previously "passed" only when earlier + # tests leaked CLOUD_PROVIDER into the process environment). + monkeypatch.setattr(global_cfg_mod, "global_config_exists", lambda: True) + monkeypatch.delenv("AGENTKIT_CLOUD_PROVIDER", raising=False) + monkeypatch.delenv("CLOUD_PROVIDER", raising=False) def fake_create_common_config_interactively(existing_config): assert isinstance(existing_config, dict) diff --git a/tests/toolkit/executors/test_executor_platform_context_provider.py b/tests/toolkit/executors/test_executor_platform_context_provider.py index b4c0c48..b9c9556 100644 --- a/tests/toolkit/executors/test_executor_platform_context_provider.py +++ b/tests/toolkit/executors/test_executor_platform_context_provider.py @@ -49,6 +49,13 @@ def test_executor_platform_context_uses_resolved_provider( global_cfg.defaults.cloud_provider = "byteplus" monkeypatch.setattr(global_cfg_mod, "get_global_config", lambda: global_cfg) + # The resolver checks global_config_exists() before consulting + # get_global_config(), and env vars outrank global config — patch/clear + # both so the test is hermetic (it previously "passed" only when earlier + # tests leaked CLOUD_PROVIDER into the process environment). + monkeypatch.setattr(global_cfg_mod, "global_config_exists", lambda: True) + monkeypatch.delenv("AGENTKIT_CLOUD_PROVIDER", raising=False) + monkeypatch.delenv("CLOUD_PROVIDER", raising=False) clear_config_cache() cfg = get_config(config_path=str(config_path), force_reload=True) diff --git a/tests/utils/test_engineering_standards.py b/tests/utils/test_engineering_standards.py index 2f90642..1972b62 100644 --- a/tests/utils/test_engineering_standards.py +++ b/tests/utils/test_engineering_standards.py @@ -111,6 +111,22 @@ def test_signed_request_retries_zero_means_single_attempt(monkeypatch, no_sleep) assert counter["attempts"] == 1 +def test_retry_after_is_capped_and_sanitized(): + """Server-provided Retry-After must be capped (a misconfigured/hostile + server saying 3600 must not stall the client for an hour) and invalid + values must be ignored.""" + + class _Resp: + def __init__(self, value): + self.headers = {} if value is None else {"Retry-After": value} + + assert ve_sign._retry_after_seconds(_Resp("5")) == 5.0 + assert ve_sign._retry_after_seconds(_Resp("3600")) == ve_sign._RETRY_AFTER_CAP + assert ve_sign._retry_after_seconds(_Resp("-1")) == 0.0 + assert ve_sign._retry_after_seconds(_Resp("not-a-number")) is None + assert ve_sign._retry_after_seconds(_Resp(None)) is None + + # --------------------------------------------------------------------------- # # http_defaults env clamping # --------------------------------------------------------------------------- # From 25d737ced7c2ab0d2cfb616bd45a05a25fa55744 Mon Sep 17 00:00:00 2001 From: "liyi.ly" Date: Fri, 3 Jul 2026 19:57:58 +0800 Subject: [PATCH 3/3] fix: address NFR shadow-review blockers and cross-cutting majors (round 3) - cr: bound the instance-creation status poll with a 30min deadline (was an unbounded while True that hung the CLI forever on a stuck instance) - sandbox exec: give the terminal websocket a connect timeout and protocol-level ping keepalive so a silently dead server is detected instead of hanging the session - apps: centralize the sensitive-header exclusion list in apps.utils (single source of truth, extended to x-security-token / cookie / x-api-key / proxy-authorization; was two hand-synced two-entry sets plus one inline copy) - a2a telemetry: stop importing google-adk's private _get_user_id (undeclared dependency; fallback logic inlined), protect trace_a2a_agent with dont_throw so a telemetry failure can no longer mask the original business exception from the finally block, and log execute failures with stack + context_id - agent server /invoke: stop echoing raw str(e) into the SSE error frame (parity with /run_sse: generic message + error_type only, json.dumps-built) - CI: add a test workflow (3.10/3.12, pytest + coverage report) and declare pytest/pytest-cov/google-adk in the dev dependency group; add pytest/coverage config to pyproject --- .github/workflows/test.yml | 35 +++++++++++++++++ agentkit/apps/a2a_app/a2a_app.py | 6 ++- agentkit/apps/a2a_app/telemetry.py | 23 ++++++++--- .../apps/agent_server_app/agent_server_app.py | 18 ++++++++- agentkit/apps/agent_server_app/middleware.py | 3 +- agentkit/apps/simple_app/telemetry.py | 32 +++------------- agentkit/apps/utils.py | 38 +++++++++++++++++++ agentkit/toolkit/cli/sandbox/cli_exec.py | 16 +++++++- agentkit/toolkit/volcengine/cr.py | 13 ++++++- pyproject.toml | 11 ++++++ tests/apps/test_agent_server_invoke.py | 10 ++++- tests/apps/test_agent_server_middleware.py | 6 +++ 12 files changed, 171 insertions(+), 40 deletions(-) create mode 100644 .github/workflows/test.yml diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 0000000..3cf1c57 --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,35 @@ +name: Tests + +on: + push: + branches: [main] + pull_request: + +jobs: + test: + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python-version: ["3.10", "3.12"] + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + cache: pip + - name: Install package and test tooling + run: | + python -m pip install --upgrade pip + pip install -e . + pip install --group dev + - name: Run tests with coverage + # No --cov-fail-under yet: establish the baseline first, then ratchet + # the gate up instead of picking an aspirational number that blocks CI. + run: | + pytest -q --cov=agentkit --cov-report=term --cov-report=xml + - name: Upload coverage report + uses: actions/upload-artifact@v4 + with: + name: coverage-${{ matrix.python-version }} + path: coverage.xml diff --git a/agentkit/apps/a2a_app/a2a_app.py b/agentkit/apps/a2a_app/a2a_app.py index dea2bf8..edaa7cb 100644 --- a/agentkit/apps/a2a_app/a2a_app.py +++ b/agentkit/apps/a2a_app/a2a_app.py @@ -56,7 +56,11 @@ async def wrapper(*args, **kwargs): ) except Exception as e: - logger.error("Invoke agent execute function failed: %s", e) + logger.exception( + "Invoke agent execute function failed (context_id=%s): %s", + context.context_id, + e, + ) exception = e raise e finally: diff --git a/agentkit/apps/a2a_app/telemetry.py b/agentkit/apps/a2a_app/telemetry.py index a07a9a1..a8ed69a 100644 --- a/agentkit/apps/a2a_app/telemetry.py +++ b/agentkit/apps/a2a_app/telemetry.py @@ -22,9 +22,8 @@ from opentelemetry.metrics import get_meter from opentelemetry.trace.span import Span from a2a.server.agent_execution.context import RequestContext -from google.adk.a2a.converters.request_converter import _get_user_id -from agentkit.apps.utils import safe_serialize_to_json_string +from agentkit.apps.utils import dont_throw, safe_serialize_to_json_string _GEN_AI_CLIENT_OPERATION_DURATION_BUCKETS = [ 0.01, @@ -47,6 +46,15 @@ logger = logging.getLogger("agentkit." + __name__) +def _get_user_id(request: RequestContext) -> str: + # Inlined from google-adk's private a2a request converter helper: a2a_app + # must not import google-adk (an undeclared dependency for pure-A2A apps). + call_context = getattr(request, "call_context", None) + if call_context and call_context.user and call_context.user.user_name: + return call_context.user.user_name + return f"A2A_USER_{request.context_id}" + + class Telemetry: def __init__(self): self.tracer = get_tracer("agentkit.a2a_app") @@ -58,6 +66,7 @@ def __init__(self): explicit_bucket_boundaries_advisory=_GEN_AI_CLIENT_OPERATION_DURATION_BUCKETS, ) + @dont_throw def trace_a2a_agent( self, func: Callable, @@ -89,10 +98,12 @@ def trace_a2a_agent( if user_id: span.set_attribute(key="gen_ai.user.id", value=user_id) - span.set_attribute( - key="gen_ai.input", - value=safe_serialize_to_json_string(request.message.parts), - ) + message = getattr(request, "message", None) + if message is not None: + span.set_attribute( + key="gen_ai.input", + value=safe_serialize_to_json_string(message.parts), + ) span.set_attribute(key="gen_ai.span.kind", value="a2a_agent") span.set_attribute(key="gen_ai.operation.name", value="invoke_agent") diff --git a/agentkit/apps/agent_server_app/agent_server_app.py b/agentkit/apps/agent_server_app/agent_server_app.py index ee41b65..566f71f 100644 --- a/agentkit/apps/agent_server_app/agent_server_app.py +++ b/agentkit/apps/agent_server_app/agent_server_app.py @@ -58,6 +58,7 @@ ) from agentkit.apps.agent_server_app.telemetry import telemetry from agentkit.apps.base_app import BaseAgentkitApp +from agentkit.apps.utils import SENSITIVE_HEADERS logger = logging.getLogger(__name__) @@ -304,7 +305,7 @@ async def _invoke_compat(request: Request): telemetry_headers = { k: v for k, v in dict(headers).items() - if k.lower() not in {"authorization", "token"} + if k.lower() not in SENSITIVE_HEADERS } # trace request attributes on current span telemetry.trace_agent_server( @@ -383,10 +384,23 @@ async def event_generator(): # finish span on successful end of stream handled by middleware pass except Exception as e: + logger.exception("Error in /invoke event_generator: %s", e) telemetry.trace_agent_server_finish( path="/invoke", func_result="", exception=e ) - yield f'data: {{"error": "{str(e)}"}}\n\n' + # Do not echo internal exception details to the client; + # keep parity with the /run_sse error frame above. + yield ( + "data: " + + json.dumps( + { + "error": "internal error while running agent; " + "see server logs", + "error_type": type(e).__name__, + } + ) + + "\n\n" + ) return StreamingResponse( event_generator(), diff --git a/agentkit/apps/agent_server_app/middleware.py b/agentkit/apps/agent_server_app/middleware.py index b006c2b..7ea8321 100644 --- a/agentkit/apps/agent_server_app/middleware.py +++ b/agentkit/apps/agent_server_app/middleware.py @@ -18,8 +18,9 @@ from opentelemetry import context as context_api from agentkit.apps.agent_server_app.telemetry import telemetry +from agentkit.apps.utils import SENSITIVE_HEADERS -_EXCLUDED_HEADERS = {"authorization", "token"} +_EXCLUDED_HEADERS = SENSITIVE_HEADERS class AgentkitTelemetryHTTPMiddleware: diff --git a/agentkit/apps/simple_app/telemetry.py b/agentkit/apps/simple_app/telemetry.py index 9dea82a..ea68759 100644 --- a/agentkit/apps/simple_app/telemetry.py +++ b/agentkit/apps/simple_app/telemetry.py @@ -15,7 +15,6 @@ import logging import time -import traceback from typing import Callable from opentelemetry import trace @@ -23,7 +22,11 @@ from opentelemetry.metrics import get_meter from opentelemetry.trace.span import Span -from agentkit.apps.utils import safe_serialize_to_json_string +from agentkit.apps.utils import ( + SENSITIVE_HEADERS, + dont_throw, + safe_serialize_to_json_string, +) _GEN_AI_CLIENT_OPERATION_DURATION_BUCKETS = [ 0.01, @@ -46,36 +49,13 @@ logger = logging.getLogger("agentkit." + __name__) -# Keep in sync with agent_server_app.middleware._EXCLUDED_HEADERS: credential -# headers must never be recorded on spans (they bypass the logging redaction). -_EXCLUDED_HEADERS = {"authorization", "token"} +_EXCLUDED_HEADERS = SENSITIVE_HEADERS def _redact_headers(headers: dict) -> dict: return {k: v for k, v in headers.items() if k.lower() not in _EXCLUDED_HEADERS} -def dont_throw(func): - """ - A decorator that wraps the passed in function and logs exceptions instead of throwing them. - - @param func: The function to wrap - @return: The wrapper function - """ - - def wrapper(*args, **kwargs): - try: - return func(*args, **kwargs) - except Exception: - logger.error( - "Agentkit failed to trace in %s, error: %s", - func.__name__, - traceback.format_exc(), - ) - - return wrapper - - class Telemetry: def __init__(self): self.tracer = get_tracer("agentkit.simple_app") diff --git a/agentkit/apps/utils.py b/agentkit/apps/utils.py index 2144ba4..2b3c994 100644 --- a/agentkit/apps/utils.py +++ b/agentkit/apps/utils.py @@ -14,9 +14,47 @@ import json import logging +import traceback logger = logging.getLogger("agentkit." + __name__) +# Credential-bearing headers that must never be recorded on telemetry spans: +# span attributes bypass the logging redaction filter. Single source of truth +# for every app's telemetry/middleware header filtering. +SENSITIVE_HEADERS = frozenset( + { + "authorization", + "proxy-authorization", + "token", + "x-security-token", # STS credentials (see agentkit.auth._sigv4) + "x-api-key", + "api-key", + "cookie", + "set-cookie", + } +) + + +def dont_throw(func): + """ + A decorator that wraps the passed in function and logs exceptions instead of throwing them. + + @param func: The function to wrap + @return: The wrapper function + """ + + def wrapper(*args, **kwargs): + try: + return func(*args, **kwargs) + except Exception: + logger.error( + "Agentkit failed to trace in %s, error: %s", + func.__name__, + traceback.format_exc(), + ) + + return wrapper + def safe_serialize_to_json_string(obj): """Safely serialize object directly to JSON string with progressive fallback handling. diff --git a/agentkit/toolkit/cli/sandbox/cli_exec.py b/agentkit/toolkit/cli/sandbox/cli_exec.py index b51798d..da810d0 100644 --- a/agentkit/toolkit/cli/sandbox/cli_exec.py +++ b/agentkit/toolkit/cli/sandbox/cli_exec.py @@ -158,6 +158,13 @@ def _write_output(data: object) -> None: sys.stdout.flush() +# Handshake must not block forever; keepalive detects a silently dead server +# (the app-level ping is server-initiated, so the client needs its own probe). +WS_CONNECT_TIMEOUT_SECONDS = 30 +WS_PING_INTERVAL_SECONDS = 30 +WS_PING_TIMEOUT_SECONDS = 10 + + def _connect_terminal( ws_url: str, initial_command: Optional[str], @@ -171,6 +178,10 @@ def _connect_terminal( "Install with: pip install websocket-client" ) + # run_forever applies getdefaulttimeout() to the connect handshake; + # process-global is acceptable for a one-shot CLI command. + websocket.setdefaulttimeout(WS_CONNECT_TIMEOUT_SECONDS) + stop_event = threading.Event() initial_command_sent = {"value": False} websocket_app = None @@ -245,7 +256,10 @@ def on_resize(_signum, _frame) -> None: err=True, ) with _raw_terminal_mode(): - websocket_app.run_forever() + websocket_app.run_forever( + ping_interval=WS_PING_INTERVAL_SECONDS, + ping_timeout=WS_PING_TIMEOUT_SECONDS, + ) except KeyboardInterrupt: websocket_app.close() finally: diff --git a/agentkit/toolkit/volcengine/cr.py b/agentkit/toolkit/volcengine/cr.py index 1e63ff4..32ee9eb 100644 --- a/agentkit/toolkit/volcengine/cr.py +++ b/agentkit/toolkit/volcengine/cr.py @@ -24,6 +24,11 @@ DEFAULT_CR_NAMESPACE_NAME = "agenkit-platform-namespace" DEFAULT_CR_REPO_NAME = "agentkit-platform-repo" +# Instance provisioning normally finishes within minutes; the deadline only +# guards against an instance stuck in Creating/unknown status forever. +CR_INSTANCE_CREATE_TIMEOUT_SECONDS = 1800 +CR_INSTANCE_POLL_INTERVAL_SECONDS = 30 + class VeCR: def __init__( @@ -107,15 +112,21 @@ def _create_instance( f"Error create cr instance {instance_name}: {error_code} {error_message}" ) + deadline = time.monotonic() + CR_INSTANCE_CREATE_TIMEOUT_SECONDS while True: status = self._check_instance(instance_name) if status == "Running": break elif status == "Failed": raise ValueError(f"cr instance {instance_name} create failed") + elif time.monotonic() >= deadline: + raise ValueError( + f"cr instance {instance_name} not Running after " + f"{CR_INSTANCE_CREATE_TIMEOUT_SECONDS}s (last status: {status})" + ) else: logger.debug(f"cr instance status: {status}") - time.sleep(30) + time.sleep(CR_INSTANCE_POLL_INTERVAL_SECONDS) return instance_name diff --git a/pyproject.toml b/pyproject.toml index 1c9f8f5..771d6ed 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,6 +62,11 @@ include-package-data = true [dependency-groups] dev = [ "pre-commit>=4.3.0", + "pytest>=8.0", + "pytest-cov>=5.0", + # tests/apps drives the ADK-based agent server; google-adk is otherwise + # an optional runtime integration, not a package dependency. + "google-adk>=2.0", ] toolkit = [ "pyyaml>=6.0.2", @@ -74,6 +79,12 @@ extensions = [] toolkit = [] dev = [] +[tool.pytest.ini_options] +testpaths = ["tests"] + +[tool.coverage.run] +source = ["agentkit"] + [tool.ruff.lint] [tool.ruff] diff --git a/tests/apps/test_agent_server_invoke.py b/tests/apps/test_agent_server_invoke.py index 2570651..742c33b 100644 --- a/tests/apps/test_agent_server_invoke.py +++ b/tests/apps/test_agent_server_invoke.py @@ -658,8 +658,14 @@ def test_invoke_stream_emits_error_frame_when_runner_raises(): response = asyncio.run(invoke(request)) chunks = asyncio.run(_drain(response)) - # The generator catches the exception and yields a single error SSE frame. - assert chunks == ['data: {"error": "runner exploded"}\n\n'] + # The generator catches the exception and yields a single error SSE frame + # carrying only the exception type — never the exception message, which + # may leak internal paths/backend detail (parity with /run_sse). + assert len(chunks) == 1 + assert chunks[0].startswith("data: ") and chunks[0].endswith("\n\n") + frame = json.loads(chunks[0][len("data: ") :]) + assert frame["error_type"] == "RuntimeError" + assert "runner exploded" not in chunks[0] def test_invoke_stream_error_path_traces_finish_with_the_exception(): diff --git a/tests/apps/test_agent_server_middleware.py b/tests/apps/test_agent_server_middleware.py index 83a9d39..a4d1246 100644 --- a/tests/apps/test_agent_server_middleware.py +++ b/tests/apps/test_agent_server_middleware.py @@ -124,6 +124,9 @@ async def _app(scope, receive, send): "headers": [ (b"authorization", b"Bearer secret-token"), (b"token", b"another-secret"), + (b"x-security-token", b"sts-session-token"), + (b"cookie", b"session=abc"), + (b"x-api-key", b"key-123"), (b"content-type", b"text/plain"), ], } @@ -135,6 +138,9 @@ async def _app(scope, receive, send): headers = kwargs["headers"] assert "authorization" not in headers assert "token" not in headers + assert "x-security-token" not in headers + assert "cookie" not in headers + assert "x-api-key" not in headers # The non-sensitive header survives with its decoded value. assert headers == {"content-type": "text/plain"}