Source code for langgraph_agent_toolkit.agents.agent_executor

import asyncio
import functools
import importlib
import os
import sys
from pathlib import Path
from typing import Any, AsyncGenerator, Callable, Dict, List, Optional, Tuple, TypeVar
from uuid import UUID, uuid4

import joblib
from langchain_core.messages import AIMessage, AIMessageChunk, HumanMessage, ToolMessage
from langchain_core.runnables import RunnableConfig
from langgraph.errors import GraphRecursionError
from langgraph.graph.state import CompiledStateGraph
from langgraph.pregel import Pregel
from langgraph.types import Command, Interrupt

from langgraph_agent_toolkit.agents.agent import Agent
from langgraph_agent_toolkit.core.settings import settings
from langgraph_agent_toolkit.helper.constants import DEFAULT_RECURSION_LIMIT, get_default_agent, set_default_agent
from langgraph_agent_toolkit.helper.logging import logger
from langgraph_agent_toolkit.helper.utils import (
    convert_message_content_to_string,
    create_ai_message,
    langchain_to_chat_message,
    remove_tool_calls,
)
from langgraph_agent_toolkit.schema import AgentInfo, ChatMessage


T = TypeVar("T")


[docs] class AgentExecutor: """Handles the loading, execution and saving logic for different LangGraph agents."""
[docs] def __init__(self, *args): """Initialize the AgentExecutor by importing agents. Args: *args: Variable length strings specifying the agents to import, e.g., "langgraph_agent_toolkit.agents.blueprints.react.agent:react_agent". Raises: ValueError: If no agents are provided. """ self.agents: Dict[str, Agent] = {} if not args: raise ValueError("At least one agent must be provided to AgentExecutor.") # Load agents from import strings self.load_agents_from_imports(args) self._validate_default_agent_loaded()
[docs] def load_agents_from_imports(self, args: tuple) -> None: """Dynamically imports agents based on the provided import strings.""" for import_str in args: try: module_path, object_name = import_str.split(":") module = importlib.import_module(module_path) agent_obj = getattr(module, object_name) if isinstance(agent_obj, (CompiledStateGraph, Pregel)): agent = Agent(name=object_name, description=f"Dynamically loaded {object_name}", graph=agent_obj) self.agents[agent.name] = agent elif isinstance(agent_obj, Agent): self.agents[agent_obj.name] = agent_obj else: logger.warning(f"Object '{object_name}' is neither a graph nor an Agent instance") except (ImportError, AttributeError, ValueError) as e: logger.error(f"Error loading agent from '{import_str}': {e}")
def _validate_default_agent_loaded(self) -> None: """Validate that a default agent is available and set it if needed. If the default agent from constants.py is not available, use the first loaded agent as the default and update the global value. """ if not self.agents: raise ValueError("No agents were loaded. Please check your imports.") initial_default = get_default_agent() if initial_default not in self.agents: new_default = list(self.agents.keys())[0] logger.warning( f"Default agent '{initial_default}' not found in loaded agents. Using '{new_default}' as default." ) set_default_agent(new_default)
[docs] def get_agent(self, agent_id: str) -> Agent: """Get an agent by its ID. Args: agent_id: The ID of the agent to retrieve Returns: The requested Agent instance Raises: KeyError: If the agent_id is not found """ if agent_id not in self.agents: raise KeyError(f"Agent '{agent_id}' not found") return self.agents[agent_id]
[docs] def get_all_agent_info(self) -> list[AgentInfo]: """Get information about all available agents. Returns: A list of AgentInfo objects containing agent IDs and descriptions """ return [AgentInfo(key=agent_id, description=agent.description) for agent_id, agent in self.agents.items()]
[docs] def add_agent(self, agent_id: str, agent: Agent) -> None: """Add a new agent to the executor. Args: agent_id: The ID to assign to the agent agent: The Agent instance to add """ self.agents[agent_id] = agent
[docs] @staticmethod def handle_agent_errors(func: Callable[..., T]) -> Callable[..., T]: """Handle errors occurring during agent execution. Specifically handles GraphRecursionError and other exceptions. Args: func: The function to decorate Returns: The decorated function """ def _handle_error(e: Exception): """Handle and re-raise errors with logging.""" if isinstance(e, GraphRecursionError): logger.opt(exception=sys.exc_info()).error(f"GraphRecursionError occurred: {e}") else: logger.opt(exception=sys.exc_info()).error(f"Error during agent execution: {e}") raise e @functools.wraps(func) async def async_wrapper(self, *args, **kwargs): try: return await func(self, *args, **kwargs) except Exception as e: return _handle_error(e) @functools.wraps(func) def sync_wrapper(self, *args, **kwargs): try: return func(self, *args, **kwargs) except Exception as e: return _handle_error(e) if asyncio.iscoroutinefunction(func): return async_wrapper else: return sync_wrapper
async def _setup_agent_execution( self, agent_id: str, input: Dict[str, Any], thread_id: Optional[str] = None, user_id: Optional[str] = None, model_name: Optional[str] = None, model_provider: Optional[str] = None, model_config_key: Optional[str] = None, agent_config: Optional[Dict[str, Any]] = None, recursion_limit: Optional[int] = None, ) -> Tuple[Agent, Any, Any, UUID]: """Apply common setup for agent execution that both invoke and stream methods share. Args: agent_id: ID of the agent to invoke input: User message to send to the agent thread_id: Optional thread ID for conversation history user_id: Optional user ID for the agent model_name: Optional model name to override the default model_provider: Optional model provider to override the default model_config_key: Optional model config key to override the default agent_config: Optional additional configuration for the agent recursion_limit: Optional recursion limit for the agent Returns: Tuple containing: - agent: The Agent instance - input_data: The properly formatted input for the agent - config: The RunnableConfig for the agent - run_id: The UUID for this run """ agent = self.get_agent(agent_id) agent_graph = agent.graph run_id = uuid4() thread_id = thread_id or str(uuid4()) recursion_limit = recursion_limit or DEFAULT_RECURSION_LIMIT configurable = { "thread_id": thread_id, "user_id": user_id, } # Handle model_config_key if provided (takes precedence over individual model settings) if model_config_key and model_config_key in settings.MODEL_CONFIGS: # Store the model_config_key so agents can use it if needed configurable["model_config_key"] = model_config_key # Extract basic model info for backward compatibility with agents that # don't explicitly check for model_config_key model_config = settings.MODEL_CONFIGS[model_config_key] if "provider" in model_config: configurable["model_provider"] = model_config["provider"] if "name" in model_config: configurable["model_name"] = model_config["name"] else: # Fall back to individual parameters if model_name: configurable["model_name"] = model_name if model_provider: configurable["model_provider"] = model_provider if agent_config: configurable.update(agent_config) callback = agent.observability.get_callback_handler( session_id=thread_id, user_id=user_id, environment=settings.ENV_MODE, tags=[agent.name], ) config = RunnableConfig( configurable=configurable, run_id=run_id, callbacks=[callback] if callback else None, recursion_limit=recursion_limit, ) # Check if there are any interrupts that need to be resumed state = await agent_graph.aget_state(config=config) interrupted_tasks = [task for task in state.tasks if hasattr(task, "interrupts") and task.interrupts] _input = input.model_dump() input_data: Command | dict[str, Any] if interrupted_tasks: # User input is a response to resume agent execution from interrupt input_data = Command(resume=_input) else: message = _input.pop("message", "") input_data = {"messages": [HumanMessage(content=message)], **_input} return agent, input_data, config, run_id
[docs] @handle_agent_errors async def invoke( self, agent_id: str, input: Dict[str, Any], thread_id: Optional[str] = None, user_id: Optional[str] = None, model_name: Optional[str] = None, model_provider: Optional[str] = None, model_config_key: Optional[str] = None, agent_config: Optional[Dict[str, Any]] = None, recursion_limit: Optional[int] = None, ) -> ChatMessage: """Invoke an agent with a message and return the response. Args: agent_id: ID of the agent to invoke input: User message to send to the agent thread_id: Optional thread ID for conversation history user_id: Optional user ID for the agent model_name: Optional model name to override the default model_provider: Optional model provider to override the default model_config_key: Optional model config key to override the default agent_config: Optional additional configuration for the agent recursion_limit: Optional recursion limit for the agent Returns: ChatMessage: The agent's response """ agent, input_data, config, run_id = await self._setup_agent_execution( agent_id=agent_id, input=input, thread_id=thread_id, user_id=user_id, model_name=model_name, model_provider=model_provider, model_config_key=model_config_key, agent_config=agent_config, recursion_limit=recursion_limit, ) # Invoke the agent response_events: list[tuple[str, Any]] = await agent.graph.ainvoke( input=input_data, config=config, stream_mode=["updates", "values"], ) response_type, response = response_events[-1] if response_type == "values": generated_message = response.get("structured_response") if not generated_message: generated_message = response["messages"][-1] # Normal response, the agent completed successfully output = langchain_to_chat_message(generated_message) elif response_type == "updates" and "__interrupt__" in response: # The last thing to occur was an interrupt # Return the value of the first interrupt as an AIMessage output = langchain_to_chat_message(AIMessage(content=response["__interrupt__"][0].value)) else: raise ValueError(f"Unexpected response type: {response_type}") output.run_id = str(run_id) return output
[docs] @handle_agent_errors async def stream( self, agent_id: str, input: Dict[str, Any], thread_id: Optional[str] = None, user_id: Optional[str] = None, model_name: Optional[str] = None, model_provider: Optional[str] = None, model_config_key: Optional[str] = None, stream_tokens: bool = True, agent_config: Optional[Dict[str, Any]] = None, recursion_limit: Optional[int] = None, ) -> AsyncGenerator[str | ChatMessage, None]: """Stream an agent's response to a message, yielding either tokens or messages. Args: agent_id: ID of the agent to invoke input: User message to send to the agent thread_id: Optional thread ID for conversation history user_id: Optional user ID for the agent model_name: Optional model name to override the default model_provider: Optional model provider to override the default model_config_key: Optional model config key to override the default stream_tokens: Whether to stream individual tokens agent_config: Optional additional configuration for the agent recursion_limit: Optional recursion limit for the agent Yields: Either ChatMessage objects for full messages or strings for token chunks """ agent, input_data, config, run_id = await self._setup_agent_execution( agent_id=agent_id, input=input, thread_id=thread_id, user_id=user_id, model_name=model_name, model_provider=model_provider, model_config_key=model_config_key, agent_config=agent_config, recursion_limit=recursion_limit, ) # Stream from the agent with appropriate modes stream_mode = ["updates", "messages", "custom"] if stream_tokens else ["updates"] async for stream_event in agent.graph.astream(input=input_data, config=config, stream_mode=stream_mode): if not isinstance(stream_event, tuple): continue stream_mode, event = stream_event new_messages = [] if stream_mode == "updates": for node, updates in event.items(): # A simple approach to handle agent interrupts. # In a more sophisticated implementation, we could add # some structured ChatMessage type to return the interrupt value. if node == "__interrupt__": interrupt: Interrupt for interrupt in updates: new_messages.append(AIMessage(content=interrupt.value)) continue update_messages = (updates or {}).get("messages", []) # Special case for supervisor agent if node == "supervisor": # Get only the last AIMessage since supervisor includes all previous messages ai_messages = [msg for msg in update_messages if isinstance(msg, AIMessage)] if ai_messages: update_messages = [ai_messages[-1]] # Special case for expert agents if node in ("research_expert", "math_expert"): # Convert to ToolMessage so it displays in the UI as a tool response if update_messages: msg = ToolMessage( content=update_messages[0].content, name=node, tool_call_id="", ) update_messages = [msg] new_messages.extend(update_messages) elif stream_mode == "custom": new_messages = [event] elif stream_mode == "messages" and stream_tokens: msg, metadata = event if "skip_stream" in metadata.get("tags", []): continue # Skip non-LLM nodes that might send messages if not isinstance(msg, AIMessageChunk): continue content = remove_tool_calls(msg.content) if content: # Empty content in OpenAI context usually means the model is asking for a tool to be invoked yield convert_message_content_to_string(content) # LangGraph streaming may emit tuples: (field_name, field_value) # e.g. ('content', <str>), ('tool_calls', [ToolCall,...]), ('additional_kwargs', {...}), etc. # We accumulate only supported fields into `parts` and skip unsupported metadata. # More info at: https://langchain-ai.github.io/langgraph/cloud/how-tos/stream_messages/ processed_messages = [] current_message: dict[str, Any] = {} for msg in new_messages: if isinstance(msg, tuple): key, value = msg # Store parts in temporary dict current_message[key] = value else: # Add complete message if we have one in progress if current_message: processed_messages.append(create_ai_message(current_message)) current_message = {} processed_messages.append(msg) # Add any remaining message parts if current_message: processed_messages.append(create_ai_message(current_message)) for msg in processed_messages: try: chat_message = langchain_to_chat_message(msg) chat_message.run_id = str(run_id) # Skip the input message if it's repeated by LangGraph if chat_message.type == "human" and chat_message.content == msg: continue yield chat_message except Exception as e: logger.error(f"Error parsing message: {e}") continue
[docs] def save(self, path: str, agent_ids: Optional[List[str]] = None) -> None: """Save agents to disk using joblib. Args: path: Directory path where to save agents agent_ids: List of agent IDs to save. If None, saves all agents. """ _path = Path(path) _path.mkdir(exist_ok=True, parents=True) agents_to_save = self.agents if agent_ids: agents_to_save = {k: v for k, v in self.agents.items() if k in agent_ids} for agent_id, agent in agents_to_save.items(): joblib.dump(agent, _path / f"{agent_id}.joblib")
[docs] def load_saved_agents(self, path: str) -> None: """Load saved agents from disk using joblib. Args: path: Directory path from which to load agents """ for filename in os.listdir(path): if filename.endswith(".joblib"): agent = joblib.load(os.path.join(path, filename)) self.agents[agent.name] = agent self._validate_default_agent_loaded()