From 1a12bf985a0f9732aa0376a2eb9f3b45e36dd1fe Mon Sep 17 00:00:00 2001 From: King Star Date: Mon, 22 Jun 2026 02:29:26 +0800 Subject: [PATCH] fix: avoid starting web host in mcp stdio mode --- .../UnitTests/McpStdioHelperTests.cs | 111 ++++++++++++++++++ src/Service/Utilities/McpStdioHelper.cs | 32 ++--- 2 files changed, 129 insertions(+), 14 deletions(-) create mode 100644 src/Service.Tests/UnitTests/McpStdioHelperTests.cs diff --git a/src/Service.Tests/UnitTests/McpStdioHelperTests.cs b/src/Service.Tests/UnitTests/McpStdioHelperTests.cs new file mode 100644 index 0000000000..daf0c9e3b1 --- /dev/null +++ b/src/Service.Tests/UnitTests/McpStdioHelperTests.cs @@ -0,0 +1,111 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#nullable enable + +using System.Threading; +using System.Threading.Tasks; +using Azure.DataApiBuilder.Mcp.Core; +using Azure.DataApiBuilder.Service.Utilities; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Hosting; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace Azure.DataApiBuilder.Service.Tests.UnitTests +{ + [TestClass] + public class McpStdioHelperTests + { + [TestMethod] + public void RunMcpStdioHost_DoesNotStartWebHost() + { + ServiceCollection services = new(); + TestApplicationLifetime lifetime = new(); + TestMcpStdioServer stdioServer = new(); + + services.AddSingleton(); + services.AddSingleton(lifetime); + services.AddSingleton(stdioServer); + + using ServiceProvider serviceProvider = services.BuildServiceProvider(); + TestHost host = new(serviceProvider); + + bool result = McpStdioHelper.RunMcpStdioHost(host); + + Assert.IsTrue(result); + Assert.AreEqual(0, host.StartAsyncCallCount, + "MCP stdio mode should not start the ASP.NET Core web host because that binds HTTP ports."); + Assert.AreEqual(0, host.StopAsyncCallCount, + "MCP stdio mode should not stop a host that was never started."); + Assert.AreEqual(1, stdioServer.RunAsyncCallCount, + "MCP stdio mode should still run the stdio JSON-RPC loop."); + Assert.AreEqual(lifetime.ApplicationStopping, stdioServer.CancellationToken, + "The stdio loop should keep using the host lifetime cancellation token."); + Assert.AreEqual(1, host.DisposeCallCount, + "MCP stdio mode should dispose the host after the stdio loop exits."); + } + + private sealed class TestHost : IHost + { + public TestHost(System.IServiceProvider services) + { + Services = services; + } + + public System.IServiceProvider Services { get; } + + public int StartAsyncCallCount { get; private set; } + + public int StopAsyncCallCount { get; private set; } + + public int DisposeCallCount { get; private set; } + + public Task StartAsync(CancellationToken cancellationToken = default) + { + StartAsyncCallCount++; + return Task.CompletedTask; + } + + public Task StopAsync(CancellationToken cancellationToken = default) + { + StopAsyncCallCount++; + return Task.CompletedTask; + } + + public void Dispose() + { + DisposeCallCount++; + } + } + + private sealed class TestApplicationLifetime : IHostApplicationLifetime + { + private readonly CancellationTokenSource _applicationStopping = new(); + + public CancellationToken ApplicationStarted => CancellationToken.None; + + public CancellationToken ApplicationStopping => _applicationStopping.Token; + + public CancellationToken ApplicationStopped => CancellationToken.None; + + public void StopApplication() + { + _applicationStopping.Cancel(); + } + } + + private sealed class TestMcpStdioServer : IMcpStdioServer + { + public int RunAsyncCallCount { get; private set; } + + public CancellationToken CancellationToken { get; private set; } + + public Task RunAsync(CancellationToken cancellationToken) + { + RunAsyncCallCount++; + CancellationToken = cancellationToken; + return Task.CompletedTask; + } + } + } +} diff --git a/src/Service/Utilities/McpStdioHelper.cs b/src/Service/Utilities/McpStdioHelper.cs index b2758d4ac2..4ee403b98e 100644 --- a/src/Service/Utilities/McpStdioHelper.cs +++ b/src/Service/Utilities/McpStdioHelper.cs @@ -76,24 +76,28 @@ public static void ConfigureMcpStdio(IConfigurationBuilder builder, string? mcpR /// The host to run. public static bool RunMcpStdioHost(IHost host) { - host.Start(); - - Mcp.Core.McpToolRegistry registry = - host.Services.GetRequiredService(); - IEnumerable tools = - host.Services.GetServices(); + try + { + Mcp.Core.McpToolRegistry registry = + host.Services.GetRequiredService(); + IEnumerable tools = + host.Services.GetServices(); - Mcp.Core.McpToolRegistry.InitializeAndRegisterTools(tools, registry, host.Services); + Mcp.Core.McpToolRegistry.InitializeAndRegisterTools(tools, registry, host.Services); - IHostApplicationLifetime lifetime = - host.Services.GetRequiredService(); - Mcp.Core.IMcpStdioServer stdio = - host.Services.GetRequiredService(); + IHostApplicationLifetime lifetime = + host.Services.GetRequiredService(); + Mcp.Core.IMcpStdioServer stdio = + host.Services.GetRequiredService(); - stdio.RunAsync(lifetime.ApplicationStopping).GetAwaiter().GetResult(); - host.StopAsync().GetAwaiter().GetResult(); + stdio.RunAsync(lifetime.ApplicationStopping).GetAwaiter().GetResult(); - return true; + return true; + } + finally + { + host.Dispose(); + } } } }