diff --git a/src/mcp/server/streamable_http.py b/src/mcp/server/streamable_http.py index b6128d3e0..6bdeab7e9 100644 --- a/src/mcp/server/streamable_http.py +++ b/src/mcp/server/streamable_http.py @@ -625,9 +625,18 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re # Start the SSE response (this will send headers immediately) try: + + async def run_response_with_cleanup() -> None: + try: + await response(scope, receive, send) + finally: + self._sse_stream_writers.pop(request_id, None) + await sse_stream_writer.aclose() + await self._clean_up_memory_streams(request_id) + # First send the response to establish the SSE connection async with anyio.create_task_group() as tg: - tg.start_soon(response, scope, receive, send) + tg.start_soon(run_response_with_cleanup) # Then send the message to be processed by the server session_message = self._create_session_message(message, request, request_id, protocol_version) await writer.send(session_message) diff --git a/tests/server/test_streamable_http_manager.py b/tests/server/test_streamable_http_manager.py index 0b3a28083..d63f084c6 100644 --- a/tests/server/test_streamable_http_manager.py +++ b/tests/server/test_streamable_http_manager.py @@ -2,13 +2,13 @@ import json import logging -from typing import Any +from typing import Any, cast from unittest.mock import AsyncMock, patch import anyio import httpx import pytest -from mcp_types import INVALID_REQUEST, ListToolsResult, PaginatedRequestParams +from mcp_types import INVALID_REQUEST, JSONRPCRequest, ListToolsResult, PaginatedRequestParams from starlette.types import Message, Scope from mcp import Client @@ -18,6 +18,7 @@ from mcp.server.auth.provider import AccessToken from mcp.server.streamable_http import MCP_SESSION_ID_HEADER, StreamableHTTPServerTransport from mcp.server.streamable_http_manager import StreamableHTTPSessionManager +from mcp.shared.message import SessionMessage @pytest.mark.anyio @@ -101,6 +102,49 @@ async def running_manager(): yield manager, app +@pytest.mark.anyio +async def test_streamable_http_post_sse_cleans_up_streams_when_response_returns(monkeypatch: pytest.MonkeyPatch): + transport = StreamableHTTPServerTransport(mcp_session_id=None) + sent_messages: list[Message] = [] + body = json.dumps({"jsonrpc": "2.0", "id": 1, "method": "tools/list", "params": {}}).encode() + + class DisconnectingEventSourceResponse: + def __init__(self, *args: Any, **kwargs: Any) -> None: + pass + + async def __call__(self, scope: Scope, receive: Any, send: Any) -> None: + await send({"type": "http.response.start", "status": 200, "headers": []}) + + async def send(message: Message) -> None: + sent_messages.append(message) + + async def receive() -> Message: + return {"type": "http.request", "body": body, "more_body": False} + + scope: Scope = { + "type": "http", + "method": "POST", + "path": "/mcp", + "headers": [ + (b"accept", b"application/json, text/event-stream"), + (b"content-type", b"application/json"), + ], + } + + monkeypatch.setattr("mcp.server.streamable_http.EventSourceResponse", DisconnectingEventSourceResponse) + + async with transport.connect() as (read_stream, _write_stream): + async with anyio.create_task_group() as tg: + tg.start_soon(transport.handle_request, scope, receive, send) + session_message = cast(SessionMessage, await read_stream.receive()) + assert isinstance(session_message.message, JSONRPCRequest) + assert session_message.message.method == "tools/list" + + assert transport._request_streams == {} + assert transport._sse_stream_writers == {} + assert any(message["type"] == "http.response.start" for message in sent_messages) + + @pytest.mark.anyio async def test_stateful_session_cleanup_on_graceful_exit(running_manager: tuple[StreamableHTTPSessionManager, Server]): manager, _app = running_manager