forked from AI_team/Philosophy-RAG-demo
86 lines
2.9 KiB
Python
86 lines
2.9 KiB
Python
import os
|
|
from enum import Enum
|
|
|
|
from langchain.chat_models import init_chat_model
|
|
from langchain_aws import BedrockEmbeddings
|
|
from langchain_core.embeddings import Embeddings
|
|
from langchain_core.language_models.chat_models import BaseChatModel
|
|
from langchain_google_vertexai import VertexAIEmbeddings
|
|
from langchain_huggingface import HuggingFaceEmbeddings
|
|
from langchain_ollama import ChatOllama, OllamaEmbeddings
|
|
from langchain_openai import AzureChatOpenAI, AzureOpenAIEmbeddings, OpenAIEmbeddings
|
|
|
|
|
|
class ChatBackend(Enum):
|
|
azure = "azure"
|
|
openai = "openai"
|
|
google_vertex = "google_vertex"
|
|
aws = "aws"
|
|
local = "local"
|
|
|
|
# make the enum pretty printable for argparse
|
|
def __str__(self):
|
|
return self.value
|
|
|
|
|
|
class EmbeddingBackend(Enum):
|
|
azure = "azure"
|
|
openai = "openai"
|
|
google_vertex = "google_vertex"
|
|
aws = "aws"
|
|
local = "local"
|
|
huggingface = "huggingface"
|
|
|
|
# make the enum pretty printable for argparse
|
|
def __str__(self):
|
|
return self.value
|
|
|
|
|
|
def get_chat_model(backend_type: ChatBackend) -> BaseChatModel:
|
|
if backend_type == ChatBackend.azure:
|
|
return AzureChatOpenAI(
|
|
azure_endpoint=os.environ["AZURE_LLM_ENDPOINT"],
|
|
azure_deployment=os.environ["AZURE_LLM_DEPLOYMENT_NAME"],
|
|
openai_api_version=os.environ["AZURE_LLM_API_VERSION"],
|
|
)
|
|
|
|
if backend_type == ChatBackend.openai:
|
|
return init_chat_model(os.environ["OPENAI_CHAT_MODEL"], model_provider="openai")
|
|
|
|
if backend_type == ChatBackend.google_vertex:
|
|
return init_chat_model(os.environ["GOOGLE_CHAT_MODEL"], model_provider="google_vertexai")
|
|
|
|
if backend_type == ChatBackend.aws:
|
|
return init_chat_model(model=os.environ["AWS_CHAT_MODEL"], model_provider="bedrock_converse")
|
|
|
|
if backend_type == ChatBackend.local:
|
|
return ChatOllama(model=os.environ["LOCAL_CHAT_MODEL"])
|
|
|
|
raise ValueError(f"Unknown backend type: {backend_type}")
|
|
|
|
|
|
def get_embedding_model(backend_type: EmbeddingBackend) -> Embeddings:
|
|
if backend_type == EmbeddingBackend.azure:
|
|
return AzureOpenAIEmbeddings(
|
|
azure_endpoint=os.environ["AZURE_EMB_ENDPOINT"],
|
|
azure_deployment=os.environ["AZURE_EMB_DEPLOYMENT_NAME"],
|
|
openai_api_version=os.environ["AZURE_EMB_API_VERSION"],
|
|
)
|
|
|
|
if backend_type == EmbeddingBackend.openai:
|
|
return OpenAIEmbeddings(model=os.environ["OPENAI_EMB_MODEL"])
|
|
|
|
if backend_type == EmbeddingBackend.google_vertex:
|
|
return VertexAIEmbeddings(model=os.environ["GOOGLE_EMB_MODEL"])
|
|
|
|
if backend_type == EmbeddingBackend.aws:
|
|
return BedrockEmbeddings(model_id=os.environ["AWS_EMB_MODEL"])
|
|
|
|
if backend_type == EmbeddingBackend.local:
|
|
return OllamaEmbeddings(model=os.environ["LOCAL_EMB_MODEL"])
|
|
|
|
if backend_type == EmbeddingBackend.huggingface:
|
|
return HuggingFaceEmbeddings(model_name=os.environ["HUGGINGFACE_EMB_MODEL"])
|
|
|
|
raise ValueError(f"Unknown backend type: {backend_type}")
|