import json import logging import os from pathlib import Path import sys import chainlit as cl from chainlit.cli import run_chainlit from langchain_chroma import Chroma from generic_rag.parsers.config import AppSettings, load_settings from generic_rag.backend.models import get_chat_model, get_embedding_model from generic_rag.graphs.cond_ret_gen import CondRetGenLangGraph from generic_rag.graphs.ret_gen import RetGenLangGraph from generic_rag.parsers.parser import add_pdf_files, add_urls logger = logging.getLogger("sogeti-rag") logger.setLevel(logging.DEBUG) system_prompt = ( "You are an assistant for question-answering tasks. " "If the question is in Dutch, answer in Dutch. If the question is in English, answer in English." "Use the following pieces of retrieved context to answer the question. " "If you don't know the answer, say that you don't know." ) CONFIG_FILE_PATH = Path("config.yaml") try: settings: AppSettings = load_settings(CONFIG_FILE_PATH) except (FileNotFoundError, Exception) as e: logger.error(f"Failed to load configuration from {CONFIG_FILE_PATH}. Exiting.") sys.exit(1) embedding_function = get_embedding_model(settings) chat_function = get_chat_model(settings) vector_store = Chroma( collection_name="generic_rag", embedding_function=embedding_function, persist_directory=str(settings.chroma_db.location), ) if settings.use_conditional_graph: graph = CondRetGenLangGraph( vector_store=vector_store, chat_model=chat_function, embedding_model=embedding_function, system_prompt=system_prompt, ) else: graph = RetGenLangGraph( vector_store=vector_store, chat_model=chat_function, embedding_model=embedding_function, system_prompt=system_prompt, ) @cl.on_message async def on_message(message: cl.Message): config = {"configurable": {"thread_id": cl.user_session.get("id")}} chainlit_response = cl.Message(content="") async for response in graph.stream(message.content, config=config): await chainlit_response.stream_token(response) if isinstance(graph, RetGenLangGraph): await add_sources(chainlit_response, graph.get_last_pdf_sources(), graph.get_last_web_sources()) if isinstance(graph, CondRetGenLangGraph): await add_sources(chainlit_response, graph.last_retrieved_docs, graph.last_retrieved_sources) await chainlit_response.send() async def add_sources(chainlit_response: cl.Message, pdf_sources: dict, web_sources: set | list) -> None: if len(pdf_sources) > 0: await chainlit_response.stream_token("\n\nThe following PDF source were consulted:\n") for source, page_numbers in pdf_sources.items(): filename = Path(source).name await chainlit_response.stream_token(f"- {filename} on page(s): {sorted(page_numbers)}\n") chainlit_response.elements.append( cl.Pdf(name=filename, display="side", path=source, page=sorted(page_numbers)[0]) ) if len(web_sources) > 0: await chainlit_response.stream_token("\n\nThe following web sources were consulted:\n") for source in web_sources: await chainlit_response.stream_token(f"- {source}\n") @cl.set_starters async def set_starters(): chainlit_starters = os.environ.get("CHAINLIT_STARTERS", None) if chainlit_starters is None: return dict_list = json.loads(chainlit_starters) starters = [] for starter in dict_list: try: starters.append(cl.Starter(label=starter["label"], message=starter["message"])) except KeyError: logger.warning( "CHAINLIT_STARTERS environment is not a list with dictionaries containing 'label' and 'message' keys." ) return starters if __name__ == "__main__": if settings.chroma_db.reset: vector_store.reset_collection() add_pdf_files( vector_store, settings.pdf.data, settings.pdf.chunk_size, settings.pdf.chunk_overlap, settings.pdf.add_start_index, settings.pdf.unstructured, ) add_urls( vector_store, settings.web.data, settings.web.chunk_size, ) run_chainlit(__file__)