Philosophy-RAG-demo/generic_rag/app.py
2025-04-17 08:32:38 +02:00

135 lines
4.2 KiB
Python

import json
import logging
import os
from pathlib import Path
import sys
import chainlit as cl
from chainlit.cli import run_chainlit
from langchain_chroma import Chroma
from generic_rag.parsers.config import AppSettings, load_settings
from generic_rag.backend.models import get_chat_model, get_embedding_model
from generic_rag.graphs.cond_ret_gen import CondRetGenLangGraph
from generic_rag.graphs.ret_gen import RetGenLangGraph
from generic_rag.parsers.parser import add_pdf_files, add_urls
logger = logging.getLogger("sogeti-rag")
logger.setLevel(logging.DEBUG)
system_prompt = (
"You are an assistant for question-answering tasks. "
"If the question is in Dutch, answer in Dutch. If the question is in English, answer in English."
"Use the following pieces of retrieved context to answer the question. "
"If you don't know the answer, say that you don't know."
)
PROJECT_ROOT = Path(__file__).resolve().parent.parent
CONFIG_FILE_PATH = PROJECT_ROOT / "config.yaml"
try:
settings: AppSettings = load_settings(CONFIG_FILE_PATH)
except (FileNotFoundError, Exception) as e:
logger.error(f"Failed to load configuration from {CONFIG_FILE_PATH}. Exiting.")
sys.exit(1)
embedding_function = get_embedding_model(settings)
chat_function = get_chat_model(settings)
vector_store = Chroma(
collection_name="generic_rag",
embedding_function=embedding_function,
persist_directory=str(settings.chroma_db.location),
)
if settings.use_conditional_graph:
graph = CondRetGenLangGraph(
vector_store=vector_store,
chat_model=chat_function,
embedding_model=embedding_function,
system_prompt=system_prompt,
)
else:
graph = RetGenLangGraph(
vector_store=vector_store,
chat_model=chat_function,
embedding_model=embedding_function,
system_prompt=system_prompt,
)
@cl.on_message
async def on_message(message: cl.Message):
config = {"configurable": {"thread_id": cl.user_session.get("id")}}
chainlit_response = cl.Message(content="")
async for response in graph.stream(message.content, config=config):
await chainlit_response.stream_token(response)
if isinstance(graph, RetGenLangGraph):
await add_sources(chainlit_response, graph.get_last_pdf_sources(), graph.get_last_web_sources())
if isinstance(graph, CondRetGenLangGraph):
await add_sources(chainlit_response, graph.last_retrieved_docs, graph.last_retrieved_sources)
await chainlit_response.send()
async def add_sources(chainlit_response: cl.Message, pdf_sources: dict, web_sources: set | list) -> None:
if len(pdf_sources) > 0:
await chainlit_response.stream_token("\n\nThe following PDF source were consulted:\n")
for source, page_numbers in pdf_sources.items():
filename = Path(source).name
await chainlit_response.stream_token(f"- {filename} on page(s): {sorted(page_numbers)}\n")
chainlit_response.elements.append(
cl.Pdf(name=filename, display="side", path=source, page=sorted(page_numbers)[0])
)
if len(web_sources) > 0:
await chainlit_response.stream_token("\n\nThe following web sources were consulted:\n")
for source in web_sources:
await chainlit_response.stream_token(f"- {source}\n")
@cl.set_starters
async def set_starters():
chainlit_starters = os.environ.get("CHAINLIT_STARTERS", None)
if chainlit_starters is None:
return
dict_list = json.loads(chainlit_starters)
starters = []
for starter in dict_list:
try:
starters.append(cl.Starter(label=starter["label"], message=starter["message"]))
except KeyError:
logger.warning(
"CHAINLIT_STARTERS environment is not a list with dictionaries containing 'label' and 'message' keys."
)
return starters
if __name__ == "__main__":
if settings.chroma_db.reset:
vector_store.reset_collection()
add_pdf_files(
vector_store,
settings.pdf.data,
settings.pdf.chunk_size,
settings.pdf.chunk_overlap,
settings.pdf.add_start_index,
settings.pdf.unstructured,
)
add_urls(
vector_store,
settings.web.data,
settings.web.chunk_size,
)
run_chainlit(__file__)