import argparse import json import logging import os from pathlib import Path import chainlit as cl from backend.models import BackendType, get_chat_model, get_embedding_model from chainlit.cli import run_chainlit from graphs.cond_ret_gen import CondRetGenLangGraph from graphs.ret_gen import RetGenLangGraph from langchain_chroma import Chroma from parsers.parser import add_pdf_files, add_urls logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) parser = argparse.ArgumentParser(description="A Sogeti Nederland Generic RAG demo.") parser.add_argument( "-b", "--backend", type=BackendType, choices=list(BackendType), default=BackendType.azure, help="Cloud provider to use as backend. In the case of local, Ollama needs to be installed..", ) parser.add_argument( "-p", "--pdf-data", type=Path, nargs="+", default=[], help="One or multiple paths to folders or files to use for retrieval. " "If a path is a folder, all files in the folder will be used. " "If a path is a file, only that file will be used. " "If the path is relative it will be relative to the current working directory.", ) parser.add_argument( "-u", "--unstructured-pdf", action="store_true", help="Use an unstructered PDF loader. " "An unstructured PDF loader might be usefull for PDF files " "that contain a lot of images with text, tables or (scanned) text as images. " "Please use '-r' when switching parsers on already indexed data.", ) parser.add_argument("--pdf-chunk_size", type=int, default=1000, help="The size of the chunks to split the text into.") parser.add_argument("--pdf-chunk_overlap", type=int, default=200, help="The overlap between the chunks.") parser.add_argument( "--pdf-add-start-index", action="store_true", help="Add the start index to the metadata of the chunks." ) parser.add_argument( "-w", "--web-data", type=str, nargs="*", default=[], help="One or multiple URLs to use for retrieval." ) parser.add_argument("--web-chunk-size", type=int, default=200, help="The size of the chunks to split the text into.") parser.add_argument( "-d", "--chroma-db-location", type=Path, default=Path(".chroma_db"), help="File path to store or load a Chroma DB from/to.", ) parser.add_argument("-r", "--reset-chrome-db", action="store_true", help="Reset the Chroma DB.") parser.add_argument( "-c", "--use-conditional-graph", action="store_true", help="Use the conditial retrieve generate graph over the regular retrieve generate graph.", ) args = parser.parse_args() vector_store = Chroma( collection_name="generic_rag", embedding_function=get_embedding_model(args.backend), persist_directory=str(args.chroma_db_location), ) if args.use_conditional_graph: graph = CondRetGenLangGraph( vector_store, chat_model=get_chat_model(args.backend), embedding_model=get_embedding_model(args.backend) ) else: graph = RetGenLangGraph( vector_store, chat_model=get_chat_model(args.backend), embedding_model=get_embedding_model(args.backend) ) @cl.on_message async def on_message(message: cl.Message): if isinstance(graph, CondRetGenLangGraph): await process_cond_response(message) elif isinstance(graph, RetGenLangGraph): await process_response(message) async def process_response(message): config = {"configurable": {"thread_id": cl.user_session.get("id")}} chainlit_response = cl.Message(content="") response = graph.invoke(message.content, config=config) await chainlit_response.stream_token(f"{response}\n") pdf_sources = graph.get_last_pdf_sources() if len(pdf_sources) > 0: await chainlit_response.stream_token("\nThe following PDF source were consulted:\n") for source, page_numbers in pdf_sources.items(): page_numbers = list(page_numbers) page_numbers.sort() # display="side" seems to be not supported by chainlit for PDF's, so we use "inline" instead. chainlit_response.elements.append(cl.Pdf(name="pdf", display="inline", path=source, page=page_numbers[0])) await chainlit_response.update() await chainlit_response.stream_token(f"- '{source}' on page(s): {page_numbers}\n") web_sources = graph.get_last_web_sources() if len(web_sources) > 0: await chainlit_response.stream_token("\nThe following web sources were consulted:\n") for source in web_sources: await chainlit_response.stream_token(f"- {source}\n") await chainlit_response.send() async def process_cond_response(message): config = {"configurable": {"thread_id": cl.user_session.get("id")}} chainlit_response = cl.Message(content="") for response in graph.stream(message.content, config=config): await chainlit_response.stream_token(response) await chainlit_response.send() @cl.set_starters async def set_starters(): chainlit_starters = os.environ.get("CHAINLIT_STARTERS", None) if chainlit_starters is None: return dict_list = json.loads(chainlit_starters) starters = [] for starter in dict_list: try: starters.append(cl.Starter(label=starter["label"], message=starter["message"])) except KeyError: logging.warning( "CHAINLIT_STARTERS environment is not a list with dictionaries containing 'label' and 'message' keys." ) return starters if __name__ == "__main__": if args.reset_chrome_db: vector_store.reset_collection() add_pdf_files( vector_store, args.pdf_data, args.pdf_chunk_size, args.pdf_chunk_overlap, args.pdf_add_start_index, args.unstructured_pdf, ) add_urls(vector_store, args.web_data, args.web_chunk_size) run_chainlit(__file__)