import os from enum import Enum from langchain.chat_models import init_chat_model from langchain_aws import BedrockEmbeddings from langchain_core.embeddings import Embeddings from langchain_core.language_models.chat_models import BaseChatModel from langchain_google_vertexai import VertexAIEmbeddings from langchain_huggingface import HuggingFaceEmbeddings from langchain_ollama import ChatOllama, OllamaEmbeddings from langchain_openai import AzureChatOpenAI, AzureOpenAIEmbeddings, OpenAIEmbeddings class ChatBackend(Enum): azure = "azure" openai = "openai" google_vertex = "google_vertex" aws = "aws" local = "local" # make the enum pretty printable for argparse def __str__(self): return self.value class EmbeddingBackend(Enum): azure = "azure" openai = "openai" google_vertex = "google_vertex" aws = "aws" local = "local" huggingface = "huggingface" # make the enum pretty printable for argparse def __str__(self): return self.value def get_chat_model(backend_type: ChatBackend) -> BaseChatModel: if backend_type == ChatBackend.azure: return AzureChatOpenAI( azure_endpoint=os.environ["AZURE_LLM_ENDPOINT"], azure_deployment=os.environ["AZURE_LLM_DEPLOYMENT_NAME"], openai_api_version=os.environ["AZURE_LLM_API_VERSION"], ) if backend_type == ChatBackend.openai: return init_chat_model(os.environ["OPENAI_CHAT_MODEL"], model_provider="openai") if backend_type == ChatBackend.google_vertex: return init_chat_model(os.environ["GOOGLE_CHAT_MODEL"], model_provider="google_vertexai") if backend_type == ChatBackend.aws: return init_chat_model(model=os.environ["AWS_CHAT_MODEL"], model_provider="bedrock_converse") if backend_type == ChatBackend.local: return ChatOllama(model=os.environ["LOCAL_CHAT_MODEL"]) raise ValueError(f"Unknown backend type: {backend_type}") def get_embedding_model(backend_type: EmbeddingBackend) -> Embeddings: if backend_type == EmbeddingBackend.azure: return AzureOpenAIEmbeddings( azure_endpoint=os.environ["AZURE_EMB_ENDPOINT"], azure_deployment=os.environ["AZURE_EMB_DEPLOYMENT_NAME"], openai_api_version=os.environ["AZURE_EMB_API_VERSION"], ) if backend_type == EmbeddingBackend.openai: return OpenAIEmbeddings(model=os.environ["OPENAI_EMB_MODEL"]) if backend_type == EmbeddingBackend.google_vertex: return VertexAIEmbeddings(model=os.environ["GOOGLE_EMB_MODEL"]) if backend_type == EmbeddingBackend.aws: return BedrockEmbeddings(model_id=os.environ["AWS_EMB_MODEL"]) if backend_type == EmbeddingBackend.local: return OllamaEmbeddings(model=os.environ["LOCAL_EMB_MODEL"]) if backend_type == EmbeddingBackend.huggingface: return HuggingFaceEmbeddings(model_name=os.environ["HUGGINGFACE_EMB_MODEL"]) raise ValueError(f"Unknown backend type: {backend_type}")