diff --git a/.github/actions/conformance/expected-failures.2026-07-28.yml b/.github/actions/conformance/expected-failures.2026-07-28.yml index 39fcdde48..702575ce4 100644 --- a/.github/actions/conformance/expected-failures.2026-07-28.yml +++ b/.github/actions/conformance/expected-failures.2026-07-28.yml @@ -28,10 +28,6 @@ client: # neither run nor evaluated on this leg. server: - # The stateless 2026 path now reaches handlers for plain request/response - # scenarios; tools-call-with-progress still fails because the stateless - # server has no channel for server→client progress notifications. - - tools-call-with-progress # SEP-2322 (multi-round-trip requests / IncompleteResult): the prompt pipeline # cannot return InputRequiredResult from MCPServer yet (tools/call can). - input-required-result-non-tool-request diff --git a/examples/stories/manifest.toml b/examples/stories/manifest.toml index 7a1f079e8..22c09aa4a 100644 --- a/examples/stories/manifest.toml +++ b/examples/stories/manifest.toml @@ -31,8 +31,6 @@ era = "dual-in-body" multi_connection = true [story.streaming] -# progress + log notifications dropped on the modern streamable-HTTP path pending SSE wiring -xfail = ["http-asgi:modern"] [story.legacy_elicitation] era = "legacy" diff --git a/examples/stories/streaming/README.md b/examples/stories/streaming/README.md index 86e2e7478..e6bedb915 100644 --- a/examples/stories/streaming/README.md +++ b/examples/stories/streaming/README.md @@ -17,16 +17,12 @@ uv run python -m stories.streaming.client uv run python -m stories.streaming.client --server server_lowlevel # HTTP — the client self-hosts the server on a free port, runs, then tears it -# down (--legacy: see the note below) -uv run python -m stories.streaming.client --http --legacy +# down +uv run python -m stories.streaming.client --http # same, against the lowlevel-API server variant -uv run python -m stories.streaming.client --http --legacy --server server_lowlevel +uv run python -m stories.streaming.client --http --server server_lowlevel ``` -The modern HTTP leg (drop `--legacy`) is `xfail` until the SSE wiring lands — -mid-call progress and log notifications are currently dropped there (see -Caveats). - ## What to look at - `client.py` `main` — opens with `async with Client(target, mode=mode, @@ -60,9 +56,6 @@ Caveats). OpenTelemetry instead of `notifications/message`. It is shown here because servers still need to support 2025-era clients during that window. Progress and cancellation are **not** deprecated. TODO(maxisbey): revisit before beta. -- On the modern (2026-07-28) streamable-HTTP path, mid-call progress and log - notifications are currently dropped pending the SSE wiring; the - `http-asgi:modern` leg of this story is `xfail` until that lands. - When a request is cancelled the server currently replies with `ErrorData(code=0, message="Request cancelled")`; the spec says it should not reply at all. The client never observes it (its awaiting task is already diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index 551bfa5f3..f28eb7c7a 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -156,7 +156,10 @@ async def _handle_sse_event( # Otherwise, return False to continue listening return isinstance(message, JSONRPCResponse | JSONRPCError) - except Exception as exc: # pragma: no cover + # Forwarding to a closed read stream lands here when the caller cancels mid-SSE + # (BrokenResourceError, not a parse failure); coverage is timing-dependent in the + # streaming story's modern HTTP cancellation leg. + except Exception as exc: # pragma: lax no cover logger.exception("Error parsing SSE message") if original_request_id is not None: error_data = ErrorData(code=PARSE_ERROR, message=f"Failed to parse SSE message: {exc}") @@ -372,7 +375,7 @@ async def _handle_sse_response( await response.aclose() return # Normal completion, no reconnect needed except Exception: - logger.debug("SSE stream ended", exc_info=True) # pragma: no cover + logger.debug("SSE stream ended", exc_info=True) # pragma: lax no cover # Stream ended without response - reconnect if we received an event with ID if last_event_id is not None: # pragma: no branch diff --git a/src/mcp/server/_streamable_http_modern.py b/src/mcp/server/_streamable_http_modern.py index cecf21f08..e36ac7dd4 100644 --- a/src/mcp/server/_streamable_http_modern.py +++ b/src/mcp/server/_streamable_http_modern.py @@ -5,9 +5,15 @@ path for earlier protocol revisions. A 2026-07-28 request is a self-contained POST: no `initialize` handshake, no -`Mcp-Session-Id`, one JSON-RPC request in, one JSON-RPC response out. This -module handles such a request directly in the ASGI task - no memory streams, -no per-request task group, no `JSONRPCDispatcher`. +`Mcp-Session-Id`, one JSON-RPC request in, one JSON-RPC response out. JSON +mode handles the request directly in the ASGI task. SSE mode runs the handler +as a sibling task and defers committing to `text/event-stream` until the +handler emits a notification or `_SSE_PING_INTERVAL` elapses, whichever +comes first: a handler that completes (or raises) within that window without +emitting still gets a JSON response with the table-mapped HTTP status, so +the spec's `404`/`400` MUSTs hold for kernel-dispatch errors; a handler that +runs silent past the window commits SSE so the keepalive ping can keep the +connection open behind a proxy idle-read timeout. """ from __future__ import annotations @@ -16,9 +22,10 @@ import logging from collections.abc import Awaitable, Mapping from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, TypeVar +from typing import TYPE_CHECKING, Any, Final, TypeVar import anyio +from anyio.streams.memory import MemoryObjectSendStream from mcp_types import ( INTERNAL_ERROR, INVALID_REQUEST, @@ -27,8 +34,10 @@ ErrorData, Implementation, JSONRPCError, + JSONRPCNotification, JSONRPCRequest, JSONRPCResponse, + ProgressToken, RequestId, ) from pydantic import BaseModel, ValidationError @@ -38,6 +47,7 @@ from mcp.server.connection import Connection from mcp.server.runner import serve_one +from mcp.server.streamable_http import check_accept_headers from mcp.server.transport_security import TransportSecurityMiddleware, TransportSecuritySettings from mcp.shared.dispatcher import CallOptions from mcp.shared.exceptions import NoBackChannelError @@ -46,7 +56,7 @@ InboundLadderRejection, classify_inbound_request, ) -from mcp.shared.jsonrpc_dispatcher import handler_exception_to_error_data +from mcp.shared.jsonrpc_dispatcher import handler_exception_to_error_data, progress_token_from_params from mcp.shared.message import MessageMetadata, ServerMessageMetadata from mcp.shared.transport_context import TransportContext @@ -66,12 +76,15 @@ class _SingleExchangeDispatchContext: Structurally satisfies `mcp.shared.dispatcher.DispatchContext`. The back-channel is closed by construction: a 2026-07-28 server cannot send - requests to the client. + requests to the client. The SSE sink, when present, carries request-scoped + notifications onto this request's response stream. """ transport: TransportContext request_id: RequestId message_metadata: MessageMetadata + progress_token: ProgressToken | None = None + sink: MemoryObjectSendStream[bytes] | None = None cancel_requested: anyio.Event = field(default_factory=anyio.Event) can_send_request: bool = field(default=False, init=False) @@ -84,12 +97,23 @@ async def send_raw_request( raise NoBackChannelError(method) async def notify(self, method: str, params: Mapping[str, Any] | None, opts: CallOptions | None = None) -> None: - # TODO(D-005a): buffer and stream as SSE once the JSON-vs-SSE response mode lands. - return None + if self.sink is None: + return + body = dict(params) if params is not None else None + try: + await self.sink.send(_sse_event(JSONRPCNotification(jsonrpc="2.0", method=method, params=body))) + except (anyio.ClosedResourceError, anyio.BrokenResourceError): + logger.debug("dropped %s: response stream closed", method) async def progress(self, progress: float, total: float | None = None, message: str | None = None) -> None: - # TODO(D-005a): no progressToken plumbing yet; ships with the SSE response mode. - return None + if self.progress_token is None: + return + params: dict[str, Any] = {"progressToken": self.progress_token, "progress": progress} + if total is not None: + params["total"] = total + if message is not None: + params["message"] = message + await self.notify("notifications/progress", params) def _typed(model: type[_ModelT], raw: Any) -> _ModelT | None: @@ -126,6 +150,28 @@ async def _to_jsonrpc_response( return JSONRPCResponse(jsonrpc="2.0", id=request_id, result=result) +_SSE_PING_INTERVAL: float = 15.0 +"""Seconds between SSE comment-line keepalives once `text/event-stream` has committed.""" + +_SSE_HEADERS: Final[list[tuple[bytes, bytes]]] = [ + (b"content-type", b"text/event-stream"), + (b"cache-control", b"no-cache, no-transform"), + (b"connection", b"keep-alive"), + (b"x-accel-buffering", b"no"), +] + + +def _sse_event(msg: JSONRPCResponse | JSONRPCError | JSONRPCNotification) -> bytes: + """Serialise a JSON-RPC message as one SSE `event: message` frame. + + SSE mode begins after the handler has emitted, so a `JSONRPCError` here + always carries the request's id; the `id: null` case lives in `_write`. + """ + body = msg.model_dump(mode="json", by_alias=True, exclude_none=True) + data = json.dumps(body, separators=(",", ":")) + return f"event: message\r\ndata: {data}\r\n\r\n".encode() + + async def _write( msg: JSONRPCResponse | JSONRPCError, scope: Scope, @@ -149,6 +195,7 @@ async def _write( async def handle_modern_request( app: Server[Any], security_settings: TransportSecuritySettings | None, + json_response: bool, lifespan_state: Any, scope: Scope, receive: Receive, @@ -169,14 +216,17 @@ async def handle_modern_request( await err(scope, receive, send) return - # TODO(D-005a): validate Accept once the JSON-vs-SSE response mode is settled. - if request.method != "POST": # HTTP-layer rejection (Allow accompanies 405 per RFC 9110) — happens # before JSON-RPC parsing, so it doesn't go through `_write`. await Response(status_code=405, headers={"Allow": "POST"})(scope, receive, send) return + has_json, has_sse = check_accept_headers(request) + if not has_json or (not json_response and not has_sse): + await Response(status_code=406)(scope, receive, send) + return + body = await request.body() try: decoded = json.loads(body) @@ -219,8 +269,65 @@ async def handle_modern_request( transport=TransportContext(kind="streamable-http", can_send_request=False, headers=request.headers), request_id=req.id, message_metadata=ServerMessageMetadata(request_context=request), + progress_token=progress_token_from_params(req.params), ) - msg = await _to_jsonrpc_response( - req.id, serve_one(app, dctx, req.method, req.params, connection=connection, lifespan_state=lifespan_state) - ) - await _write(msg, scope, receive, send) + + if json_response: + msg = await _to_jsonrpc_response( + req.id, serve_one(app, dctx, req.method, req.params, connection=connection, lifespan_state=lifespan_state) + ) + await _write(msg, scope, receive, send) + return + + send_ch, recv_ch = anyio.create_memory_object_stream[bytes](0) + dctx.sink = send_ch + result: list[JSONRPCResponse | JSONRPCError] = [] + + async def run_handler() -> None: + async with send_ch: + result.append( + await _to_jsonrpc_response( + req.id, + serve_one(app, dctx, req.method, req.params, connection=connection, lifespan_state=lifespan_state), + ) + ) + + async def watch_disconnect(cancel_scope: anyio.CancelScope) -> None: + while (await receive()).get("type") != "http.disconnect": + pass # pragma: no cover + cancel_scope.cancel() + + async with recv_ch, anyio.create_task_group() as tg: + tg.start_soon(run_handler) + tg.start_soon(watch_disconnect, tg.cancel_scope) + + event: bytes | None = None + done = False + with anyio.move_on_after(_SSE_PING_INTERVAL): + try: + event = await recv_ch.receive() + except anyio.EndOfStream: + done = True + + if done: + # Handler completed within the deferral window without emitting: + # `application/json` with the table-mapped status. Kernel-dispatch + # errors (METHOD_NOT_FOUND, missing-capability, INVALID_PARAMS) + # resolve here in practice. + await _write(result[0], scope, receive, send) + else: + # First notification arrived, or the deferral window elapsed: commit + # `text/event-stream` and start pinging so a proxy idle-read timeout + # cannot close the stream (which on this path cancels the handler). + await send({"type": "http.response.start", "status": _OK_STATUS, "headers": _SSE_HEADERS}) + while not done: + await send({"type": "http.response.body", "body": event or b": ping\r\n\r\n", "more_body": True}) + event = None + with anyio.move_on_after(_SSE_PING_INTERVAL): + try: + event = await recv_ch.receive() + except anyio.EndOfStream: + done = True + await send({"type": "http.response.body", "body": _sse_event(result[0]), "more_body": False}) + + tg.cancel_scope.cancel() diff --git a/src/mcp/server/streamable_http.py b/src/mcp/server/streamable_http.py index b6128d3e0..d316345c7 100644 --- a/src/mcp/server/streamable_http.py +++ b/src/mcp/server/streamable_http.py @@ -76,6 +76,24 @@ SSEEvent = dict[str, Any] +def check_accept_headers(request: Request) -> tuple[bool, bool]: + """Return (has_json, has_sse) for the request's Accept header, with RFC 7231 wildcard handling. + + Supports wildcard media types per RFC 7231, section 5.3.2: + - */* matches any media type + - application/* matches any application/ subtype + - text/* matches any text/ subtype + """ + accept_header = request.headers.get("accept", "") + accept_types = [media_type.strip().split(";")[0].strip().lower() for media_type in accept_header.split(",")] + + has_wildcard = "*/*" in accept_types + has_json = has_wildcard or any(t in (CONTENT_TYPE_JSON, "application/*") for t in accept_types) + has_sse = has_wildcard or any(t in (CONTENT_TYPE_SSE, "text/*") for t in accept_types) + + return has_json, has_sse + + @dataclass class EventMessage: """A JSONRPCMessage with an optional event ID for stream resumability.""" @@ -415,23 +433,6 @@ async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> No else: await self._handle_unsupported_request(request, send) - def _check_accept_headers(self, request: Request) -> tuple[bool, bool]: - """Check if the request accepts the required media types. - - Supports wildcard media types per RFC 7231, section 5.3.2: - - */* matches any media type - - application/* matches any application/ subtype - - text/* matches any text/ subtype - """ - accept_header = request.headers.get("accept", "") - accept_types = [media_type.strip().split(";")[0].strip().lower() for media_type in accept_header.split(",")] - - has_wildcard = "*/*" in accept_types - has_json = has_wildcard or any(t in (CONTENT_TYPE_JSON, "application/*") for t in accept_types) - has_sse = has_wildcard or any(t in (CONTENT_TYPE_SSE, "text/*") for t in accept_types) - - return has_json, has_sse - def _check_content_type(self, request: Request) -> bool: """Check if the request has the correct Content-Type.""" content_type = request.headers.get("content-type", "") @@ -441,7 +442,7 @@ def _check_content_type(self, request: Request) -> bool: async def _validate_accept_header(self, request: Request, scope: Scope, send: Send) -> bool: """Validate Accept header based on response mode. Returns True if valid.""" - has_json, has_sse = self._check_accept_headers(request) + has_json, has_sse = check_accept_headers(request) if self.is_json_response_enabled: # For JSON-only responses, only require application/json if not has_json: @@ -661,7 +662,7 @@ async def _handle_get_request(self, request: Request, send: Send) -> None: raise ValueError("No read stream writer available. Ensure connect() is called first.") # Validate Accept header - must include text/event-stream - _, has_sse = self._check_accept_headers(request) + _, has_sse = check_accept_headers(request) if not has_sse: response = self._create_error_response( diff --git a/src/mcp/server/streamable_http_manager.py b/src/mcp/server/streamable_http_manager.py index 578639853..60b098961 100644 --- a/src/mcp/server/streamable_http_manager.py +++ b/src/mcp/server/streamable_http_manager.py @@ -170,7 +170,9 @@ async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> No header = MCP_PROTOCOL_VERSION_HEADER.encode("ascii") pv = next((v.decode("latin-1") for k, v in scope["headers"] if k == header), None) if pv is not None and pv not in HANDSHAKE_PROTOCOL_VERSIONS: - await handle_modern_request(self.app, self.security_settings, self._lifespan_state, scope, receive, send) + await handle_modern_request( + self.app, self.security_settings, self.json_response, self._lifespan_state, scope, receive, send + ) return # Dispatch to the appropriate handler diff --git a/src/mcp/shared/jsonrpc_dispatcher.py b/src/mcp/shared/jsonrpc_dispatcher.py index 2e784e2c2..64fcd3298 100644 --- a/src/mcp/shared/jsonrpc_dispatcher.py +++ b/src/mcp/shared/jsonrpc_dispatcher.py @@ -49,7 +49,7 @@ ) from mcp.shared.transport_context import TransportContext -__all__ = ["JSONRPCDispatcher", "handler_exception_to_error_data"] +__all__ = ["JSONRPCDispatcher", "handler_exception_to_error_data", "progress_token_from_params"] logger = logging.getLogger(__name__) @@ -84,6 +84,15 @@ def handler_exception_to_error_data(exc: BaseException) -> ErrorData | None: return None +def progress_token_from_params(params: Mapping[str, Any] | None) -> ProgressToken | None: + """Read `params._meta.progressToken`; reject bool (bool subclasses int, so True would alias 1).""" + match params: + case {"_meta": {"progressToken": str() | int() as token}} if not isinstance(token, bool): + return token + case _: + return None + + def _coerce_id(request_id: RequestId) -> RequestId: """Coerce a stringified int request ID back to int so a peer-echoed ID still correlates (matches the TS SDK).""" if isinstance(request_id, str): @@ -515,13 +524,7 @@ async def _dispatch_request( on_request: OnRequest, sender_ctx: contextvars.Context | None, ) -> None: - progress_token: ProgressToken | None - match req.params: - # bool subclasses int: without the guard True would alias request id 1. - case {"_meta": {"progressToken": str() | int() as progress_token}} if not isinstance(progress_token, bool): - pass - case _: - progress_token = None + progress_token = progress_token_from_params(req.params) try: transport_ctx = self._transport_builder(metadata) except Exception: diff --git a/tests/examples/conftest.py b/tests/examples/conftest.py index ffe22caad..48c5bffa5 100644 --- a/tests/examples/conftest.py +++ b/tests/examples/conftest.py @@ -99,7 +99,7 @@ def pytest_generate_tests(metafunc: pytest.Metafunc) -> None: for leg, cfg in _legs(): marks: list[pytest.MarkDecorator] = [] if f"{leg.transport}:{leg.era}" in cfg["xfail"]: - marks.append(pytest.mark.xfail(strict=True, reason="manifest xfail")) + marks.append(pytest.mark.xfail(strict=True, reason="manifest xfail")) # pragma: lax no cover params.append(pytest.param(leg, marks=marks, id=leg.id)) metafunc.parametrize("leg", params) diff --git a/tests/interaction/_requirements.py b/tests/interaction/_requirements.py index 0150513e2..341dde877 100644 --- a/tests/interaction/_requirements.py +++ b/tests/interaction/_requirements.py @@ -93,12 +93,6 @@ "unimplemented." ) -_MODERN_NOTIFY_DROP = ( - "The modern single-exchange dispatch context no-ops notify() on the streamable-http driver; " - "handler-emitted logging/progress notifications never reach the per-request SSE response. " - "Passes once SSE response mode lands." -) - @dataclass(frozen=True, kw_only=True) class Divergence: @@ -656,9 +650,6 @@ def __post_init__(self) -> None: "Progress notifications emitted by a handler during a request are delivered to the caller's " "progress callback, in order, with their progress, total, and message." ), - known_failures=( - KnownFailure(spec_version="2026-07-28", transport="streamable-http", note=_MODERN_NOTIFY_DROP, issue=None), - ), ), "protocol:progress:token-injected": Requirement( source=f"{SPEC_BASE_URL}/basic/utilities/progress#progress-flow", @@ -676,9 +667,6 @@ def __post_init__(self) -> None: "interleaved emission. Token distinctness is the JSON-RPC mechanism for that; the in-process " "direct dispatcher carries the callback per-request without a wire-level token." ), - known_failures=( - KnownFailure(spec_version="2026-07-28", transport="streamable-http", note=_MODERN_NOTIFY_DROP, issue=None), - ), ), "protocol:progress:monotonic": Requirement( source=f"{SPEC_BASE_URL}/basic/utilities/progress#progress-flow", @@ -691,9 +679,6 @@ def __post_init__(self) -> None: "handler that emits non-increasing values has them forwarded to the callback unchanged." ), ), - known_failures=( - KnownFailure(spec_version="2026-07-28", transport="streamable-http", note=_MODERN_NOTIFY_DROP, issue=None), - ), ), "protocol:progress:stops-after-completion": Requirement( source=f"{SPEC_BASE_URL}/basic/utilities/progress#behavior-requirements", @@ -831,9 +816,6 @@ def __post_init__(self) -> None: "Log notifications emitted by a tool handler during execution reach the client's logging " "callback before the tool result returns." ), - known_failures=( - KnownFailure(spec_version="2026-07-28", transport="streamable-http", note=_MODERN_NOTIFY_DROP, issue=None), - ), ), "tools:call:progress": Requirement( source=f"{SPEC_BASE_URL}/basic/utilities/progress#progress-flow", @@ -841,9 +823,6 @@ def __post_init__(self) -> None: "Progress notifications emitted by a tool handler reach the caller's progress callback before " "the tool result returns." ), - known_failures=( - KnownFailure(spec_version="2026-07-28", transport="streamable-http", note=_MODERN_NOTIFY_DROP, issue=None), - ), ), "tools:call:sampling-roundtrip": Requirement( source=f"{SPEC_BASE_URL}/client/sampling#creating-messages", @@ -1064,18 +1043,12 @@ def __post_init__(self) -> None: "The Context logging helpers (debug/info/warning/error) send log message notifications at the " "corresponding severity." ), - known_failures=( - KnownFailure(spec_version="2026-07-28", transport="streamable-http", note=_MODERN_NOTIFY_DROP, issue=None), - ), ), "mcpserver:context:progress": Requirement( source="sdk", behavior=( "Context.report_progress sends a progress notification against the requesting client's progress token." ), - known_failures=( - KnownFailure(spec_version="2026-07-28", transport="streamable-http", note=_MODERN_NOTIFY_DROP, issue=None), - ), ), "mcpserver:context:elicit": Requirement( source="sdk", @@ -1433,9 +1406,6 @@ def __post_init__(self) -> None: "logging:message:all-levels": Requirement( source=f"{SPEC_BASE_URL}/server/utilities/logging#log-levels", behavior="All eight RFC 5424 severity levels are deliverable as log message notifications.", - known_failures=( - KnownFailure(spec_version="2026-07-28", transport="streamable-http", note=_MODERN_NOTIFY_DROP, issue=None), - ), ), "logging:message:fields": Requirement( source=f"{SPEC_BASE_URL}/server/utilities/logging#log-message-notifications", @@ -1443,9 +1413,6 @@ def __post_init__(self) -> None: "A log message sent by a server handler is delivered to the client's logging callback with its " "severity level, logger name, and data." ), - known_failures=( - KnownFailure(spec_version="2026-07-28", transport="streamable-http", note=_MODERN_NOTIFY_DROP, issue=None), - ), ), "logging:message:filtered": Requirement( source=f"{SPEC_BASE_URL}/server/utilities/logging#setting-log-level", @@ -2065,17 +2032,6 @@ def __post_init__(self) -> None: "client cannot learn that the set changed without polling." ), ), - known_failures=( - KnownFailure( - spec_version="2026-07-28", - transport="streamable-http", - note=( - "List-mutation assertions hold; only the sentinel ctx.info() never reaches the client. " - + _MODERN_NOTIFY_DROP - ), - issue=None, - ), - ), ), # ═══════════════════════════════════════════════════════════════════════════ # Pagination diff --git a/tests/server/test_streamable_http_modern.py b/tests/server/test_streamable_http_modern.py index 0ba61cf39..6e8df458d 100644 --- a/tests/server/test_streamable_http_modern.py +++ b/tests/server/test_streamable_http_modern.py @@ -6,6 +6,7 @@ ``handle_modern_request``. """ +import json import logging from typing import Any @@ -26,11 +27,14 @@ JSONRPCError, JSONRPCResponse, ListToolsResult, + LoggingMessageNotification, + LoggingMessageNotificationParams, PaginatedRequestParams, Tool, ) from mcp_types.version import LATEST_MODERN_VERSION -from starlette.types import Receive, Scope, Send +from starlette.types import Message, Receive, Scope, Send +from trio.testing import MockClock from mcp.server import Server, ServerRequestContext, runner from mcp.server._streamable_http_modern import ( @@ -42,12 +46,14 @@ from mcp.shared.exceptions import MCPError, NoBackChannelError from mcp.shared.inbound import MCP_METHOD_HEADER, MCP_NAME_HEADER, MCP_PROTOCOL_VERSION_HEADER from mcp.shared.transport_context import TransportContext +from tests.interaction.transports import StreamingASGITransport pytestmark = pytest.mark.anyio async def test_single_exchange_dispatch_context_has_no_back_channel() -> None: - """The per-request dispatch context refuses server-initiated requests and drops notify/progress.""" + """The per-request dispatch context refuses server-initiated requests; without an SSE sink, + notify/progress are no-ops.""" dctx = _SingleExchangeDispatchContext( transport=TransportContext(kind="streamable-http", can_send_request=False), request_id=1, @@ -60,17 +66,24 @@ async def test_single_exchange_dispatch_context_has_no_back_channel() -> None: assert await dctx.progress(0.5, total=1.0, message="half") is None -def _asgi_client(server: Server[Any], security_settings: TransportSecuritySettings | None = None) -> httpx.AsyncClient: +def _asgi_client( + server: Server[Any], + security_settings: TransportSecuritySettings | None = None, + *, + json_response: bool = True, + accept: str = "application/json, text/event-stream", +) -> httpx.AsyncClient: async def app(scope: Scope, receive: Receive, send: Send) -> None: async with server.lifespan(server) as lifespan_state: - await handle_modern_request(server, security_settings, lifespan_state, scope, receive, send) + await handle_modern_request(server, security_settings, json_response, lifespan_state, scope, receive, send) return httpx.AsyncClient( - transport=httpx.ASGITransport(app=app), + transport=StreamingASGITransport(app), base_url="http://testserver", headers={ MCP_PROTOCOL_VERSION_HEADER: LATEST_MODERN_VERSION, "content-type": "application/json", + "accept": accept, }, ) @@ -301,3 +314,336 @@ async def test_handle_modern_request_rejects_mismatched_name_header_with_400_and ) assert response.status_code == 400 assert response.json()["error"]["code"] == HEADER_MISMATCH + + +# --- SSE response mode --------------------------------------------------------- + + +def _sse_payloads(body: str) -> list[dict[str, Any]]: + """Parse an SSE body into the list of JSON `data:` payloads, in delivery order.""" + return [ + json.loads(line.removeprefix("data:").strip()) + for line in body.replace("\r\n", "\n").splitlines() + if line.startswith("data:") + ] + + +def _list_tools_body_with_token(token: str | int) -> dict[str, Any]: + body = _list_tools_body() + body["params"]["_meta"]["progressToken"] = token + return body + + +async def test_sse_mode_streams_progress_then_result() -> None: + """SSE mode: a handler's `report_progress` calls stream as `notifications/progress` events + (carrying the request's progressToken) before the terminal JSON-RPC response event. + + Spec-mandated: `notifications/progress` carries the caller's token; the per-request SSE stream + closes after the terminal response. Asserted at the wire because Content-Type and event order + are the contract. + """ + + async def list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + await ctx.session.report_progress(1.0, total=3.0) + await ctx.session.report_progress(2.0, total=3.0, message="almost") + return ListToolsResult(tools=[], ttl_ms=0, cache_scope="public") + + async with _asgi_client(Server("test", on_list_tools=list_tools), json_response=False) as http: + with anyio.fail_after(5): + response = await http.post( + "/mcp", json=_list_tools_body_with_token("tok-1"), headers={MCP_METHOD_HEADER: "tools/list"} + ) + + assert response.status_code == 200 + assert response.headers["content-type"].split(";", 1)[0] == "text/event-stream" + events = _sse_payloads(response.text) + assert len(events) == 3 + assert events[0] == { + "jsonrpc": "2.0", + "method": "notifications/progress", + "params": {"progressToken": "tok-1", "progress": 1.0, "total": 3.0}, + } + assert events[1] == { + "jsonrpc": "2.0", + "method": "notifications/progress", + "params": {"progressToken": "tok-1", "progress": 2.0, "total": 3.0, "message": "almost"}, + } + assert events[2]["id"] == 1 + assert events[2]["result"]["tools"] == [] + + +@pytest.mark.parametrize( + "anyio_backend", + [pytest.param(("trio", {"clock": MockClock(autojump_threshold=0)}), id="trio-mockclock")], +) +async def test_sse_mode_emits_keepalive_comment_between_events(monkeypatch: pytest.MonkeyPatch) -> None: + """SSE mode: while the stream is idle between events the server emits an SSE comment line so a + proxy idle-read timeout does not close the stream (which would cancel the handler). + SDK-defined: spec encourages keepalive comments for long-lived streams. + + Runs on trio's autojumping MockClock so the `move_on_after(_SSE_PING_INTERVAL)` deadlines and + the handler's `anyio.sleep` advance without wall-clock time.""" + monkeypatch.setattr("mcp.server._streamable_http_modern._SSE_PING_INTERVAL", 1.0) + + async def list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + await ctx.session.report_progress(1.0) + await anyio.sleep(2.5) + return ListToolsResult(tools=[], ttl_ms=0, cache_scope="public") + + async with _asgi_client(Server("test", on_list_tools=list_tools), json_response=False) as http: + with anyio.fail_after(5): + response = await http.post( + "/mcp", json=_list_tools_body_with_token("tok"), headers={MCP_METHOD_HEADER: "tools/list"} + ) + + assert response.headers["content-type"].split(";", 1)[0] == "text/event-stream" + assert response.content.count(b": ping\r\n\r\n") == 2 + events = _sse_payloads(response.text) + assert len(events) == 2 + assert events[0]["method"] == "notifications/progress" + assert events[1]["result"]["tools"] == [] + + +@pytest.mark.parametrize( + "anyio_backend", + [pytest.param(("trio", {"clock": MockClock(autojump_threshold=0)}), id="trio-mockclock")], +) +async def test_sse_mode_silent_handler_commits_sse_after_ping_interval(monkeypatch: pytest.MonkeyPatch) -> None: + """SSE mode: a handler that runs silent past the deferral window commits `text/event-stream` + and starts pinging — even though it never emits a notification — so a proxy idle-read timeout + does not close the connection and cancel it. SDK-defined: the deferral window is bounded by + `_SSE_PING_INTERVAL`. + + Runs on trio's autojumping MockClock; the 2.5s handler sleep takes no wall-clock time.""" + monkeypatch.setattr("mcp.server._streamable_http_modern._SSE_PING_INTERVAL", 1.0) + + async def list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + await anyio.sleep(2.5) + return ListToolsResult(tools=[], ttl_ms=0, cache_scope="public") + + async with _asgi_client(Server("test", on_list_tools=list_tools), json_response=False) as http: + with anyio.fail_after(5): + response = await http.post("/mcp", json=_list_tools_body(), headers={MCP_METHOD_HEADER: "tools/list"}) + + assert response.status_code == 200 + assert response.headers["content-type"].split(";", 1)[0] == "text/event-stream" + assert response.content.count(b": ping\r\n\r\n") == 2 + events = _sse_payloads(response.text) + assert len(events) == 1 + assert events[0]["result"]["tools"] == [] + + +async def test_sse_mode_streams_log_notification() -> None: + """SSE mode: a request-scoped `notifications/message` emitted by the handler precedes the + terminal response on the same stream. SDK-defined: notifications sent on the request's outbound + channel reach the per-request SSE response.""" + + async def list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + await ctx.session.send_notification( + LoggingMessageNotification(params=LoggingMessageNotificationParams(level="info", data="hello")), + related_request_id=ctx.request_id, + ) + return ListToolsResult(tools=[], ttl_ms=0, cache_scope="public") + + async with _asgi_client(Server("test", on_list_tools=list_tools), json_response=False) as http: + with anyio.fail_after(5): + response = await http.post("/mcp", json=_list_tools_body(), headers={MCP_METHOD_HEADER: "tools/list"}) + + assert response.headers["content-type"].split(";", 1)[0] == "text/event-stream" + events = _sse_payloads(response.text) + assert len(events) == 2 + assert events[0]["method"] == "notifications/message" + assert events[0]["params"] == {"level": "info", "data": "hello"} + assert events[1]["result"]["tools"] == [] + + +async def test_json_mode_drops_progress() -> None: + """JSON mode: `report_progress` is a no-op (no sink); the response is a plain + `application/json` body carrying only the terminal result. SDK-defined.""" + + async def list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + await ctx.session.report_progress(1, total=2) + return ListToolsResult(tools=[], ttl_ms=0, cache_scope="public") + + async with _asgi_client(Server("test", on_list_tools=list_tools), json_response=True) as http: + response = await http.post( + "/mcp", json=_list_tools_body_with_token("tok"), headers={MCP_METHOD_HEADER: "tools/list"} + ) + + assert response.headers["content-type"].split(";", 1)[0] == "application/json" + body = response.json() + assert body["id"] == 1 + assert body["result"]["tools"] == [] + assert "notifications/progress" not in response.text + + +async def test_sse_mode_error_before_any_notify_is_json_with_mapped_status() -> None: + """SSE mode: an error raised before the handler emits any notification is written as + `application/json` with the table-mapped HTTP status — SSE has not committed yet. + Spec-mandated: METHOD_NOT_FOUND MUST be `404 Not Found`.""" + + async def list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + raise MCPError(code=METHOD_NOT_FOUND, message="nope") + + async with _asgi_client(Server("test", on_list_tools=list_tools), json_response=False) as http: + with anyio.fail_after(5): + response = await http.post("/mcp", json=_list_tools_body(), headers={MCP_METHOD_HEADER: "tools/list"}) + + assert response.status_code == 404 + assert response.headers["content-type"].split(";", 1)[0] == "application/json" + assert response.json() == {"jsonrpc": "2.0", "id": 1, "error": {"code": METHOD_NOT_FOUND, "message": "nope"}} + + +async def test_sse_mode_error_after_notify_is_sse_event() -> None: + """SSE mode: an error raised after the handler has emitted is delivered as the terminal SSE + event (HTTP 200) — `text/event-stream` headers were committed on the first notification.""" + + async def list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + await ctx.session.report_progress(1.0) + raise MCPError(code=INTERNAL_ERROR, message="boom") + + async with _asgi_client(Server("test", on_list_tools=list_tools), json_response=False) as http: + with anyio.fail_after(5): + response = await http.post( + "/mcp", json=_list_tools_body_with_token("tok"), headers={MCP_METHOD_HEADER: "tools/list"} + ) + + assert response.status_code == 200 + assert response.headers["content-type"].split(";", 1)[0] == "text/event-stream" + events = _sse_payloads(response.text) + assert len(events) == 2 + assert events[0]["method"] == "notifications/progress" + assert events[1] == {"jsonrpc": "2.0", "id": 1, "error": {"code": INTERNAL_ERROR, "message": "boom"}} + + +async def test_sse_mode_no_notify_response_is_json() -> None: + """SSE mode: a handler that emits nothing (here `report_progress` is a no-op because no + `progressToken` was supplied) gets a plain `application/json` response. SDK-defined: SSE only + commits once there is something to stream.""" + + async def list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + await ctx.session.report_progress(1, total=2) + return ListToolsResult(tools=[], ttl_ms=0, cache_scope="public") + + async with _asgi_client(Server("test", on_list_tools=list_tools), json_response=False) as http: + with anyio.fail_after(5): + response = await http.post("/mcp", json=_list_tools_body(), headers={MCP_METHOD_HEADER: "tools/list"}) + + assert response.status_code == 200 + assert response.headers["content-type"].split(";", 1)[0] == "application/json" + assert response.json()["result"]["tools"] == [] + + +async def test_accept_missing_sse_406_in_sse_mode() -> None: + """SDK-defined: in SSE mode the client must accept both `application/json` and + `text/event-stream`; an Accept header naming only JSON is rejected at HTTP 406 before any + JSON-RPC parsing.""" + async with _asgi_client(Server("test"), json_response=False, accept="application/json") as http: + response = await http.post("/mcp", json=_list_tools_body(), headers={MCP_METHOD_HEADER: "tools/list"}) + assert response.status_code == 406 + assert response.content == b"" + + +async def test_accept_missing_sse_ok_in_json_mode() -> None: + """SDK-defined: in JSON mode only `application/json` need be acceptable; an Accept header that + omits `text/event-stream` still routes (200 + result).""" + + async def list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult(tools=[], ttl_ms=0, cache_scope="public") + + async with _asgi_client( + Server("test", on_list_tools=list_tools), json_response=True, accept="application/json" + ) as http: + response = await http.post("/mcp", json=_list_tools_body(), headers={MCP_METHOD_HEADER: "tools/list"}) + assert response.status_code == 200 + assert response.headers["content-type"].split(";", 1)[0] == "application/json" + + +@pytest.mark.parametrize("json_response", [True, False]) +async def test_accept_wildcard_satisfies_both_response_modes(json_response: bool) -> None: + """SDK-defined: `Accept: */*` satisfies both representations (RFC 7231 wildcard) in either + response mode.""" + + async def list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult(tools=[], ttl_ms=0, cache_scope="public") + + async with _asgi_client( + Server("test", on_list_tools=list_tools), json_response=json_response, accept="*/*" + ) as http: + with anyio.fail_after(5): + response = await http.post("/mcp", json=_list_tools_body(), headers={MCP_METHOD_HEADER: "tools/list"}) + assert response.status_code == 200 + + +async def test_late_notify_after_terminal_dropped() -> None: + """SDK-defined: a `notify()` after the SSE sink has closed is silently dropped — the closed + stream must not propagate as an exception out of the dispatch context.""" + send_ch, recv_ch = anyio.create_memory_object_stream[bytes](0) + dctx = _SingleExchangeDispatchContext( + transport=TransportContext(kind="streamable-http", can_send_request=False), + request_id=1, + message_metadata=None, + sink=send_ch, + ) + await recv_ch.aclose() + # Neither raises despite the receiver being gone (BrokenResourceError caught and dropped). + assert await dctx.notify("notifications/message", {"level": "info", "data": "late"}) is None + dctx.progress_token = "tok" + assert await dctx.progress(1.0) is None + await send_ch.aclose() + + +async def test_disconnect_cancels_handler_and_runs_exit_stack() -> None: + """SSE mode: when the client disconnects mid-stream the handler task is cancelled and + `connection.exit_stack` still unwinds. SDK-defined: `serve_one`'s shielded cleanup runs in the + cancellation path so handler-registered teardown is not skipped on disconnect.""" + handler_started = anyio.Event() + cleanup_ran = anyio.Event() + + async def list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + ctx.session._connection.exit_stack.callback(cleanup_ran.set) + handler_started.set() + await anyio.Event().wait() + raise AssertionError("unreachable") # pragma: no cover + + server: Server[Any] = Server("test", on_list_tools=list_tools) + body = json.dumps(_list_tools_body()).encode() + scope: Scope = { + "type": "http", + "asgi": {"version": "3.0"}, + "http_version": "1.1", + "method": "POST", + "scheme": "http", + "server": ("testserver", 80), + "client": ("127.0.0.1", 1234), + "path": "/mcp", + "raw_path": b"/mcp", + "query_string": b"", + "root_path": "", + "headers": [ + (b"host", b"testserver"), + (b"content-type", b"application/json"), + (b"accept", b"application/json, text/event-stream"), + (MCP_PROTOCOL_VERSION_HEADER.encode(), LATEST_MODERN_VERSION.encode()), + (MCP_METHOD_HEADER.encode(), b"tools/list"), + ], + } + request_delivered = anyio.Event() + + async def receive() -> Message: + # First call delivers the request body; once the handler is parked, deliver disconnect. + if not request_delivered.is_set(): + request_delivered.set() + return {"type": "http.request", "body": body, "more_body": False} + await handler_started.wait() + return {"type": "http.disconnect"} + + async def send(message: Message) -> None: # pragma: no cover + pass + + with anyio.fail_after(5): + async with server.lifespan(server) as lifespan_state: + await handle_modern_request(server, None, False, lifespan_state, scope, receive, send) + await cleanup_ran.wait() + + assert handler_started.is_set()