Source code for langgraph_agent_toolkit.core.memory.postgres

from collections.abc import AsyncGenerator, Callable
from contextlib import AbstractAsyncContextManager, asynccontextmanager
from typing import TypeVar

from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
from langgraph.store.postgres.aio import AsyncPostgresStore
from psycopg.rows import dict_row
from psycopg_pool import AsyncConnectionPool, PoolTimeout

from langgraph_agent_toolkit.core.memory.base import BaseMemoryBackend
from langgraph_agent_toolkit.core.settings import settings
from langgraph_agent_toolkit.helper.logging import logger


T = TypeVar("T")


[docs] class PostgresMemoryBackend(BaseMemoryBackend): """PostgreSQL implementation of memory backend."""
[docs] def validate_config(self) -> bool: """Validate that all required PostgreSQL configuration is present.""" required_vars = [ "POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_HOST", "POSTGRES_PORT", "POSTGRES_DB", ] missing = [var for var in required_vars if not getattr(settings, var, None)] if missing: raise ValueError( f"Missing required PostgreSQL configuration: {', '.join(missing)}. " "These environment variables must be set to use PostgreSQL persistence." ) if settings.POSTGRES_MIN_SIZE > settings.POSTGRES_POOL_SIZE: raise ValueError( f"POSTGRES_MIN_SIZE ({settings.POSTGRES_MIN_SIZE}) must be less than or equal to " f"POSTGRES_POOL_SIZE ({settings.POSTGRES_POOL_SIZE})" ) return True
[docs] @staticmethod def get_connection_string() -> str: """Build and return the PostgreSQL connection string from settings.""" return ( f"postgresql://{settings.POSTGRES_USER}:" f"{settings.POSTGRES_PASSWORD.get_secret_value()}@" f"{settings.POSTGRES_HOST}:{settings.POSTGRES_PORT}/" f"{settings.POSTGRES_DB}" )
@asynccontextmanager async def _get_connection_context( self, factory_func: Callable[[AsyncConnectionPool], T], app_prefix: str, ) -> AsyncGenerator[T, None]: """Yield the result of the factory function. Args: factory_func: Function that creates the appropriate object from the connection pool app_prefix: Prefix for the application name in the connection pool Yields: The object created by the factory_func """ application_name = f"{settings.POSTGRES_APPLICATION_NAME}-{app_prefix}" logger.info( f"Creating PostgreSQL connection pool: min_size={settings.POSTGRES_MIN_SIZE}, " f"max_size={settings.POSTGRES_POOL_SIZE}, max_idle={settings.POSTGRES_MAX_IDLE}, " f"timeout={settings.POSTGRES_POOL_TIMEOUT}s, reconnect_timeout={settings.POSTGRES_RECONNECT_TIMEOUT}s, " f"max_lifetime={settings.POSTGRES_MAX_LIFETIME}s, " f"statement_timeout={settings.POSTGRES_STATEMENT_TIMEOUT}ms, " f"lock_timeout={settings.POSTGRES_LOCK_TIMEOUT}ms, " f"idle_in_transaction_timeout={settings.POSTGRES_IDLE_IN_TRANSACTION_SESSION_TIMEOUT}ms, " f"schema={settings.POSTGRES_SCHEMA}, application_name={application_name}" ) # Prepare connection kwargs with schema setting connection_kwargs = { "autocommit": True, "prepare_threshold": 0, "row_factory": dict_row, "application_name": application_name, } # Build PostgreSQL options string with timeouts and schema pg_options = [] # Set search_path if schema is specified and not default if settings.POSTGRES_SCHEMA and settings.POSTGRES_SCHEMA != "public": pg_options.append(f"-c search_path={settings.POSTGRES_SCHEMA}") # Set statement timeout - kills queries running too long if settings.POSTGRES_STATEMENT_TIMEOUT > 0: pg_options.append(f"-c statement_timeout={settings.POSTGRES_STATEMENT_TIMEOUT}") # Set lock timeout - prevents waiting forever for locks if settings.POSTGRES_LOCK_TIMEOUT > 0: pg_options.append(f"-c lock_timeout={settings.POSTGRES_LOCK_TIMEOUT}") # Set idle_in_transaction_session_timeout - kills idle transactions if settings.POSTGRES_IDLE_IN_TRANSACTION_SESSION_TIMEOUT > 0: pg_options.append( f"-c idle_in_transaction_session_timeout={settings.POSTGRES_IDLE_IN_TRANSACTION_SESSION_TIMEOUT}" ) if pg_options: connection_kwargs["options"] = " ".join(pg_options) # Callback for logging reconnection attempts def on_reconnect_failed(pool: AsyncConnectionPool) -> None: logger.error( f"Failed to reconnect to PostgreSQL. Pool stats: " f"pool_size={pool.get_stats().get('pool_size', 'N/A')}, " f"pool_available={pool.get_stats().get('pool_available', 'N/A')}" ) # Use AsyncConnectionPool as an async context manager async with AsyncConnectionPool( self.get_connection_string(), min_size=settings.POSTGRES_MIN_SIZE, max_size=settings.POSTGRES_POOL_SIZE, max_idle=settings.POSTGRES_MAX_IDLE, timeout=settings.POSTGRES_POOL_TIMEOUT, reconnect_timeout=settings.POSTGRES_RECONNECT_TIMEOUT, # Maximum time a connection can live before being recycled (prevents stale connections) max_lifetime=settings.POSTGRES_MAX_LIFETIME, # Number of background workers for connection maintenance num_workers=settings.POSTGRES_NUM_WORKERS, # Automatically check connection health before returning from pool check=AsyncConnectionPool.check_connection, # Callback when reconnection fails reconnect_failed=on_reconnect_failed, # Connection configuration kwargs=connection_kwargs, # Don't open immediately, we'll open manually after setup open=False, ) as pool: # Open the pool and wait for min_size connections await pool.open(wait=True, timeout=settings.POSTGRES_POOL_TIMEOUT) logger.info( f"PostgreSQL connection pool opened successfully. " f"Initial stats: pool_size={pool.get_stats().get('pool_size', 'N/A')}, " f"pool_available={pool.get_stats().get('pool_available', 'N/A')}" ) try: yield factory_func(pool) except PoolTimeout: # Log pool statistics for debugging stats = pool.get_stats() logger.error( f"Pool timeout occurred. Pool stats: " f"pool_size={stats.get('pool_size', 'N/A')}, " f"pool_available={stats.get('pool_available', 'N/A')}, " f"requests_waiting={stats.get('requests_waiting', 'N/A')}, " f"requests_num={stats.get('requests_num', 'N/A')}" ) raise finally: logger.info("PostgreSQL connection pool will be closed automatically")
[docs] @asynccontextmanager async def get_saver(self) -> AsyncGenerator[AsyncPostgresSaver, None]: """Asynchronous context manager for acquiring a PostgreSQL saver. Yields: AsyncPostgresSaver: The database saver instance """ async with self._get_connection_context( lambda pool: AsyncPostgresSaver(conn=pool), app_prefix="saver" ) as saver: yield saver
[docs] @asynccontextmanager async def get_store(self) -> AsyncGenerator[AsyncPostgresStore, None]: """Asynchronous context manager for acquiring a PostgreSQL store. Yields: AsyncPostgresStore: The database store instance """ async with self._get_connection_context( lambda pool: AsyncPostgresStore(conn=pool, app_prefix="store") ) as store: yield store
[docs] def get_checkpoint_saver(self) -> AbstractAsyncContextManager[AsyncPostgresSaver]: """Initialize and return a PostgreSQL saver instance.""" self.validate_config() return self.get_saver()
[docs] def get_memory_store(self) -> AbstractAsyncContextManager[AsyncPostgresStore]: """Initialize and return a PostgreSQL store instance.""" self.validate_config() return self.get_store()