Source code for autogen_ext.tools.mcp._session
from contextlib import asynccontextmanager
from datetime import timedelta
from typing import AsyncGenerator
from mcp import ClientSession
from mcp.client.session import SamplingFnT
from mcp.client.sse import sse_client
from mcp.client.stdio import stdio_client
from mcp.client.streamable_http import streamablehttp_client
from ._config import McpServerParams, SseServerParams, StdioServerParams, StreamableHttpServerParams
[docs]
@asynccontextmanager
async def create_mcp_server_session(
server_params: McpServerParams, sampling_callback: SamplingFnT | None = None
) -> AsyncGenerator[ClientSession, None]:
"""Create an MCP client session for the given server parameters."""
if isinstance(server_params, StdioServerParams):
async with stdio_client(server_params) as (read, write):
async with ClientSession(
read_stream=read,
write_stream=write,
read_timeout_seconds=timedelta(seconds=server_params.read_timeout_seconds),
sampling_callback=sampling_callback,
) as session:
yield session
elif isinstance(server_params, SseServerParams):
async with sse_client(**server_params.model_dump(exclude={"type"})) as (read, write):
async with ClientSession(
read_stream=read,
write_stream=write,
read_timeout_seconds=timedelta(seconds=server_params.sse_read_timeout),
sampling_callback=sampling_callback,
) as session:
yield session
elif isinstance(server_params, StreamableHttpServerParams):
# Convert float seconds to timedelta for the streamablehttp_client
params_dict = server_params.model_dump(exclude={"type"})
params_dict["timeout"] = timedelta(seconds=server_params.timeout)
params_dict["sse_read_timeout"] = timedelta(seconds=server_params.sse_read_timeout)
async with streamablehttp_client(**params_dict) as (
read,
write,
session_id_callback, # type: ignore[assignment, unused-variable]
):
# TODO: Handle session_id_callback if needed
async with ClientSession(
read_stream=read,
write_stream=write,
read_timeout_seconds=timedelta(seconds=server_params.sse_read_timeout),
sampling_callback=sampling_callback,
) as session:
yield session