Philosophy-RAG-demo/generic_rag/app.py
2025-03-11 17:31:24 +01:00

118 lines
5.2 KiB
Python

import argparse
import logging
from pathlib import Path
import chainlit as cl
from chainlit.cli import run_chainlit
from langchain import hub
from langchain_core.documents import Document
from langchain_core.vectorstores import InMemoryVectorStore
from langgraph.graph import START, StateGraph
from typing_extensions import List, TypedDict
from backend.model import BackendType, get_embedding_model, get_chat_model
from parsers.parser import process_local_files, process_web_sites
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
parser = argparse.ArgumentParser(description="A Sogeti Nederland Generic RAG demo.")
parser.add_argument("-b", "--back-end", type=BackendType, choices=list(BackendType), default=BackendType.azure,
help="(Cloud) back-end to use. In the case of local, a locally installed ollama will be used.")
parser.add_argument("-p", "--pdf-data", type=Path, required=True, nargs="+",
help="One or multiple paths to folders or files to use for retrieval. "
"If a path is a folder, all files in the folder will be used. "
"If a path is a file, only that file will be used. "
"If the path is relative it will be relative to the current working directory.")
parser.add_argument("--pdf-chunk_size", type=int, default=1000,
help="The size of the chunks to split the text into.")
parser.add_argument("--pdf-chunk_overlap", type=int, default=200,
help="The overlap between the chunks.")
parser.add_argument("--pdf-add-start-index", action="store_true",
help="Add the start index to the metadata of the chunks.")
parser.add_argument("-w", "--web-data", type=str, nargs="*", default=[],
help="One or multiple URLs to use for retrieval.")
parser.add_argument("--web-chunk-size", type=int, default=200,
help="The size of the chunks to split the text into.")
args = parser.parse_args()
class State(TypedDict):
question: str
context: List[Document]
answer: str
def retrieve(state: State):
vector_store = cl.user_session.get("vector_store")
retrieved_docs = vector_store.similarity_search(state["question"])
return {"context": retrieved_docs}
def generate(state: State):
prompt = cl.user_session.get("prompt")
llm = cl.user_session.get("chat_model")
docs_content = "\n\n".join(doc.page_content for doc in state["context"])
messages = prompt.invoke({"question": state["question"], "context": docs_content})
response = llm.invoke(messages)
return {"answer": response.content}
@cl.on_chat_start
async def on_chat_start():
await cl.Message(author="System", content="Starting up application").send()
embedding = get_embedding_model(args.back_end)
vector_store = InMemoryVectorStore(embedding)
await cl.Message(author="System", content="Processing PDF files.").send()
pdf_splits = await cl.make_async(process_local_files)(args.pdf_data, args.pdf_chunk_size,
args.pdf_chunk_overlap, args.pdf_add_start_index)
await cl.Message(author="System", content="Processing web sites.").send()
web_splits = await cl.make_async(process_web_sites)(args.web_data, args.web_chunk_size)
_ = vector_store.add_documents(documents=pdf_splits + web_splits)
cl.user_session.set("emb_model", embedding)
cl.user_session.set("vector_store", vector_store)
cl.user_session.set("chat_model", get_chat_model(args.back_end))
cl.user_session.set("prompt", hub.pull("rlm/rag-prompt"))
graph_builder = StateGraph(State).add_sequence([retrieve, generate])
graph_builder.add_edge(START, "retrieve")
graph = graph_builder.compile()
cl.user_session.set("graph", graph)
await cl.Message(content="Ready for chatting!").send()
@cl.on_message
async def on_message(message: cl.Message):
graph = cl.user_session.get("graph")
response = graph.invoke({"question": message.content})
# Send the final answer.
await cl.Message(content=response).send()
@cl.set_starters
async def set_starters():
return [cl.Starter(label="Morning routine ideation",
message="Can you help me create a personalized morning routine that would help increase my "
"productivity throughout the day? Start by asking me about my current habits and what "
"activities energize me in the morning.", ),
cl.Starter(label="Explain superconductors",
message="Explain superconductors like I'm five years old.", ),
cl.Starter(label="Python script for daily email reports",
message="Write a script to automate sending daily email reports in Python, and walk me through "
"how I would set it up.", ),
cl.Starter(label="Text inviting friend to wedding",
message="Write a text asking a friend to be my plus-one at a wedding next month. I want to keep "
"it super short and casual, and offer an out.", )]
if __name__ == "__main__":
run_chainlit(__file__)