import warnings
from functools import cache
from typing import (
Any,
Dict,
List,
Literal,
Optional,
Tuple,
Union,
cast,
)
from langchain.chat_models.base import _ConfigurableModel, _init_chat_model_helper
from langchain_core.language_models import BaseChatModel
from langchain_core.runnables import Runnable, RunnableConfig
from typing_extensions import TypeAlias
from langgraph_agent_toolkit.core.models import ChatOpenAIPatched, FakeToolModel
from langgraph_agent_toolkit.helper.constants import (
DEFAULT_CONFIG_PREFIX,
DEFAULT_CONFIGURABLE_FIELDS,
DEFAULT_MODEL_PARAMETER_VALUES,
)
from langgraph_agent_toolkit.schema.models import ModelProvider
ModelT: TypeAlias = FakeToolModel | _ConfigurableModel | BaseChatModel
class _ConfigurableModelCustom(_ConfigurableModel):
def _model(self, config: Optional[RunnableConfig] = None) -> Runnable:
params = {**self._default_config, **self._model_params(config)}
model = ModelFactory._init_chat_model_helper(**params)
for name, args, kwargs in self._queued_declarative_operations:
model = getattr(model, name)(*args, **kwargs)
return model
[docs]
class ModelFactory:
"""Factory for creating model instances."""
@staticmethod
def _init_chat_model_helper(model: str, *, model_provider: Optional[str] = None, **kwargs: Any) -> BaseChatModel:
if model_provider == "openai":
return ChatOpenAIPatched(model_name=model, **kwargs)
else:
return _init_chat_model_helper(model, model_provider=model_provider, **kwargs)
[docs]
@staticmethod
def init_chat_model(
model: Optional[str] = None,
*,
model_provider: Optional[str] = None,
configurable_fields: Optional[Union[Literal["any"], List[str], Tuple[str, ...]]] = None,
config_prefix: Optional[str] = None,
**kwargs: Any,
) -> Union[BaseChatModel, _ConfigurableModel]:
if not model and not configurable_fields:
configurable_fields = ("model", "model_provider")
config_prefix = config_prefix or ""
if config_prefix and not configurable_fields:
warnings.warn(
f"{config_prefix=} has been set but no fields are configurable. Set "
f"`configurable_fields=(...)` to specify the model params that are "
f"configurable."
)
if not configurable_fields:
return ModelFactory._init_chat_model_helper(cast(str, model), model_provider=model_provider, **kwargs)
else:
if model:
kwargs["model"] = model
if model_provider:
kwargs["model_provider"] = model_provider
return _ConfigurableModelCustom(
default_config=kwargs,
config_prefix=config_prefix,
configurable_fields=configurable_fields,
)
[docs]
@staticmethod
@cache
def create(
model_provider: ModelProvider,
model_name: Optional[str] = None,
configurable_fields: Optional[Union[Literal["any"], List[str], Tuple[str, ...]]] = None,
config_prefix: Optional[str] = None,
model_parameter_values: Optional[Tuple[Tuple[str, Any], ...]] = None,
**kwargs: Any,
) -> ModelT:
"""Create and return a model instance.
Args:
model_provider: The model provider to use. This should be one of the supported model providers.
model_name: The name of the model to use. If not provided, the default model name will be used.
configurable_fields: The fields that are configurable. If not provided, the default fields will be used.
config_prefix: The prefix to use for the configuration. If not provided, the default prefix will be used.
model_parameter_values: The values for the model parameters as a tuple of (key, value) pairs.
If not provided, the default values will be used.
**kwargs: Additional keyword arguments to pass to the model.
Returns:
An instance of the requested model
Raises:
ValueError: If the requested model is not supported
""" # noqa: E501
_configurable_fields = DEFAULT_CONFIGURABLE_FIELDS if configurable_fields is None else configurable_fields
_config_prefix = DEFAULT_CONFIG_PREFIX if config_prefix is None else config_prefix
_model_parameter_values = (
DEFAULT_MODEL_PARAMETER_VALUES if model_parameter_values is None else dict(model_parameter_values)
)
match model_provider:
case ModelProvider.FAKE:
return FakeToolModel(responses=["This is a test response from the fake model."])
case _:
if model_name is None:
raise ValueError("Model name must be provided for non-fake models.")
return ModelFactory.init_chat_model(
model=model_name,
model_provider=model_provider,
configurable_fields=_configurable_fields,
config_prefix=_config_prefix,
**_model_parameter_values,
**kwargs,
)
[docs]
@classmethod
def get_model_from_config(cls, config: Dict[str, Any], **override_params) -> BaseChatModel:
"""Create a model from a configuration dictionary.
Args:
config: Model configuration dictionary
**override_params: Parameters to override from the configuration
Returns:
A BaseChatModel instance
Example:
>>> config = {"provider": "openai", "name": "gpt-4", "temperature": 0.7}
>>> model = ModelFactory.get_model_from_config(config)
"""
if not config:
raise ValueError("Model configuration cannot be empty")
# Extract basic parameters
provider = config.get("provider", "openai")
model_name = config.get("name") or config.get("model_name")
if not model_name:
raise ValueError("Model name must be specified in the configuration")
# Copy the config to avoid modifying the original
params = dict(config)
# Remove some keys that are handled separately
params.pop("provider", None)
params.pop("name", None)
params.pop("model_name", None)
# Apply overrides
params.update(override_params)
# Create and return the model
return cls.create(model_provider=provider, model_name=model_name, model_parameter_values=(), **params)