Philosophy-RAG-demo/generic_rag/app.py
2025-03-26 11:08:54 +01:00

170 lines
5.7 KiB
Python

import argparse
import json
import logging
import os
from pathlib import Path
import chainlit as cl
from backend.models import BackendType, get_chat_model, get_embedding_model
from chainlit.cli import run_chainlit
from graphs.cond_ret_gen import CondRetGenLangGraph
from graphs.ret_gen import RetGenLangGraph
from langchain_chroma import Chroma
from parsers.parser import add_pdf_files, add_urls
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
parser = argparse.ArgumentParser(description="A Sogeti Nederland Generic RAG demo.")
parser.add_argument(
"-b",
"--backend",
type=BackendType,
choices=list(BackendType),
default=BackendType.azure,
help="Cloud provider to use as backend. In the case of local, Ollama needs to be installed..",
)
parser.add_argument(
"-p",
"--pdf-data",
type=Path,
nargs="+",
default=[],
help="One or multiple paths to folders or files to use for retrieval. "
"If a path is a folder, all files in the folder will be used. "
"If a path is a file, only that file will be used. "
"If the path is relative it will be relative to the current working directory.",
)
parser.add_argument(
"-u",
"--unstructured-pdf",
action="store_true",
help="Use an unstructered PDF loader. "
"An unstructured PDF loader might be usefull for PDF files "
"that contain a lot of images with text, tables or (scanned) text as images. "
"Please use '-r' when switching parsers on already indexed data.",
)
parser.add_argument("--pdf-chunk_size", type=int, default=1000, help="The size of the chunks to split the text into.")
parser.add_argument("--pdf-chunk_overlap", type=int, default=200, help="The overlap between the chunks.")
parser.add_argument(
"--pdf-add-start-index", action="store_true", help="Add the start index to the metadata of the chunks."
)
parser.add_argument(
"-w", "--web-data", type=str, nargs="*", default=[], help="One or multiple URLs to use for retrieval."
)
parser.add_argument("--web-chunk-size", type=int, default=200, help="The size of the chunks to split the text into.")
parser.add_argument(
"-d",
"--chroma-db-location",
type=Path,
default=Path(".chroma_db"),
help="File path to store or load a Chroma DB from/to.",
)
parser.add_argument("-r", "--reset-chrome-db", action="store_true", help="Reset the Chroma DB.")
parser.add_argument(
"-c",
"--use-conditional-graph",
action="store_true",
help="Use the conditial retrieve generate graph over the regular retrieve generate graph.",
)
args = parser.parse_args()
vector_store = Chroma(
collection_name="generic_rag",
embedding_function=get_embedding_model(args.backend),
persist_directory=str(args.chroma_db_location),
)
if args.use_conditional_graph:
graph = CondRetGenLangGraph(
vector_store, chat_model=get_chat_model(args.backend), embedding_model=get_embedding_model(args.backend)
)
else:
graph = RetGenLangGraph(
vector_store, chat_model=get_chat_model(args.backend), embedding_model=get_embedding_model(args.backend)
)
@cl.on_message
async def on_message(message: cl.Message):
if isinstance(graph, CondRetGenLangGraph):
await process_cond_response(message)
elif isinstance(graph, RetGenLangGraph):
await process_response(message)
async def process_response(message):
config = {"configurable": {"thread_id": cl.user_session.get("id")}}
chainlit_response = cl.Message(content="")
response = graph.invoke(message.content, config=config)
await chainlit_response.stream_token(f"{response}\n\n")
pdf_sources = graph.get_last_pdf_sources()
if len(pdf_sources) > 0:
await chainlit_response.stream_token("The following PDF source were consulted:\n")
for source, page_numbers in pdf_sources.items():
page_numbers = list(page_numbers)
page_numbers.sort()
# display="side" seems to be not supported by chainlit for PDF's, so we use "inline" instead.
chainlit_response.elements.append(cl.Pdf(name="pdf", display="inline", path=source, page=page_numbers[0]))
await chainlit_response.update()
await chainlit_response.stream_token(f"'{source}' on page(s): {page_numbers}\n")
web_sources = graph.get_last_web_sources()
if len(web_sources) > 0:
await chainlit_response.stream_token(f"The following web sources were consulted: {web_sources}\n")
await chainlit_response.send()
async def process_cond_response(message):
config = {"configurable": {"thread_id": cl.user_session.get("id")}}
chainlit_response = cl.Message(content="")
for response in graph.stream(message.content, config=config):
await chainlit_response.stream_token(response)
await chainlit_response.send()
@cl.set_starters
async def set_starters():
chainlit_starters = os.environ["CHAINLIT_STARTERS"]
if chainlit_starters is None:
return
dict_list = json.loads(chainlit_starters)
starters = []
for starter in dict_list:
try:
starters.append(cl.Starter(label=starter["label"], message=starter["message"]))
except KeyError:
logging.warning(
"CHAINLIT_STARTERS environment is not a list with dictionaries containing 'label' and 'message' keys."
)
return starters
if __name__ == "__main__":
if args.reset_chrome_db:
vector_store.reset_collection()
add_pdf_files(
vector_store,
args.pdf_data,
args.pdf_chunk_size,
args.pdf_chunk_overlap,
args.pdf_add_start_index,
args.unstructured_pdf,
)
add_urls(vector_store, args.web_data, args.web_chunk_size)
run_chainlit(__file__)