Source code for langgraph_agent_toolkit.client.client

import json
import os
from collections.abc import AsyncGenerator, Generator
from typing import Any, Dict

import httpx

from langgraph_agent_toolkit.schema import (
    AddMessagesInput,
    AddMessagesResponse,
    ChatHistory,
    ChatHistoryInput,
    ChatMessage,
    ClearHistoryInput,
    ClearHistoryResponse,
    Feedback,
    FeedbackResponse,
    MessageInput,
    ServiceMetadata,
    StreamInput,
    UserComplexInput,
    UserInput,
)


[docs] class AgentClientError(Exception): pass
[docs] class AgentClient: """Client for interacting with the agent service."""
[docs] def __init__( self, base_url: str = "http://0.0.0.0", agent: str | None = None, timeout: float | None = None, get_info: bool = True, verify: bool = False, ) -> None: """Initialize the client. Args: base_url (str): The base URL of the agent service. agent (str): The name of the default agent to use. timeout (float, optional): The timeout for requests. get_info (bool, optional): Whether to fetch agent information on init. Default: True verify (bool, optional): Whether to verify the agent information. Default: False """ self.base_url = base_url self.auth_secret = os.getenv("AUTH_SECRET") self.timeout = timeout self.info: ServiceMetadata | None = None self.agent: str | None = None if get_info: self.retrieve_info() if agent: self.update_agent(agent, verify=verify)
@property def _headers(self) -> dict[str, str]: headers = {} if self.auth_secret: headers["Authorization"] = f"Bearer {self.auth_secret}" return headers
[docs] def retrieve_info(self) -> None: try: response = httpx.get( f"{self.base_url}/info", headers=self._headers, timeout=self.timeout, ) response.raise_for_status() except httpx.HTTPError as e: raise AgentClientError(f"Error getting service info: {e}") self.info = ServiceMetadata.model_validate(response.json()) if not self.agent or self.agent not in [a.key for a in self.info.agents]: self.agent = self.info.default_agent
[docs] def update_agent(self, agent: str, verify: bool = True) -> None: if verify: if not self.info: self.retrieve_info() agent_keys = [a.key for a in self.info.agents] if agent not in agent_keys: raise AgentClientError(f"Agent {agent} not found in available agents: {', '.join(agent_keys)}") self.agent = agent
[docs] async def ainvoke( self, input: Dict[str, Any], model_name: str | None = None, model_provider: str | None = None, model_config_key: str | None = None, thread_id: str | None = None, user_id: str | None = None, agent_config: dict[str, Any] | None = None, recursion_limit: int | None = None, ) -> ChatMessage: """Invoke the agent asynchronously. Only the final message is returned. Args: input (Dict[str, Any]): The input to send to the agent model_name (str, optional): LLM model to use for the agent model_provider (str, optional): LLM model provider to use for the agent model_config_key (str, optional): Key for predefined model configuration thread_id (str, optional): Thread ID for continuing a conversation user_id (str, optional): User ID for identifying the user agent_config (dict[str, Any], optional): Additional configuration to pass through to the agent recursion_limit (int, optional): Recursion limit for the agent Returns: ChatMessage: The response from the agent """ if not self.agent: raise AgentClientError("No agent selected. Use update_agent() to select an agent.") request = UserInput(input=UserComplexInput(**input)) if thread_id: request.thread_id = thread_id if model_name: request.model_name = model_name if model_provider: request.model_provider = model_provider if model_config_key: request.model_config_key = model_config_key if agent_config: request.agent_config = agent_config if user_id: request.user_id = user_id if recursion_limit is not None: request.recursion_limit = recursion_limit async with httpx.AsyncClient() as client: try: response = await client.post( f"{self.base_url}/{self.agent}/invoke", json=request.model_dump(), headers=self._headers, timeout=self.timeout, ) response.raise_for_status() except httpx.HTTPError as e: raise AgentClientError(f"Error: {e}") return ChatMessage.model_validate(response.json())
[docs] def invoke( self, input: Dict[str, Any], model_name: str | None = None, model_provider: str | None = None, model_config_key: str | None = None, thread_id: str | None = None, user_id: str | None = None, agent_config: dict[str, Any] | None = None, recursion_limit: int | None = None, ) -> ChatMessage: """Invoke the agent synchronously. Only the final message is returned. Args: input (Dict[str, Any]): The input to send to the agent model_name (str, optional): LLM model to use for the agent model_provider (str, optional): LLM model provider to use for the agent model_config_key (str, optional): Key for predefined model configuration thread_id (str, optional): Thread ID for continuing a conversation user_id (str, optional): User ID for identifying the user agent_config (dict[str, Any], optional): Additional configuration to pass through to the agent recursion_limit (int, optional): Recursion limit for the agent Returns: ChatMessage: The response from the agent """ if not self.agent: raise AgentClientError("No agent selected. Use update_agent() to select an agent.") request = UserInput(input=UserComplexInput(**input)) if thread_id: request.thread_id = thread_id if model_name: request.model_name = model_name if model_provider: request.model_provider = model_provider if model_config_key: request.model_config_key = model_config_key if agent_config: request.agent_config = agent_config if user_id: request.user_id = user_id if recursion_limit is not None: request.recursion_limit = recursion_limit try: response = httpx.post( f"{self.base_url}/{self.agent}/invoke", json=request.model_dump(), headers=self._headers, timeout=self.timeout, ) response.raise_for_status() except httpx.HTTPError as e: raise AgentClientError(f"Error: {e}") return ChatMessage.model_validate(response.json())
def _parse_stream_line(self, line: str) -> ChatMessage | str | None: line = line.strip() if line.startswith("data: "): data = line[6:] if data == "[DONE]": return None try: parsed = json.loads(data) except Exception as e: raise Exception(f"Error JSON parsing message from server: {e}") match parsed["type"]: case "message": # Convert the JSON formatted message to an AnyMessage try: return ChatMessage.model_validate(parsed["content"]) except Exception as e: raise Exception(f"Server returned invalid message: {e}") case "token": # Yield the str token directly return parsed["content"] case "error": error_msg = "Error: " + parsed["content"] return ChatMessage(type="ai", content=error_msg) return None
[docs] def stream( self, input: Dict[str, Any], model_name: str | None = None, model_provider: str | None = None, model_config_key: str | None = None, thread_id: str | None = None, user_id: str | None = None, agent_config: dict[str, Any] | None = None, recursion_limit: int | None = None, stream_tokens: bool = True, ) -> Generator[ChatMessage | str, None, None]: """Stream the agent's response synchronously. Each intermediate message of the agent process is yielded as a ChatMessage. If stream_tokens is True (the default value), the response will also yield content tokens from streaming models as they are generated. Args: input (Dict[str, Any]): The input to send to the agent model_name (str, optional): LLM model to use for the agent model_provider (str, optional): LLM model provider to use for the agent model_config_key (str, optional): Key for predefined model configuration thread_id (str, optional): Thread ID for continuing a conversation user_id (str, optional): User ID for identifying the user agent_config (dict[str, Any], optional): Additional configuration to pass through to the agent recursion_limit (int, optional): Recursion limit for the agent stream_tokens (bool, optional): Stream tokens as they are generated Default: True Returns: Generator[ChatMessage | str, None, None]: The response from the agent """ if not self.agent: raise AgentClientError("No agent selected. Use update_agent() to select an agent.") request = StreamInput(input=UserComplexInput(**input), stream_tokens=stream_tokens) if thread_id: request.thread_id = thread_id if model_name: request.model_name = model_name if model_provider: request.model_provider = model_provider if model_config_key: request.model_config_key = model_config_key if agent_config: request.agent_config = agent_config if user_id: request.user_id = user_id if recursion_limit is not None: request.recursion_limit = recursion_limit try: with httpx.stream( "POST", f"{self.base_url}/{self.agent}/stream", json=request.model_dump(), headers=self._headers, timeout=self.timeout, ) as response: response.raise_for_status() for line in response.iter_lines(): if line.strip(): parsed = self._parse_stream_line(line) if parsed is None: break yield parsed except httpx.HTTPError as e: raise AgentClientError(f"Error: {e}")
[docs] async def astream( self, input: Dict[str, Any], model_name: str | None = None, model_provider: str | None = None, model_config_key: str | None = None, thread_id: str | None = None, user_id: str | None = None, agent_config: dict[str, Any] | None = None, recursion_limit: int | None = None, stream_tokens: bool = True, ) -> AsyncGenerator[ChatMessage | str, None]: """Stream the agent's response asynchronously. Each intermediate message of the agent process is yielded as a ChatMessage. If stream_tokens is True (the default value), the response will also yield content tokens from streaming models as they are generated. Args: input (Dict[str, Any]): The input to send to the agent model_name (str, optional): LLM model to use for the agent model_provider (str, optional): LLM model provider to use for the agent model_config_key (str, optional): Key for predefined model configuration thread_id (str, optional): Thread ID for continuing a conversation user_id (str, optional): User ID for identifying the user agent_config (dict[str, Any], optional): Additional configuration to pass through to the agent recursion_limit (int, optional): Recursion limit for the agent stream_tokens (bool, optional): Stream tokens as they are generated Default: True Returns: AsyncGenerator[ChatMessage | str, None]: The response from the agent """ if not self.agent: raise AgentClientError("No agent selected. Use update_agent() to select an agent.") request = StreamInput(input=UserComplexInput(**input), stream_tokens=stream_tokens) if thread_id: request.thread_id = thread_id if model_name: request.model_name = model_name if model_provider: request.model_provider = model_provider if model_config_key: request.model_config_key = model_config_key if agent_config: request.agent_config = agent_config if user_id: request.user_id = user_id if recursion_limit is not None: request.recursion_limit = recursion_limit async with httpx.AsyncClient() as client: try: async with client.stream( "POST", f"{self.base_url}/{self.agent}/stream", json=request.model_dump(), headers=self._headers, timeout=self.timeout, ) as response: response.raise_for_status() async for line in response.aiter_lines(): if line.strip(): try: parsed = self._parse_stream_line(line) if parsed is None: break yield parsed except GeneratorExit: # Handle GeneratorExit properly to close the stream gracefully break except httpx.HTTPError as e: raise AgentClientError(f"Error: {e}")
[docs] async def acreate_feedback( self, run_id: str, key: str, score: float, kwargs: dict[str, Any] = {}, user_id: str | None = None, ) -> None: """Create a feedback record for a run. Args: run_id (str): The ID of the run to provide feedback for key (str): The key for the feedback score (float): The score for the feedback kwargs (dict[str, Any], optional): Additional metadata for the feedback user_id (str, optional): User ID for identifying the user """ request = Feedback(run_id=run_id, key=key, score=score, user_id=user_id, kwargs=kwargs) async with httpx.AsyncClient() as client: try: response = await client.post( f"{self.base_url}/feedback", json=request.model_dump(), headers=self._headers, timeout=self.timeout, ) response.raise_for_status() response.json() except httpx.HTTPError as e: raise AgentClientError(f"Error: {e}")
[docs] def get_history( self, thread_id: str, user_id: str | None = None, ) -> ChatHistory: """Get chat history. Args: thread_id (str, optional): Thread ID for identifying a conversation user_id (str, optional): User ID for identifying the user """ request = ChatHistoryInput(thread_id=thread_id, user_id=user_id) try: response = httpx.get( f"{self.base_url}/{self.agent}/history" if self.agent else f"{self.base_url}/history", params=request.model_dump(), headers=self._headers, timeout=self.timeout, ) response.raise_for_status() except httpx.HTTPError as e: raise AgentClientError(f"Error: {e}") return ChatHistory.model_validate(response.json())
[docs] async def aget_history( self, thread_id: str, user_id: str | None = None, ) -> ChatHistory: """Get chat history asynchronously. Args: thread_id (str, optional): Thread ID for identifying a conversation user_id (str, optional): User ID for identifying the user """ request = ChatHistoryInput(thread_id=thread_id, user_id=user_id) async with httpx.AsyncClient() as client: try: response = await client.get( f"{self.base_url}/{self.agent}/history" if self.agent else f"{self.base_url}/history", params=request.model_dump(), headers=self._headers, timeout=self.timeout, ) response.raise_for_status() except httpx.HTTPError as e: raise AgentClientError(f"Error: {e}") return ChatHistory.model_validate(response.json())
[docs] def clear_history( self, thread_id: str | None = None, user_id: str | None = None, ) -> ClearHistoryResponse: """Clear chat history. Args: thread_id (str, optional): Thread ID for identifying a conversation user_id (str, optional): User ID for identifying the user """ if not thread_id and not user_id: raise AgentClientError("At least one of thread_id or user_id must be provided") request = ClearHistoryInput(thread_id=thread_id, user_id=user_id) try: response = httpx.delete( f"{self.base_url}/{self.agent}/history/clear" if self.agent else f"{self.base_url}/history/clear", json=request.model_dump(), headers=self._headers, timeout=self.timeout, ) response.raise_for_status() except httpx.HTTPError as e: raise AgentClientError(f"Error: {e}") return ClearHistoryResponse.model_validate(response.json())
[docs] async def aclear_history( self, thread_id: str | None = None, user_id: str | None = None, ) -> ClearHistoryResponse: """Clear chat history asynchronously. Args: thread_id (str, optional): Thread ID for identifying a conversation user_id (str, optional): User ID for identifying the user """ if not thread_id and not user_id: raise AgentClientError("At least one of thread_id or user_id must be provided") request = ClearHistoryInput(thread_id=thread_id, user_id=user_id) async with httpx.AsyncClient() as client: try: response = await client.delete( f"{self.base_url}/{self.agent}/history/clear" if self.agent else f"{self.base_url}/history/clear", json=request.model_dump(), headers=self._headers, timeout=self.timeout, ) response.raise_for_status() except httpx.HTTPError as e: raise AgentClientError(f"Error: {e}") return ClearHistoryResponse.model_validate(response.json())
[docs] def add_messages( self, messages: list[dict[str, str]] | list[MessageInput], thread_id: str | None = None, user_id: str | None = None, ) -> AddMessagesResponse: """Add messages to chat history. Args: messages (list[dict[str, str]] | list[MessageInput]): Messages to add thread_id (str, optional): Thread ID for identifying a conversation user_id (str, optional): User ID for identifying the user """ if not thread_id and not user_id: raise AgentClientError("At least one of thread_id or user_id must be provided") # Convert dict messages to MessageInput if needed message_inputs = [ m if isinstance(m, MessageInput) else MessageInput(type=m["type"], content=m["content"]) for m in messages ] request = AddMessagesInput(thread_id=thread_id, user_id=user_id, messages=message_inputs) try: response = httpx.post( f"{self.base_url}/{self.agent}/history/add_messages" if self.agent else f"{self.base_url}/history/add_messages", json=request.model_dump(), headers=self._headers, timeout=self.timeout, ) response.raise_for_status() except httpx.HTTPError as e: raise AgentClientError(f"Error: {e}") return AddMessagesResponse.model_validate(response.json())
[docs] async def aadd_messages( self, messages: list[dict[str, str]] | list[MessageInput], thread_id: str | None = None, user_id: str | None = None, ) -> AddMessagesResponse: """Add messages to chat history asynchronously. Args: messages (list[dict[str, str]] | list[MessageInput]): Messages to add thread_id (str, optional): Thread ID for identifying a conversation user_id (str, optional): User ID for identifying the user """ if not thread_id and not user_id: raise AgentClientError("At least one of thread_id or user_id must be provided") # Convert dict messages to MessageInput if needed message_inputs = [ m if isinstance(m, MessageInput) else MessageInput(type=m["type"], content=m["content"]) for m in messages ] request = AddMessagesInput(thread_id=thread_id, user_id=user_id, messages=message_inputs) async with httpx.AsyncClient() as client: try: response = await client.post( f"{self.base_url}/{self.agent}/history/add_messages" if self.agent else f"{self.base_url}/history/add_messages", json=request.model_dump(), headers=self._headers, timeout=self.timeout, ) response.raise_for_status() except httpx.HTTPError as e: raise AgentClientError(f"Error: {e}") return AddMessagesResponse.model_validate(response.json())
[docs] def create_feedback( self, run_id: str, key: str, score: float, kwargs: dict[str, Any] = {}, user_id: str | None = None, ) -> FeedbackResponse: """Create a feedback record for a run. Args: run_id (str): The ID of the run to provide feedback for key (str): The key for the feedback score (float): The score for the feedback kwargs (dict[str, Any], optional): Additional metadata for the feedback user_id (str, optional): User ID for identifying the user """ request = Feedback(run_id=run_id, key=key, score=score, user_id=user_id, kwargs=kwargs) try: response = httpx.post( f"{self.base_url}/{self.agent}/feedback" if self.agent else f"{self.base_url}/feedback", json=request.model_dump(), headers=self._headers, timeout=self.timeout, ) response.raise_for_status() return FeedbackResponse.model_validate(response.json()) except httpx.HTTPError as e: raise AgentClientError(f"Error: {e}")