forked from AI_team/Philosophy-RAG-demo
Added support for (local) huggingface embedding models
This commit is contained in:
parent
ab78fdc0c7
commit
1d4e99459e
@ -5,7 +5,7 @@ import os
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import chainlit as cl
|
import chainlit as cl
|
||||||
from backend.models import BackendType, get_chat_model, get_embedding_model
|
from backend.models import ChatBackend, EmbeddingBackend, get_chat_model, get_embedding_model
|
||||||
from chainlit.cli import run_chainlit
|
from chainlit.cli import run_chainlit
|
||||||
from graphs.cond_ret_gen import CondRetGenLangGraph
|
from graphs.cond_ret_gen import CondRetGenLangGraph
|
||||||
from graphs.ret_gen import RetGenLangGraph
|
from graphs.ret_gen import RetGenLangGraph
|
||||||
@ -17,12 +17,20 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
parser = argparse.ArgumentParser(description="A Sogeti Nederland Generic RAG demo.")
|
parser = argparse.ArgumentParser(description="A Sogeti Nederland Generic RAG demo.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"-b",
|
"-c",
|
||||||
"--backend",
|
"--chat-backend",
|
||||||
type=BackendType,
|
type=ChatBackend,
|
||||||
choices=list(BackendType),
|
choices=list(ChatBackend),
|
||||||
default=BackendType.azure,
|
default=ChatBackend.local,
|
||||||
help="Cloud provider to use as backend. In the case of local, Ollama needs to be installed..",
|
help="Cloud provider or local LLM to use as backend. In the case of 'local', Ollama needs to be installed.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-e",
|
||||||
|
"--emb-backend",
|
||||||
|
type=EmbeddingBackend,
|
||||||
|
choices=list(EmbeddingBackend),
|
||||||
|
default=EmbeddingBackend.huggingface,
|
||||||
|
help="Cloud provider or local embedding to use as backend. In the case of 'local', Ollama needs to be installed. ",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"-p",
|
"-p",
|
||||||
@ -62,7 +70,6 @@ parser.add_argument(
|
|||||||
)
|
)
|
||||||
parser.add_argument("-r", "--reset-chrome-db", action="store_true", help="Reset the Chroma DB.")
|
parser.add_argument("-r", "--reset-chrome-db", action="store_true", help="Reset the Chroma DB.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"-c",
|
|
||||||
"--use-conditional-graph",
|
"--use-conditional-graph",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="Use the conditial retrieve generate graph over the regular retrieve generate graph.",
|
help="Use the conditial retrieve generate graph over the regular retrieve generate graph.",
|
||||||
@ -71,17 +78,21 @@ args = parser.parse_args()
|
|||||||
|
|
||||||
vector_store = Chroma(
|
vector_store = Chroma(
|
||||||
collection_name="generic_rag",
|
collection_name="generic_rag",
|
||||||
embedding_function=get_embedding_model(args.backend),
|
embedding_function=get_embedding_model(args.emb_backend),
|
||||||
persist_directory=str(args.chroma_db_location),
|
persist_directory=str(args.chroma_db_location),
|
||||||
)
|
)
|
||||||
|
|
||||||
if args.use_conditional_graph:
|
if args.use_conditional_graph:
|
||||||
graph = CondRetGenLangGraph(
|
graph = CondRetGenLangGraph(
|
||||||
vector_store, chat_model=get_chat_model(args.backend), embedding_model=get_embedding_model(args.backend)
|
vector_store,
|
||||||
|
chat_model=get_chat_model(args.chat_backend),
|
||||||
|
embedding_model=get_embedding_model(args.emb_backend),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
graph = RetGenLangGraph(
|
graph = RetGenLangGraph(
|
||||||
vector_store, chat_model=get_chat_model(args.backend), embedding_model=get_embedding_model(args.backend)
|
vector_store,
|
||||||
|
chat_model=get_chat_model(args.chat_backend),
|
||||||
|
embedding_model=get_embedding_model(args.emb_backend),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -95,7 +106,7 @@ async def on_message(message: cl.Message):
|
|||||||
|
|
||||||
async def add_sources(chainlit_response: cl.Message, pdf_sources: dict, web_sources: set | list):
|
async def add_sources(chainlit_response: cl.Message, pdf_sources: dict, web_sources: set | list):
|
||||||
if len(pdf_sources) > 0:
|
if len(pdf_sources) > 0:
|
||||||
await chainlit_response.stream_token("\nThe following PDF source were consulted:\n")
|
await chainlit_response.stream_token("\n\nThe following PDF source were consulted:\n")
|
||||||
for source, page_numbers in pdf_sources.items():
|
for source, page_numbers in pdf_sources.items():
|
||||||
page_numbers = list(page_numbers)
|
page_numbers = list(page_numbers)
|
||||||
page_numbers.sort()
|
page_numbers.sort()
|
||||||
@ -104,7 +115,7 @@ async def add_sources(chainlit_response: cl.Message, pdf_sources: dict, web_sour
|
|||||||
await chainlit_response.update()
|
await chainlit_response.update()
|
||||||
await chainlit_response.stream_token(f"- '{source}' on page(s): {page_numbers}\n")
|
await chainlit_response.stream_token(f"- '{source}' on page(s): {page_numbers}\n")
|
||||||
if len(web_sources) > 0:
|
if len(web_sources) > 0:
|
||||||
await chainlit_response.stream_token("\nThe following web sources were consulted:\n")
|
await chainlit_response.stream_token("\n\nThe following web sources were consulted:\n")
|
||||||
for source in web_sources:
|
for source in web_sources:
|
||||||
await chainlit_response.stream_token(f"- {source}\n")
|
await chainlit_response.stream_token(f"- {source}\n")
|
||||||
|
|
||||||
|
|||||||
@ -6,13 +6,12 @@ from langchain_aws import BedrockEmbeddings
|
|||||||
from langchain_core.embeddings import Embeddings
|
from langchain_core.embeddings import Embeddings
|
||||||
from langchain_core.language_models.chat_models import BaseChatModel
|
from langchain_core.language_models.chat_models import BaseChatModel
|
||||||
from langchain_google_vertexai import VertexAIEmbeddings
|
from langchain_google_vertexai import VertexAIEmbeddings
|
||||||
from langchain_ollama import ChatOllama
|
from langchain_huggingface import HuggingFaceEmbeddings
|
||||||
from langchain_openai import AzureChatOpenAI, AzureOpenAIEmbeddings
|
from langchain_ollama import ChatOllama, OllamaEmbeddings
|
||||||
from langchain_openai import OpenAIEmbeddings
|
from langchain_openai import AzureChatOpenAI, AzureOpenAIEmbeddings, OpenAIEmbeddings
|
||||||
from langchain_ollama import OllamaEmbeddings
|
|
||||||
|
|
||||||
|
|
||||||
class BackendType(Enum):
|
class ChatBackend(Enum):
|
||||||
azure = "azure"
|
azure = "azure"
|
||||||
openai = "openai"
|
openai = "openai"
|
||||||
google_vertex = "google_vertex"
|
google_vertex = "google_vertex"
|
||||||
@ -24,47 +23,63 @@ class BackendType(Enum):
|
|||||||
return self.value
|
return self.value
|
||||||
|
|
||||||
|
|
||||||
def get_chat_model(backend_type: BackendType) -> BaseChatModel:
|
class EmbeddingBackend(Enum):
|
||||||
if backend_type == BackendType.azure:
|
azure = "azure"
|
||||||
|
openai = "openai"
|
||||||
|
google_vertex = "google_vertex"
|
||||||
|
aws = "aws"
|
||||||
|
local = "local"
|
||||||
|
huggingface = "huggingface"
|
||||||
|
|
||||||
|
# make the enum pretty printable for argparse
|
||||||
|
def __str__(self):
|
||||||
|
return self.value
|
||||||
|
|
||||||
|
|
||||||
|
def get_chat_model(backend_type: ChatBackend) -> BaseChatModel:
|
||||||
|
if backend_type == ChatBackend.azure:
|
||||||
return AzureChatOpenAI(
|
return AzureChatOpenAI(
|
||||||
azure_endpoint=os.environ["AZURE_LLM_ENDPOINT"],
|
azure_endpoint=os.environ["AZURE_LLM_ENDPOINT"],
|
||||||
azure_deployment=os.environ["AZURE_LLM_DEPLOYMENT_NAME"],
|
azure_deployment=os.environ["AZURE_LLM_DEPLOYMENT_NAME"],
|
||||||
openai_api_version=os.environ["AZURE_LLM_API_VERSION"],
|
openai_api_version=os.environ["AZURE_LLM_API_VERSION"],
|
||||||
)
|
)
|
||||||
|
|
||||||
if backend_type == BackendType.openai:
|
if backend_type == ChatBackend.openai:
|
||||||
return init_chat_model(os.environ["OPENAI_CHAT_MODEL"], model_provider="openai")
|
return init_chat_model(os.environ["OPENAI_CHAT_MODEL"], model_provider="openai")
|
||||||
|
|
||||||
if backend_type == BackendType.google_vertex:
|
if backend_type == ChatBackend.google_vertex:
|
||||||
return init_chat_model(os.environ["GOOGLE_CHAT_MODEL"], model_provider="google_vertexai")
|
return init_chat_model(os.environ["GOOGLE_CHAT_MODEL"], model_provider="google_vertexai")
|
||||||
|
|
||||||
if backend_type == BackendType.aws:
|
if backend_type == ChatBackend.aws:
|
||||||
return init_chat_model(model=os.environ["AWS_CHAT_MODEL"], model_provider="bedrock_converse")
|
return init_chat_model(model=os.environ["AWS_CHAT_MODEL"], model_provider="bedrock_converse")
|
||||||
|
|
||||||
if backend_type == BackendType.local:
|
if backend_type == ChatBackend.local:
|
||||||
return ChatOllama(model=os.environ["LOCAL_CHAT_MODEL"])
|
return ChatOllama(model=os.environ["LOCAL_CHAT_MODEL"])
|
||||||
|
|
||||||
raise ValueError(f"Unknown backend type: {backend_type}")
|
raise ValueError(f"Unknown backend type: {backend_type}")
|
||||||
|
|
||||||
|
|
||||||
def get_embedding_model(backend_type: BackendType) -> Embeddings:
|
def get_embedding_model(backend_type: EmbeddingBackend) -> Embeddings:
|
||||||
if backend_type == BackendType.azure:
|
if backend_type == EmbeddingBackend.azure:
|
||||||
return AzureOpenAIEmbeddings(
|
return AzureOpenAIEmbeddings(
|
||||||
azure_endpoint=os.environ["AZURE_EMB_ENDPOINT"],
|
azure_endpoint=os.environ["AZURE_EMB_ENDPOINT"],
|
||||||
azure_deployment=os.environ["AZURE_EMB_DEPLOYMENT_NAME"],
|
azure_deployment=os.environ["AZURE_EMB_DEPLOYMENT_NAME"],
|
||||||
openai_api_version=os.environ["AZURE_EMB_API_VERSION"],
|
openai_api_version=os.environ["AZURE_EMB_API_VERSION"],
|
||||||
)
|
)
|
||||||
|
|
||||||
if backend_type == BackendType.openai:
|
if backend_type == EmbeddingBackend.openai:
|
||||||
return OpenAIEmbeddings(model=os.environ["OPENAI_EMB_MODEL"])
|
return OpenAIEmbeddings(model=os.environ["OPENAI_EMB_MODEL"])
|
||||||
|
|
||||||
if backend_type == BackendType.google_vertex:
|
if backend_type == EmbeddingBackend.google_vertex:
|
||||||
return VertexAIEmbeddings(model=os.environ["GOOGLE_EMB_MODEL"])
|
return VertexAIEmbeddings(model=os.environ["GOOGLE_EMB_MODEL"])
|
||||||
|
|
||||||
if backend_type == BackendType.aws:
|
if backend_type == EmbeddingBackend.aws:
|
||||||
return BedrockEmbeddings(model_id=os.environ["AWS_EMB_MODEL"])
|
return BedrockEmbeddings(model_id=os.environ["AWS_EMB_MODEL"])
|
||||||
|
|
||||||
if backend_type == BackendType.local:
|
if backend_type == EmbeddingBackend.local:
|
||||||
return OllamaEmbeddings(model=os.environ["LOCAL_EMB_MODEL"])
|
return OllamaEmbeddings(model=os.environ["LOCAL_EMB_MODEL"])
|
||||||
|
|
||||||
|
if backend_type == EmbeddingBackend.huggingface:
|
||||||
|
return HuggingFaceEmbeddings(model_name=os.environ["HUGGINGFACE_EMB_MODEL"])
|
||||||
|
|
||||||
raise ValueError(f"Unknown backend type: {backend_type}")
|
raise ValueError(f"Unknown backend type: {backend_type}")
|
||||||
|
|||||||
@ -13,6 +13,7 @@ dependencies = [
|
|||||||
"langchain-chroma>=0.2.2",
|
"langchain-chroma>=0.2.2",
|
||||||
"langchain-community>=0.3.19",
|
"langchain-community>=0.3.19",
|
||||||
"langchain-google-vertexai>=2.0.15",
|
"langchain-google-vertexai>=2.0.15",
|
||||||
|
"langchain-huggingface>=0.1.2",
|
||||||
"langchain-ollama>=0.2.3",
|
"langchain-ollama>=0.2.3",
|
||||||
"langchain-openai>=0.3.7",
|
"langchain-openai>=0.3.7",
|
||||||
"langchain-text-splitters>=0.3.6",
|
"langchain-text-splitters>=0.3.6",
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user