forked from AI_team/Philosophy-RAG-demo
146 lines
4.5 KiB
Python
146 lines
4.5 KiB
Python
import argparse
|
|
import json
|
|
import logging
|
|
import os
|
|
from pathlib import Path
|
|
import sys
|
|
|
|
import chainlit as cl
|
|
from chainlit.cli import run_chainlit
|
|
from langchain_chroma import Chroma
|
|
|
|
from generic_rag.parsers.config import AppSettings, load_settings
|
|
from generic_rag.backend.models import get_chat_model, get_embedding_model
|
|
from generic_rag.graphs.cond_ret_gen import CondRetGenLangGraph
|
|
from generic_rag.graphs.ret_gen import RetGenLangGraph
|
|
from generic_rag.parsers.parser import add_pdf_files, add_urls
|
|
|
|
logger = logging.getLogger("sogeti-rag")
|
|
logger.setLevel(logging.DEBUG)
|
|
|
|
PROJECT_ROOT = Path(__file__).resolve().parent.parent
|
|
|
|
system_prompt = (
|
|
"You are an assistant for question-answering tasks. "
|
|
"If the question is in Dutch, answer in Dutch. If the question is in English, answer in English."
|
|
"Use the following pieces of retrieved context to answer the question. "
|
|
"If you don't know the answer, say that you don't know."
|
|
)
|
|
|
|
parser = argparse.ArgumentParser(description="A Sogeti Netherlands Generic RAG demo.")
|
|
parser.add_argument(
|
|
"-c",
|
|
"--config",
|
|
type=Path,
|
|
default=PROJECT_ROOT / "config.yaml",
|
|
help="Path to configuration file (YAML format). Defaults to 'config.yaml' in project root.",
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
try:
|
|
settings: AppSettings = load_settings(args.config)
|
|
except (FileNotFoundError, Exception) as e:
|
|
logger.error(f"Failed to load configuration from {args.config}. Exiting.")
|
|
sys.exit(1)
|
|
|
|
embedding_function = get_embedding_model(settings)
|
|
|
|
chat_function = get_chat_model(settings)
|
|
|
|
vector_store = Chroma(
|
|
collection_name="generic_rag",
|
|
embedding_function=embedding_function,
|
|
persist_directory=str(settings.chroma_db.location),
|
|
)
|
|
|
|
if settings.use_conditional_graph:
|
|
graph = CondRetGenLangGraph(
|
|
vector_store=vector_store,
|
|
chat_model=chat_function,
|
|
embedding_model=embedding_function,
|
|
system_prompt=system_prompt,
|
|
)
|
|
else:
|
|
graph = RetGenLangGraph(
|
|
vector_store=vector_store,
|
|
chat_model=chat_function,
|
|
embedding_model=embedding_function,
|
|
system_prompt=system_prompt,
|
|
)
|
|
|
|
|
|
@cl.on_message
|
|
async def on_message(message: cl.Message):
|
|
config = {"configurable": {"thread_id": cl.user_session.get("id")}}
|
|
|
|
chainlit_response = cl.Message(content="")
|
|
|
|
async for response in graph.stream(message.content, config=config):
|
|
await chainlit_response.stream_token(response)
|
|
|
|
if isinstance(graph, RetGenLangGraph):
|
|
await add_sources(chainlit_response, graph.get_last_pdf_sources(), graph.get_last_web_sources())
|
|
if isinstance(graph, CondRetGenLangGraph):
|
|
await add_sources(chainlit_response, graph.last_retrieved_docs, graph.last_retrieved_sources)
|
|
|
|
await chainlit_response.send()
|
|
|
|
|
|
async def add_sources(chainlit_response: cl.Message, pdf_sources: dict, web_sources: set | list) -> None:
|
|
if len(pdf_sources) > 0:
|
|
await chainlit_response.stream_token("\n\nThe following PDF source were consulted:\n")
|
|
for source, page_numbers in pdf_sources.items():
|
|
filename = Path(source).name
|
|
await chainlit_response.stream_token(f"- {filename} on page(s): {sorted(page_numbers)}\n")
|
|
chainlit_response.elements.append(
|
|
cl.Pdf(name=filename, display="side", path=source, page=sorted(page_numbers)[0])
|
|
)
|
|
|
|
if len(web_sources) > 0:
|
|
await chainlit_response.stream_token("\n\nThe following web sources were consulted:\n")
|
|
for source in web_sources:
|
|
await chainlit_response.stream_token(f"- {source}\n")
|
|
|
|
|
|
@cl.set_starters
|
|
async def set_starters():
|
|
chainlit_starters = os.environ.get("CHAINLIT_STARTERS", None)
|
|
|
|
if chainlit_starters is None:
|
|
return
|
|
|
|
dict_list = json.loads(chainlit_starters)
|
|
|
|
starters = []
|
|
for starter in dict_list:
|
|
try:
|
|
starters.append(cl.Starter(label=starter["label"], message=starter["message"]))
|
|
except KeyError:
|
|
logger.warning(
|
|
"CHAINLIT_STARTERS environment is not a list with dictionaries containing 'label' and 'message' keys."
|
|
)
|
|
|
|
return starters
|
|
|
|
|
|
if __name__ == "__main__":
|
|
if settings.chroma_db.reset:
|
|
vector_store.reset_collection()
|
|
|
|
add_pdf_files(
|
|
vector_store,
|
|
settings.pdf.data,
|
|
settings.pdf.chunk_size,
|
|
settings.pdf.chunk_overlap,
|
|
settings.pdf.add_start_index,
|
|
settings.pdf.unstructured,
|
|
)
|
|
add_urls(
|
|
vector_store,
|
|
settings.web.data,
|
|
settings.web.chunk_size,
|
|
)
|
|
|
|
run_chainlit(__file__)
|