Source code for langgraph_agent_toolkit.core.memory.postgres

from collections.abc import AsyncGenerator, Callable
from contextlib import AbstractAsyncContextManager, asynccontextmanager
from typing import TypeVar

from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
from langgraph.store.postgres.aio import AsyncPostgresStore
from psycopg.rows import dict_row
from psycopg_pool import AsyncConnectionPool

from langgraph_agent_toolkit.core.memory.base import BaseMemoryBackend
from langgraph_agent_toolkit.core.settings import settings
from langgraph_agent_toolkit.helper.logging import logger


T = TypeVar("T")


[docs] class PostgresMemoryBackend(BaseMemoryBackend): """PostgreSQL implementation of memory backend."""
[docs] def validate_config(self) -> bool: """Validate that all required PostgreSQL configuration is present.""" required_vars = [ "POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_HOST", "POSTGRES_PORT", "POSTGRES_DB", ] missing = [var for var in required_vars if not getattr(settings, var, None)] if missing: raise ValueError( f"Missing required PostgreSQL configuration: {', '.join(missing)}. " "These environment variables must be set to use PostgreSQL persistence." ) return True
[docs] @staticmethod def get_connection_string() -> str: """Build and return the PostgreSQL connection string from settings.""" return ( f"postgresql://{settings.POSTGRES_USER}:" f"{settings.POSTGRES_PASSWORD.get_secret_value()}@" f"{settings.POSTGRES_HOST}:{settings.POSTGRES_PORT}/" f"{settings.POSTGRES_DB}" )
@asynccontextmanager async def _get_connection_context( self, factory_func: Callable[[AsyncConnectionPool], T] ) -> AsyncGenerator[T, None]: """Yield the result of the factory function. Args: factory_func: Function that creates the appropriate object from the connection pool Yields: The object created by the factory_func """ logger.info( f"Creating PostgreSQL connection pool: min_size={settings.POSTGRES_MIN_SIZE}, " f"max_size={settings.POSTGRES_POOL_SIZE}, max_idle={settings.POSTGRES_MAX_IDLE}" ) # Use AsyncConnectionPool as an async context manager async with AsyncConnectionPool( self.get_connection_string(), min_size=settings.POSTGRES_MIN_SIZE, max_size=settings.POSTGRES_POOL_SIZE, max_idle=settings.POSTGRES_MAX_IDLE, kwargs=dict(autocommit=True, prepare_threshold=0, row_factory=dict_row), ) as pool: logger.info("PostgreSQL connection pool opened successfully") try: yield factory_func(pool) finally: logger.info("PostgreSQL connection pool will be closed automatically")
[docs] @asynccontextmanager async def get_saver(self) -> AsyncGenerator[AsyncPostgresSaver, None]: """Asynchronous context manager for acquiring a PostgreSQL saver. Yields: AsyncPostgresSaver: The database saver instance """ async with self._get_connection_context(lambda pool: AsyncPostgresSaver(conn=pool)) as saver: yield saver
[docs] @asynccontextmanager async def get_store(self) -> AsyncGenerator[AsyncPostgresStore, None]: """Asynchronous context manager for acquiring a PostgreSQL store. Yields: AsyncPostgresStore: The database store instance """ async with self._get_connection_context(lambda pool: AsyncPostgresStore(conn=pool)) as store: yield store
[docs] def get_checkpoint_saver(self) -> AbstractAsyncContextManager[AsyncPostgresSaver]: """Initialize and return a PostgreSQL saver instance.""" self.validate_config() return self.get_saver()
[docs] def get_memory_store(self) -> AbstractAsyncContextManager[AsyncPostgresStore]: """Initialize and return a PostgreSQL store instance.""" self.validate_config() return self.get_store()