Philosophy-RAG-demo/generic_rag/backend/models.py
2025-04-15 16:22:55 +02:00

181 lines
8.6 KiB
Python

import logging
from generic_rag.parsers.config import AppSettings, ChatBackend, EmbeddingBackend
# Langchain imports
from langchain_core.embeddings import Embeddings
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_aws import BedrockEmbeddings, ChatBedrock # Import ChatBedrock
from langchain_google_vertexai import VertexAIEmbeddings, ChatVertexAI # Import ChatVertexAI
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_ollama import ChatOllama, OllamaEmbeddings
from langchain_openai import AzureChatOpenAI, AzureOpenAIEmbeddings, ChatOpenAI, OpenAIEmbeddings # Import ChatOpenAI
logger = logging.getLogger(__name__)
def get_chat_model(settings: AppSettings) -> BaseChatModel:
"""
Initializes and returns a chat model based on the backend type and configuration.
Args:
settings: The loaded AppSettings object containing configurations.
Returns:
An instance of BaseChatModel.
Raises:
ValueError: If the backend type is unknown or required configuration is missing.
"""
logger.info(f"Initializing chat model for backend: {settings.chat_backend.value}")
if settings.chat_backend == ChatBackend.azure:
if not settings.azure:
raise ValueError("Azure chat backend selected, but 'azure' configuration section is missing in config.")
if (
not settings.azure.llm_endpoint
or not settings.azure.llm_deployment_name
or not settings.azure.llm_api_version
):
raise ValueError(
"Azure configuration requires 'llm_endpoint', 'llm_deployment_name', and 'llm_api_version'."
)
return AzureChatOpenAI(
azure_endpoint=settings.azure.llm_endpoint,
azure_deployment=settings.azure.llm_deployment_name,
openai_api_version=settings.azure.llm_api_version,
openai_api_key=settings.azure.openai_api_key.get_secret_value() if settings.azure.openai_api_key else None,
)
if settings.chat_backend == ChatBackend.openai:
if not settings.openai:
raise ValueError("OpenAI chat backend selected, but 'openai' configuration section is missing.")
if not settings.openai.api_key or not settings.openai.chat_model:
raise ValueError("OpenAI configuration requires 'api_key' and 'chat_model'.")
logger.info(f"Using OpenAI model: {model_name}")
return ChatOpenAI(model=settings.openai.chat_model, openai_api_key=settings.openai.api_key.get_secret_value())
if settings.chat_backend == ChatBackend.google_vertex:
if not settings.google:
raise ValueError("Google Vertex chat backend selected, but 'google' configuration section is missing.")
if settings.google.chat_model:
model_name = settings.google.chat_model
logger.info(f"Using Google Vertex model: {model_name}")
return ChatVertexAI(
model_name=settings.google.chat_model,
project=settings.google.project_id,
location=settings.google.location,
)
if settings.chat_backend == ChatBackend.aws:
if not settings.aws:
raise ValueError("AWS Bedrock chat backend selected, but 'aws' configuration section is missing.")
model_name = "anthropic.claude-v2" # Example default
if hasattr(settings.aws, "chat_model") and settings.aws.chat_model:
model_name = settings.aws.chat_model
logger.info(f"Using AWS Bedrock model: {model_name}")
return ChatBedrock(
model_id=model_name,
region_name=settings.aws.region_name,
)
if settings.chat_backend == ChatBackend.local:
if not settings.local or not settings.local.chat_model:
raise ValueError("Local chat backend selected, but 'local.chat_model' is missing in config.")
logger.info(f"Using Local Ollama model: {settings.local.chat_model}")
# Base URL can also be configured, e.g., base_url=config.local.ollama_base_url
return ChatOllama(model=settings.local.chat_model)
# This should not be reached if all Enum members are handled
raise ValueError(f"Unknown or unhandled chat backend type: {settings.chat_backend}")
def get_embedding_model(settings: AppSettings) -> Embeddings:
"""
Initializes and returns an embedding model based on the backend type and configuration.
Args:
settings: The loaded AppSettings object containing configurations.
Returns:
An instance of Embeddings.
Raises:
ValueError: If the backend type is unknown or required configuration is missing.
"""
logger.info(f"Initializing embedding model for backend: {settings.emb_backend.value}")
if settings.emb_backend == EmbeddingBackend.azure:
if not settings.azure:
raise ValueError("Azure embedding backend selected, but 'azure' configuration section is missing.")
if (
not settings.azure.emb_endpoint
or not settings.azure.emb_deployment_name
or not settings.azure.emb_api_version
):
raise ValueError(
"Azure configuration requires 'emb_endpoint', 'emb_deployment_name', and 'emb_api_version'."
)
return AzureOpenAIEmbeddings(
azure_endpoint=settings.azure.emb_endpoint,
azure_deployment=settings.azure.emb_deployment_name,
openai_api_version=settings.azure.emb_api_version,
openai_api_key=settings.azure.openai_api_key.get_secret_value() if settings.azure.openai_api_key else None,
)
if settings.emb_backend == EmbeddingBackend.openai:
if not settings.openai:
raise ValueError("OpenAI embedding backend selected, but 'openai' configuration section is missing.")
if not settings.openai.api_key:
raise ValueError("OpenAI configuration requires 'api_key'.")
model_name = "text-embedding-ada-002" # Example default
if hasattr(settings.openai, "emb_model") and settings.openai.emb_model:
model_name = settings.openai.emb_model
logger.info(f"Using OpenAI embedding model: {model_name}")
return OpenAIEmbeddings(model=model_name, openai_api_key=settings.openai.api_key.get_secret_value())
if settings.emb_backend == EmbeddingBackend.google_vertex:
if not settings.google:
raise ValueError("Google Vertex embedding backend selected, but 'google' configuration section is missing.")
model_name = "textembedding-gecko@001" # Example default
if settings.google.emb_model:
model_name = settings.google.emb_model
logger.info(f"Using Google Vertex embedding model: {model_name}")
return VertexAIEmbeddings(
model_name=model_name, project=settings.google.project_id, location=settings.google.location
)
if settings.emb_backend == EmbeddingBackend.aws:
if not settings.aws:
raise ValueError("AWS Bedrock embedding backend selected, but 'aws' configuration section is missing.")
model_name = "amazon.titan-embed-text-v1" # Example default
if hasattr(settings.aws, "emb_model") and settings.aws.emb_model:
model_name = settings.aws.emb_model
logger.info(f"Using AWS Bedrock embedding model: {model_name}")
return BedrockEmbeddings(model_id=model_name, region_name=settings.aws.region_name)
if settings.emb_backend == EmbeddingBackend.local:
if not settings.local or not settings.local.emb_model:
raise ValueError("Local embedding backend selected, but 'local.emb_model' is missing in config.")
logger.info(f"Using Local Ollama embedding model: {settings.local.emb_model}")
return OllamaEmbeddings(model=settings.local.emb_model)
if settings.emb_backend == EmbeddingBackend.huggingface:
if not settings.huggingface or not settings.huggingface.emb_model:
if settings.local and settings.local.emb_model:
logger.warning(
"HuggingFace backend selected, but 'huggingface.emb_model' missing. Using 'local.emb_model'."
)
model_name = settings.local.emb_model
else:
raise ValueError(
"HuggingFace embedding backend selected, but 'huggingface.emb_model' (or 'local.emb_model') is missing in config."
)
else:
model_name = settings.huggingface.emb_model
logger.info(f"Using HuggingFace embedding model: {model_name}")
return HuggingFaceEmbeddings(model_name=model_name)
raise ValueError(f"Unknown or unhandled embedding backend type: {settings.emb_backend}")