forked from AI_team/Philosophy-RAG-demo
Merge pull request 'Added support for (local) huggingface embedding models' (#21) from fix_embedding into main
Reviewed-on: AI_team/generic-RAG-demo#21
This commit is contained in:
commit
2929ed1e27
@ -5,7 +5,7 @@ import os
|
||||
from pathlib import Path
|
||||
|
||||
import chainlit as cl
|
||||
from backend.models import BackendType, get_chat_model, get_embedding_model
|
||||
from backend.models import ChatBackend, EmbeddingBackend, get_chat_model, get_embedding_model
|
||||
from chainlit.cli import run_chainlit
|
||||
from graphs.cond_ret_gen import CondRetGenLangGraph
|
||||
from graphs.ret_gen import RetGenLangGraph
|
||||
@ -17,12 +17,20 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
parser = argparse.ArgumentParser(description="A Sogeti Nederland Generic RAG demo.")
|
||||
parser.add_argument(
|
||||
"-b",
|
||||
"--backend",
|
||||
type=BackendType,
|
||||
choices=list(BackendType),
|
||||
default=BackendType.azure,
|
||||
help="Cloud provider to use as backend. In the case of local, Ollama needs to be installed..",
|
||||
"-c",
|
||||
"--chat-backend",
|
||||
type=ChatBackend,
|
||||
choices=list(ChatBackend),
|
||||
default=ChatBackend.local,
|
||||
help="Cloud provider or local LLM to use as backend. In the case of 'local', Ollama needs to be installed.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-e",
|
||||
"--emb-backend",
|
||||
type=EmbeddingBackend,
|
||||
choices=list(EmbeddingBackend),
|
||||
default=EmbeddingBackend.huggingface,
|
||||
help="Cloud provider or local embedding to use as backend. In the case of 'local', Ollama needs to be installed. ",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-p",
|
||||
@ -62,7 +70,6 @@ parser.add_argument(
|
||||
)
|
||||
parser.add_argument("-r", "--reset-chrome-db", action="store_true", help="Reset the Chroma DB.")
|
||||
parser.add_argument(
|
||||
"-c",
|
||||
"--use-conditional-graph",
|
||||
action="store_true",
|
||||
help="Use the conditial retrieve generate graph over the regular retrieve generate graph.",
|
||||
@ -71,17 +78,21 @@ args = parser.parse_args()
|
||||
|
||||
vector_store = Chroma(
|
||||
collection_name="generic_rag",
|
||||
embedding_function=get_embedding_model(args.backend),
|
||||
embedding_function=get_embedding_model(args.emb_backend),
|
||||
persist_directory=str(args.chroma_db_location),
|
||||
)
|
||||
|
||||
if args.use_conditional_graph:
|
||||
graph = CondRetGenLangGraph(
|
||||
vector_store, chat_model=get_chat_model(args.backend), embedding_model=get_embedding_model(args.backend)
|
||||
vector_store,
|
||||
chat_model=get_chat_model(args.chat_backend),
|
||||
embedding_model=get_embedding_model(args.emb_backend),
|
||||
)
|
||||
else:
|
||||
graph = RetGenLangGraph(
|
||||
vector_store, chat_model=get_chat_model(args.backend), embedding_model=get_embedding_model(args.backend)
|
||||
vector_store,
|
||||
chat_model=get_chat_model(args.chat_backend),
|
||||
embedding_model=get_embedding_model(args.emb_backend),
|
||||
)
|
||||
|
||||
|
||||
@ -95,7 +106,7 @@ async def on_message(message: cl.Message):
|
||||
|
||||
async def add_sources(chainlit_response: cl.Message, pdf_sources: dict, web_sources: set | list):
|
||||
if len(pdf_sources) > 0:
|
||||
await chainlit_response.stream_token("\nThe following PDF source were consulted:\n")
|
||||
await chainlit_response.stream_token("\n\nThe following PDF source were consulted:\n")
|
||||
for source, page_numbers in pdf_sources.items():
|
||||
page_numbers = list(page_numbers)
|
||||
page_numbers.sort()
|
||||
@ -104,7 +115,7 @@ async def add_sources(chainlit_response: cl.Message, pdf_sources: dict, web_sour
|
||||
await chainlit_response.update()
|
||||
await chainlit_response.stream_token(f"- '{source}' on page(s): {page_numbers}\n")
|
||||
if len(web_sources) > 0:
|
||||
await chainlit_response.stream_token("\nThe following web sources were consulted:\n")
|
||||
await chainlit_response.stream_token("\n\nThe following web sources were consulted:\n")
|
||||
for source in web_sources:
|
||||
await chainlit_response.stream_token(f"- {source}\n")
|
||||
|
||||
|
||||
@ -6,13 +6,12 @@ from langchain_aws import BedrockEmbeddings
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_google_vertexai import VertexAIEmbeddings
|
||||
from langchain_ollama import ChatOllama
|
||||
from langchain_openai import AzureChatOpenAI, AzureOpenAIEmbeddings
|
||||
from langchain_openai import OpenAIEmbeddings
|
||||
from langchain_ollama import OllamaEmbeddings
|
||||
from langchain_huggingface import HuggingFaceEmbeddings
|
||||
from langchain_ollama import ChatOllama, OllamaEmbeddings
|
||||
from langchain_openai import AzureChatOpenAI, AzureOpenAIEmbeddings, OpenAIEmbeddings
|
||||
|
||||
|
||||
class BackendType(Enum):
|
||||
class ChatBackend(Enum):
|
||||
azure = "azure"
|
||||
openai = "openai"
|
||||
google_vertex = "google_vertex"
|
||||
@ -24,47 +23,63 @@ class BackendType(Enum):
|
||||
return self.value
|
||||
|
||||
|
||||
def get_chat_model(backend_type: BackendType) -> BaseChatModel:
|
||||
if backend_type == BackendType.azure:
|
||||
class EmbeddingBackend(Enum):
|
||||
azure = "azure"
|
||||
openai = "openai"
|
||||
google_vertex = "google_vertex"
|
||||
aws = "aws"
|
||||
local = "local"
|
||||
huggingface = "huggingface"
|
||||
|
||||
# make the enum pretty printable for argparse
|
||||
def __str__(self):
|
||||
return self.value
|
||||
|
||||
|
||||
def get_chat_model(backend_type: ChatBackend) -> BaseChatModel:
|
||||
if backend_type == ChatBackend.azure:
|
||||
return AzureChatOpenAI(
|
||||
azure_endpoint=os.environ["AZURE_LLM_ENDPOINT"],
|
||||
azure_deployment=os.environ["AZURE_LLM_DEPLOYMENT_NAME"],
|
||||
openai_api_version=os.environ["AZURE_LLM_API_VERSION"],
|
||||
)
|
||||
|
||||
if backend_type == BackendType.openai:
|
||||
if backend_type == ChatBackend.openai:
|
||||
return init_chat_model(os.environ["OPENAI_CHAT_MODEL"], model_provider="openai")
|
||||
|
||||
if backend_type == BackendType.google_vertex:
|
||||
if backend_type == ChatBackend.google_vertex:
|
||||
return init_chat_model(os.environ["GOOGLE_CHAT_MODEL"], model_provider="google_vertexai")
|
||||
|
||||
if backend_type == BackendType.aws:
|
||||
if backend_type == ChatBackend.aws:
|
||||
return init_chat_model(model=os.environ["AWS_CHAT_MODEL"], model_provider="bedrock_converse")
|
||||
|
||||
if backend_type == BackendType.local:
|
||||
if backend_type == ChatBackend.local:
|
||||
return ChatOllama(model=os.environ["LOCAL_CHAT_MODEL"])
|
||||
|
||||
raise ValueError(f"Unknown backend type: {backend_type}")
|
||||
|
||||
|
||||
def get_embedding_model(backend_type: BackendType) -> Embeddings:
|
||||
if backend_type == BackendType.azure:
|
||||
def get_embedding_model(backend_type: EmbeddingBackend) -> Embeddings:
|
||||
if backend_type == EmbeddingBackend.azure:
|
||||
return AzureOpenAIEmbeddings(
|
||||
azure_endpoint=os.environ["AZURE_EMB_ENDPOINT"],
|
||||
azure_deployment=os.environ["AZURE_EMB_DEPLOYMENT_NAME"],
|
||||
openai_api_version=os.environ["AZURE_EMB_API_VERSION"],
|
||||
)
|
||||
|
||||
if backend_type == BackendType.openai:
|
||||
if backend_type == EmbeddingBackend.openai:
|
||||
return OpenAIEmbeddings(model=os.environ["OPENAI_EMB_MODEL"])
|
||||
|
||||
if backend_type == BackendType.google_vertex:
|
||||
if backend_type == EmbeddingBackend.google_vertex:
|
||||
return VertexAIEmbeddings(model=os.environ["GOOGLE_EMB_MODEL"])
|
||||
|
||||
if backend_type == BackendType.aws:
|
||||
if backend_type == EmbeddingBackend.aws:
|
||||
return BedrockEmbeddings(model_id=os.environ["AWS_EMB_MODEL"])
|
||||
|
||||
if backend_type == BackendType.local:
|
||||
if backend_type == EmbeddingBackend.local:
|
||||
return OllamaEmbeddings(model=os.environ["LOCAL_EMB_MODEL"])
|
||||
|
||||
if backend_type == EmbeddingBackend.huggingface:
|
||||
return HuggingFaceEmbeddings(model_name=os.environ["HUGGINGFACE_EMB_MODEL"])
|
||||
|
||||
raise ValueError(f"Unknown backend type: {backend_type}")
|
||||
|
||||
@ -13,6 +13,7 @@ dependencies = [
|
||||
"langchain-chroma>=0.2.2",
|
||||
"langchain-community>=0.3.19",
|
||||
"langchain-google-vertexai>=2.0.15",
|
||||
"langchain-huggingface>=0.1.2",
|
||||
"langchain-ollama>=0.2.3",
|
||||
"langchain-openai>=0.3.7",
|
||||
"langchain-text-splitters>=0.3.6",
|
||||
|
||||
Loading…
Reference in New Issue
Block a user