From 3fa0e315211178b7207f722cd7042bf471a06a05 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nielson=20Jann=C3=A9?= Date: Mon, 17 Mar 2025 14:15:50 +0100 Subject: [PATCH] Refactor out Retrieval/Generator LangGraph --- generic_rag/app.py | 108 +++++----------------------------- generic_rag/graphs/ret_gen.py | 80 +++++++++++++++++++++++++ 2 files changed, 96 insertions(+), 92 deletions(-) create mode 100644 generic_rag/graphs/ret_gen.py diff --git a/generic_rag/app.py b/generic_rag/app.py index 6f1736e..0448de2 100644 --- a/generic_rag/app.py +++ b/generic_rag/app.py @@ -6,14 +6,13 @@ from pathlib import Path import chainlit as cl from backend.models import BackendType, get_chat_model, get_embedding_model +from graphs.ret_gen import RetGenLangGraph from chainlit.cli import run_chainlit -from langchain import hub + from langchain_chroma import Chroma -from langchain_core.documents import Document -from langgraph.graph import START, StateGraph -from langgraph.pregel.io import AddableValuesDict + from parsers.parser import add_pdf_files, add_urls -from typing_extensions import List, TypedDict + logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @@ -65,62 +64,26 @@ parser.add_argument( parser.add_argument("-r", "--reset-chrome-db", action="store_true", help="Reset the Chroma DB.") args = parser.parse_args() +vector_store = Chroma( + collection_name="generic_rag", + embedding_function=get_embedding_model(args.backend), + persist_directory=str(args.chroma_db_location), +) -class State(TypedDict): - question: str - context: List[Document] - answer: str - - -def retrieve(state: State): - vector_store = cl.user_session.get("vector_store") - - retrieved_docs = vector_store.similarity_search(state["question"]) - - return {"context": retrieved_docs} - - -def generate(state: State): - prompt = cl.user_session.get("prompt") - llm = cl.user_session.get("chat_model") - - docs_content = "\n\n".join(doc.page_content for doc in state["context"]) - messages = prompt.invoke({"question": state["question"], "context": docs_content}) - response = llm.invoke(messages) - - return {"answer": response.content} - - -@cl.on_chat_start -async def on_chat_start(): - vector_store = Chroma( - collection_name="generic_rag", - embedding_function=get_embedding_model(args.back_end), - persist_directory=str(args.chroma_db_location), - ) - - cl.user_session.set("vector_store", vector_store) - cl.user_session.set("emb_model", get_embedding_model(args.back_end)) - cl.user_session.set("chat_model", get_chat_model(args.back_end)) - cl.user_session.set("prompt", hub.pull("rlm/rag-prompt")) - - graph_builder = StateGraph(State).add_sequence([retrieve, generate]) - graph_builder.add_edge(START, "retrieve") - graph = graph_builder.compile() - - cl.user_session.set("graph", graph) +ret_gen_graph = RetGenLangGraph( + vector_store, chat_model=get_chat_model(args.backend), embedding_model=get_embedding_model(args.backend) +) @cl.on_message async def on_message(message: cl.Message): - graph = cl.user_session.get("graph") - response = graph.invoke({"question": message.content}) + response = ret_gen_graph.invoke(message.content) answer = response["answer"] answer += "\n\n" - pdf_sources = get_pdf_sources(response) - web_sources = get_web_sources(response) + pdf_sources = ret_gen_graph.get_last_pdf_sources() + web_sources = ret_gen_graph.get_last_web_sources() elements = [] if len(pdf_sources) > 0: @@ -128,7 +91,7 @@ async def on_message(message: cl.Message): for source, page_numbers in pdf_sources.items(): page_numbers = list(page_numbers) page_numbers.sort() - # display="side" seems to be not supported by chainlit for PDF's, so we use "inline" instead + # display="side" seems to be not supported by chainlit for PDF's, so we use "inline" instead. elements.append(cl.Pdf(name="pdf", display="inline", path=source, page=page_numbers[0])) answer += f"'{source}' on page(s): {page_numbers}\n" @@ -138,39 +101,6 @@ async def on_message(message: cl.Message): await cl.Message(content=answer, elements=elements).send() -def get_pdf_sources(response: AddableValuesDict) -> dict[str, list[int]]: - """ - Function that retrieves the PDF sources with page numbers from a response. - """ - pdf_sources = {} - for context in response["context"]: - try: - if context.metadata["filetype"] == "application/pdf": - source = context.metadata["source"] - page_number = context.metadata["page_number"] - if source in pdf_sources: - pdf_sources[source].add(page_number) - else: - pdf_sources[source] = {page_number} - except KeyError: - pass - return pdf_sources - - -def get_web_sources(response: AddableValuesDict) -> set: - """ - Function that retrieves the web sources from a response. - """ - web_sources = set() - for context in response["context"]: - try: - if context.metadata["filetype"] == "web": - web_sources.add(context.metadata["source"]) - except KeyError: - pass - return web_sources - - @cl.set_starters async def set_starters(): chainlit_starters = os.environ["CHAINLIT_STARTERS"] @@ -193,12 +123,6 @@ async def set_starters(): if __name__ == "__main__": - vector_store = Chroma( - collection_name="generic_rag", - embedding_function=get_embedding_model(args.back_end), - persist_directory=str(args.chroma_db_location), - ) - if args.reset_chrome_db: vector_store.reset_collection() diff --git a/generic_rag/graphs/ret_gen.py b/generic_rag/graphs/ret_gen.py new file mode 100644 index 0000000..63605d3 --- /dev/null +++ b/generic_rag/graphs/ret_gen.py @@ -0,0 +1,80 @@ +from langgraph.graph import START, END, StateGraph +from typing_extensions import List, TypedDict +from langchain_core.documents import Document +from langchain import hub +from typing import Any, Union + + +class State(TypedDict): + question: str + context: List[Document] + answer: str + + +class RetGenLangGraph: + def __init__(self, vector_store, chat_model, embedding_model): + self.vector_store = vector_store + self.chat_model = chat_model + self.embedding_model = embedding_model + self.prompt = hub.pull("rlm/rag-prompt") + + graph_builder = StateGraph(State).add_sequence([self._retrieve, self._generate]) + graph_builder.add_edge(START, "_retrieve") + graph_builder.add_edge("_retrieve", "_generate") + graph_builder.add_edge("_generate", END) + + self.graph = graph_builder.compile() + self.last_invoke = None + + def invoke(self, message: str) -> Union[dict[str, Any], Any]: + self.last_invoke = self.graph.invoke(message) + return self.last_invoke + + def _retrieve(self, state: State) -> dict: + retrieved_docs = self.vector_store.similarity_search(state["question"]) + return {"context": retrieved_docs} + + def _generate(self, state: State) -> dict: + docs_content = "\n\n".join(doc.page_content for doc in state["context"]) + messages = self.prompt.invoke({"question": state["question"], "context": docs_content}) + response = self.chat_model.invoke(messages) + return {"answer": response.content} + + def get_last_pdf_sources(self) -> dict[str, list[int]]: + """ + Method that retrieves the PDF sources used during the last invoke. + """ + if self.last_invoke is None: + return [] + + pdf_sources = {} + for context in self.last_invoke["context"]: + try: + if context.metadata["filetype"] == "application/pdf": + source = context.metadata["source"] + page_number = context.metadata["page_number"] + if source in pdf_sources: + pdf_sources[source].add(page_number) + else: + pdf_sources[source] = {page_number} + except KeyError: + pass + + return pdf_sources + + def get_last_web_sources(self) -> set: + """ + Method that retrieves the web sources used during the last invoke. + """ + if self.last_invoke is None: + return set() + + web_sources = set() + for context in self.last_invoke["context"]: + try: + if context.metadata["filetype"] == "web": + web_sources.add(context.metadata["source"]) + except KeyError: + pass + + return web_sources