Merge pull request 'Added support for (local) huggingface embedding models' (#21) from fix_embedding into main

Reviewed-on: AI_team/generic-RAG-demo#21
This commit is contained in:
rubenl 2025-04-11 13:01:29 +02:00
commit 2929ed1e27
3 changed files with 57 additions and 30 deletions

View File

@ -5,7 +5,7 @@ import os
from pathlib import Path
import chainlit as cl
from backend.models import BackendType, get_chat_model, get_embedding_model
from backend.models import ChatBackend, EmbeddingBackend, get_chat_model, get_embedding_model
from chainlit.cli import run_chainlit
from graphs.cond_ret_gen import CondRetGenLangGraph
from graphs.ret_gen import RetGenLangGraph
@ -17,12 +17,20 @@ logger = logging.getLogger(__name__)
parser = argparse.ArgumentParser(description="A Sogeti Nederland Generic RAG demo.")
parser.add_argument(
"-b",
"--backend",
type=BackendType,
choices=list(BackendType),
default=BackendType.azure,
help="Cloud provider to use as backend. In the case of local, Ollama needs to be installed..",
"-c",
"--chat-backend",
type=ChatBackend,
choices=list(ChatBackend),
default=ChatBackend.local,
help="Cloud provider or local LLM to use as backend. In the case of 'local', Ollama needs to be installed.",
)
parser.add_argument(
"-e",
"--emb-backend",
type=EmbeddingBackend,
choices=list(EmbeddingBackend),
default=EmbeddingBackend.huggingface,
help="Cloud provider or local embedding to use as backend. In the case of 'local', Ollama needs to be installed. ",
)
parser.add_argument(
"-p",
@ -62,7 +70,6 @@ parser.add_argument(
)
parser.add_argument("-r", "--reset-chrome-db", action="store_true", help="Reset the Chroma DB.")
parser.add_argument(
"-c",
"--use-conditional-graph",
action="store_true",
help="Use the conditial retrieve generate graph over the regular retrieve generate graph.",
@ -71,17 +78,21 @@ args = parser.parse_args()
vector_store = Chroma(
collection_name="generic_rag",
embedding_function=get_embedding_model(args.backend),
embedding_function=get_embedding_model(args.emb_backend),
persist_directory=str(args.chroma_db_location),
)
if args.use_conditional_graph:
graph = CondRetGenLangGraph(
vector_store, chat_model=get_chat_model(args.backend), embedding_model=get_embedding_model(args.backend)
vector_store,
chat_model=get_chat_model(args.chat_backend),
embedding_model=get_embedding_model(args.emb_backend),
)
else:
graph = RetGenLangGraph(
vector_store, chat_model=get_chat_model(args.backend), embedding_model=get_embedding_model(args.backend)
vector_store,
chat_model=get_chat_model(args.chat_backend),
embedding_model=get_embedding_model(args.emb_backend),
)
@ -95,7 +106,7 @@ async def on_message(message: cl.Message):
async def add_sources(chainlit_response: cl.Message, pdf_sources: dict, web_sources: set | list):
if len(pdf_sources) > 0:
await chainlit_response.stream_token("\nThe following PDF source were consulted:\n")
await chainlit_response.stream_token("\n\nThe following PDF source were consulted:\n")
for source, page_numbers in pdf_sources.items():
page_numbers = list(page_numbers)
page_numbers.sort()
@ -104,7 +115,7 @@ async def add_sources(chainlit_response: cl.Message, pdf_sources: dict, web_sour
await chainlit_response.update()
await chainlit_response.stream_token(f"- '{source}' on page(s): {page_numbers}\n")
if len(web_sources) > 0:
await chainlit_response.stream_token("\nThe following web sources were consulted:\n")
await chainlit_response.stream_token("\n\nThe following web sources were consulted:\n")
for source in web_sources:
await chainlit_response.stream_token(f"- {source}\n")

View File

@ -6,13 +6,12 @@ from langchain_aws import BedrockEmbeddings
from langchain_core.embeddings import Embeddings
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_google_vertexai import VertexAIEmbeddings
from langchain_ollama import ChatOllama
from langchain_openai import AzureChatOpenAI, AzureOpenAIEmbeddings
from langchain_openai import OpenAIEmbeddings
from langchain_ollama import OllamaEmbeddings
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_ollama import ChatOllama, OllamaEmbeddings
from langchain_openai import AzureChatOpenAI, AzureOpenAIEmbeddings, OpenAIEmbeddings
class BackendType(Enum):
class ChatBackend(Enum):
azure = "azure"
openai = "openai"
google_vertex = "google_vertex"
@ -24,47 +23,63 @@ class BackendType(Enum):
return self.value
def get_chat_model(backend_type: BackendType) -> BaseChatModel:
if backend_type == BackendType.azure:
class EmbeddingBackend(Enum):
azure = "azure"
openai = "openai"
google_vertex = "google_vertex"
aws = "aws"
local = "local"
huggingface = "huggingface"
# make the enum pretty printable for argparse
def __str__(self):
return self.value
def get_chat_model(backend_type: ChatBackend) -> BaseChatModel:
if backend_type == ChatBackend.azure:
return AzureChatOpenAI(
azure_endpoint=os.environ["AZURE_LLM_ENDPOINT"],
azure_deployment=os.environ["AZURE_LLM_DEPLOYMENT_NAME"],
openai_api_version=os.environ["AZURE_LLM_API_VERSION"],
)
if backend_type == BackendType.openai:
if backend_type == ChatBackend.openai:
return init_chat_model(os.environ["OPENAI_CHAT_MODEL"], model_provider="openai")
if backend_type == BackendType.google_vertex:
if backend_type == ChatBackend.google_vertex:
return init_chat_model(os.environ["GOOGLE_CHAT_MODEL"], model_provider="google_vertexai")
if backend_type == BackendType.aws:
if backend_type == ChatBackend.aws:
return init_chat_model(model=os.environ["AWS_CHAT_MODEL"], model_provider="bedrock_converse")
if backend_type == BackendType.local:
if backend_type == ChatBackend.local:
return ChatOllama(model=os.environ["LOCAL_CHAT_MODEL"])
raise ValueError(f"Unknown backend type: {backend_type}")
def get_embedding_model(backend_type: BackendType) -> Embeddings:
if backend_type == BackendType.azure:
def get_embedding_model(backend_type: EmbeddingBackend) -> Embeddings:
if backend_type == EmbeddingBackend.azure:
return AzureOpenAIEmbeddings(
azure_endpoint=os.environ["AZURE_EMB_ENDPOINT"],
azure_deployment=os.environ["AZURE_EMB_DEPLOYMENT_NAME"],
openai_api_version=os.environ["AZURE_EMB_API_VERSION"],
)
if backend_type == BackendType.openai:
if backend_type == EmbeddingBackend.openai:
return OpenAIEmbeddings(model=os.environ["OPENAI_EMB_MODEL"])
if backend_type == BackendType.google_vertex:
if backend_type == EmbeddingBackend.google_vertex:
return VertexAIEmbeddings(model=os.environ["GOOGLE_EMB_MODEL"])
if backend_type == BackendType.aws:
if backend_type == EmbeddingBackend.aws:
return BedrockEmbeddings(model_id=os.environ["AWS_EMB_MODEL"])
if backend_type == BackendType.local:
if backend_type == EmbeddingBackend.local:
return OllamaEmbeddings(model=os.environ["LOCAL_EMB_MODEL"])
if backend_type == EmbeddingBackend.huggingface:
return HuggingFaceEmbeddings(model_name=os.environ["HUGGINGFACE_EMB_MODEL"])
raise ValueError(f"Unknown backend type: {backend_type}")

View File

@ -13,6 +13,7 @@ dependencies = [
"langchain-chroma>=0.2.2",
"langchain-community>=0.3.19",
"langchain-google-vertexai>=2.0.15",
"langchain-huggingface>=0.1.2",
"langchain-ollama>=0.2.3",
"langchain-openai>=0.3.7",
"langchain-text-splitters>=0.3.6",