Source code for autogen_ext.memory.chromadb._chromadb

import logging
import uuid
from typing import Any, List

from autogen_core import CancellationToken, Component, Image
from autogen_core.memory import Memory, MemoryContent, MemoryMimeType, MemoryQueryResult, UpdateContextResult
from autogen_core.model_context import ChatCompletionContext
from autogen_core.models import SystemMessage
from chromadb import HttpClient, PersistentClient
from chromadb.api.models.Collection import Collection
from chromadb.api.types import Document, Metadata
from typing_extensions import Self

from ._chroma_configs import (
    ChromaDBVectorMemoryConfig,
    CustomEmbeddingFunctionConfig,
    DefaultEmbeddingFunctionConfig,
    HttpChromaDBVectorMemoryConfig,
    OpenAIEmbeddingFunctionConfig,
    PersistentChromaDBVectorMemoryConfig,
    SentenceTransformerEmbeddingFunctionConfig,
)

logger = logging.getLogger(__name__)


try:
    from chromadb.api import ClientAPI
except ImportError as e:
    raise ImportError(
        "To use the ChromaDBVectorMemory the chromadb extra must be installed. Run `pip install autogen-ext[chromadb]`"
    ) from e


[docs] class ChromaDBVectorMemory(Memory, Component[ChromaDBVectorMemoryConfig]): """ Store and retrieve memory using vector similarity search powered by ChromaDB. `ChromaDBVectorMemory` provides a vector-based memory implementation that uses ChromaDB for storing and retrieving content based on semantic similarity. It enhances agents with the ability to recall contextually relevant information during conversations by leveraging vector embeddings to find similar content. This implementation serves as a reference for more complex memory systems using vector embeddings. For advanced use cases requiring specialized formatting of retrieved content, users should extend this class and override the `update_context()` method. This implementation requires the ChromaDB extra to be installed. Install with: .. code-block:: bash pip install "autogen-ext[chromadb]" Args: config (ChromaDBVectorMemoryConfig | None): Configuration for the ChromaDB memory. If None, defaults to a PersistentChromaDBVectorMemoryConfig with default values. Two config types are supported: * PersistentChromaDBVectorMemoryConfig: For local storage * HttpChromaDBVectorMemoryConfig: For connecting to a remote ChromaDB server Example: .. code-block:: python import os import asyncio from pathlib import Path from autogen_agentchat.agents import AssistantAgent from autogen_agentchat.ui import Console from autogen_core.memory import MemoryContent, MemoryMimeType from autogen_ext.memory.chromadb import ( ChromaDBVectorMemory, PersistentChromaDBVectorMemoryConfig, SentenceTransformerEmbeddingFunctionConfig, OpenAIEmbeddingFunctionConfig, ) from autogen_ext.models.openai import OpenAIChatCompletionClient def get_weather(city: str) -> str: return f"The weather in {city} is sunny with a high of 90°F and a low of 70°F." def fahrenheit_to_celsius(fahrenheit: float) -> float: return (fahrenheit - 32) * 5.0 / 9.0 async def main() -> None: # Use default embedding function default_memory = ChromaDBVectorMemory( config=PersistentChromaDBVectorMemoryConfig( collection_name="user_preferences", persistence_path=os.path.join(str(Path.home()), ".chromadb_autogen"), k=3, # Return top 3 results score_threshold=0.5, # Minimum similarity score ) ) # Using a custom SentenceTransformer model custom_memory = ChromaDBVectorMemory( config=PersistentChromaDBVectorMemoryConfig( collection_name="multilingual_memory", persistence_path=os.path.join(str(Path.home()), ".chromadb_autogen"), embedding_function_config=SentenceTransformerEmbeddingFunctionConfig( model_name="paraphrase-multilingual-mpnet-base-v2" ), ) ) # Using OpenAI embeddings openai_memory = ChromaDBVectorMemory( config=PersistentChromaDBVectorMemoryConfig( collection_name="openai_memory", persistence_path=os.path.join(str(Path.home()), ".chromadb_autogen"), embedding_function_config=OpenAIEmbeddingFunctionConfig( api_key=os.environ["OPENAI_API_KEY"], model_name="text-embedding-3-small" ), ) ) # Add user preferences to memory await openai_memory.add( MemoryContent( content="The user prefers weather temperatures in Celsius", mime_type=MemoryMimeType.TEXT, metadata={"category": "preferences", "type": "units"}, ) ) # Create assistant agent with ChromaDB memory assistant = AssistantAgent( name="assistant", model_client=OpenAIChatCompletionClient( model="gpt-4.1", ), tools=[ get_weather, fahrenheit_to_celsius, ], max_tool_iterations=10, memory=[openai_memory], ) # The memory will automatically retrieve relevant content during conversations await Console(assistant.run_stream(task="What's the temperature in New York?")) # Remember to close the memory when finished await default_memory.close() await custom_memory.close() await openai_memory.close() asyncio.run(main()) Output: .. code-block:: text ---------- TextMessage (user) ---------- What's the temperature in New York? ---------- MemoryQueryEvent (assistant) ---------- [MemoryContent(content='The user prefers weather temperatures in Celsius', mime_type='MemoryMimeType.TEXT', metadata={'type': 'units', 'category': 'preferences', 'mime_type': 'MemoryMimeType.TEXT', 'score': 0.3133561611175537, 'id': 'fb00506c-acf4-4174-93d7-2a942593f3f7'}), MemoryContent(content='The user prefers weather temperatures in Celsius', mime_type='MemoryMimeType.TEXT', metadata={'mime_type': 'MemoryMimeType.TEXT', 'category': 'preferences', 'type': 'units', 'score': 0.3133561611175537, 'id': '34311689-b419-4e1a-8bc4-09143f356c66'})] ---------- ToolCallRequestEvent (assistant) ---------- [FunctionCall(id='call_7TjsFd430J1aKwU5T2w8bvdh', arguments='{"city":"New York"}', name='get_weather')] ---------- ToolCallExecutionEvent (assistant) ---------- [FunctionExecutionResult(content='The weather in New York is sunny with a high of 90°F and a low of 70°F.', name='get_weather', call_id='call_7TjsFd430J1aKwU5T2w8bvdh', is_error=False)] ---------- ToolCallRequestEvent (assistant) ---------- [FunctionCall(id='call_RTjMHEZwDXtjurEYTjDlvq9c', arguments='{"fahrenheit": 90}', name='fahrenheit_to_celsius'), FunctionCall(id='call_3mMuCK1aqtzZPTqIHPoHKxtP', arguments='{"fahrenheit": 70}', name='fahrenheit_to_celsius')] ---------- ToolCallExecutionEvent (assistant) ---------- [FunctionExecutionResult(content='32.22222222222222', name='fahrenheit_to_celsius', call_id='call_RTjMHEZwDXtjurEYTjDlvq9c', is_error=False), FunctionExecutionResult(content='21.11111111111111', name='fahrenheit_to_celsius', call_id='call_3mMuCK1aqtzZPTqIHPoHKxtP', is_error=False)] ---------- TextMessage (assistant) ---------- The temperature in New York today is sunny with a high of about 32°C and a low of about 21°C. """ component_config_schema = ChromaDBVectorMemoryConfig component_provider_override = "autogen_ext.memory.chromadb.ChromaDBVectorMemory" def __init__(self, config: ChromaDBVectorMemoryConfig | None = None) -> None: self._config = config or PersistentChromaDBVectorMemoryConfig() self._client: ClientAPI | None = None self._collection: Collection | None = None @property def collection_name(self) -> str: """Get the name of the ChromaDB collection.""" return self._config.collection_name def _create_embedding_function(self) -> Any: """Create an embedding function based on the configuration. Returns: A ChromaDB-compatible embedding function. Raises: ValueError: If the embedding function type is unsupported. ImportError: If required dependencies are not installed. """ try: from chromadb.utils import embedding_functions except ImportError as e: raise ImportError( "ChromaDB embedding functions not available. Ensure chromadb is properly installed." ) from e config = self._config.embedding_function_config if isinstance(config, DefaultEmbeddingFunctionConfig): return embedding_functions.DefaultEmbeddingFunction() elif isinstance(config, SentenceTransformerEmbeddingFunctionConfig): try: return embedding_functions.SentenceTransformerEmbeddingFunction(model_name=config.model_name) except Exception as e: raise ImportError( f"Failed to create SentenceTransformer embedding function with model '{config.model_name}'. " f"Ensure sentence-transformers is installed and the model is available. Error: {e}" ) from e elif isinstance(config, OpenAIEmbeddingFunctionConfig): try: return embedding_functions.OpenAIEmbeddingFunction(api_key=config.api_key, model_name=config.model_name) except Exception as e: raise ImportError( f"Failed to create OpenAI embedding function with model '{config.model_name}'. " f"Ensure openai is installed and API key is valid. Error: {e}" ) from e elif isinstance(config, CustomEmbeddingFunctionConfig): try: return config.function(**config.params) except Exception as e: raise ValueError(f"Failed to create custom embedding function. Error: {e}") from e else: raise ValueError(f"Unsupported embedding function config type: {type(config)}") def _ensure_initialized(self) -> None: """Ensure ChromaDB client and collection are initialized.""" if self._client is None: try: from chromadb.config import Settings settings = Settings(allow_reset=self._config.allow_reset) if isinstance(self._config, PersistentChromaDBVectorMemoryConfig): self._client = PersistentClient( path=self._config.persistence_path, settings=settings, tenant=self._config.tenant, database=self._config.database, ) elif isinstance(self._config, HttpChromaDBVectorMemoryConfig): self._client = HttpClient( host=self._config.host, port=self._config.port, ssl=self._config.ssl, headers=self._config.headers, settings=settings, tenant=self._config.tenant, database=self._config.database, ) else: raise ValueError(f"Unsupported config type: {type(self._config)}") except Exception as e: logger.error(f"Failed to initialize ChromaDB client: {e}") raise if self._collection is None: try: # Create embedding function embedding_function = self._create_embedding_function() # Create or get collection with embedding function self._collection = self._client.get_or_create_collection( name=self._config.collection_name, metadata={"distance_metric": self._config.distance_metric}, embedding_function=embedding_function, ) except Exception as e: logger.error(f"Failed to get/create collection: {e}") raise def _extract_text(self, content_item: str | MemoryContent) -> str: """Extract searchable text from content.""" if isinstance(content_item, str): return content_item content = content_item.content mime_type = content_item.mime_type if mime_type in [MemoryMimeType.TEXT, MemoryMimeType.MARKDOWN]: return str(content) elif mime_type == MemoryMimeType.JSON: if isinstance(content, dict): # Store original JSON string representation return str(content).lower() raise ValueError("JSON content must be a dict") elif isinstance(content, Image): raise ValueError("Image content cannot be converted to text") else: raise ValueError(f"Unsupported content type: {mime_type}") def _calculate_score(self, distance: float) -> float: """Convert ChromaDB distance to a similarity score.""" if self._config.distance_metric == "cosine": return 1.0 - (distance / 2.0) return 1.0 / (1.0 + distance)
[docs] async def update_context( self, model_context: ChatCompletionContext, ) -> UpdateContextResult: messages = await model_context.get_messages() if not messages: return UpdateContextResult(memories=MemoryQueryResult(results=[])) # Extract query from last message last_message = messages[-1] query_text = last_message.content if isinstance(last_message.content, str) else str(last_message) # Query memory and get results query_results = await self.query(query_text) if query_results.results: # Format results for context memory_strings = [f"{i}. {str(memory.content)}" for i, memory in enumerate(query_results.results, 1)] memory_context = "\nRelevant memory content:\n" + "\n".join(memory_strings) # Add to context await model_context.add_message(SystemMessage(content=memory_context)) return UpdateContextResult(memories=query_results)
[docs] async def add(self, content: MemoryContent, cancellation_token: CancellationToken | None = None) -> None: self._ensure_initialized() if self._collection is None: raise RuntimeError("Failed to initialize ChromaDB") try: # Extract text from content text = self._extract_text(content) # Use metadata directly from content metadata_dict = content.metadata or {} metadata_dict["mime_type"] = str(content.mime_type) # Add to ChromaDB self._collection.add(documents=[text], metadatas=[metadata_dict], ids=[str(uuid.uuid4())]) except Exception as e: logger.error(f"Failed to add content to ChromaDB: {e}") raise
[docs] async def query( self, query: str | MemoryContent, cancellation_token: CancellationToken | None = None, **kwargs: Any, ) -> MemoryQueryResult: self._ensure_initialized() if self._collection is None: raise RuntimeError("Failed to initialize ChromaDB") try: # Extract text for query query_text = self._extract_text(query) # Query ChromaDB results = self._collection.query( query_texts=[query_text], n_results=self._config.k, include=["documents", "metadatas", "distances"], **kwargs, ) # Convert results to MemoryContent list memory_results: List[MemoryContent] = [] if ( not results or not results.get("documents") or not results.get("metadatas") or not results.get("distances") ): return MemoryQueryResult(results=memory_results) documents: List[Document] = results["documents"][0] if results["documents"] else [] metadatas: List[Metadata] = results["metadatas"][0] if results["metadatas"] else [] distances: List[float] = results["distances"][0] if results["distances"] else [] ids: List[str] = results["ids"][0] if results["ids"] else [] for doc, metadata_dict, distance, doc_id in zip(documents, metadatas, distances, ids, strict=False): # Calculate score score = self._calculate_score(distance) metadata = dict(metadata_dict) metadata["score"] = score metadata["id"] = doc_id if self._config.score_threshold is not None and score < self._config.score_threshold: continue # Extract mime_type from metadata mime_type = str(metadata_dict.get("mime_type", MemoryMimeType.TEXT.value)) # Create MemoryContent content = MemoryContent( content=doc, mime_type=mime_type, metadata=metadata, ) memory_results.append(content) return MemoryQueryResult(results=memory_results) except Exception as e: logger.error(f"Failed to query ChromaDB: {e}") raise
[docs] async def clear(self) -> None: self._ensure_initialized() if self._collection is None: raise RuntimeError("Failed to initialize ChromaDB") try: results = self._collection.get() if results and results["ids"]: self._collection.delete(ids=results["ids"]) except Exception as e: logger.error(f"Failed to clear ChromaDB collection: {e}") raise
[docs] async def close(self) -> None: """Clean up ChromaDB client and resources.""" self._collection = None self._client = None
[docs] async def reset(self) -> None: self._ensure_initialized() if not self._config.allow_reset: raise RuntimeError("Reset not allowed. Set allow_reset=True in config to enable.") if self._client is not None: try: self._client.reset() except Exception as e: logger.error(f"Error during ChromaDB reset: {e}") finally: self._collection = None
def _to_config(self) -> ChromaDBVectorMemoryConfig: """Serialize the memory configuration.""" return self._config @classmethod def _from_config(cls, config: ChromaDBVectorMemoryConfig) -> Self: """Deserialize the memory configuration.""" return cls(config=config)