Source code for autogen_core.tools._static_workbench

import asyncio
import builtins
from typing import Any, AsyncGenerator, Dict, List, Literal, Mapping, Optional

from pydantic import BaseModel, Field
from typing_extensions import Self

from .._cancellation_token import CancellationToken
from .._component_config import Component, ComponentModel
from ._base import BaseTool, StreamTool, ToolOverride, ToolSchema
from ._workbench import StreamWorkbench, TextResultContent, ToolResult, Workbench


class StaticWorkbenchConfig(BaseModel):
    tools: List[ComponentModel] = []
    tool_overrides: Dict[str, ToolOverride] = Field(default_factory=dict)


class StateicWorkbenchState(BaseModel):
    type: Literal["StaticWorkbenchState"] = "StaticWorkbenchState"
    tools: Dict[str, Mapping[str, Any]] = {}


[docs] class StaticWorkbench(Workbench, Component[StaticWorkbenchConfig]): """ A workbench that provides a static set of tools that do not change after each tool execution. Args: tools (List[BaseTool[Any, Any]]): A list of tools to be included in the workbench. The tools should be subclasses of :class:`~autogen_core.tools.BaseTool`. tool_overrides (Optional[Dict[str, ToolOverride]]): Optional mapping of original tool names to override configurations for name and/or description. This allows customizing how tools appear to consumers while maintaining the underlying tool functionality. """ component_provider_override = "autogen_core.tools.StaticWorkbench" component_config_schema = StaticWorkbenchConfig def __init__( self, tools: List[BaseTool[Any, Any]], tool_overrides: Optional[Dict[str, ToolOverride]] = None ) -> None: self._tools = tools self._tool_overrides = tool_overrides or {} # Build reverse mapping from override names to original names for call_tool self._override_name_to_original: Dict[str, str] = {} existing_tool_names = {tool.name for tool in self._tools} for original_name, override in self._tool_overrides.items(): if override.name and override.name != original_name: # Check for conflicts with existing tool names if override.name in existing_tool_names and override.name != original_name: raise ValueError( f"Tool override name '{override.name}' conflicts with existing tool name. " f"Override names must not conflict with any tool names." ) # Check for conflicts with other override names if override.name in self._override_name_to_original: existing_original = self._override_name_to_original[override.name] raise ValueError( f"Tool override name '{override.name}' is used by multiple tools: " f"'{existing_original}' and '{original_name}'. Override names must be unique." ) self._override_name_to_original[override.name] = original_name
[docs] async def list_tools(self) -> List[ToolSchema]: result_schemas: List[ToolSchema] = [] for tool in self._tools: original_schema = tool.schema # Apply overrides if they exist for this tool if tool.name in self._tool_overrides: override = self._tool_overrides[tool.name] # Create a new ToolSchema with overrides applied schema: ToolSchema = { "name": override.name if override.name is not None else original_schema["name"], "description": override.description if override.description is not None else original_schema.get("description", ""), } # Copy optional fields if "parameters" in original_schema: schema["parameters"] = original_schema["parameters"] if "strict" in original_schema: schema["strict"] = original_schema["strict"] else: schema = original_schema result_schemas.append(schema) return result_schemas
[docs] async def call_tool( self, name: str, arguments: Mapping[str, Any] | None = None, cancellation_token: CancellationToken | None = None, call_id: str | None = None, ) -> ToolResult: # Check if the name is an override name and map it back to the original original_name = self._override_name_to_original.get(name, name) tool = next((tool for tool in self._tools if tool.name == original_name), None) if tool is None: return ToolResult( name=name, # Return the requested name (which might be overridden) result=[TextResultContent(content=f"Tool {name} not found.")], is_error=True, ) if not cancellation_token: cancellation_token = CancellationToken() if not arguments: arguments = {} try: result_future = asyncio.ensure_future(tool.run_json(arguments, cancellation_token, call_id=call_id)) cancellation_token.link_future(result_future) actual_tool_output = await result_future is_error = False result_str = tool.return_value_as_string(actual_tool_output) except Exception as e: result_str = self._format_errors(e) is_error = True return ToolResult(name=name, result=[TextResultContent(content=result_str)], is_error=is_error)
[docs] async def start(self) -> None: return None
[docs] async def stop(self) -> None: return None
[docs] async def reset(self) -> None: return None
[docs] async def save_state(self) -> Mapping[str, Any]: tool_states = StateicWorkbenchState() for tool in self._tools: tool_states.tools[tool.name] = await tool.save_state_json() return tool_states.model_dump()
[docs] async def load_state(self, state: Mapping[str, Any]) -> None: parsed_state = StateicWorkbenchState.model_validate(state) for tool in self._tools: if tool.name in parsed_state.tools: await tool.load_state_json(parsed_state.tools[tool.name])
[docs] def _to_config(self) -> StaticWorkbenchConfig: return StaticWorkbenchConfig( tools=[tool.dump_component() for tool in self._tools], tool_overrides=self._tool_overrides )
[docs] @classmethod def _from_config(cls, config: StaticWorkbenchConfig) -> Self: return cls(tools=[BaseTool.load_component(tool) for tool in config.tools], tool_overrides=config.tool_overrides)
def _format_errors(self, error: Exception) -> str: """Recursively format errors into a string.""" error_message = "" if hasattr(builtins, "ExceptionGroup") and isinstance(error, builtins.ExceptionGroup): # ExceptionGroup is available in Python 3.11+. # TODO: how to make this compatible with Python 3.10? for sub_exception in error.exceptions: # type: ignore error_message += self._format_errors(sub_exception) # type: ignore else: error_message += f"{str(error)}\n" return error_message.strip()
[docs] class StaticStreamWorkbench(StaticWorkbench, StreamWorkbench): """ A workbench that provides a static set of tools that do not change after each tool execution, and supports streaming results. """ component_provider_override = "autogen_core.tools.StaticStreamWorkbench"
[docs] async def call_tool_stream( self, name: str, arguments: Mapping[str, Any] | None = None, cancellation_token: CancellationToken | None = None, call_id: str | None = None, ) -> AsyncGenerator[Any | ToolResult, None]: tool = next((tool for tool in self._tools if tool.name == name), None) if tool is None: yield ToolResult( name=name, result=[TextResultContent(content=f"Tool {name} not found.")], is_error=True, ) return if not cancellation_token: cancellation_token = CancellationToken() if not arguments: arguments = {} try: actual_tool_output: Any | None = None if isinstance(tool, StreamTool): previous_result: Any | None = None try: async for result in tool.run_json_stream(arguments, cancellation_token, call_id=call_id): if previous_result is not None: yield previous_result previous_result = result actual_tool_output = previous_result except Exception as e: # If there was a previous result before the exception, yield it first if previous_result is not None: yield previous_result # Then yield the error result result_str = self._format_errors(e) yield ToolResult(name=tool.name, result=[TextResultContent(content=result_str)], is_error=True) return else: # If the tool is not a stream tool, we run it normally and yield the result result_future = asyncio.ensure_future(tool.run_json(arguments, cancellation_token, call_id=call_id)) cancellation_token.link_future(result_future) actual_tool_output = await result_future is_error = False result_str = tool.return_value_as_string(actual_tool_output) except Exception as e: result_str = self._format_errors(e) is_error = True yield ToolResult(name=tool.name, result=[TextResultContent(content=result_str)], is_error=is_error)