import argparse import json import logging import os from pathlib import Path import chainlit as cl from backend.models import BackendType, get_chat_model, get_embedding_model from chainlit.cli import run_chainlit from langchain import hub from langchain_chroma import Chroma from langchain_core.documents import Document from langgraph.graph import START, StateGraph from langgraph.pregel.io import AddableValuesDict from parsers.parser import add_pdf_files, add_urls from typing_extensions import List, TypedDict logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) parser = argparse.ArgumentParser(description="A Sogeti Nederland Generic RAG demo.") parser.add_argument( "-b", "--back-end", type=BackendType, choices=list(BackendType), default=BackendType.azure, help="(Cloud) back-end to use. In the case of local, a locally installed ollama will be used.", ) parser.add_argument( "-p", "--pdf-data", type=Path, required=True, nargs="+", help="One or multiple paths to folders or files to use for retrieval. " "If a path is a folder, all files in the folder will be used. " "If a path is a file, only that file will be used. " "If the path is relative it will be relative to the current working directory.", ) parser.add_argument("--pdf-chunk_size", type=int, default=1000, help="The size of the chunks to split the text into.") parser.add_argument("--pdf-chunk_overlap", type=int, default=200, help="The overlap between the chunks.") parser.add_argument( "--pdf-add-start-index", action="store_true", help="Add the start index to the metadata of the chunks." ) parser.add_argument( "-w", "--web-data", type=str, nargs="*", default=[], help="One or multiple URLs to use for retrieval." ) parser.add_argument("--web-chunk-size", type=int, default=200, help="The size of the chunks to split the text into.") parser.add_argument( "-c", "--chroma-db-location", type=Path, default=Path(".chroma_db"), help="File path to store or load a Chroma DB from/to.", ) parser.add_argument("-r", "--reset-chrome-db", action="store_true", help="Reset the Chroma DB.") args = parser.parse_args() class State(TypedDict): question: str context: List[Document] answer: str def retrieve(state: State): vector_store = cl.user_session.get("vector_store") retrieved_docs = vector_store.similarity_search(state["question"]) return {"context": retrieved_docs} def generate(state: State): prompt = cl.user_session.get("prompt") llm = cl.user_session.get("chat_model") docs_content = "\n\n".join(doc.page_content for doc in state["context"]) messages = prompt.invoke({"question": state["question"], "context": docs_content}) response = llm.invoke(messages) return {"answer": response.content} @cl.on_chat_start async def on_chat_start(): vector_store = Chroma( collection_name="generic_rag", embedding_function=get_embedding_model(args.back_end), persist_directory=str(args.chroma_db_location), ) cl.user_session.set("vector_store", vector_store) cl.user_session.set("emb_model", get_embedding_model(args.back_end)) cl.user_session.set("chat_model", get_chat_model(args.back_end)) cl.user_session.set("prompt", hub.pull("rlm/rag-prompt")) graph_builder = StateGraph(State).add_sequence([retrieve, generate]) graph_builder.add_edge(START, "retrieve") graph = graph_builder.compile() cl.user_session.set("graph", graph) @cl.on_message async def on_message(message: cl.Message): graph = cl.user_session.get("graph") response = graph.invoke({"question": message.content}) answer = response["answer"] answer += "\n\n" pdf_sources = get_pdf_sources(response) web_sources = get_web_sources(response) elements = [] if len(pdf_sources) > 0: answer += "The following PDF source were consulted:\n" for source, page_numbers in pdf_sources.items(): page_numbers = list(page_numbers) page_numbers.sort() # display="side" seems to be not supported by chainlit for PDF's, so we use "inline" instead elements.append(cl.Pdf(name="pdf", display="inline", path=source, page=page_numbers[0])) answer += f"'{source}' on page(s): {page_numbers}\n" if len(web_sources) > 0: answer += f"The following web sources were consulted: {web_sources}\n" await cl.Message(content=answer, elements=elements).send() def get_pdf_sources(response: AddableValuesDict) -> dict[str, list[int]]: """ Function that retrieves the PDF sources with page numbers from a response. """ pdf_sources = {} for context in response["context"]: try: if context.metadata["filetype"] == "application/pdf": source = context.metadata["source"] page_number = context.metadata["page_number"] if source in pdf_sources: pdf_sources[source].add(page_number) else: pdf_sources[source] = {page_number} except KeyError: pass return pdf_sources def get_web_sources(response: AddableValuesDict) -> set: """ Function that retrieves the web sources from a response. """ web_sources = set() for context in response["context"]: try: if context.metadata["filetype"] == "web": web_sources.add(context.metadata["source"]) except KeyError: pass return web_sources @cl.set_starters async def set_starters(): chainlit_starters = os.environ["CHAINLIT_STARTERS"] if chainlit_starters is None: return dict_list = json.loads(chainlit_starters) starters = [] for starter in dict_list: try: starters.append(cl.Starter(label=starter["label"], message=starter["message"])) except KeyError: logging.warning( "CHAINLIT_STARTERS environment is not a list with dictionaries containing 'label' and 'message' keys." ) return starters if __name__ == "__main__": vector_store = Chroma( collection_name="generic_rag", embedding_function=get_embedding_model(args.back_end), persist_directory=str(args.chroma_db_location), ) if args.reset_chrome_db: vector_store.reset_collection() add_pdf_files(vector_store, args.pdf_data, args.pdf_chunk_size, args.pdf_chunk_overlap, args.pdf_add_start_index) add_urls(vector_store, args.web_data, args.web_chunk_size) run_chainlit(__file__)