Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 111 additions & 0 deletions src/Service.Tests/UnitTests/McpStdioHelperTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

#nullable enable

using System.Threading;
using System.Threading.Tasks;
using Azure.DataApiBuilder.Mcp.Core;
using Azure.DataApiBuilder.Service.Utilities;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Hosting;
using Microsoft.VisualStudio.TestTools.UnitTesting;

namespace Azure.DataApiBuilder.Service.Tests.UnitTests
{
[TestClass]
public class McpStdioHelperTests
{
[TestMethod]
public void RunMcpStdioHost_DoesNotStartWebHost()
{
ServiceCollection services = new();
TestApplicationLifetime lifetime = new();
TestMcpStdioServer stdioServer = new();

services.AddSingleton<McpToolRegistry>();
services.AddSingleton<IHostApplicationLifetime>(lifetime);
services.AddSingleton<IMcpStdioServer>(stdioServer);

using ServiceProvider serviceProvider = services.BuildServiceProvider();
TestHost host = new(serviceProvider);

bool result = McpStdioHelper.RunMcpStdioHost(host);

Assert.IsTrue(result);
Assert.AreEqual(0, host.StartAsyncCallCount,
"MCP stdio mode should not start the ASP.NET Core web host because that binds HTTP ports.");
Assert.AreEqual(0, host.StopAsyncCallCount,
"MCP stdio mode should not stop a host that was never started.");
Assert.AreEqual(1, stdioServer.RunAsyncCallCount,
"MCP stdio mode should still run the stdio JSON-RPC loop.");
Assert.AreEqual(lifetime.ApplicationStopping, stdioServer.CancellationToken,
"The stdio loop should keep using the host lifetime cancellation token.");
Assert.AreEqual(1, host.DisposeCallCount,
"MCP stdio mode should dispose the host after the stdio loop exits.");
}

private sealed class TestHost : IHost
{
public TestHost(System.IServiceProvider services)
{
Services = services;
}

public System.IServiceProvider Services { get; }

public int StartAsyncCallCount { get; private set; }

public int StopAsyncCallCount { get; private set; }

public int DisposeCallCount { get; private set; }

public Task StartAsync(CancellationToken cancellationToken = default)
{
StartAsyncCallCount++;
return Task.CompletedTask;
}

public Task StopAsync(CancellationToken cancellationToken = default)
{
StopAsyncCallCount++;
return Task.CompletedTask;
}

public void Dispose()
{
DisposeCallCount++;
}
}

private sealed class TestApplicationLifetime : IHostApplicationLifetime
{
private readonly CancellationTokenSource _applicationStopping = new();

public CancellationToken ApplicationStarted => CancellationToken.None;

public CancellationToken ApplicationStopping => _applicationStopping.Token;

public CancellationToken ApplicationStopped => CancellationToken.None;

public void StopApplication()
{
_applicationStopping.Cancel();
}
}

private sealed class TestMcpStdioServer : IMcpStdioServer
{
public int RunAsyncCallCount { get; private set; }

public CancellationToken CancellationToken { get; private set; }

public Task RunAsync(CancellationToken cancellationToken)
{
RunAsyncCallCount++;
CancellationToken = cancellationToken;
return Task.CompletedTask;
}
}
}
}
32 changes: 18 additions & 14 deletions src/Service/Utilities/McpStdioHelper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -76,24 +76,28 @@ public static void ConfigureMcpStdio(IConfigurationBuilder builder, string? mcpR
/// <param name="host"> The host to run.</param>
public static bool RunMcpStdioHost(IHost host)
{
host.Start();

Mcp.Core.McpToolRegistry registry =
host.Services.GetRequiredService<Mcp.Core.McpToolRegistry>();
IEnumerable<Mcp.Model.IMcpTool> tools =
host.Services.GetServices<Mcp.Model.IMcpTool>();
try
{
Mcp.Core.McpToolRegistry registry =
host.Services.GetRequiredService<Mcp.Core.McpToolRegistry>();
IEnumerable<Mcp.Model.IMcpTool> tools =
host.Services.GetServices<Mcp.Model.IMcpTool>();

Mcp.Core.McpToolRegistry.InitializeAndRegisterTools(tools, registry, host.Services);
Mcp.Core.McpToolRegistry.InitializeAndRegisterTools(tools, registry, host.Services);

IHostApplicationLifetime lifetime =
host.Services.GetRequiredService<IHostApplicationLifetime>();
Mcp.Core.IMcpStdioServer stdio =
host.Services.GetRequiredService<Mcp.Core.IMcpStdioServer>();
IHostApplicationLifetime lifetime =
host.Services.GetRequiredService<IHostApplicationLifetime>();
Mcp.Core.IMcpStdioServer stdio =
host.Services.GetRequiredService<Mcp.Core.IMcpStdioServer>();

stdio.RunAsync(lifetime.ApplicationStopping).GetAwaiter().GetResult();
host.StopAsync().GetAwaiter().GetResult();
stdio.RunAsync(lifetime.ApplicationStopping).GetAwaiter().GetResult();

return true;
return true;
}
finally
{
host.Dispose();
}
}
}
}