import json
import logging
import secrets
import traceback
import warnings
from typing import Annotated, Any, AsyncGenerator, Optional
from fastapi import Depends, HTTPException, Request, status
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from langchain_core._api import LangChainBetaWarning
from langgraph_agent_toolkit.agents.agent import Agent
from langgraph_agent_toolkit.agents.agent_executor import AgentExecutor
from langgraph_agent_toolkit.core import settings
from langgraph_agent_toolkit.helper.logging import InterceptHandler, logger
from langgraph_agent_toolkit.schema import ChatMessage, StreamInput
[docs]
def verify_bearer(
http_auth: Annotated[
HTTPAuthorizationCredentials | None,
Depends(HTTPBearer(description="Please provide AUTH_SECRET api key.", auto_error=False)),
],
) -> None:
if not settings.AUTH_SECRET:
return
auth_secret = settings.AUTH_SECRET.get_secret_value()
if not http_auth or not secrets.compare_digest(http_auth.credentials, auth_secret):
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
[docs]
def get_agent_executor(request: Request) -> AgentExecutor:
"""Get the AgentExecutor instance that was initialized in lifespan."""
app = request.app
if not hasattr(app.state, "agent_executor"):
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Agent executor not initialized. Service might be starting up.",
)
return app.state.agent_executor
[docs]
def get_agent(request: Request, agent_id: str) -> Agent:
"""Get an agent by its ID from the initialized AgentExecutor."""
executor = get_agent_executor(request)
try:
return executor.get_agent(agent_id)
except KeyError:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Agent '{agent_id}' not found",
)
[docs]
def get_all_agent_info(request: Request):
"""Get information about all available agents from the initialized AgentExecutor."""
executor = get_agent_executor(request)
return executor.get_all_agent_info()
def _validate_thread_or_user_id(thread_id: Optional[str], user_id: Optional[str]) -> None:
"""Validate that either thread_id or user_id is provided."""
if thread_id is None and user_id is None:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Either thread_id or user_id must be provided.",
)
[docs]
async def message_generator(
stream_input: StreamInput,
request: Request,
agent_id: str,
) -> AsyncGenerator[str, None]:
"""Generate messages from an agent."""
executor = get_agent_executor(request)
try:
async for message in executor.stream(
agent_id=agent_id,
input=stream_input.input,
thread_id=stream_input.thread_id,
user_id=stream_input.user_id,
model_name=stream_input.model_name,
model_provider=stream_input.model_provider,
model_config_key=stream_input.model_config_key,
stream_tokens=stream_input.stream_tokens,
agent_config=stream_input.agent_config,
recursion_limit=stream_input.recursion_limit,
):
if isinstance(message, str):
# Token output
yield f"data: {json.dumps({'type': 'token', 'content': message})}\n\n"
elif isinstance(message, ChatMessage):
# Complete message
yield f"data: {json.dumps({'type': 'message', 'content': message.model_dump()})}\n\n"
except Exception as e:
tb_str = traceback.format_exc()
logger.error(f"Error in message_generator: {e}\n\nFull traceback:\n{tb_str}")
yield f"data: {json.dumps({'type': 'error', 'content': f'Internal server error: {e}'})}\n\n"
finally:
yield "data: [DONE]\n\n"
def _sse_response_example() -> dict[int, Any]:
return {
status.HTTP_200_OK: {
"description": "Server Sent Event Response",
"content": {
"text/event-stream": {
"example": (
"data: {'type': 'token', 'content': 'Hello'}\n\n"
"data: {'type': 'token', 'content': ' World'}\n\n"
"data: [DONE]\n\n"
),
"schema": {"type": "string"},
}
},
}
}
[docs]
def setup_logging():
"""Configure application logging to use loguru."""
# Setup logging once - redirect standard library logging to loguru
logging.basicConfig(handlers=[InterceptHandler()], level=0)
# Reduce noise from uvicorn logs in production
logging.getLogger("uvicorn.access").setLevel(logging.WARNING)
logging.getLogger("uvicorn").handlers = [InterceptHandler()]
# Additional loggers that should be quieter in production
logging.getLogger("asyncio").setLevel(logging.WARNING)
logging.getLogger("fastapi").setLevel(logging.WARNING)
# Suppress LangChain beta warnings
warnings.filterwarnings("ignore", category=LangChainBetaWarning)
return logger