Source code for autogen_ext.tools.mcp._session

from contextlib import asynccontextmanager
from datetime import timedelta
from typing import AsyncGenerator

from mcp import ClientSession
from mcp.client.session import SamplingFnT
from mcp.client.sse import sse_client
from mcp.client.stdio import stdio_client
from mcp.client.streamable_http import streamablehttp_client

from ._config import McpServerParams, SseServerParams, StdioServerParams, StreamableHttpServerParams


[docs] @asynccontextmanager async def create_mcp_server_session( server_params: McpServerParams, sampling_callback: SamplingFnT | None = None ) -> AsyncGenerator[ClientSession, None]: """Create an MCP client session for the given server parameters.""" if isinstance(server_params, StdioServerParams): async with stdio_client(server_params) as (read, write): async with ClientSession( read_stream=read, write_stream=write, read_timeout_seconds=timedelta(seconds=server_params.read_timeout_seconds), sampling_callback=sampling_callback, ) as session: yield session elif isinstance(server_params, SseServerParams): async with sse_client(**server_params.model_dump(exclude={"type"})) as (read, write): async with ClientSession( read_stream=read, write_stream=write, read_timeout_seconds=timedelta(seconds=server_params.sse_read_timeout), sampling_callback=sampling_callback, ) as session: yield session elif isinstance(server_params, StreamableHttpServerParams): # Convert float seconds to timedelta for the streamablehttp_client params_dict = server_params.model_dump(exclude={"type"}) params_dict["timeout"] = timedelta(seconds=server_params.timeout) params_dict["sse_read_timeout"] = timedelta(seconds=server_params.sse_read_timeout) async with streamablehttp_client(**params_dict) as ( read, write, session_id_callback, # type: ignore[assignment, unused-variable] ): # TODO: Handle session_id_callback if needed async with ClientSession( read_stream=read, write_stream=write, read_timeout_seconds=timedelta(seconds=server_params.sse_read_timeout), sampling_callback=sampling_callback, ) as session: yield session