Source code for langgraph_agent_toolkit.core.prompts.chat_prompt_template

import inspect
import re
import time
from typing import Any, Dict, List, Literal, Optional, Sequence, Union

from langchain_core.prompt_values import PromptValue
from langchain_core.prompts.chat import (
    AIMessagePromptTemplate,
    BaseChatPromptTemplate,
    BaseMessage,
    BaseMessagePromptTemplate,
    ChatPromptTemplate,
    ChatPromptValue,
    HumanMessagePromptTemplate,
    MessageLikeRepresentation,
    MessagesPlaceholder,
    SystemMessagePromptTemplate,
)
from langchain_core.prompts.string import get_template_variables
from pydantic import Field

from langgraph_agent_toolkit.core.observability.base import BaseObservabilityPlatform
from langgraph_agent_toolkit.core.observability.factory import ObservabilityFactory
from langgraph_agent_toolkit.core.observability.types import MessageRole, ObservabilityBackend, PromptReturnType
from langgraph_agent_toolkit.helper.constants import DEFAULT_CACHE_TTL_SECOND
from langgraph_agent_toolkit.helper.logging import logger


def _convert_template_format(content: str, target_format: str) -> str:
    """Convert template string between different formats."""
    if not content or not isinstance(content, str):
        return content

    if target_format == "jinja2" and "{" in content and "{{" not in content:
        # Convert f-string format to Jinja2 format
        return re.sub(r"{(\w+)}", r"{{ \1 }}", content)
    elif target_format == "f-string" and "{{" in content:
        # Convert Jinja2 format to f-string format
        return re.sub(r"{{\s*(\w+)\s*}}", r"{\1}", content)
    return content


[docs] class ObservabilityChatPromptTemplate(ChatPromptTemplate): """A chat prompt template that loads prompts from observability platforms.""" prompt_name: Optional[str] = Field(default=None, description="Name of the prompt to load") prompt_version: Optional[int] = Field(default=None, description="Version of the prompt") prompt_label: Optional[str] = Field(default=None, description="Label of the prompt") load_at_runtime: bool = Field(default=False, description="Whether to load prompt at runtime") observability_backend: Optional[ObservabilityBackend] = Field( default=None, description="Observability backend to use" ) cache_ttl_seconds: int = Field(default=DEFAULT_CACHE_TTL_SECOND, description="Cache TTL for prompts") template_format: str = Field(default="f-string", description="Format of the template") _observability_platform: Optional[BaseObservabilityPlatform] = None _loaded_prompt: Any = None _last_load_time: float = 0 model_config = {"extra": "allow"}
[docs] def __init__( self, messages: Optional[Sequence[MessageLikeRepresentation]] = None, *, prompt_name: Optional[str] = None, prompt_version: Optional[int] = None, prompt_label: Optional[str] = None, load_at_runtime: bool = False, observability_platform: Optional[BaseObservabilityPlatform] = None, observability_backend: Optional[Union[ObservabilityBackend, str]] = None, cache_ttl_seconds: int = DEFAULT_CACHE_TTL_SECOND, template_format: Literal["f-string", "mustache", "jinja2"] = "f-string", input_variables: Optional[List[str]] = None, partial_variables: Optional[Dict[str, Any]] = None, **kwargs: Any, ): """Initialize ObservabilityChatPromptTemplate.""" # Process observability platform/backend _observability_platform = observability_platform _observability_backend = observability_backend if not observability_platform and observability_backend: _observability_backend = ( ObservabilityBackend(observability_backend) if isinstance(observability_backend, str) else observability_backend ) _observability_platform = ObservabilityFactory.create(_observability_backend) # Load messages if needed _messages = messages or [] if not load_at_runtime and prompt_name and _observability_platform: try: self._loaded_prompt = loaded_prompt = self._load_prompt_from_platform( _observability_platform, prompt_name=prompt_name, prompt_version=prompt_version, prompt_label=prompt_label, cache_ttl_seconds=cache_ttl_seconds, template_format=template_format, ) # Process returned prompt based on type if not _messages: if hasattr(loaded_prompt, "messages"): _messages = self._process_messages_from_prompt(loaded_prompt.messages, template_format) elif isinstance(loaded_prompt, BaseChatPromptTemplate): _messages = loaded_prompt.messages elif isinstance(loaded_prompt, list): processed_messages = self._process_list_prompt(loaded_prompt, template_format) if processed_messages: _messages = processed_messages except Exception as e: logger.warning(f"Failed to load prompt {prompt_name}: {e}") # Save input variables and partial variables _input_variables = list(input_variables) if input_variables else [] _partial_variables = dict(partial_variables) if partial_variables else {} # Initialize parent class with messages only super().__init__(messages=_messages) # Set attributes self.prompt_name = prompt_name self.prompt_version = prompt_version self.prompt_label = prompt_label self.load_at_runtime = load_at_runtime self.observability_backend = _observability_backend self.cache_ttl_seconds = cache_ttl_seconds self.template_format = template_format # Explicitly set input variables and partial variables if _input_variables: self.input_variables = _input_variables if _partial_variables: self.partial_variables = _partial_variables # Set private attributes self._observability_platform = _observability_platform self._last_load_time = time.time()
@property def observability_platform(self) -> Optional[BaseObservabilityPlatform]: """Get the observability platform.""" return self._observability_platform @observability_platform.setter def observability_platform(self, platform: BaseObservabilityPlatform) -> None: """Set the observability platform.""" self._observability_platform = platform self._loaded_prompt = None self._last_load_time = 0 def _load_prompt_from_platform( self, platform: BaseObservabilityPlatform, prompt_name: str, prompt_version: Optional[int] = None, prompt_label: Optional[str] = None, cache_ttl_seconds: int = DEFAULT_CACHE_TTL_SECOND, template_format: Literal["f-string", "mustache", "jinja2"] = "f-string", ) -> PromptReturnType: """Load prompt from observability platform.""" kwargs = {} try: sig = inspect.signature(platform.pull_prompt).parameters if "cache_ttl_seconds" in sig: kwargs["cache_ttl_seconds"] = cache_ttl_seconds if "template_format" in sig: kwargs["template_format"] = template_format except (ValueError, TypeError): pass if prompt_version is not None: kwargs["version"] = prompt_version if prompt_label is not None: kwargs["label"] = prompt_label return platform.pull_prompt(prompt_name, **kwargs) def _load_prompt_from_observability(self) -> PromptReturnType: """Load prompt from observability platform.""" if not self._observability_platform: raise ValueError("No observability platform set") if not self.prompt_name: raise ValueError("No prompt name provided") return self._load_prompt_from_platform( self._observability_platform, prompt_name=self.prompt_name, prompt_version=self.prompt_version, prompt_label=self.prompt_label, cache_ttl_seconds=self.cache_ttl_seconds, template_format=self.template_format, ) def _update_messages_from_loaded_prompt(self) -> None: """Update the messages from the loaded prompt.""" if self._loaded_prompt is None: return MESSAGE_TYPE_MAP = { "system": SystemMessagePromptTemplate, "human": HumanMessagePromptTemplate, "ai": AIMessagePromptTemplate, "assistant": AIMessagePromptTemplate, } if hasattr(self._loaded_prompt, "messages"): processed_messages = [] for msg in self._loaded_prompt.messages: if isinstance(msg, BaseMessage) and msg.type in MESSAGE_TYPE_MAP: content = _convert_template_format(msg.content, self.template_format) template_class = MESSAGE_TYPE_MAP[msg.type] processed_messages.append( template_class.from_template(content, template_format=self.template_format) ) else: processed_messages.append(msg) self.messages = processed_messages elif isinstance(self._loaded_prompt, list): processed_messages = self._process_list_prompt(self._loaded_prompt, self.template_format) if processed_messages: self.messages = processed_messages elif isinstance(self._loaded_prompt, BaseChatPromptTemplate): self.messages = self._loaded_prompt.messages def _process_messages_from_prompt(self, messages: Any, template_format: str) -> List[MessageLikeRepresentation]: """Process messages from a loaded prompt.""" MESSAGE_TYPE_MAP = { MessageRole.SYSTEM: SystemMessagePromptTemplate, MessageRole.HUMAN: HumanMessagePromptTemplate, MessageRole.USER: HumanMessagePromptTemplate, MessageRole.AI: AIMessagePromptTemplate, MessageRole.ASSISTANT: AIMessagePromptTemplate, } processed_messages = [] for msg in messages: if isinstance(msg, MessagesPlaceholder): # Preserve MessagesPlaceholder objects processed_messages.append(msg) elif isinstance(msg, BaseMessage) and msg.type in MESSAGE_TYPE_MAP: content = _convert_template_format(msg.content, template_format) template_class = MESSAGE_TYPE_MAP[MessageRole(msg.type)] processed_messages.append(template_class.from_template(content, template_format=template_format)) elif isinstance(msg, tuple) and len(msg) == 2: role, content = msg if role in MESSAGE_TYPE_MAP: content = _convert_template_format(content, template_format) template_class = MESSAGE_TYPE_MAP[role] processed_messages.append(template_class.from_template(content, template_format=template_format)) elif isinstance(msg, dict) and "role" in msg and "content" in msg: role, content = msg["role"], msg["content"] if role in MESSAGE_TYPE_MAP: content = _convert_template_format(content, template_format) template_class = MESSAGE_TYPE_MAP[role] processed_messages.append(template_class.from_template(content, template_format=template_format)) else: processed_messages.append(msg) return processed_messages def _process_list_prompt( self, prompt_list: List[Any], template_format: str ) -> Optional[List[MessageLikeRepresentation]]: """Process a list prompt from an observability platform.""" MESSAGE_TYPE_MAP = { MessageRole.SYSTEM: SystemMessagePromptTemplate, MessageRole.HUMAN: HumanMessagePromptTemplate, MessageRole.USER: HumanMessagePromptTemplate, MessageRole.AI: AIMessagePromptTemplate, MessageRole.ASSISTANT: AIMessagePromptTemplate, } processed_messages = [] # Handle list of tuples (role, content) if all(isinstance(item, tuple) and len(item) == 2 for item in prompt_list): for role, content in prompt_list: if role in MESSAGE_TYPE_MAP: content = _convert_template_format(content, template_format) template_class = MESSAGE_TYPE_MAP[role] processed_messages.append(template_class.from_template(content, template_format=template_format)) return processed_messages # Handle list of dicts with role and content if all(isinstance(item, dict) and "role" in item and "content" in item for item in prompt_list): for item in prompt_list: role, content = item["role"], item["content"] if role in MESSAGE_TYPE_MAP: content = _convert_template_format(content, template_format) template_class = MESSAGE_TYPE_MAP[role] processed_messages.append(template_class.from_template(content, template_format=template_format)) # Handle MessagesPlaceholder elif role.lower() in (MessageRole.PLACEHOLDER, MessageRole.MESSAGES_PLACEHOLDER): # Create a MessagesPlaceholder with the content as variable name processed_messages.append(MessagesPlaceholder(variable_name=content)) return processed_messages return processed_messages or None def _format_message_with_input(self, msg: Any, input_dict: Dict[str, Any]) -> List[BaseMessage]: """Format a message with input.""" full_input_dict = dict(self.partial_variables or {}) full_input_dict.update(input_dict) try: # Special handling for MessagesPlaceholder if isinstance(msg, MessagesPlaceholder): var_name = msg.variable_name if var_name in full_input_dict: messages_value = full_input_dict[var_name] if isinstance(messages_value, list): return messages_value if all(isinstance(m, BaseMessage) for m in messages_value) else [msg] return [messages_value] if isinstance(messages_value, BaseMessage) else [msg] return [] if hasattr(msg, "format") and callable(msg.format): return [msg.format(**full_input_dict)] elif hasattr(msg, "format_messages") and callable(msg.format_messages): return msg.format_messages(**full_input_dict) except Exception as e: logger.warning(f"Error formatting message: {e}") return [msg] async def _aformat_message_with_input(self, msg: Any, input_dict: Dict[str, Any]) -> List[BaseMessage]: """Asynchronously format a message with input.""" full_input_dict = dict(self.partial_variables or {}) full_input_dict.update(input_dict) try: # Special handling for MessagesPlaceholder if isinstance(msg, MessagesPlaceholder): var_name = msg.variable_name if var_name in full_input_dict: messages_value = full_input_dict[var_name] if isinstance(messages_value, list): return messages_value if all(isinstance(m, BaseMessage) for m in messages_value) else [msg] return [messages_value] if isinstance(messages_value, BaseMessage) else [msg] return [] if hasattr(msg, "aformat") and callable(msg.aformat): return [await msg.aformat(**full_input_dict)] elif hasattr(msg, "aformat_messages") and callable(msg.aformat_messages): return await msg.aformat_messages(**full_input_dict) elif hasattr(msg, "format") and callable(msg.format): return [msg.format(**full_input_dict)] elif hasattr(msg, "format_messages") and callable(msg.format_messages): return msg.format_messages(**full_input_dict) except Exception: pass return [msg] def _ensure_messages_loaded(self) -> None: """Ensure messages are loaded from observability platform if needed.""" if not self.load_at_runtime or not self.prompt_name or not self._observability_platform: return current_time = time.time() if self._loaded_prompt is None or current_time - self._last_load_time > self.cache_ttl_seconds: try: self._loaded_prompt = self._load_prompt_from_observability() self._last_load_time = current_time self._update_messages_from_loaded_prompt() except Exception as e: logger.error(f"Failed to load prompt: {e}") if not self.messages: raise ValueError(f"Failed to load prompt and no fallback available: {e}")
[docs] def invoke(self, input: Any, config: Optional[Dict[str, Any]] = None, **kwargs: Any) -> PromptValue: """Invoke the prompt template.""" self._ensure_messages_loaded() if isinstance(input, dict): formatted_messages = [] for msg in self.messages: formatted_messages.extend(self._format_message_with_input(msg, input)) return ChatPromptValue(messages=formatted_messages) return super().invoke(input=input, config=config, **kwargs)
[docs] async def ainvoke(self, input: Any, config: Optional[Dict[str, Any]] = None, **kwargs: Any) -> PromptValue: """Asynchronously invoke the prompt template.""" self._ensure_messages_loaded() if isinstance(input, dict): formatted_messages = [] for msg in self.messages: formatted_messages.extend(await self._aformat_message_with_input(msg, input)) return ChatPromptValue(messages=formatted_messages) return await super().ainvoke(input=input, config=config, **kwargs)
[docs] @classmethod def from_observability_platform( cls, prompt_name: str, observability_platform: BaseObservabilityPlatform, *, prompt_version: Optional[int] = None, prompt_label: Optional[str] = None, load_at_runtime: bool = True, **kwargs: Any, ) -> "ObservabilityChatPromptTemplate": """Create a chat prompt template from an observability platform.""" return cls( prompt_name=prompt_name, prompt_version=prompt_version, prompt_label=prompt_label, load_at_runtime=load_at_runtime, observability_platform=observability_platform, **kwargs, )
[docs] @classmethod def from_observability_backend( cls, prompt_name: str, observability_backend: Union[ObservabilityBackend, str], *, prompt_version: Optional[int] = None, prompt_label: Optional[str] = None, load_at_runtime: bool = True, **kwargs: Any, ) -> "ObservabilityChatPromptTemplate": """Create a chat prompt template from an observability backend.""" backend = ( ObservabilityBackend(observability_backend) if isinstance(observability_backend, str) else observability_backend ) platform = ObservabilityFactory.create(backend) return cls( prompt_name=prompt_name, prompt_version=prompt_version, prompt_label=prompt_label, load_at_runtime=load_at_runtime, observability_platform=platform, observability_backend=backend, **kwargs, )
def __add__(self, other: Any) -> ChatPromptTemplate: """Combine two prompt templates.""" if isinstance(other, ChatPromptTemplate): # Create a copy of messages from both templates combined_messages = list(self.messages) # Process messages from the other template other_messages = [] for msg in other.messages: # Special handling for MessagesPlaceholder if isinstance(msg, MessagesPlaceholder): other_messages.append(msg) continue if isinstance(msg, BaseMessagePromptTemplate): other_messages.append(msg) elif isinstance(msg, BaseMessage): content = msg.content if isinstance(content, str): template_vars = get_template_variables(content, self.template_format) if template_vars: template_class = { MessageRole.SYSTEM: SystemMessagePromptTemplate, MessageRole.HUMAN: HumanMessagePromptTemplate, MessageRole.AI: AIMessagePromptTemplate, MessageRole.ASSISTANT: AIMessagePromptTemplate, }.get(MessageRole(msg.type)) if template_class: other_messages.append( template_class.from_template(content, template_format=self.template_format) ) continue other_messages.append(msg) else: other_messages.append(msg) combined_messages.extend(other_messages) # Collect all input variables all_vars = set(self.input_variables or []) other_vars = set(other.input_variables or []) all_vars.update(other_vars) # Get variables from MessagesPlaceholder for msg in combined_messages: if isinstance(msg, MessagesPlaceholder): all_vars.add(msg.variable_name) elif hasattr(msg, "input_variables"): all_vars.update(msg.input_variables) # Create new partial variables dict combined_partial_vars = dict(self.partial_variables or {}) if hasattr(other, "partial_variables") and other.partial_variables: for k, v in other.partial_variables.items(): if k not in combined_partial_vars: combined_partial_vars[k] = v # Create the combined template return ChatPromptTemplate( messages=combined_messages, input_variables=list(all_vars), partial_variables=combined_partial_vars, ) elif isinstance(other, (BaseMessagePromptTemplate, BaseMessage)): return self + ChatPromptTemplate.from_messages([other]) elif isinstance(other, (list, tuple)): return self + ChatPromptTemplate.from_messages(other) elif isinstance(other, str): return self + ChatPromptTemplate.from_template(other) else: raise NotImplementedError(f"Unsupported operand type for +: {type(other)}")
__all__ = ["ObservabilityChatPromptTemplate"]