import os from enum import Enum from langchain.chat_models import init_chat_model from langchain_aws import BedrockEmbeddings from langchain_core.embeddings import Embeddings from langchain_core.language_models.chat_models import BaseChatModel from langchain_google_vertexai import VertexAIEmbeddings from langchain_ollama import ChatOllama from langchain_openai import AzureChatOpenAI, AzureOpenAIEmbeddings from langchain_openai import OpenAIEmbeddings from langchain_ollama import OllamaEmbeddings class BackendType(Enum): azure = "azure" openai = "openai" google_vertex = "google_vertex" aws = "aws" local = "local" # make the enum pretty printable for argparse def __str__(self): return self.value def get_chat_model(backend_type: BackendType) -> BaseChatModel: if backend_type == BackendType.azure: return AzureChatOpenAI( azure_endpoint=os.environ["AZURE_LLM_ENDPOINT"], azure_deployment=os.environ["AZURE_LLM_DEPLOYMENT_NAME"], openai_api_version=os.environ["AZURE_LLM_API_VERSION"], ) if backend_type == BackendType.openai: return init_chat_model(os.environ["OPENAI_CHAT_MODEL"], model_provider="openai") if backend_type == BackendType.google_vertex: return init_chat_model(os.environ["GOOGLE_CHAT_MODEL"], model_provider="google_vertexai") if backend_type == BackendType.aws: return init_chat_model(model=os.environ["AWS_CHAT_MODEL"], model_provider="bedrock_converse") if backend_type == BackendType.local: return ChatOllama(model=os.environ["LOCAL_CHAT_MODEL"]) raise ValueError(f"Unknown backend type: {backend_type}") def get_embedding_model(backend_type: BackendType) -> Embeddings: if backend_type == BackendType.azure: return AzureOpenAIEmbeddings( azure_endpoint=os.environ["AZURE_EMB_ENDPOINT"], azure_deployment=os.environ["AZURE_EMB_DEPLOYMENT_NAME"], openai_api_version=os.environ["AZURE_EMB_API_VERSION"], ) if backend_type == BackendType.openai: return OpenAIEmbeddings(model=os.environ["OPENAI_EMB_MODEL"]) if backend_type == BackendType.google_vertex: return VertexAIEmbeddings(model=os.environ["GOOGLE_EMB_MODEL"]) if backend_type == BackendType.aws: return BedrockEmbeddings(model_id=os.environ["AWS_EMB_MODEL"]) if backend_type == BackendType.local: return OllamaEmbeddings(model=os.environ["LOCAL_EMB_MODEL"]) raise ValueError(f"Unknown backend type: {backend_type}")