forked from AI_team/Philosophy-RAG-demo
187 lines
6.4 KiB
Python
187 lines
6.4 KiB
Python
import argparse
|
|
import json
|
|
import logging
|
|
import os
|
|
from pathlib import Path
|
|
|
|
import chainlit as cl
|
|
from chainlit.cli import run_chainlit
|
|
from langchain_chroma import Chroma
|
|
|
|
from generic_rag.backend.models import (
|
|
ChatBackend,
|
|
EmbeddingBackend,
|
|
get_chat_model,
|
|
get_embedding_model,
|
|
get_compression_model,
|
|
)
|
|
from generic_rag.graphs.cond_ret_gen import CondRetGenLangGraph
|
|
from generic_rag.graphs.ret_gen import RetGenLangGraph
|
|
from generic_rag.parsers.parser import add_pdf_files, add_urls
|
|
|
|
logger = logging.getLogger("sogeti-rag")
|
|
logger.setLevel(logging.DEBUG)
|
|
|
|
system_prompt = (
|
|
"You are an assistant for question-answering tasks. "
|
|
"If the question is in Dutch, answer in Dutch. If the question is in English, answer in English."
|
|
"Use the following pieces of retrieved context to answer the question. "
|
|
"If you don't know the answer, say that you don't know."
|
|
)
|
|
|
|
parser = argparse.ArgumentParser(description="A Sogeti Nederland Generic RAG demo.")
|
|
parser.add_argument(
|
|
"-c",
|
|
"--chat-backend",
|
|
type=ChatBackend,
|
|
choices=list(ChatBackend),
|
|
default=ChatBackend.local,
|
|
help="Cloud provider or local LLM to use as backend. In the case of 'local', Ollama needs to be installed.",
|
|
)
|
|
parser.add_argument(
|
|
"-e",
|
|
"--emb-backend",
|
|
type=EmbeddingBackend,
|
|
choices=list(EmbeddingBackend),
|
|
default=EmbeddingBackend.huggingface,
|
|
help="Cloud provider or local embedding to use as backend. In the case of 'local', Ollama needs to be installed. ",
|
|
)
|
|
parser.add_argument(
|
|
"-p",
|
|
"--pdf-data",
|
|
type=Path,
|
|
nargs="+",
|
|
default=[],
|
|
help="One or multiple paths to folders or files to use for retrieval. "
|
|
"If a path is a folder, all files in the folder will be used. "
|
|
"If a path is a file, only that file will be used. "
|
|
"If the path is relative it will be relative to the current working directory.",
|
|
)
|
|
parser.add_argument(
|
|
"-u",
|
|
"--unstructured-pdf",
|
|
action="store_true",
|
|
help="Use an unstructered PDF loader. "
|
|
"An unstructured PDF loader might be usefull for PDF files "
|
|
"that contain a lot of images with text, tables or (scanned) text as images. "
|
|
"Please use '-r' when switching parsers on already indexed data.",
|
|
)
|
|
parser.add_argument("--pdf-chunk_size", type=int, default=1000, help="The size of the chunks to split the text into.")
|
|
parser.add_argument("--pdf-chunk_overlap", type=int, default=200, help="The overlap between the chunks.")
|
|
parser.add_argument(
|
|
"--pdf-add-start-index", action="store_true", help="Add the start index to the metadata of the chunks."
|
|
)
|
|
parser.add_argument(
|
|
"-w", "--web-data", type=str, nargs="*", default=[], help="One or multiple URLs to use for retrieval."
|
|
)
|
|
parser.add_argument("--web-chunk-size", type=int, default=200, help="The size of the chunks to split the text into.")
|
|
parser.add_argument(
|
|
"-d",
|
|
"--chroma-db-location",
|
|
type=Path,
|
|
default=Path(".chroma_db"),
|
|
help="File path to store or load a Chroma DB from/to.",
|
|
)
|
|
parser.add_argument("-r", "--reset-chrome-db", action="store_true", help="Reset the Chroma DB.")
|
|
parser.add_argument(
|
|
"--use-conditional-graph",
|
|
action="store_true",
|
|
help="Use the conditial retrieve generate graph over the regular retrieve generate graph.",
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
vector_store = Chroma(
|
|
collection_name="generic_rag",
|
|
embedding_function=get_embedding_model(args.emb_backend),
|
|
persist_directory=str(args.chroma_db_location),
|
|
)
|
|
|
|
if args.use_conditional_graph:
|
|
graph = CondRetGenLangGraph(
|
|
vector_store=vector_store,
|
|
chat_model=get_chat_model(args.chat_backend),
|
|
embedding_model=get_embedding_model(args.emb_backend),
|
|
system_prompt=system_prompt,
|
|
)
|
|
else:
|
|
graph = RetGenLangGraph(
|
|
vector_store=vector_store,
|
|
chat_model=get_chat_model(args.chat_backend),
|
|
embedding_model=get_embedding_model(args.emb_backend),
|
|
system_prompt=system_prompt,
|
|
compression_model=get_compression_model(
|
|
"BAAI/bge-reranker-base", vector_store
|
|
), # TODO: implement in config parser
|
|
)
|
|
|
|
|
|
@cl.on_message
|
|
async def on_message(message: cl.Message):
|
|
config = {"configurable": {"thread_id": cl.user_session.get("id")}}
|
|
|
|
chainlit_response = cl.Message(content="")
|
|
|
|
async for response in graph.stream(message.content, config=config):
|
|
await chainlit_response.stream_token(response)
|
|
|
|
if isinstance(graph, RetGenLangGraph):
|
|
await add_sources(chainlit_response, graph.get_last_pdf_sources(), graph.get_last_web_sources())
|
|
if isinstance(graph, CondRetGenLangGraph):
|
|
await add_sources(chainlit_response, graph.last_retrieved_docs, graph.last_retrieved_sources)
|
|
|
|
await chainlit_response.send()
|
|
|
|
|
|
async def add_sources(chainlit_response: cl.Message, pdf_sources: dict, web_sources: set | list) -> None:
|
|
if len(pdf_sources) > 0:
|
|
await chainlit_response.stream_token("\n\nThe following PDF source were consulted:\n")
|
|
for source, page_numbers in pdf_sources.items():
|
|
filename = Path(source).name
|
|
await chainlit_response.stream_token(f"- {filename} on page(s): {sorted(page_numbers)}\n")
|
|
chainlit_response.elements.append(
|
|
cl.Pdf(name=filename, display="side", path=source, page=sorted(page_numbers)[0])
|
|
)
|
|
|
|
if len(web_sources) > 0:
|
|
await chainlit_response.stream_token("\n\nThe following web sources were consulted:\n")
|
|
for source in web_sources:
|
|
await chainlit_response.stream_token(f"- {source}\n")
|
|
|
|
|
|
@cl.set_starters
|
|
async def set_starters():
|
|
chainlit_starters = os.environ.get("CHAINLIT_STARTERS", None)
|
|
|
|
if chainlit_starters is None:
|
|
return
|
|
|
|
dict_list = json.loads(chainlit_starters)
|
|
|
|
starters = []
|
|
for starter in dict_list:
|
|
try:
|
|
starters.append(cl.Starter(label=starter["label"], message=starter["message"]))
|
|
except KeyError:
|
|
logger.warning(
|
|
"CHAINLIT_STARTERS environment is not a list with dictionaries containing 'label' and 'message' keys."
|
|
)
|
|
|
|
return starters
|
|
|
|
|
|
if __name__ == "__main__":
|
|
if args.reset_chrome_db:
|
|
vector_store.reset_collection()
|
|
|
|
add_pdf_files(
|
|
vector_store,
|
|
args.pdf_data,
|
|
args.pdf_chunk_size,
|
|
args.pdf_chunk_overlap,
|
|
args.pdf_add_start_index,
|
|
args.unstructured_pdf,
|
|
)
|
|
add_urls(vector_store, args.web_data, args.web_chunk_size)
|
|
|
|
run_chainlit(__file__)
|