Source code for langgraph_agent_toolkit.agents.blueprints.knowledge_base_agent.agent

import os
from typing import Any

from langchain_aws import AmazonKnowledgeBasesRetriever
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
from langchain_core.runnables import RunnableConfig, RunnableLambda, RunnableSerializable
from langchain_core.runnables.base import RunnableSequence
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import END, MessagesState, StateGraph
from langgraph.managed import RemainingSteps

from langgraph_agent_toolkit.agents.agent import Agent
from langgraph_agent_toolkit.core import settings
from langgraph_agent_toolkit.core.models.factory import ModelFactory
from langgraph_agent_toolkit.helper.logging import logger
from langgraph_agent_toolkit.schema.models import ModelProvider


# Define the state
[docs] class AgentState(MessagesState, total=False): """State for Knowledge Base agent.""" remaining_steps: RemainingSteps retrieved_documents: list[dict[str, Any]] kb_documents: str
# Create the retriever
[docs] def get_kb_retriever(): """Create and return a Knowledge Base retriever instance.""" # Get the Knowledge Base ID from environment kb_id = os.environ.get("AWS_KB_ID", "") if not kb_id: raise ValueError("AWS_KB_ID environment variable must be set") # Create the retriever with the specified Knowledge Base ID retriever = AmazonKnowledgeBasesRetriever( knowledge_base_id=kb_id, retrieval_config={ "vectorSearchConfiguration": { "numberOfResults": 3, } }, ) return retriever
[docs] def wrap_model(model: BaseChatModel) -> RunnableSerializable[AgentState, AIMessage]: """Wrap the model with a system prompt for the Knowledge Base agent.""" def create_system_message(state): base_prompt = """You are a helpful assistant that provides accurate information based on retrieved documents. You will receive a query along with relevant documents retrieved from a knowledge base. Use these documents to inform your response. Follow these guidelines: 1. Base your answer primarily on the retrieved documents 2. If the documents contain the answer, provide it clearly and concisely 3. If the documents are insufficient, state that you don't have enough information 4. Never make up facts or information not present in the documents 5. Always cite the source documents when referring to specific information 6. If the documents contradict each other, acknowledge this and explain the different perspectives Format your response in a clear, conversational manner. Use markdown formatting when appropriate. """ # Check if documents were retrieved if "kb_documents" in state: # Append document information to the system prompt document_prompt = ( f"\n\nI've retrieved the following documents that may be relevant to the query:" f"\n\n{state['kb_documents']}\n\n" "Please use these documents to inform your response to the user's query. " "Only use information from these documents and clearly indicate when you are unsure." ) return [SystemMessage(content=base_prompt + document_prompt)] + state["messages"] else: # No documents were retrieved no_docs_prompt = "\n\nNo relevant documents were found in the knowledge base for this query." return [SystemMessage(content=base_prompt + no_docs_prompt)] + state["messages"] preprocessor = RunnableLambda( create_system_message, name="StateModifier", ) return RunnableSequence(preprocessor, model)
[docs] async def retrieve_documents(state: AgentState, config: RunnableConfig) -> AgentState: """Retrieve relevant documents from the knowledge base.""" # Get the last human message human_messages = [msg for msg in state["messages"] if isinstance(msg, HumanMessage)] if not human_messages: # Include messages from original state return {"messages": [], "retrieved_documents": []} # Use the last human message as the query query = human_messages[-1].content try: # Initialize the retriever retriever = get_kb_retriever() # Retrieve documents retrieved_docs = await retriever.ainvoke(query) # Create document summaries for the state document_summaries = [] for i, doc in enumerate(retrieved_docs, 1): summary = { "id": doc.metadata.get("id", f"doc-{i}"), "source": doc.metadata.get("source", "Unknown"), "title": doc.metadata.get("title", f"Document {i}"), "content": doc.page_content, "relevance_score": doc.metadata.get("score", 0), } document_summaries.append(summary) logger.info(f"Retrieved {len(document_summaries)} documents for query: {query[:50]}...") return {"retrieved_documents": document_summaries, "messages": []} except Exception as e: logger.error(f"Error retrieving documents: {str(e)}") return {"retrieved_documents": [], "messages": []}
[docs] async def prepare_augmented_prompt(state: AgentState, config: RunnableConfig) -> AgentState: """Prepare a prompt augmented with retrieved document content.""" # Get retrieved documents documents = state.get("retrieved_documents", []) if not documents: return {"messages": []} # Format retrieved documents for the model formatted_docs = "\n\n".join( [ f"--- Document {i + 1} ---\n" f"Source: {doc.get('source', 'Unknown')}\n" f"Title: {doc.get('title', 'Unknown')}\n\n" f"{doc.get('content', '')}" for i, doc in enumerate(documents) ] ) # Store formatted documents in the state return {"kb_documents": formatted_docs, "messages": []}
[docs] async def acall_model(state: AgentState, config: RunnableConfig) -> AgentState: """Generate a response based on the retrieved documents.""" # Check if a model_config is specified in agent_config model_config_key = config["configurable"].get("agent_config", {}).get("model_config") if model_config_key and model_config_key in settings.MODEL_CONFIGS: # Use the model configuration from settings model_config = settings.MODEL_CONFIGS[model_config_key] model = ModelFactory.get_model_from_config(model_config) else: # Fall back to the traditional approach model = ModelFactory.create( model_provider=config["configurable"].get("model_provider", ModelProvider.OPENAI), model_name=config["configurable"].get("model_name", settings.OPENAI_MODEL_NAME), openai_api_base=settings.OPENAI_API_BASE_URL, openai_api_key=settings.OPENAI_API_KEY, ) model_runnable = wrap_model(model) response = await model_runnable.ainvoke(state, config) return {"messages": [response]}
# Define the graph agent = StateGraph(AgentState) # Add nodes agent.add_node("retrieve_documents", retrieve_documents) agent.add_node("prepare_augmented_prompt", prepare_augmented_prompt) agent.add_node("model", acall_model) # Set entry point agent.set_entry_point("retrieve_documents") # Add edges to define the flow agent.add_edge("retrieve_documents", "prepare_augmented_prompt") agent.add_edge("prepare_augmented_prompt", "model") agent.add_edge("model", END) # Compile the agent kb_agent = Agent( name="kb-agent", description="A retrieval-augmented generation agent using Amazon Bedrock Knowledge Base.", graph=agent.compile(checkpointer=MemorySaver()), )