import os
from typing import Any
from langchain_aws import AmazonKnowledgeBasesRetriever
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
from langchain_core.runnables import RunnableConfig, RunnableLambda, RunnableSerializable
from langchain_core.runnables.base import RunnableSequence
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import END, MessagesState, StateGraph
from langgraph.managed import RemainingSteps
from langgraph_agent_toolkit.agents.agent import Agent
from langgraph_agent_toolkit.core import settings
from langgraph_agent_toolkit.core.models.factory import ModelFactory
from langgraph_agent_toolkit.helper.logging import logger
from langgraph_agent_toolkit.schema.models import ModelProvider
# Define the state
[docs]
class AgentState(MessagesState, total=False):
"""State for Knowledge Base agent."""
remaining_steps: RemainingSteps
retrieved_documents: list[dict[str, Any]]
kb_documents: str
# Create the retriever
[docs]
def get_kb_retriever():
"""Create and return a Knowledge Base retriever instance."""
# Get the Knowledge Base ID from environment
kb_id = os.environ.get("AWS_KB_ID", "")
if not kb_id:
raise ValueError("AWS_KB_ID environment variable must be set")
# Create the retriever with the specified Knowledge Base ID
retriever = AmazonKnowledgeBasesRetriever(
knowledge_base_id=kb_id,
retrieval_config={
"vectorSearchConfiguration": {
"numberOfResults": 3,
}
},
)
return retriever
[docs]
def wrap_model(model: BaseChatModel) -> RunnableSerializable[AgentState, AIMessage]:
"""Wrap the model with a system prompt for the Knowledge Base agent."""
def create_system_message(state):
base_prompt = """You are a helpful assistant that provides accurate information based on retrieved documents.
You will receive a query along with relevant documents retrieved from a knowledge base.
Use these documents to inform your response.
Follow these guidelines:
1. Base your answer primarily on the retrieved documents
2. If the documents contain the answer, provide it clearly and concisely
3. If the documents are insufficient, state that you don't have enough information
4. Never make up facts or information not present in the documents
5. Always cite the source documents when referring to specific information
6. If the documents contradict each other, acknowledge this and explain the different perspectives
Format your response in a clear, conversational manner. Use markdown formatting when appropriate.
"""
# Check if documents were retrieved
if "kb_documents" in state:
# Append document information to the system prompt
document_prompt = (
f"\n\nI've retrieved the following documents that may be relevant to the query:"
f"\n\n{state['kb_documents']}\n\n"
"Please use these documents to inform your response to the user's query. "
"Only use information from these documents and clearly indicate when you are unsure."
)
return [SystemMessage(content=base_prompt + document_prompt)] + state["messages"]
else:
# No documents were retrieved
no_docs_prompt = "\n\nNo relevant documents were found in the knowledge base for this query."
return [SystemMessage(content=base_prompt + no_docs_prompt)] + state["messages"]
preprocessor = RunnableLambda(
create_system_message,
name="StateModifier",
)
return RunnableSequence(preprocessor, model)
[docs]
async def retrieve_documents(state: AgentState, config: RunnableConfig) -> AgentState:
"""Retrieve relevant documents from the knowledge base."""
# Get the last human message
human_messages = [msg for msg in state["messages"] if isinstance(msg, HumanMessage)]
if not human_messages:
# Include messages from original state
return {"messages": [], "retrieved_documents": []}
# Use the last human message as the query
query = human_messages[-1].content
try:
# Initialize the retriever
retriever = get_kb_retriever()
# Retrieve documents
retrieved_docs = await retriever.ainvoke(query)
# Create document summaries for the state
document_summaries = []
for i, doc in enumerate(retrieved_docs, 1):
summary = {
"id": doc.metadata.get("id", f"doc-{i}"),
"source": doc.metadata.get("source", "Unknown"),
"title": doc.metadata.get("title", f"Document {i}"),
"content": doc.page_content,
"relevance_score": doc.metadata.get("score", 0),
}
document_summaries.append(summary)
logger.info(f"Retrieved {len(document_summaries)} documents for query: {query[:50]}...")
return {"retrieved_documents": document_summaries, "messages": []}
except Exception as e:
logger.error(f"Error retrieving documents: {str(e)}")
return {"retrieved_documents": [], "messages": []}
[docs]
async def prepare_augmented_prompt(state: AgentState, config: RunnableConfig) -> AgentState:
"""Prepare a prompt augmented with retrieved document content."""
# Get retrieved documents
documents = state.get("retrieved_documents", [])
if not documents:
return {"messages": []}
# Format retrieved documents for the model
formatted_docs = "\n\n".join(
[
f"--- Document {i + 1} ---\n"
f"Source: {doc.get('source', 'Unknown')}\n"
f"Title: {doc.get('title', 'Unknown')}\n\n"
f"{doc.get('content', '')}"
for i, doc in enumerate(documents)
]
)
# Store formatted documents in the state
return {"kb_documents": formatted_docs, "messages": []}
[docs]
async def acall_model(state: AgentState, config: RunnableConfig) -> AgentState:
"""Generate a response based on the retrieved documents."""
# Check if a model_config is specified in agent_config
model_config_key = config["configurable"].get("agent_config", {}).get("model_config")
if model_config_key and model_config_key in settings.MODEL_CONFIGS:
# Use the model configuration from settings
model_config = settings.MODEL_CONFIGS[model_config_key]
model = ModelFactory.get_model_from_config(model_config)
else:
# Fall back to the traditional approach
model = ModelFactory.create(
model_provider=config["configurable"].get("model_provider", ModelProvider.OPENAI),
model_name=config["configurable"].get("model_name", settings.OPENAI_MODEL_NAME),
openai_api_base=settings.OPENAI_API_BASE_URL,
openai_api_key=settings.OPENAI_API_KEY,
)
model_runnable = wrap_model(model)
response = await model_runnable.ainvoke(state, config)
return {"messages": [response]}
# Define the graph
agent = StateGraph(AgentState)
# Add nodes
agent.add_node("retrieve_documents", retrieve_documents)
agent.add_node("prepare_augmented_prompt", prepare_augmented_prompt)
agent.add_node("model", acall_model)
# Set entry point
agent.set_entry_point("retrieve_documents")
# Add edges to define the flow
agent.add_edge("retrieve_documents", "prepare_augmented_prompt")
agent.add_edge("prepare_augmented_prompt", "model")
agent.add_edge("model", END)
# Compile the agent
kb_agent = Agent(
name="kb-agent",
description="A retrieval-augmented generation agent using Amazon Bedrock Knowledge Base.",
graph=agent.compile(checkpointer=MemorySaver()),
)