diff --git a/docs/proxying.md b/docs/proxying.md new file mode 100644 index 000000000..9c6aaaec6 --- /dev/null +++ b/docs/proxying.md @@ -0,0 +1,141 @@ +# Proxying MCP Transports + +The `mcp_proxy()` helper bridges two MCP transports and forwards messages in both directions. + +It is useful when you want to put a transport boundary between an MCP client and an upstream MCP server without +rewriting the forwarding loop yourself. + +## What It Does + +`mcp_proxy()` takes two transport pairs: + +- a transport facing the downstream client +- a transport facing the upstream server + +While the context manager is active, it: + +- forwards `SessionMessage` objects from client to server +- forwards `SessionMessage` objects from server to client +- sends transport exceptions to an optional `on_error` callback +- closes the paired write side when the corresponding read side stops + +## What It Does Not Do + +`mcp_proxy()` is a transport relay, not a full proxy server. + +It does not add: + +- authentication +- authorization +- request or response rewriting +- routing across multiple upstream servers +- retries or buffering policies +- metrics or tracing by default + +If you need those behaviors, build them around the helper. + +## Weather Service Example + +This example proxies a small weather service. The upstream service is defined with `MCPServer` and exposed over +streamable HTTP. The proxy bridges a downstream transport to that upstream transport. + +- `get_weather(city)` for a structured weather snapshot +- `get_weather_alerts(region)` for active alerts + +The client talks only to the downstream side of the proxy. + +```python +import anyio +import uvicorn + +from mcp.client.session import ClientSession +from mcp.client.streamable_http import streamable_http_client +from mcp.proxy import mcp_proxy +from mcp.server.mcpserver import MCPServer +from mcp.shared.memory import create_client_server_memory_streams + + +app = MCPServer("Weather Service") + + +@app.tool() +def get_weather(city: str) -> dict[str, str | float]: + return { + "city": city, + "temperature_c": 22.5, + "condition": "partly cloudy", + "wind_speed_kmh": 12.3, + } + + +@app.tool() +def get_weather_alerts(region: str) -> dict[str, object]: + return { + "region": region, + "alerts": [{"severity": "medium", "title": "Heat advisory"}], + } + + +async def main() -> None: + starlette_app = app.streamable_http_app(streamable_http_path="/mcp") + config = uvicorn.Config(starlette_app, host="127.0.0.1", port=8765, log_level="warning") + upstream_server = uvicorn.Server(config) + + async with ( + create_client_server_memory_streams() as (client_streams, proxy_client_streams), + streamable_http_client("http://127.0.0.1:8765/mcp") as proxy_server_streams, + anyio.create_task_group() as tg, + ): + tg.start_soon(upstream_server.serve) + + async with mcp_proxy( + proxy_client_streams, + proxy_server_streams, + ): + async with ClientSession(client_streams[0], client_streams[1]) as session: + await session.initialize() + weather = await session.call_tool("get_weather", {"city": "London"}) + alerts = await session.call_tool("get_weather_alerts", {"region": "California"}) + + print(weather.content[0].text) + print(alerts.content[0].text) + + upstream_server.should_exit = True + tg.cancel_scope.cancel() + + +anyio.run(main) +``` + +## Error Handling + +Use `on_error` to observe transport-level exceptions: + +```python +async with mcp_proxy( + downstream_transport, + upstream_transport, + on_error=handle_transport_error, +): + ... +``` + +`on_error` is keyword-only. It may be either: + +- an async callable +- a sync callable, which will run in a worker thread + +Exceptions raised by `on_error` are swallowed. Transport exceptions still terminate the proxy instead of being silently +consumed. + +## When To Use It + +`mcp_proxy()` is a good fit when you are: + +- exposing an upstream MCP server through a different transport boundary +- inserting middleware-like behavior between two MCP transports +- building a local relay for testing or development +- experimenting with transport adapters + +If all you need is to test a server directly, prefer [`Client`](testing.md), which already provides an in-memory +transport for that use case. diff --git a/mkdocs.yml b/mkdocs.yml index 3a555785a..a035b6db6 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -17,6 +17,7 @@ nav: - Documentation: - Concepts: concepts.md - Low-Level Server: low-level-server.md + - Proxying Transports: proxying.md - Authorization: authorization.md - Testing: testing.md - Experimental: diff --git a/src/mcp/__init__.py b/src/mcp/__init__.py index 4b5caa9cc..0e875b6c5 100644 --- a/src/mcp/__init__.py +++ b/src/mcp/__init__.py @@ -2,6 +2,7 @@ from .client.session import ClientSession from .client.session_group import ClientSessionGroup from .client.stdio import StdioServerParameters, stdio_client +from .proxy import mcp_proxy from .server.session import ServerSession from .server.stdio import stdio_server from .shared.exceptions import MCPError, UrlElicitationRequiredError @@ -97,6 +98,7 @@ "LoggingLevel", "LoggingMessageNotification", "MCPError", + "mcp_proxy", "Notification", "PingRequest", "ProgressNotification", diff --git a/src/mcp/proxy.py b/src/mcp/proxy.py new file mode 100644 index 000000000..0b76e010d --- /dev/null +++ b/src/mcp/proxy.py @@ -0,0 +1,99 @@ +"""Provide utilities for proxying messages between two MCP transports.""" + +from __future__ import annotations + +import contextvars +from collections.abc import AsyncGenerator, Awaitable, Callable +from contextlib import asynccontextmanager +from functools import partial +from typing import Any, Protocol, cast + +import anyio +from anyio import to_thread + +from mcp.shared._callable_inspection import is_async_callable +from mcp.shared._stream_protocols import ReadStream, WriteStream +from mcp.shared.message import SessionMessage + +MessageStream = tuple[ReadStream[SessionMessage | Exception], WriteStream[SessionMessage]] +ErrorHandler = Callable[[Exception], None | Awaitable[None]] + + +class ContextualWriteStream(Protocol): + async def send_with_context(self, context: contextvars.Context, item: SessionMessage | Exception) -> None: ... + + +@asynccontextmanager +async def mcp_proxy( + transport_to_client: MessageStream, + transport_to_server: MessageStream, + *, + on_error: ErrorHandler | None = None, +) -> AsyncGenerator[None]: + """Proxy messages bidirectionally between two MCP transports.""" + client_read, client_write = transport_to_client + server_read, server_write = transport_to_server + + async with anyio.create_task_group() as task_group: + task_group.start_soon(_forward_messages, client_read, server_write, on_error) + task_group.start_soon(_forward_messages, server_read, client_write, on_error) + try: + yield + finally: + task_group.cancel_scope.cancel() + + +async def _forward_messages( + read_stream: ReadStream[SessionMessage | Exception], + write_stream: WriteStream[SessionMessage], + on_error: ErrorHandler | None, +) -> None: + try: + async with write_stream: + async with read_stream: + async for item in read_stream: + if isinstance(item, Exception): + await _run_error_handler(item, on_error) + raise item + + try: + await _forward_message(item, write_stream, read_stream) + except anyio.ClosedResourceError: + break + except anyio.ClosedResourceError: + return + + +async def _forward_message( + item: SessionMessage, + write_stream: WriteStream[SessionMessage], + read_stream: ReadStream[SessionMessage | Exception], +) -> None: + sender_context: contextvars.Context | None = getattr(read_stream, "last_context", None) + context_write_stream = cast(ContextualWriteStream | None, _get_contextual_write_stream(write_stream)) + + if sender_context is not None and context_write_stream is not None: + await context_write_stream.send_with_context(sender_context, item) + return + + await write_stream.send(item) + + +def _get_contextual_write_stream(write_stream: WriteStream[SessionMessage]) -> Any: + send_with_context = getattr(write_stream, "send_with_context", None) + if callable(send_with_context): + return write_stream + return None + + +async def _run_error_handler(error: Exception, on_error: ErrorHandler | None) -> None: + if on_error is None: + return + + try: + if is_async_callable(on_error): + await cast(Awaitable[None], on_error(error)) + else: + await to_thread.run_sync(partial(on_error, error)) + except Exception: + return diff --git a/src/mcp/server/mcpserver/tools/base.py b/src/mcp/server/mcpserver/tools/base.py index dc65be988..754313eb8 100644 --- a/src/mcp/server/mcpserver/tools/base.py +++ b/src/mcp/server/mcpserver/tools/base.py @@ -1,7 +1,5 @@ from __future__ import annotations -import functools -import inspect from collections.abc import Callable from functools import cached_property from typing import TYPE_CHECKING, Any @@ -11,6 +9,7 @@ from mcp.server.mcpserver.exceptions import ToolError from mcp.server.mcpserver.utilities.context_injection import find_context_parameter from mcp.server.mcpserver.utilities.func_metadata import FuncMetadata, func_metadata +from mcp.shared._callable_inspection import is_async_callable from mcp.shared.exceptions import UrlElicitationRequiredError from mcp.shared.tool_name_validation import validate_and_warn_tool_name from mcp.types import Icon, ToolAnnotations @@ -63,7 +62,7 @@ def from_function( raise ValueError("You must provide a name for lambda functions") func_doc = description or fn.__doc__ or "" - is_async = _is_async_callable(fn) + is_async = is_async_callable(fn) if context_kwarg is None: # pragma: no branch context_kwarg = find_context_parameter(fn) @@ -118,12 +117,3 @@ async def run( raise except Exception as e: raise ToolError(f"Error executing tool {self.name}: {e}") from e - - -def _is_async_callable(obj: Any) -> bool: - while isinstance(obj, functools.partial): # pragma: lax no cover - obj = obj.func - - return inspect.iscoroutinefunction(obj) or ( - callable(obj) and inspect.iscoroutinefunction(getattr(obj, "__call__", None)) - ) diff --git a/src/mcp/shared/_callable_inspection.py b/src/mcp/shared/_callable_inspection.py new file mode 100644 index 000000000..ced945f4a --- /dev/null +++ b/src/mcp/shared/_callable_inspection.py @@ -0,0 +1,14 @@ +from __future__ import annotations + +import functools +import inspect +from typing import Any + + +def is_async_callable(obj: Any) -> bool: + while isinstance(obj, functools.partial): # pragma: lax no cover + obj = obj.func + + return inspect.iscoroutinefunction(obj) or ( + callable(obj) and inspect.iscoroutinefunction(getattr(obj, "__call__", None)) + ) diff --git a/src/mcp/shared/_context_streams.py b/src/mcp/shared/_context_streams.py index 04c33306d..83e3710e6 100644 --- a/src/mcp/shared/_context_streams.py +++ b/src/mcp/shared/_context_streams.py @@ -36,6 +36,9 @@ def __init__(self, inner: MemoryObjectSendStream[_Envelope[T]]) -> None: async def send(self, item: T) -> None: await self._inner.send((contextvars.copy_context(), item)) + async def send_with_context(self, context: contextvars.Context, item: T) -> None: + await self._inner.send((context, item)) + def close(self) -> None: self._inner.close() diff --git a/tests/test_proxy.py b/tests/test_proxy.py new file mode 100644 index 000000000..16bde60f5 --- /dev/null +++ b/tests/test_proxy.py @@ -0,0 +1,437 @@ +from __future__ import annotations + +import contextvars +from types import TracebackType + +import anyio +import pytest + +from mcp.proxy import _forward_message, _forward_messages, mcp_proxy +from mcp.shared._context_streams import create_context_streams +from mcp.shared.message import SessionMessage +from mcp.types import JSONRPCRequest + + +def make_message(request_id: str, method: str) -> SessionMessage: + return SessionMessage(JSONRPCRequest(jsonrpc="2.0", id=request_id, method=method, params={})) + + +def assert_request(message: SessionMessage, request_id: str, method: str) -> None: + assert isinstance(message.message, JSONRPCRequest) + assert message.message.id == request_id + assert message.message.method == method + + +class StaticReadStream: + def __init__(self, *items: SessionMessage | Exception, error: Exception | None = None) -> None: + self._items = list(items) + self._error = error + self.closed = False + + async def receive(self) -> SessionMessage | Exception: + try: + return await self.__anext__() + except StopAsyncIteration as exc: + raise anyio.EndOfStream from exc + + async def aclose(self) -> None: + self.closed = True + + def __aiter__(self) -> StaticReadStream: + return self + + async def __anext__(self) -> SessionMessage | Exception: + if self._items: + return self._items.pop(0) + if self._error is not None: + error = self._error + self._error = None + raise error + raise StopAsyncIteration + + async def __aenter__(self) -> StaticReadStream: + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool | None: + await self.aclose() + return None + + +class TrackingWriteStream: + def __init__(self, error: Exception | None = None) -> None: + self.items: list[SessionMessage] = [] + self.error = error + self.closed = anyio.Event() + + async def send(self, item: SessionMessage, /) -> None: + if self.error is not None: + raise self.error + self.items.append(item) + + async def aclose(self) -> None: + self.closed.set() + + async def __aenter__(self) -> TrackingWriteStream: + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool | None: + await self.aclose() + return None + + +class ReadStreamWithContext: + def __init__(self, context: contextvars.Context) -> None: + self.last_context = context + + async def receive(self) -> SessionMessage | Exception: + raise NotImplementedError + + async def aclose(self) -> None: + return None + + def __aiter__(self) -> ReadStreamWithContext: + return self + + async def __anext__(self) -> SessionMessage | Exception: + raise StopAsyncIteration + + async def __aenter__(self) -> ReadStreamWithContext: + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool | None: + return None + + +class NestedException(Exception): + def __init__(self, *exceptions: BaseException) -> None: + super().__init__("nested") + self.exceptions = exceptions + + +def assert_contains_exception(exc: BaseException, expected_type: type[Exception], expected_message: str) -> None: + nested_exceptions = getattr(exc, "exceptions", None) + if nested_exceptions is not None: + for nested in nested_exceptions: + try: + assert_contains_exception(nested, expected_type, expected_message) + return + except AssertionError: + continue + raise AssertionError(f"Did not find {expected_type.__name__} containing {expected_message!r} in {exc!r}") + + assert isinstance(exc, expected_type) + assert expected_message in str(exc) + + +@pytest.mark.anyio +async def test_static_read_stream_receive_raises_end_of_stream_when_exhausted() -> None: + stream = StaticReadStream() + + with pytest.raises(anyio.EndOfStream): + await stream.receive() + + +@pytest.mark.anyio +async def test_tracking_write_stream_send_raises_configured_error() -> None: + stream = TrackingWriteStream(RuntimeError("write boom")) + + with pytest.raises(RuntimeError, match="write boom"): + await stream.send(make_message("client", "client/method")) + + +@pytest.mark.anyio +async def test_read_stream_with_context_support_methods() -> None: + stream = ReadStreamWithContext(contextvars.copy_context()) + + assert stream.__aiter__() is stream + assert await stream.__aenter__() is stream + assert await stream.aclose() is None + assert await stream.__aexit__(None, None, None) is None + + with pytest.raises(StopAsyncIteration): + await stream.__anext__() + + +def test_assert_contains_exception_reports_missing_nested_exception() -> None: + exc = NestedException(ValueError("boom")) + + with pytest.raises(AssertionError, match="Did not find RuntimeError containing 'missing'"): + assert_contains_exception(exc, RuntimeError, "missing") + + +@pytest.mark.anyio +async def test_proxy_forwards_messages_bidirectionally() -> None: + client_read_send, client_read = anyio.create_memory_object_stream[SessionMessage | Exception](1) + client_write, client_write_read = anyio.create_memory_object_stream[SessionMessage](1) + server_read_send, server_read = anyio.create_memory_object_stream[SessionMessage | Exception](1) + server_write, server_write_read = anyio.create_memory_object_stream[SessionMessage](1) + + async with ( + client_read_send, + client_read, + client_write, + client_write_read, + server_read_send, + server_read, + server_write, + server_write_read, + ): + async with mcp_proxy((client_read, client_write), (server_read, server_write)): + await client_read_send.send(make_message("client", "client/method")) + await server_read_send.send(make_message("server", "server/method")) + + assert_request(await server_write_read.receive(), "client", "client/method") + assert_request(await client_write_read.receive(), "server", "server/method") + + +@pytest.mark.anyio +async def test_proxy_calls_sync_error_handler_before_raising_transport_exception() -> None: + errors: list[Exception] = [] + handled = anyio.Event() + + def on_error(error: Exception) -> None: + errors.append(error) + handled.set() + + client_read_send, client_read = anyio.create_memory_object_stream[SessionMessage | Exception](1) + client_write, _client_write_read = anyio.create_memory_object_stream[SessionMessage](1) + server_read_send, server_read = anyio.create_memory_object_stream[SessionMessage | Exception](1) + server_write, server_write_read = anyio.create_memory_object_stream[SessionMessage](1) + + async with ( + client_read_send, + client_read, + client_write, + _client_write_read, + server_read_send, + server_read, + server_write, + server_write_read, + ): + with pytest.raises(Exception) as exc_info: + async with mcp_proxy((client_read, client_write), (server_read, server_write), on_error=on_error): + await client_read_send.send(ValueError("boom")) + await handled.wait() + + assert_contains_exception(exc_info.value, ValueError, "boom") + assert len(errors) == 1 + assert isinstance(errors[0], ValueError) + assert str(errors[0]) == "boom" + + +@pytest.mark.anyio +async def test_proxy_calls_async_error_handler_before_raising_transport_exception() -> None: + errors: list[Exception] = [] + handled = anyio.Event() + + async def on_error(error: Exception) -> None: + errors.append(error) + handled.set() + + client_read_send, client_read = anyio.create_memory_object_stream[SessionMessage | Exception](1) + client_write, _client_write_read = anyio.create_memory_object_stream[SessionMessage](1) + server_read_send, server_read = anyio.create_memory_object_stream[SessionMessage | Exception](1) + server_write, _server_write_read = anyio.create_memory_object_stream[SessionMessage](1) + + async with ( + client_read_send, + client_read, + client_write, + _client_write_read, + server_read_send, + server_read, + server_write, + _server_write_read, + ): + with pytest.raises(Exception) as exc_info: + async with mcp_proxy((client_read, client_write), (server_read, server_write), on_error=on_error): + await client_read_send.send(ValueError("async-boom")) + await handled.wait() + + assert_contains_exception(exc_info.value, ValueError, "async-boom") + assert len(errors) == 1 + assert isinstance(errors[0], ValueError) + assert str(errors[0]) == "async-boom" + + +@pytest.mark.anyio +async def test_proxy_ignores_sync_error_handler_failures_and_raises_transport_exception() -> None: + def on_error(error: Exception) -> None: + raise RuntimeError(f"handler failed for {error}") + + client_read_send, client_read = anyio.create_memory_object_stream[SessionMessage | Exception](1) + client_write, _client_write_read = anyio.create_memory_object_stream[SessionMessage](1) + server_read_send, server_read = anyio.create_memory_object_stream[SessionMessage | Exception](1) + server_write, server_write_read = anyio.create_memory_object_stream[SessionMessage](1) + + async with ( + client_read_send, + client_read, + client_write, + _client_write_read, + server_read_send, + server_read, + server_write, + server_write_read, + ): + with pytest.raises(Exception) as exc_info: + async with mcp_proxy((client_read, client_write), (server_read, server_write), on_error=on_error): + await client_read_send.send(ValueError("boom")) + await anyio.sleep(0.05) + + assert_contains_exception(exc_info.value, ValueError, "boom") + + +@pytest.mark.anyio +async def test_proxy_ignores_async_error_handler_failures_and_raises_transport_exception() -> None: + async def on_error(error: Exception) -> None: + raise RuntimeError(f"handler failed for {error}") + + client_read_send, client_read = anyio.create_memory_object_stream[SessionMessage | Exception](1) + client_write, _client_write_read = anyio.create_memory_object_stream[SessionMessage](1) + server_read_send, server_read = anyio.create_memory_object_stream[SessionMessage | Exception](1) + server_write, server_write_read = anyio.create_memory_object_stream[SessionMessage](1) + + async with ( + client_read_send, + client_read, + client_write, + _client_write_read, + server_read_send, + server_read, + server_write, + server_write_read, + ): + with pytest.raises(Exception) as exc_info: + async with mcp_proxy((client_read, client_write), (server_read, server_write), on_error=on_error): + await client_read_send.send(ValueError("boom")) + await anyio.sleep(0.05) + + assert_contains_exception(exc_info.value, ValueError, "boom") + + +@pytest.mark.anyio +async def test_proxy_raises_transport_exception_without_error_handler() -> None: + client_read_send, client_read = anyio.create_memory_object_stream[SessionMessage | Exception](1) + client_write, _client_write_read = anyio.create_memory_object_stream[SessionMessage](1) + server_read_send, server_read = anyio.create_memory_object_stream[SessionMessage | Exception](1) + server_write, server_write_read = anyio.create_memory_object_stream[SessionMessage](1) + + async with ( + client_read_send, + client_read, + client_write, + _client_write_read, + server_read_send, + server_read, + server_write, + server_write_read, + ): + with pytest.raises(Exception) as exc_info: + async with mcp_proxy((client_read, client_write), (server_read, server_write)): + await client_read_send.send(ValueError("boom")) + await anyio.sleep(0.05) + + assert_contains_exception(exc_info.value, ValueError, "boom") + + +@pytest.mark.anyio +async def test_proxy_stops_forwarding_when_target_stream_is_closed() -> None: + server_write = TrackingWriteStream(anyio.ClosedResourceError()) + client_write = TrackingWriteStream() + + async with mcp_proxy( + (StaticReadStream(make_message("client", "client/method")), server_write), + (StaticReadStream(), client_write), + ): + await server_write.closed.wait() + + assert server_write.items == [] + assert server_write.closed.is_set() + assert client_write.closed.is_set() + + +@pytest.mark.anyio +async def test_forward_messages_stops_on_closed_target_stream() -> None: + await _forward_messages( + StaticReadStream(make_message("client", "client/method")), + TrackingWriteStream(anyio.ClosedResourceError()), + on_error=None, + ) + + +@pytest.mark.anyio +async def test_proxy_closes_target_stream_when_source_stream_is_closed() -> None: + server_write = TrackingWriteStream() + client_write = TrackingWriteStream() + + async with mcp_proxy((StaticReadStream(), server_write), (StaticReadStream(), client_write)): + await server_write.closed.wait() + await client_write.closed.wait() + + assert server_write.items == [] + assert client_write.items == [] + + +@pytest.mark.anyio +async def test_proxy_handles_closed_resource_error_from_source_stream() -> None: + server_write = TrackingWriteStream() + client_write = TrackingWriteStream() + + async with mcp_proxy( + (StaticReadStream(error=anyio.ClosedResourceError()), server_write), + (StaticReadStream(), client_write), + ): + await server_write.closed.wait() + + assert server_write.items == [] + + +@pytest.mark.anyio +async def test_proxy_preserves_sender_context_for_context_aware_streams() -> None: + request_id_var: contextvars.ContextVar[str] = contextvars.ContextVar("request_id") + server_write, server_receive = create_context_streams[SessionMessage | Exception](1) + + request_id_var.set("proxy-request-123") + sender_context = contextvars.copy_context() + + async with server_write, server_receive: + await _forward_message( + make_message("client", "client/method"), + server_write, + ReadStreamWithContext(sender_context), + ) + received = await server_receive.receive() + + assert isinstance(received, SessionMessage) + assert server_receive.last_context is not None + assert server_receive.last_context.get(request_id_var) == "proxy-request-123" + + +@pytest.mark.anyio +async def test_proxy_raises_transport_exceptions() -> None: + client_send, client_read = create_context_streams[SessionMessage | Exception](1) + plain_write_send, plain_write_receive = anyio.create_memory_object_stream[SessionMessage](1) + + async with client_send, client_read, plain_write_send, plain_write_receive: + with pytest.raises(Exception) as exc_info: + async with mcp_proxy((client_read, plain_write_send), (StaticReadStream(), TrackingWriteStream())): + await client_send.send(ValueError("transport boom")) + await anyio.sleep(0.05) + + assert_contains_exception(exc_info.value, ValueError, "transport boom")