diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProvider.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProvider.java index 9a785e150..48263ac0b 100644 --- a/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProvider.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProvider.java @@ -33,6 +33,8 @@ import io.modelcontextprotocol.json.McpJsonMapper; import io.modelcontextprotocol.util.KeepAliveScheduler; import jakarta.servlet.AsyncContext; +import jakarta.servlet.AsyncEvent; +import jakarta.servlet.AsyncListener; import jakarta.servlet.ServletException; import jakarta.servlet.annotation.WebServlet; import jakarta.servlet.http.HttpServlet; @@ -326,6 +328,8 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response) if (request.getHeader(HttpHeaders.LAST_EVENT_ID) != null) { String lastId = request.getHeader(HttpHeaders.LAST_EVENT_ID); + registerAsyncLifecycle(asyncContext, sessionId, sessionTransport::close); + try { session.replay(lastId) .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)) @@ -338,13 +342,14 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response) } catch (Exception e) { logger.error("Failed to replay message: {}", e.getMessage()); - asyncContext.complete(); + sessionTransport.close(); } }); + sessionTransport.close(); } catch (Exception e) { logger.error("Failed to replay messages: {}", e.getMessage()); - asyncContext.complete(); + sessionTransport.close(); } } else { @@ -352,30 +357,7 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response) McpStreamableServerSession.McpStreamableServerSessionStream listeningStream = session .listeningStream(sessionTransport); - asyncContext.addListener(new jakarta.servlet.AsyncListener() { - @Override - public void onComplete(jakarta.servlet.AsyncEvent event) throws IOException { - logger.debug("SSE connection completed for session: {}", sessionId); - listeningStream.close(); - } - - @Override - public void onTimeout(jakarta.servlet.AsyncEvent event) throws IOException { - logger.debug("SSE connection timed out for session: {}", sessionId); - listeningStream.close(); - } - - @Override - public void onError(jakarta.servlet.AsyncEvent event) throws IOException { - logger.debug("SSE connection error for session: {}", sessionId); - listeningStream.close(); - } - - @Override - public void onStartAsync(jakarta.servlet.AsyncEvent event) throws IOException { - // No action needed - } - }); + registerAsyncLifecycle(asyncContext, sessionId, listeningStream::close); } } catch (Exception e) { @@ -528,6 +510,8 @@ else if (message instanceof McpSchema.JSONRPCRequest jsonrpcRequest) { HttpServletStreamableMcpSessionTransport sessionTransport = new HttpServletStreamableMcpSessionTransport( sessionId, asyncContext, response.getWriter()); + registerAsyncLifecycle(asyncContext, sessionId, sessionTransport::close); + try { session.responseStream(jsonrpcRequest, sessionTransport) .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)) @@ -535,7 +519,7 @@ else if (message instanceof McpSchema.JSONRPCRequest jsonrpcRequest) { } catch (Exception e) { logger.error("Failed to handle request stream: {}", e.getMessage()); - asyncContext.complete(); + sessionTransport.close(); } } else { @@ -648,6 +632,32 @@ public void responseError(HttpServletResponse response, int httpCode, McpError m return; } + private void registerAsyncLifecycle(AsyncContext asyncContext, String sessionId, Runnable onClose) { + asyncContext.addListener(new AsyncListener() { + @Override + public void onComplete(AsyncEvent event) { + logger.debug("Async context completed for session: {}", sessionId); + onClose.run(); + } + + @Override + public void onTimeout(AsyncEvent event) { + logger.debug("Async context timed out for session: {}", sessionId); + onClose.run(); + } + + @Override + public void onError(AsyncEvent event) { + logger.debug("Async context error for session: {}", sessionId); + onClose.run(); + } + + @Override + public void onStartAsync(AsyncEvent event) { + } + }); + } + /** * Sends an SSE event to a client with a specific ID. * @param writer The writer to send the event through @@ -755,8 +765,7 @@ public Mono sendMessage(McpSchema.JSONRPCMessage message, String messageId } catch (Exception e) { logger.error("Failed to send message to session {}: {}", this.sessionId, e.getMessage()); - HttpServletStreamableServerTransportProvider.this.sessions.remove(this.sessionId); - this.asyncContext.complete(); + HttpServletStreamableMcpSessionTransport.this.close(); } finally { lock.unlock();