Source code for langgraph_agent_toolkit.core.prompts.chat_prompt_template

import inspect
import re
import time
from typing import Any, Dict, List, Literal, Optional, Sequence, Union

from jinja2 import Environment
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
from langchain_core.prompt_values import ChatPromptValue, PromptValue
from langchain_core.prompts.chat import (
    AIMessagePromptTemplate,
    BaseChatPromptTemplate,
    BaseMessage,
    BaseMessagePromptTemplate,
    ChatPromptTemplate,
    HumanMessagePromptTemplate,
    MessageLikeRepresentation,
    MessagesPlaceholder,
    SystemMessagePromptTemplate,
)
from langchain_core.prompts.string import get_template_variables
from pydantic import Field

from langgraph_agent_toolkit.core.observability.base import BaseObservabilityPlatform
from langgraph_agent_toolkit.core.observability.factory import ObservabilityFactory
from langgraph_agent_toolkit.core.observability.types import MessageRole, ObservabilityBackend, PromptReturnType
from langgraph_agent_toolkit.helper.constants import DEFAULT_CACHE_TTL_SECOND
from langgraph_agent_toolkit.helper.logging import logger


# Create a custom Jinja2 environment that allows attribute access on dictionaries
# This is needed because LangChain's default SandboxedEnvironment restricts this
_JINJA2_ENV = Environment(autoescape=False)

# Map message roles to their corresponding prompt template classes
_MESSAGE_TYPE_MAP = {
    MessageRole.SYSTEM: SystemMessagePromptTemplate,
    MessageRole.HUMAN: HumanMessagePromptTemplate,
    MessageRole.USER: HumanMessagePromptTemplate,
    MessageRole.AI: AIMessagePromptTemplate,
    MessageRole.ASSISTANT: AIMessagePromptTemplate,
}

# Map string message types to prompt template classes (for BaseMessage.type)
_STRING_TYPE_MAP = {
    "system": SystemMessagePromptTemplate,
    "human": HumanMessagePromptTemplate,
    "ai": AIMessagePromptTemplate,
    "assistant": AIMessagePromptTemplate,
}


def _convert_template_format(content: str, target_format: str) -> str:
    """Convert template string between different formats."""
    if not content or not isinstance(content, str):
        return content

    if target_format == "jinja2" and "{" in content and "{{" not in content:
        # Convert f-string format to Jinja2 format
        return re.sub(r"{(\w+)}", r"{{ \1 }}", content)
    elif target_format == "f-string" and "{{" in content:
        # Convert Jinja2 format to f-string format
        return re.sub(r"{{\s*(\w+)\s*}}", r"{\1}", content)
    return content


[docs] class ObservabilityChatPromptTemplate(ChatPromptTemplate): """A chat prompt template that loads prompts from observability platforms.""" prompt_name: Optional[str] = Field(default=None, description="Name of the prompt to load") prompt_version: Optional[int] = Field(default=None, description="Version of the prompt") prompt_label: Optional[str] = Field(default=None, description="Label of the prompt") load_at_runtime: bool = Field(default=False, description="Whether to load prompt at runtime") observability_backend: Optional[ObservabilityBackend] = Field( default=None, description="Observability backend to use" ) cache_ttl_seconds: int = Field(default=DEFAULT_CACHE_TTL_SECOND, description="Cache TTL for prompts") template_format: str = Field(default="f-string", description="Format of the template") _observability_platform: Optional[BaseObservabilityPlatform] = None _loaded_prompt: Any = None _last_load_time: float = 0 model_config = {"extra": "allow"}
[docs] def __init__( self, messages: Optional[Sequence[MessageLikeRepresentation]] = None, *, prompt_name: Optional[str] = None, prompt_version: Optional[int] = None, prompt_label: Optional[str] = None, load_at_runtime: bool = False, observability_platform: Optional[BaseObservabilityPlatform] = None, observability_backend: Optional[Union[ObservabilityBackend, str]] = None, cache_ttl_seconds: int = DEFAULT_CACHE_TTL_SECOND, template_format: Literal["f-string", "mustache", "jinja2"] = "f-string", input_variables: Optional[List[str]] = None, partial_variables: Optional[Dict[str, Any]] = None, **kwargs: Any, ): """Initialize ObservabilityChatPromptTemplate.""" # Process observability platform/backend _observability_platform = observability_platform _observability_backend = observability_backend if not observability_platform and observability_backend: _observability_backend = ( ObservabilityBackend(observability_backend) if isinstance(observability_backend, str) else observability_backend ) _observability_platform = ObservabilityFactory.create(_observability_backend) # Load messages if needed _messages = messages or [] if not load_at_runtime and prompt_name and _observability_platform: try: self._loaded_prompt = loaded_prompt = self._load_prompt_from_platform( _observability_platform, prompt_name=prompt_name, prompt_version=prompt_version, prompt_label=prompt_label, cache_ttl_seconds=cache_ttl_seconds, template_format=template_format, ) # Process returned prompt based on type if not _messages: if hasattr(loaded_prompt, "messages"): _messages = self._process_messages_from_prompt(loaded_prompt.messages, template_format) elif isinstance(loaded_prompt, BaseChatPromptTemplate): _messages = loaded_prompt.messages elif isinstance(loaded_prompt, list): processed_messages = self._process_list_prompt(loaded_prompt, template_format) if processed_messages: _messages = processed_messages except Exception as e: logger.warning(f"Failed to load prompt {prompt_name}: {e}") # Save input variables and partial variables _input_variables = list(input_variables) if input_variables else [] _partial_variables = dict(partial_variables) if partial_variables else {} # Initialize parent class with messages only super().__init__(messages=_messages) # Set attributes self.prompt_name = prompt_name self.prompt_version = prompt_version self.prompt_label = prompt_label self.load_at_runtime = load_at_runtime self.observability_backend = _observability_backend self.cache_ttl_seconds = cache_ttl_seconds self.template_format = template_format # Explicitly set input variables and partial variables if _input_variables: self.input_variables = _input_variables if _partial_variables: self.partial_variables = _partial_variables # Set private attributes self._observability_platform = _observability_platform self._last_load_time = time.time()
@property def observability_platform(self) -> Optional[BaseObservabilityPlatform]: """Get the observability platform.""" return self._observability_platform @observability_platform.setter def observability_platform(self, platform: BaseObservabilityPlatform) -> None: """Set the observability platform.""" self._observability_platform = platform self._loaded_prompt = None self._last_load_time = 0 def _load_prompt_from_platform( self, platform: BaseObservabilityPlatform, prompt_name: str, prompt_version: Optional[int] = None, prompt_label: Optional[str] = None, cache_ttl_seconds: int = DEFAULT_CACHE_TTL_SECOND, template_format: Literal["f-string", "mustache", "jinja2"] = "f-string", ) -> PromptReturnType: """Load prompt from observability platform.""" kwargs = {} try: sig = inspect.signature(platform.pull_prompt).parameters if "cache_ttl_seconds" in sig: kwargs["cache_ttl_seconds"] = cache_ttl_seconds if "template_format" in sig: kwargs["template_format"] = template_format except (ValueError, TypeError): pass if prompt_version is not None: kwargs["version"] = prompt_version if prompt_label is not None: kwargs["label"] = prompt_label return platform.pull_prompt(prompt_name, **kwargs) def _load_prompt_from_observability(self) -> PromptReturnType: """Load prompt from observability platform.""" if not self._observability_platform: raise ValueError("No observability platform set") if not self.prompt_name: raise ValueError("No prompt name provided") return self._load_prompt_from_platform( self._observability_platform, prompt_name=self.prompt_name, prompt_version=self.prompt_version, prompt_label=self.prompt_label, cache_ttl_seconds=self.cache_ttl_seconds, template_format=self.template_format, ) def _update_messages_from_loaded_prompt(self) -> None: """Update the messages from the loaded prompt.""" if self._loaded_prompt is None: return if hasattr(self._loaded_prompt, "messages"): processed_messages = [] for msg in self._loaded_prompt.messages: if isinstance(msg, BaseMessage) and msg.type in _STRING_TYPE_MAP: content = _convert_template_format(msg.content, self.template_format) template_class = _STRING_TYPE_MAP[msg.type] processed_messages.append( template_class.from_template(content, template_format=self.template_format) ) else: processed_messages.append(msg) self.messages = processed_messages elif isinstance(self._loaded_prompt, list): processed_messages = self._process_list_prompt(self._loaded_prompt, self.template_format) if processed_messages: self.messages = processed_messages elif isinstance(self._loaded_prompt, BaseChatPromptTemplate): self.messages = self._loaded_prompt.messages def _process_messages_from_prompt(self, messages: Any, template_format: str) -> List[MessageLikeRepresentation]: """Process messages from a loaded prompt.""" processed_messages = [] for msg in messages: if isinstance(msg, MessagesPlaceholder): # Preserve MessagesPlaceholder objects processed_messages.append(msg) elif isinstance(msg, BaseMessage) and msg.type in _STRING_TYPE_MAP: content = _convert_template_format(msg.content, template_format) template_class = _STRING_TYPE_MAP[msg.type] processed_messages.append(template_class.from_template(content, template_format=template_format)) elif isinstance(msg, tuple) and len(msg) == 2: role, content = msg if role in _MESSAGE_TYPE_MAP: content = _convert_template_format(content, template_format) template_class = _MESSAGE_TYPE_MAP[role] processed_messages.append(template_class.from_template(content, template_format=template_format)) elif isinstance(msg, dict) and "role" in msg and "content" in msg: role, content = msg["role"], msg["content"] if role in _MESSAGE_TYPE_MAP: content = _convert_template_format(content, template_format) template_class = _MESSAGE_TYPE_MAP[role] processed_messages.append(template_class.from_template(content, template_format=template_format)) else: processed_messages.append(msg) return processed_messages def _process_list_prompt( self, prompt_list: List[Any], template_format: str ) -> Optional[List[MessageLikeRepresentation]]: """Process a list prompt from an observability platform.""" processed_messages = [] # Handle list of tuples (role, content) if all(isinstance(item, tuple) and len(item) == 2 for item in prompt_list): for role, content in prompt_list: if role in _MESSAGE_TYPE_MAP: content = _convert_template_format(content, template_format) template_class = _MESSAGE_TYPE_MAP[role] processed_messages.append(template_class.from_template(content, template_format=template_format)) return processed_messages # Handle list of dicts with role and content if all(isinstance(item, dict) and "role" in item and "content" in item for item in prompt_list): for item in prompt_list: role, content = item["role"], item["content"] if role in _MESSAGE_TYPE_MAP: content = _convert_template_format(content, template_format) template_class = _MESSAGE_TYPE_MAP[role] processed_messages.append(template_class.from_template(content, template_format=template_format)) # Handle MessagesPlaceholder elif role.lower() in (MessageRole.PLACEHOLDER, MessageRole.MESSAGES_PLACEHOLDER): # Create a MessagesPlaceholder with the content as variable name processed_messages.append(MessagesPlaceholder(variable_name=content)) return processed_messages return processed_messages or None def _ensure_messages_loaded(self) -> None: """Ensure messages are loaded from observability platform if needed.""" if not self.load_at_runtime or not self.prompt_name or not self._observability_platform: return current_time = time.time() if self._loaded_prompt is None or current_time - self._last_load_time > self.cache_ttl_seconds: try: self._loaded_prompt = self._load_prompt_from_observability() self._last_load_time = current_time self._update_messages_from_loaded_prompt() except Exception as e: logger.error(f"Failed to load prompt: {e}") if not self.messages: raise ValueError(f"Failed to load prompt and no fallback available: {e}") def _render_jinja2_template(self, template_content: str, variables: Dict[str, Any]) -> str: """Render a Jinja2 template using an unsandboxed environment. This allows attribute access on dictionaries (e.g., item.name instead of item['name']). """ template = _JINJA2_ENV.from_string(template_content) return template.render(**variables) def _format_messages_with_jinja2(self, input_values: Dict[str, Any]) -> List[BaseMessage]: """Format messages using custom Jinja2 environment for unsandboxed rendering.""" formatted_messages = [] for msg in self.messages: if isinstance(msg, MessagesPlaceholder): # Get messages from input placeholder_messages = input_values.get(msg.variable_name, []) if isinstance(placeholder_messages, list): formatted_messages.extend(placeholder_messages) continue if isinstance(msg, BaseMessagePromptTemplate): # Get the template content if hasattr(msg, "prompt") and hasattr(msg.prompt, "template"): template_content = msg.prompt.template # Render using custom Jinja2 environment rendered_content = self._render_jinja2_template(template_content, input_values) # Create the appropriate message type if isinstance(msg, SystemMessagePromptTemplate): formatted_messages.append(SystemMessage(content=rendered_content)) elif isinstance(msg, HumanMessagePromptTemplate): formatted_messages.append(HumanMessage(content=rendered_content)) elif isinstance(msg, AIMessagePromptTemplate): formatted_messages.append(AIMessage(content=rendered_content)) else: # Fallback to parent formatting formatted_messages.append(msg.format(**input_values)) else: # Fallback to parent formatting formatted_messages.append(msg.format(**input_values)) elif isinstance(msg, BaseMessage): formatted_messages.append(msg) else: # Try to format if it has a format method if hasattr(msg, "format"): formatted_messages.append(msg.format(**input_values)) else: formatted_messages.append(msg) return formatted_messages
[docs] def invoke(self, input: Any, config: Optional[Dict[str, Any]] = None, **kwargs: Any) -> PromptValue: """Invoke the prompt template.""" self._ensure_messages_loaded() # For jinja2 templates, use custom rendering to avoid sandbox restrictions if self.template_format == "jinja2": # Merge input with partial variables input_values = {**self.partial_variables, **(input if isinstance(input, dict) else {})} # Format messages using custom Jinja2 environment formatted_messages = self._format_messages_with_jinja2(input_values) return ChatPromptValue(messages=formatted_messages) # Delegate to parent class for other formats return super().invoke(input=input, config=config, **kwargs)
[docs] async def ainvoke(self, input: Any, config: Optional[Dict[str, Any]] = None, **kwargs: Any) -> PromptValue: """Asynchronously invoke the prompt template.""" self._ensure_messages_loaded() # For jinja2 templates, use custom rendering to avoid sandbox restrictions if self.template_format == "jinja2": # Merge input with partial variables input_values = {**self.partial_variables, **(input if isinstance(input, dict) else {})} # Format messages using custom Jinja2 environment formatted_messages = self._format_messages_with_jinja2(input_values) return ChatPromptValue(messages=formatted_messages) # Delegate to parent class for other formats return await super().ainvoke(input=input, config=config, **kwargs)
[docs] @classmethod def from_observability_platform( cls, prompt_name: str, observability_platform: BaseObservabilityPlatform, *, prompt_version: Optional[int] = None, prompt_label: Optional[str] = None, load_at_runtime: bool = True, **kwargs: Any, ) -> "ObservabilityChatPromptTemplate": """Create a chat prompt template from an observability platform.""" return cls( prompt_name=prompt_name, prompt_version=prompt_version, prompt_label=prompt_label, load_at_runtime=load_at_runtime, observability_platform=observability_platform, **kwargs, )
[docs] @classmethod def from_observability_backend( cls, prompt_name: str, observability_backend: Union[ObservabilityBackend, str], *, prompt_version: Optional[int] = None, prompt_label: Optional[str] = None, load_at_runtime: bool = True, **kwargs: Any, ) -> "ObservabilityChatPromptTemplate": """Create a chat prompt template from an observability backend.""" backend = ( ObservabilityBackend(observability_backend) if isinstance(observability_backend, str) else observability_backend ) platform = ObservabilityFactory.create(backend) return cls( prompt_name=prompt_name, prompt_version=prompt_version, prompt_label=prompt_label, load_at_runtime=load_at_runtime, observability_platform=platform, observability_backend=backend, **kwargs, )
def __add__(self, other: Any) -> ChatPromptTemplate: """Combine two prompt templates.""" if isinstance(other, ChatPromptTemplate): # Create a copy of messages from both templates combined_messages = list(self.messages) # Process messages from the other template other_messages = [] for msg in other.messages: # Special handling for MessagesPlaceholder if isinstance(msg, MessagesPlaceholder): other_messages.append(msg) continue if isinstance(msg, BaseMessagePromptTemplate): other_messages.append(msg) elif isinstance(msg, BaseMessage): content = msg.content if isinstance(content, str): template_vars = get_template_variables(content, self.template_format) if template_vars: template_class = _STRING_TYPE_MAP.get(msg.type) if template_class: other_messages.append( template_class.from_template(content, template_format=self.template_format) ) continue other_messages.append(msg) else: other_messages.append(msg) combined_messages.extend(other_messages) # Collect all input variables all_vars = set(self.input_variables or []) other_vars = set(other.input_variables or []) all_vars.update(other_vars) # Get variables from MessagesPlaceholder for msg in combined_messages: if isinstance(msg, MessagesPlaceholder): all_vars.add(msg.variable_name) elif hasattr(msg, "input_variables"): all_vars.update(msg.input_variables) # Create new partial variables dict combined_partial_vars = dict(self.partial_variables or {}) if hasattr(other, "partial_variables") and other.partial_variables: for k, v in other.partial_variables.items(): if k not in combined_partial_vars: combined_partial_vars[k] = v # Create the combined template return ChatPromptTemplate( messages=combined_messages, input_variables=list(all_vars), partial_variables=combined_partial_vars, ) elif isinstance(other, (BaseMessagePromptTemplate, BaseMessage)): return self + ChatPromptTemplate.from_messages([other]) elif isinstance(other, (list, tuple)): return self + ChatPromptTemplate.from_messages(other) elif isinstance(other, str): return self + ChatPromptTemplate.from_template(other) else: raise NotImplementedError(f"Unsupported operand type for +: {type(other)}")
__all__ = ["ObservabilityChatPromptTemplate"]