Source code for autogen_ext.cache_store.redis

import json
from typing import Any, Dict, Optional, TypeVar, cast

import redis
from autogen_core import CacheStore, Component
from pydantic import BaseModel
from typing_extensions import Self

T = TypeVar("T")


[docs] class RedisStoreConfig(BaseModel): """Configuration for RedisStore""" host: str = "localhost" port: int = 6379 db: int = 0 # Add other relevant redis connection parameters username: Optional[str] = None password: Optional[str] = None ssl: bool = False socket_timeout: Optional[float] = None
[docs] class RedisStore(CacheStore[T], Component[RedisStoreConfig]): """ A typed CacheStore implementation that uses redis as the underlying storage. See :class:`~autogen_ext.models.cache.ChatCompletionCache` for an example of usage. This implementation provides automatic serialization and deserialization for: - Pydantic models (uses model_dump_json/model_validate_json) - Primitive types (strings, numbers, etc.) Args: cache_instance: An instance of `redis.Redis`. The user is responsible for managing the Redis instance's lifetime. """ component_config_schema = RedisStoreConfig component_provider_override = "autogen_ext.cache_store.redis.RedisStore" def __init__(self, redis_instance: redis.Redis): self.cache = redis_instance
[docs] def get(self, key: str, default: Optional[T] = None) -> Optional[T]: """ Retrieve a value from the Redis cache. This method handles both primitive values and complex objects: - Pydantic models are automatically deserialized from JSON - Primitive values (strings, numbers, etc.) are returned as-is - If deserialization fails, returns the raw value or default Args: key: The key to retrieve default: Value to return if key doesn't exist Returns: The value if found and properly deserialized, otherwise the default """ try: raw_value = self.cache.get(key) if raw_value is None: return default if isinstance(raw_value, bytes): try: # First try to decode as UTF-8 string decoded_str = raw_value.decode("utf-8") try: # Try to parse as JSON and return the parsed object parsed_json = json.loads(decoded_str) return cast(Optional[T], parsed_json) except json.JSONDecodeError: # If not valid JSON, return the decoded string. return cast(Optional[T], decoded_str) except UnicodeDecodeError: return default else: # Backward compatibility for primitives return cast(Optional[T], raw_value) except (redis.RedisError, ConnectionError): # Log Redis-specific errors but return default gracefully return default
[docs] def set(self, key: str, value: T) -> None: """ Store a value in the Redis cache. This method handles both primitive values and complex objects: - Pydantic models are automatically serialized to JSON - Lists containing Pydantic models are serialized to JSON - Primitive values (strings, numbers, etc.) are stored as-is Args: key: The key to store the value under value: The value to store """ try: if isinstance(value, BaseModel): # Serialize Pydantic models to JSON serialized_value = value.model_dump_json().encode("utf-8") self.cache.set(key, serialized_value) elif isinstance(value, list): # Serialize lists (which may contain Pydantic models) to JSON serializable_list: list[Any] = [] item: Any for item in value: if isinstance(item, BaseModel): serializable_list.append(item.model_dump()) else: serializable_list.append(item) serialized_value = json.dumps(serializable_list).encode("utf-8") self.cache.set(key, serialized_value) else: # Backward compatibility for primitives self.cache.set(key, cast(Any, value)) except (redis.RedisError, ConnectionError, UnicodeEncodeError, TypeError): # Log the error but don't re-raise to maintain robustness pass
[docs] def _to_config(self) -> RedisStoreConfig: # Extract connection info from redis instance connection_pool = self.cache.connection_pool connection_kwargs: Dict[str, Any] = connection_pool.connection_kwargs # type: ignore[reportUnknownMemberType] username = connection_kwargs.get("username") password = connection_kwargs.get("password") socket_timeout = connection_kwargs.get("socket_timeout") return RedisStoreConfig( host=str(connection_kwargs.get("host", "localhost")), port=int(connection_kwargs.get("port", 6379)), db=int(connection_kwargs.get("db", 0)), username=str(username) if username is not None else None, password=str(password) if password is not None else None, ssl=bool(connection_kwargs.get("ssl", False)), socket_timeout=float(socket_timeout) if socket_timeout is not None else None, )
[docs] @classmethod def _from_config(cls, config: RedisStoreConfig) -> Self: # Create new redis instance from config redis_instance = redis.Redis( host=config.host, port=config.port, db=config.db, username=config.username, password=config.password, ssl=config.ssl, socket_timeout=config.socket_timeout, ) return cls(redis_instance=redis_instance)