diff --git a/generic_rag/app.py b/generic_rag/app.py index 4ecb79a..c3a3fb1 100644 --- a/generic_rag/app.py +++ b/generic_rag/app.py @@ -5,7 +5,7 @@ import os from pathlib import Path import chainlit as cl -from backend.models import BackendType, get_chat_model, get_embedding_model +from backend.models import ChatBackend, EmbeddingBackend, get_chat_model, get_embedding_model from chainlit.cli import run_chainlit from graphs.cond_ret_gen import CondRetGenLangGraph from graphs.ret_gen import RetGenLangGraph @@ -17,12 +17,20 @@ logger = logging.getLogger(__name__) parser = argparse.ArgumentParser(description="A Sogeti Nederland Generic RAG demo.") parser.add_argument( - "-b", - "--backend", - type=BackendType, - choices=list(BackendType), - default=BackendType.azure, - help="Cloud provider to use as backend. In the case of local, Ollama needs to be installed..", + "-c", + "--chat-backend", + type=ChatBackend, + choices=list(ChatBackend), + default=ChatBackend.local, + help="Cloud provider or local LLM to use as backend. In the case of 'local', Ollama needs to be installed.", +) +parser.add_argument( + "-e", + "--emb-backend", + type=EmbeddingBackend, + choices=list(EmbeddingBackend), + default=EmbeddingBackend.huggingface, + help="Cloud provider or local embedding to use as backend. In the case of 'local', Ollama needs to be installed. ", ) parser.add_argument( "-p", @@ -62,7 +70,6 @@ parser.add_argument( ) parser.add_argument("-r", "--reset-chrome-db", action="store_true", help="Reset the Chroma DB.") parser.add_argument( - "-c", "--use-conditional-graph", action="store_true", help="Use the conditial retrieve generate graph over the regular retrieve generate graph.", @@ -71,17 +78,21 @@ args = parser.parse_args() vector_store = Chroma( collection_name="generic_rag", - embedding_function=get_embedding_model(args.backend), + embedding_function=get_embedding_model(args.emb_backend), persist_directory=str(args.chroma_db_location), ) if args.use_conditional_graph: graph = CondRetGenLangGraph( - vector_store, chat_model=get_chat_model(args.backend), embedding_model=get_embedding_model(args.backend) + vector_store, + chat_model=get_chat_model(args.chat_backend), + embedding_model=get_embedding_model(args.emb_backend), ) else: graph = RetGenLangGraph( - vector_store, chat_model=get_chat_model(args.backend), embedding_model=get_embedding_model(args.backend) + vector_store, + chat_model=get_chat_model(args.chat_backend), + embedding_model=get_embedding_model(args.emb_backend), ) @@ -95,7 +106,7 @@ async def on_message(message: cl.Message): async def add_sources(chainlit_response: cl.Message, pdf_sources: dict, web_sources: set | list): if len(pdf_sources) > 0: - await chainlit_response.stream_token("\nThe following PDF source were consulted:\n") + await chainlit_response.stream_token("\n\nThe following PDF source were consulted:\n") for source, page_numbers in pdf_sources.items(): page_numbers = list(page_numbers) page_numbers.sort() @@ -104,7 +115,7 @@ async def add_sources(chainlit_response: cl.Message, pdf_sources: dict, web_sour await chainlit_response.update() await chainlit_response.stream_token(f"- '{source}' on page(s): {page_numbers}\n") if len(web_sources) > 0: - await chainlit_response.stream_token("\nThe following web sources were consulted:\n") + await chainlit_response.stream_token("\n\nThe following web sources were consulted:\n") for source in web_sources: await chainlit_response.stream_token(f"- {source}\n") diff --git a/generic_rag/backend/models.py b/generic_rag/backend/models.py index 15be5cd..1e100ef 100644 --- a/generic_rag/backend/models.py +++ b/generic_rag/backend/models.py @@ -6,13 +6,12 @@ from langchain_aws import BedrockEmbeddings from langchain_core.embeddings import Embeddings from langchain_core.language_models.chat_models import BaseChatModel from langchain_google_vertexai import VertexAIEmbeddings -from langchain_ollama import ChatOllama -from langchain_openai import AzureChatOpenAI, AzureOpenAIEmbeddings -from langchain_openai import OpenAIEmbeddings -from langchain_ollama import OllamaEmbeddings +from langchain_huggingface import HuggingFaceEmbeddings +from langchain_ollama import ChatOllama, OllamaEmbeddings +from langchain_openai import AzureChatOpenAI, AzureOpenAIEmbeddings, OpenAIEmbeddings -class BackendType(Enum): +class ChatBackend(Enum): azure = "azure" openai = "openai" google_vertex = "google_vertex" @@ -24,47 +23,63 @@ class BackendType(Enum): return self.value -def get_chat_model(backend_type: BackendType) -> BaseChatModel: - if backend_type == BackendType.azure: +class EmbeddingBackend(Enum): + azure = "azure" + openai = "openai" + google_vertex = "google_vertex" + aws = "aws" + local = "local" + huggingface = "huggingface" + + # make the enum pretty printable for argparse + def __str__(self): + return self.value + + +def get_chat_model(backend_type: ChatBackend) -> BaseChatModel: + if backend_type == ChatBackend.azure: return AzureChatOpenAI( azure_endpoint=os.environ["AZURE_LLM_ENDPOINT"], azure_deployment=os.environ["AZURE_LLM_DEPLOYMENT_NAME"], openai_api_version=os.environ["AZURE_LLM_API_VERSION"], ) - if backend_type == BackendType.openai: + if backend_type == ChatBackend.openai: return init_chat_model(os.environ["OPENAI_CHAT_MODEL"], model_provider="openai") - if backend_type == BackendType.google_vertex: + if backend_type == ChatBackend.google_vertex: return init_chat_model(os.environ["GOOGLE_CHAT_MODEL"], model_provider="google_vertexai") - if backend_type == BackendType.aws: + if backend_type == ChatBackend.aws: return init_chat_model(model=os.environ["AWS_CHAT_MODEL"], model_provider="bedrock_converse") - if backend_type == BackendType.local: + if backend_type == ChatBackend.local: return ChatOllama(model=os.environ["LOCAL_CHAT_MODEL"]) raise ValueError(f"Unknown backend type: {backend_type}") -def get_embedding_model(backend_type: BackendType) -> Embeddings: - if backend_type == BackendType.azure: +def get_embedding_model(backend_type: EmbeddingBackend) -> Embeddings: + if backend_type == EmbeddingBackend.azure: return AzureOpenAIEmbeddings( azure_endpoint=os.environ["AZURE_EMB_ENDPOINT"], azure_deployment=os.environ["AZURE_EMB_DEPLOYMENT_NAME"], openai_api_version=os.environ["AZURE_EMB_API_VERSION"], ) - if backend_type == BackendType.openai: + if backend_type == EmbeddingBackend.openai: return OpenAIEmbeddings(model=os.environ["OPENAI_EMB_MODEL"]) - if backend_type == BackendType.google_vertex: + if backend_type == EmbeddingBackend.google_vertex: return VertexAIEmbeddings(model=os.environ["GOOGLE_EMB_MODEL"]) - if backend_type == BackendType.aws: + if backend_type == EmbeddingBackend.aws: return BedrockEmbeddings(model_id=os.environ["AWS_EMB_MODEL"]) - if backend_type == BackendType.local: + if backend_type == EmbeddingBackend.local: return OllamaEmbeddings(model=os.environ["LOCAL_EMB_MODEL"]) + if backend_type == EmbeddingBackend.huggingface: + return HuggingFaceEmbeddings(model_name=os.environ["HUGGINGFACE_EMB_MODEL"]) + raise ValueError(f"Unknown backend type: {backend_type}") diff --git a/pyproject.toml b/pyproject.toml index b073287..a0af033 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,6 +13,7 @@ dependencies = [ "langchain-chroma>=0.2.2", "langchain-community>=0.3.19", "langchain-google-vertexai>=2.0.15", + "langchain-huggingface>=0.1.2", "langchain-ollama>=0.2.3", "langchain-openai>=0.3.7", "langchain-text-splitters>=0.3.6",