Philosophy-RAG-demo/generic_rag/backend/models.py
2025-04-17 12:42:53 +02:00

98 lines
3.6 KiB
Python

import os
from enum import Enum
from langchain_chroma import Chroma
from langchain.chat_models import init_chat_model
from langchain_aws import BedrockEmbeddings
from langchain_core.embeddings import Embeddings
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.retrievers import BaseRetriever
from langchain_google_vertexai import VertexAIEmbeddings
from langchain_huggingface import HuggingFaceEmbeddings
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import CrossEncoderReranker
from langchain_community.cross_encoders import HuggingFaceCrossEncoder
from langchain_ollama import ChatOllama, OllamaEmbeddings
from langchain_openai import AzureChatOpenAI, AzureOpenAIEmbeddings, OpenAIEmbeddings
class ChatBackend(Enum):
azure = "azure"
openai = "openai"
google_vertex = "google_vertex"
aws = "aws"
local = "local"
# make the enum pretty printable for argparse
def __str__(self):
return self.value
class EmbeddingBackend(Enum):
azure = "azure"
openai = "openai"
google_vertex = "google_vertex"
aws = "aws"
local = "local"
huggingface = "huggingface"
# make the enum pretty printable for argparse
def __str__(self):
return self.value
def get_chat_model(backend_type: ChatBackend) -> BaseChatModel:
if backend_type == ChatBackend.azure:
return AzureChatOpenAI(
azure_endpoint=os.environ["AZURE_LLM_ENDPOINT"],
azure_deployment=os.environ["AZURE_LLM_DEPLOYMENT_NAME"],
openai_api_version=os.environ["AZURE_LLM_API_VERSION"],
)
if backend_type == ChatBackend.openai:
return init_chat_model(os.environ["OPENAI_CHAT_MODEL"], model_provider="openai")
if backend_type == ChatBackend.google_vertex:
return init_chat_model(os.environ["GOOGLE_CHAT_MODEL"], model_provider="google_vertexai")
if backend_type == ChatBackend.aws:
return init_chat_model(model=os.environ["AWS_CHAT_MODEL"], model_provider="bedrock_converse")
if backend_type == ChatBackend.local:
return ChatOllama(model=os.environ["LOCAL_CHAT_MODEL"])
raise ValueError(f"Unknown backend type: {backend_type}")
def get_embedding_model(backend_type: EmbeddingBackend) -> Embeddings:
if backend_type == EmbeddingBackend.azure:
return AzureOpenAIEmbeddings(
azure_endpoint=os.environ["AZURE_EMB_ENDPOINT"],
azure_deployment=os.environ["AZURE_EMB_DEPLOYMENT_NAME"],
openai_api_version=os.environ["AZURE_EMB_API_VERSION"],
)
if backend_type == EmbeddingBackend.openai:
return OpenAIEmbeddings(model=os.environ["OPENAI_EMB_MODEL"])
if backend_type == EmbeddingBackend.google_vertex:
return VertexAIEmbeddings(model=os.environ["GOOGLE_EMB_MODEL"])
if backend_type == EmbeddingBackend.aws:
return BedrockEmbeddings(model_id=os.environ["AWS_EMB_MODEL"])
if backend_type == EmbeddingBackend.local:
return OllamaEmbeddings(model=os.environ["LOCAL_EMB_MODEL"])
if backend_type == EmbeddingBackend.huggingface:
return HuggingFaceEmbeddings(model_name=os.environ["HUGGINGFACE_EMB_MODEL"])
raise ValueError(f"Unknown backend type: {backend_type}")
def get_compression_model(model_name: str, vector_store: Chroma) -> BaseRetriever:
base_retriever = vector_store.as_retriever(search_kwargs={"k": 20})
rerank_model = HuggingFaceCrossEncoder(model_name=model_name)
compressor = CrossEncoderReranker(model=rerank_model, top_n=4)
return ContextualCompressionRetriever(base_compressor=compressor, base_retriever=base_retriever)