Philosophy-RAG-demo/generic_rag/app.py
2025-03-17 12:46:52 +01:00

209 lines
6.8 KiB
Python

import argparse
import json
import logging
import os
from pathlib import Path
import chainlit as cl
from backend.models import BackendType, get_chat_model, get_embedding_model
from chainlit.cli import run_chainlit
from langchain import hub
from langchain_chroma import Chroma
from langchain_core.documents import Document
from langgraph.graph import START, StateGraph
from langgraph.pregel.io import AddableValuesDict
from parsers.parser import add_pdf_files, add_urls
from typing_extensions import List, TypedDict
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
parser = argparse.ArgumentParser(description="A Sogeti Nederland Generic RAG demo.")
parser.add_argument(
"-b",
"--back-end",
type=BackendType,
choices=list(BackendType),
default=BackendType.azure,
help="(Cloud) back-end to use. In the case of local, a locally installed ollama will be used.",
)
parser.add_argument(
"-p",
"--pdf-data",
type=Path,
required=True,
nargs="+",
help="One or multiple paths to folders or files to use for retrieval. "
"If a path is a folder, all files in the folder will be used. "
"If a path is a file, only that file will be used. "
"If the path is relative it will be relative to the current working directory.",
)
parser.add_argument(
"--unstructured-pdf",
action="store_true",
help="Use an unstructered PDF loader. "
"An unstructured PDF loader might be usefull for PDF files "
"that contain a lot of images with text, tables or (scanned) text as images. "
"Please use '-r' when switching parsers on already indexed data.",
)
parser.add_argument("--pdf-chunk_size", type=int, default=1000, help="The size of the chunks to split the text into.")
parser.add_argument("--pdf-chunk_overlap", type=int, default=200, help="The overlap between the chunks.")
parser.add_argument(
"--pdf-add-start-index", action="store_true", help="Add the start index to the metadata of the chunks."
)
parser.add_argument(
"-w", "--web-data", type=str, nargs="*", default=[], help="One or multiple URLs to use for retrieval."
)
parser.add_argument("--web-chunk-size", type=int, default=200, help="The size of the chunks to split the text into.")
parser.add_argument(
"-c",
"--chroma-db-location",
type=Path,
default=Path(".chroma_db"),
help="File path to store or load a Chroma DB from/to.",
)
parser.add_argument("-r", "--reset-chrome-db", action="store_true", help="Reset the Chroma DB.")
args = parser.parse_args()
class State(TypedDict):
question: str
context: List[Document]
answer: str
def retrieve(state: State):
vector_store = cl.user_session.get("vector_store")
retrieved_docs = vector_store.similarity_search(state["question"])
return {"context": retrieved_docs}
def generate(state: State):
prompt = cl.user_session.get("prompt")
llm = cl.user_session.get("chat_model")
docs_content = "\n\n".join(doc.page_content for doc in state["context"])
messages = prompt.invoke({"question": state["question"], "context": docs_content})
response = llm.invoke(messages)
return {"answer": response.content}
@cl.on_chat_start
async def on_chat_start():
vector_store = Chroma(
collection_name="generic_rag",
embedding_function=get_embedding_model(args.back_end),
persist_directory=str(args.chroma_db_location),
)
cl.user_session.set("vector_store", vector_store)
cl.user_session.set("emb_model", get_embedding_model(args.back_end))
cl.user_session.set("chat_model", get_chat_model(args.back_end))
cl.user_session.set("prompt", hub.pull("rlm/rag-prompt"))
graph_builder = StateGraph(State).add_sequence([retrieve, generate])
graph_builder.add_edge(START, "retrieve")
graph = graph_builder.compile()
cl.user_session.set("graph", graph)
@cl.on_message
async def on_message(message: cl.Message):
graph = cl.user_session.get("graph")
response = graph.invoke({"question": message.content})
answer = response["answer"]
answer += "\n\n"
pdf_sources = get_pdf_sources(response)
web_sources = get_web_sources(response)
elements = []
if len(pdf_sources) > 0:
answer += "The following PDF source were consulted:\n"
for source, page_numbers in pdf_sources.items():
page_numbers = list(page_numbers)
page_numbers.sort()
# display="side" seems to be not supported by chainlit for PDF's, so we use "inline" instead
elements.append(cl.Pdf(name="pdf", display="inline", path=source, page=page_numbers[0]))
answer += f"'{source}' on page(s): {page_numbers}\n"
if len(web_sources) > 0:
answer += f"The following web sources were consulted: {web_sources}\n"
await cl.Message(content=answer, elements=elements).send()
def get_pdf_sources(response: AddableValuesDict) -> dict[str, list[int]]:
"""
Function that retrieves the PDF sources with page numbers from a response.
"""
pdf_sources = {}
for context in response["context"]:
try:
if context.metadata["filetype"] == "application/pdf":
source = context.metadata["source"]
page_number = context.metadata["page_number"]
if source in pdf_sources:
pdf_sources[source].add(page_number)
else:
pdf_sources[source] = {page_number}
except KeyError:
pass
return pdf_sources
def get_web_sources(response: AddableValuesDict) -> set:
"""
Function that retrieves the web sources from a response.
"""
web_sources = set()
for context in response["context"]:
try:
if context.metadata["filetype"] == "web":
web_sources.add(context.metadata["source"])
except KeyError:
pass
return web_sources
@cl.set_starters
async def set_starters():
chainlit_starters = os.environ["CHAINLIT_STARTERS"]
if chainlit_starters is None:
return
dict_list = json.loads(chainlit_starters)
starters = []
for starter in dict_list:
try:
starters.append(cl.Starter(label=starter["label"], message=starter["message"]))
except KeyError:
logging.warning(
"CHAINLIT_STARTERS environment is not a list with dictionaries containing 'label' and 'message' keys."
)
return starters
if __name__ == "__main__":
vector_store = Chroma(
collection_name="generic_rag",
embedding_function=get_embedding_model(args.back_end),
persist_directory=str(args.chroma_db_location),
)
if args.reset_chrome_db:
vector_store.reset_collection()
add_pdf_files(vector_store, args.pdf_data, args.pdf_chunk_size, args.pdf_chunk_overlap, args.pdf_add_start_index)
add_urls(vector_store, args.web_data, args.web_chunk_size)
run_chainlit(__file__)