Merge pull request 'Added support for (local) huggingface embedding models' (#21) from fix_embedding into main

Reviewed-on: AI_team/generic-RAG-demo#21
This commit is contained in:
rubenl 2025-04-11 13:01:29 +02:00
commit 2929ed1e27
3 changed files with 57 additions and 30 deletions

View File

@ -5,7 +5,7 @@ import os
from pathlib import Path from pathlib import Path
import chainlit as cl import chainlit as cl
from backend.models import BackendType, get_chat_model, get_embedding_model from backend.models import ChatBackend, EmbeddingBackend, get_chat_model, get_embedding_model
from chainlit.cli import run_chainlit from chainlit.cli import run_chainlit
from graphs.cond_ret_gen import CondRetGenLangGraph from graphs.cond_ret_gen import CondRetGenLangGraph
from graphs.ret_gen import RetGenLangGraph from graphs.ret_gen import RetGenLangGraph
@ -17,12 +17,20 @@ logger = logging.getLogger(__name__)
parser = argparse.ArgumentParser(description="A Sogeti Nederland Generic RAG demo.") parser = argparse.ArgumentParser(description="A Sogeti Nederland Generic RAG demo.")
parser.add_argument( parser.add_argument(
"-b", "-c",
"--backend", "--chat-backend",
type=BackendType, type=ChatBackend,
choices=list(BackendType), choices=list(ChatBackend),
default=BackendType.azure, default=ChatBackend.local,
help="Cloud provider to use as backend. In the case of local, Ollama needs to be installed..", help="Cloud provider or local LLM to use as backend. In the case of 'local', Ollama needs to be installed.",
)
parser.add_argument(
"-e",
"--emb-backend",
type=EmbeddingBackend,
choices=list(EmbeddingBackend),
default=EmbeddingBackend.huggingface,
help="Cloud provider or local embedding to use as backend. In the case of 'local', Ollama needs to be installed. ",
) )
parser.add_argument( parser.add_argument(
"-p", "-p",
@ -62,7 +70,6 @@ parser.add_argument(
) )
parser.add_argument("-r", "--reset-chrome-db", action="store_true", help="Reset the Chroma DB.") parser.add_argument("-r", "--reset-chrome-db", action="store_true", help="Reset the Chroma DB.")
parser.add_argument( parser.add_argument(
"-c",
"--use-conditional-graph", "--use-conditional-graph",
action="store_true", action="store_true",
help="Use the conditial retrieve generate graph over the regular retrieve generate graph.", help="Use the conditial retrieve generate graph over the regular retrieve generate graph.",
@ -71,17 +78,21 @@ args = parser.parse_args()
vector_store = Chroma( vector_store = Chroma(
collection_name="generic_rag", collection_name="generic_rag",
embedding_function=get_embedding_model(args.backend), embedding_function=get_embedding_model(args.emb_backend),
persist_directory=str(args.chroma_db_location), persist_directory=str(args.chroma_db_location),
) )
if args.use_conditional_graph: if args.use_conditional_graph:
graph = CondRetGenLangGraph( graph = CondRetGenLangGraph(
vector_store, chat_model=get_chat_model(args.backend), embedding_model=get_embedding_model(args.backend) vector_store,
chat_model=get_chat_model(args.chat_backend),
embedding_model=get_embedding_model(args.emb_backend),
) )
else: else:
graph = RetGenLangGraph( graph = RetGenLangGraph(
vector_store, chat_model=get_chat_model(args.backend), embedding_model=get_embedding_model(args.backend) vector_store,
chat_model=get_chat_model(args.chat_backend),
embedding_model=get_embedding_model(args.emb_backend),
) )
@ -95,7 +106,7 @@ async def on_message(message: cl.Message):
async def add_sources(chainlit_response: cl.Message, pdf_sources: dict, web_sources: set | list): async def add_sources(chainlit_response: cl.Message, pdf_sources: dict, web_sources: set | list):
if len(pdf_sources) > 0: if len(pdf_sources) > 0:
await chainlit_response.stream_token("\nThe following PDF source were consulted:\n") await chainlit_response.stream_token("\n\nThe following PDF source were consulted:\n")
for source, page_numbers in pdf_sources.items(): for source, page_numbers in pdf_sources.items():
page_numbers = list(page_numbers) page_numbers = list(page_numbers)
page_numbers.sort() page_numbers.sort()
@ -104,7 +115,7 @@ async def add_sources(chainlit_response: cl.Message, pdf_sources: dict, web_sour
await chainlit_response.update() await chainlit_response.update()
await chainlit_response.stream_token(f"- '{source}' on page(s): {page_numbers}\n") await chainlit_response.stream_token(f"- '{source}' on page(s): {page_numbers}\n")
if len(web_sources) > 0: if len(web_sources) > 0:
await chainlit_response.stream_token("\nThe following web sources were consulted:\n") await chainlit_response.stream_token("\n\nThe following web sources were consulted:\n")
for source in web_sources: for source in web_sources:
await chainlit_response.stream_token(f"- {source}\n") await chainlit_response.stream_token(f"- {source}\n")

View File

@ -6,13 +6,12 @@ from langchain_aws import BedrockEmbeddings
from langchain_core.embeddings import Embeddings from langchain_core.embeddings import Embeddings
from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.language_models.chat_models import BaseChatModel
from langchain_google_vertexai import VertexAIEmbeddings from langchain_google_vertexai import VertexAIEmbeddings
from langchain_ollama import ChatOllama from langchain_huggingface import HuggingFaceEmbeddings
from langchain_openai import AzureChatOpenAI, AzureOpenAIEmbeddings from langchain_ollama import ChatOllama, OllamaEmbeddings
from langchain_openai import OpenAIEmbeddings from langchain_openai import AzureChatOpenAI, AzureOpenAIEmbeddings, OpenAIEmbeddings
from langchain_ollama import OllamaEmbeddings
class BackendType(Enum): class ChatBackend(Enum):
azure = "azure" azure = "azure"
openai = "openai" openai = "openai"
google_vertex = "google_vertex" google_vertex = "google_vertex"
@ -24,47 +23,63 @@ class BackendType(Enum):
return self.value return self.value
def get_chat_model(backend_type: BackendType) -> BaseChatModel: class EmbeddingBackend(Enum):
if backend_type == BackendType.azure: azure = "azure"
openai = "openai"
google_vertex = "google_vertex"
aws = "aws"
local = "local"
huggingface = "huggingface"
# make the enum pretty printable for argparse
def __str__(self):
return self.value
def get_chat_model(backend_type: ChatBackend) -> BaseChatModel:
if backend_type == ChatBackend.azure:
return AzureChatOpenAI( return AzureChatOpenAI(
azure_endpoint=os.environ["AZURE_LLM_ENDPOINT"], azure_endpoint=os.environ["AZURE_LLM_ENDPOINT"],
azure_deployment=os.environ["AZURE_LLM_DEPLOYMENT_NAME"], azure_deployment=os.environ["AZURE_LLM_DEPLOYMENT_NAME"],
openai_api_version=os.environ["AZURE_LLM_API_VERSION"], openai_api_version=os.environ["AZURE_LLM_API_VERSION"],
) )
if backend_type == BackendType.openai: if backend_type == ChatBackend.openai:
return init_chat_model(os.environ["OPENAI_CHAT_MODEL"], model_provider="openai") return init_chat_model(os.environ["OPENAI_CHAT_MODEL"], model_provider="openai")
if backend_type == BackendType.google_vertex: if backend_type == ChatBackend.google_vertex:
return init_chat_model(os.environ["GOOGLE_CHAT_MODEL"], model_provider="google_vertexai") return init_chat_model(os.environ["GOOGLE_CHAT_MODEL"], model_provider="google_vertexai")
if backend_type == BackendType.aws: if backend_type == ChatBackend.aws:
return init_chat_model(model=os.environ["AWS_CHAT_MODEL"], model_provider="bedrock_converse") return init_chat_model(model=os.environ["AWS_CHAT_MODEL"], model_provider="bedrock_converse")
if backend_type == BackendType.local: if backend_type == ChatBackend.local:
return ChatOllama(model=os.environ["LOCAL_CHAT_MODEL"]) return ChatOllama(model=os.environ["LOCAL_CHAT_MODEL"])
raise ValueError(f"Unknown backend type: {backend_type}") raise ValueError(f"Unknown backend type: {backend_type}")
def get_embedding_model(backend_type: BackendType) -> Embeddings: def get_embedding_model(backend_type: EmbeddingBackend) -> Embeddings:
if backend_type == BackendType.azure: if backend_type == EmbeddingBackend.azure:
return AzureOpenAIEmbeddings( return AzureOpenAIEmbeddings(
azure_endpoint=os.environ["AZURE_EMB_ENDPOINT"], azure_endpoint=os.environ["AZURE_EMB_ENDPOINT"],
azure_deployment=os.environ["AZURE_EMB_DEPLOYMENT_NAME"], azure_deployment=os.environ["AZURE_EMB_DEPLOYMENT_NAME"],
openai_api_version=os.environ["AZURE_EMB_API_VERSION"], openai_api_version=os.environ["AZURE_EMB_API_VERSION"],
) )
if backend_type == BackendType.openai: if backend_type == EmbeddingBackend.openai:
return OpenAIEmbeddings(model=os.environ["OPENAI_EMB_MODEL"]) return OpenAIEmbeddings(model=os.environ["OPENAI_EMB_MODEL"])
if backend_type == BackendType.google_vertex: if backend_type == EmbeddingBackend.google_vertex:
return VertexAIEmbeddings(model=os.environ["GOOGLE_EMB_MODEL"]) return VertexAIEmbeddings(model=os.environ["GOOGLE_EMB_MODEL"])
if backend_type == BackendType.aws: if backend_type == EmbeddingBackend.aws:
return BedrockEmbeddings(model_id=os.environ["AWS_EMB_MODEL"]) return BedrockEmbeddings(model_id=os.environ["AWS_EMB_MODEL"])
if backend_type == BackendType.local: if backend_type == EmbeddingBackend.local:
return OllamaEmbeddings(model=os.environ["LOCAL_EMB_MODEL"]) return OllamaEmbeddings(model=os.environ["LOCAL_EMB_MODEL"])
if backend_type == EmbeddingBackend.huggingface:
return HuggingFaceEmbeddings(model_name=os.environ["HUGGINGFACE_EMB_MODEL"])
raise ValueError(f"Unknown backend type: {backend_type}") raise ValueError(f"Unknown backend type: {backend_type}")

View File

@ -13,6 +13,7 @@ dependencies = [
"langchain-chroma>=0.2.2", "langchain-chroma>=0.2.2",
"langchain-community>=0.3.19", "langchain-community>=0.3.19",
"langchain-google-vertexai>=2.0.15", "langchain-google-vertexai>=2.0.15",
"langchain-huggingface>=0.1.2",
"langchain-ollama>=0.2.3", "langchain-ollama>=0.2.3",
"langchain-openai>=0.3.7", "langchain-openai>=0.3.7",
"langchain-text-splitters>=0.3.6", "langchain-text-splitters>=0.3.6",