Source code for autogen_ext.tools.mcp._actor

import asyncio
import atexit
import base64
import io
import logging
from typing import Any, Coroutine, Dict, Mapping, TypedDict

from autogen_core import Component, ComponentBase, ComponentModel, Image
from autogen_core.models import (
    AssistantMessage,
    ChatCompletionClient,
    LLMMessage,
    ModelInfo,
    SystemMessage,
    UserMessage,
)
from PIL import Image as PILImage
from pydantic import BaseModel
from typing_extensions import Self

from mcp import types as mcp_types
from mcp.client.session import ClientSession
from mcp.shared.context import RequestContext

from ._config import McpServerParams
from ._session import create_mcp_server_session

logger = logging.getLogger(__name__)

McpResult = (
    Coroutine[Any, Any, mcp_types.ListToolsResult]
    | Coroutine[Any, Any, mcp_types.CallToolResult]
    | Coroutine[Any, Any, mcp_types.ListPromptsResult]
    | Coroutine[Any, Any, mcp_types.ListResourcesResult]
    | Coroutine[Any, Any, mcp_types.ListResourceTemplatesResult]
    | Coroutine[Any, Any, mcp_types.ReadResourceResult]
    | Coroutine[Any, Any, mcp_types.GetPromptResult]
)
McpFuture = asyncio.Future[McpResult]


def _parse_sampling_content(
    content: mcp_types.TextContent | mcp_types.ImageContent | mcp_types.AudioContent, model_info: ModelInfo
) -> str | Image:
    """Convert MCP content types to Autogen content types."""
    if content.type == "text":
        return content.text
    elif content.type == "image":
        if not model_info["vision"]:
            raise ValueError("Sampling model does not support image content.")
        # Decode base64 image data and create PIL Image
        image_data = base64.b64decode(content.data)
        pil_image = PILImage.open(io.BytesIO(image_data))
        return Image.from_pil(pil_image)
    else:
        raise ValueError(f"Unsupported content type: {content.type}")


def _parse_sampling_message(message: mcp_types.SamplingMessage, model_info: ModelInfo) -> LLMMessage:
    """Convert MCP sampling messages to Autogen messages."""
    content = _parse_sampling_content(message.content, model_info=model_info)
    if message.role == "user":
        return UserMessage(
            source="user",
            content=[content],
        )
    elif message.role == "assistant":
        assert isinstance(content, str), "Assistant messages only support string content."
        return AssistantMessage(
            source="assistant",
            content=content,
        )
    else:
        raise ValueError(f"Unrecognized message role: {message.role}")


class McpActorArgs(TypedDict):
    name: str | None
    kargs: Mapping[str, Any]


class McpSessionActorConfig(BaseModel):
    server_params: McpServerParams
    model_client: ComponentModel | Dict[str, Any] | None = None


[docs] class McpSessionActor(ComponentBase[BaseModel], Component[McpSessionActorConfig]): component_type = "mcp_session_actor" component_config_schema = McpSessionActorConfig component_provider_override = "autogen_ext.tools.mcp.McpSessionActor" server_params: McpServerParams # model_config = ConfigDict(arbitrary_types_allowed=True) def __init__(self, server_params: McpServerParams, model_client: ChatCompletionClient | None = None) -> None: self.server_params: McpServerParams = server_params self._model_client = model_client self.name = "mcp_session_actor" self.description = "MCP session actor" self._command_queue: asyncio.Queue[Dict[str, Any]] = asyncio.Queue() self._actor_task: asyncio.Task[Any] | None = None self._shutdown_future: asyncio.Future[Any] | None = None self._active = False self._initialize_result: mcp_types.InitializeResult | None = None atexit.register(self._sync_shutdown) @property def initialize_result(self) -> mcp_types.InitializeResult | None: return self._initialize_result
[docs] async def initialize(self) -> None: if not self._active: self._active = True self._actor_task = asyncio.create_task(self._run_actor())
[docs] async def call(self, type: str, args: McpActorArgs | None = None) -> McpFuture: if not self._active: raise RuntimeError("MCP Actor not running, call initialize() first") if self._actor_task and self._actor_task.done(): raise RuntimeError("MCP actor task crashed", self._actor_task.exception()) fut: asyncio.Future[McpFuture] = asyncio.Future() if type in {"list_tools", "list_prompts", "list_resources", "list_resource_templates", "shutdown"}: await self._command_queue.put({"type": type, "future": fut}) res = await fut elif type in {"call_tool", "read_resource", "get_prompt"}: if args is None: raise ValueError(f"args is required for {type}") name = args.get("name", None) kwargs = args.get("kargs", {}) if type == "call_tool" and name is None: raise ValueError("name is required for call_tool") elif type == "read_resource": uri = kwargs.get("uri", None) if uri is None: raise ValueError("uri is required for read_resource") await self._command_queue.put({"type": type, "uri": uri, "future": fut}) elif type == "get_prompt": if name is None: raise ValueError("name is required for get_prompt") prompt_args = kwargs.get("arguments", None) await self._command_queue.put({"type": type, "name": name, "args": prompt_args, "future": fut}) else: # call_tool await self._command_queue.put({"type": type, "name": name, "args": kwargs, "future": fut}) res = await fut else: raise ValueError(f"Unknown command type: {type}") return res
[docs] async def close(self) -> None: if not self._active or self._actor_task is None: return self._shutdown_future = asyncio.Future() await self._command_queue.put({"type": "shutdown", "future": self._shutdown_future}) await self._shutdown_future await self._actor_task self._active = False
async def _sampling_callback( self, context: RequestContext[ClientSession, Any], params: mcp_types.CreateMessageRequestParams, ) -> mcp_types.CreateMessageResult | mcp_types.ErrorData: """Handle sampling requests using the provided model client.""" if self._model_client is None: # Return an error when no model client is available return mcp_types.ErrorData( code=mcp_types.INVALID_REQUEST, message="No model client available for sampling.", data=None, ) llm_messages: list[LLMMessage] = [] try: if params.systemPrompt: llm_messages.append(SystemMessage(content=params.systemPrompt)) for mcp_message in params.messages: llm_messages.append(_parse_sampling_message(mcp_message, model_info=self._model_client.model_info)) except Exception as e: return mcp_types.ErrorData( code=mcp_types.INVALID_PARAMS, message="Error processing sampling messages.", data=f"{type(e).__name__}: {e}", ) try: result = await self._model_client.create(messages=llm_messages) content = result.content if not isinstance(content, str): content = str(content) return mcp_types.CreateMessageResult( role="assistant", content=mcp_types.TextContent(type="text", text=content), model=self._model_client.model_info["family"], stopReason=result.finish_reason, ) except Exception as e: return mcp_types.ErrorData( code=mcp_types.INTERNAL_ERROR, message="Error sampling from model client.", data=f"{type(e).__name__}: {e}", ) async def _run_actor(self) -> None: result: McpResult try: async with create_mcp_server_session( self.server_params, sampling_callback=self._sampling_callback ) as session: # Save the initialize result self._initialize_result = await session.initialize() while True: cmd = await self._command_queue.get() if cmd["type"] == "shutdown": cmd["future"].set_result("ok") break elif cmd["type"] == "call_tool": try: result = session.call_tool(name=cmd["name"], arguments=cmd["args"]) cmd["future"].set_result(result) except Exception as e: cmd["future"].set_exception(e) elif cmd["type"] == "read_resource": try: result = session.read_resource(uri=cmd["uri"]) cmd["future"].set_result(result) except Exception as e: cmd["future"].set_exception(e) elif cmd["type"] == "get_prompt": try: result = session.get_prompt(name=cmd["name"], arguments=cmd["args"]) cmd["future"].set_result(result) except Exception as e: cmd["future"].set_exception(e) elif cmd["type"] == "list_tools": try: result = session.list_tools() cmd["future"].set_result(result) except Exception as e: cmd["future"].set_exception(e) elif cmd["type"] == "list_prompts": try: result = session.list_prompts() cmd["future"].set_result(result) except Exception as e: cmd["future"].set_exception(e) elif cmd["type"] == "list_resources": try: result = session.list_resources() cmd["future"].set_result(result) except Exception as e: cmd["future"].set_exception(e) elif cmd["type"] == "list_resource_templates": try: result = session.list_resource_templates() cmd["future"].set_result(result) except Exception as e: cmd["future"].set_exception(e) except Exception as e: if self._shutdown_future and not self._shutdown_future.done(): self._shutdown_future.set_exception(e) else: logger.exception("Exception in MCP actor task") finally: self._active = False self._actor_task = None def _sync_shutdown(self) -> None: if not self._active or self._actor_task is None: return try: loop = asyncio.get_event_loop() except RuntimeError: # No loop available — interpreter is likely shutting down return if loop.is_closed(): return if loop.is_running(): loop.create_task(self.close()) else: loop.run_until_complete(self.close()) def _to_config(self) -> McpSessionActorConfig: """ Convert the adapter to its configuration representation. Returns: McpSessionConfig: The configuration of the adapter. """ return McpSessionActorConfig(server_params=self.server_params) @classmethod def _from_config(cls, config: McpSessionActorConfig) -> Self: """ Create an instance of McpSessionActor from its configuration. Args: config (McpSessionConfig): The configuration of the adapter. Returns: McpSessionActor: An instance of SseMcpToolAdapter. """ return cls(server_params=config.server_params)