From 3295bb8992bdca8c11f223ddd31a97fd57e3fa60 Mon Sep 17 00:00:00 2001 From: Ruben Lucas Date: Wed, 9 Apr 2025 15:23:54 +0200 Subject: [PATCH 1/4] =?UTF-8?q?=E2=9C=A8=20Find=20and=20add=20found=20sour?= =?UTF-8?q?ces?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- generic_rag/app.py | 15 +++++++++++ generic_rag/graphs/cond_ret_gen.py | 42 ++++++++++++++++++++++-------- 2 files changed, 46 insertions(+), 11 deletions(-) diff --git a/generic_rag/app.py b/generic_rag/app.py index 56b7162..64ec26d 100644 --- a/generic_rag/app.py +++ b/generic_rag/app.py @@ -129,6 +129,21 @@ async def process_cond_response(message): for response in graph.stream(message.content, config=config): await chainlit_response.stream_token(response) + if len(graph.last_retrieved_docs) > 0: + await chainlit_response.stream_token("\nThe following PDF source were consulted:\n") + for source, page_numbers in graph.last_retrieved_docs.items(): + page_numbers = list(page_numbers) + page_numbers.sort() + # display="side" seems to be not supported by chainlit for PDF's, so we use "inline" instead. + chainlit_response.elements.append(cl.Pdf(name="pdf", display="inline", path=source, page=page_numbers[0])) + await chainlit_response.update() + await chainlit_response.stream_token(f"- '{source}' on page(s): {page_numbers}\n") + + if len(graph.last_retrieved_sources) > 0: + await chainlit_response.stream_token("\nThe following web sources were consulted:\n") + for source in graph.last_retrieved_sources: + await chainlit_response.stream_token(f"- {source}\n") + await chainlit_response.send() diff --git a/generic_rag/graphs/cond_ret_gen.py b/generic_rag/graphs/cond_ret_gen.py index fbbf534..9b0cdd3 100644 --- a/generic_rag/graphs/cond_ret_gen.py +++ b/generic_rag/graphs/cond_ret_gen.py @@ -1,5 +1,8 @@ import logging -from typing import Any, Iterator, List +from typing import Any, Iterator +import re +import ast +from pathlib import Path from langchain_chroma import Chroma from langchain_core.documents import Document @@ -7,6 +10,7 @@ from langchain_core.embeddings import Embeddings from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage from langchain_core.tools import tool +from langchain_core.runnables.config import RunnableConfig from langgraph.checkpoint.memory import MemorySaver from langgraph.graph import END, MessagesState, StateGraph from langgraph.prebuilt import InjectedStore, ToolNode, tools_condition @@ -39,28 +43,44 @@ class CondRetGenLangGraph: self.graph = graph_builder.compile(checkpointer=memory, store=vector_store) - def stream(self, message: str, config=None) -> Iterator[str]: + self.file_path_pattern = r"'file_path'\s*:\s*'((?:[^'\\]|\\.)*)'" + self.source_pattern = r"'source'\s*:\s*'((?:[^'\\]|\\.)*)'" + self.page_pattern = r"'page'\s*:\s*(\d+)" + self.pattern = r"Source:\s*(\{.*?\})" + + self.last_retrieved_docs = {} + self.last_retrieved_sources = set() + + def stream(self, message: str, config: RunnableConfig | None = None) -> Iterator[str]: for llm_response, metadata in self.graph.stream( {"messages": [{"role": "user", "content": message}]}, stream_mode="messages", config=config ): - if ( - llm_response.content - and not isinstance(llm_response, HumanMessage) - and metadata["langgraph_node"] == "_generate" - ): + if llm_response.content and metadata["langgraph_node"] == "_generate": yield llm_response.content - - # TODO: read souces used in AIMessages and set internal value sources used in last received stream. + elif llm_response.name == "_retrieve": + dictionary_strings = re.findall( + self.pattern, llm_response.content, re.DOTALL + ) # Use re.DOTALL if dicts might span newlines + for dict_str in dictionary_strings: + parsed_dict = ast.literal_eval(dict_str) + print(parsed_dict) + if "filetype" in parsed_dict and parsed_dict["filetype"] == "web": + self.last_retrieved_sources.add(parsed_dict["source"]) + elif Path(parsed_dict["source"]).suffix == ".pdf": + if parsed_dict["source"] in self.last_retrieved_docs: + self.last_retrieved_docs[parsed_dict["source"]].add(parsed_dict["page"]) + else: + self.last_retrieved_docs[parsed_dict["source"]] = {parsed_dict["page"]} @tool(response_format="content_and_artifact") def _retrieve( query: str, full_user_content: str, vector_store: Annotated[Any, InjectedStore()] - ) -> tuple[str, List[Document]]: + ) -> tuple[str, list[Document]]: """ Retrieve information related to a query and user content. """ # This method is used as a tool in the graph. - # It's doc-string is used for the pydentic model, please consider doc-string text carefully. + # It's doc-string is used for the pydantic model, please consider doc-string text carefully. # Furthermore, it can not and should not have the `self` parameter. # If you want to pass on state, please refer to: # https://python.langchain.com/docs/concepts/tools/#special-type-annotations From ab1235bd28942323d8ba55eaffcfd29f273848fa Mon Sep 17 00:00:00 2001 From: Ruben Lucas Date: Wed, 9 Apr 2025 16:03:35 +0200 Subject: [PATCH 2/4] =?UTF-8?q?=E2=9C=A8=20Create=20single=20source=20aggr?= =?UTF-8?q?egation=20definition?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- generic_rag/app.py | 40 +++++++++++------------------- generic_rag/graphs/cond_ret_gen.py | 1 - 2 files changed, 15 insertions(+), 26 deletions(-) diff --git a/generic_rag/app.py b/generic_rag/app.py index 64ec26d..4ecb79a 100644 --- a/generic_rag/app.py +++ b/generic_rag/app.py @@ -93,15 +93,7 @@ async def on_message(message: cl.Message): await process_response(message) -async def process_response(message): - config = {"configurable": {"thread_id": cl.user_session.get("id")}} - - chainlit_response = cl.Message(content="") - - async for response in graph.stream(message.content, config=config): - await chainlit_response.stream_token(response) - - pdf_sources = graph.get_last_pdf_sources() +async def add_sources(chainlit_response: cl.Message, pdf_sources: dict, web_sources: set | list): if len(pdf_sources) > 0: await chainlit_response.stream_token("\nThe following PDF source were consulted:\n") for source, page_numbers in pdf_sources.items(): @@ -111,13 +103,24 @@ async def process_response(message): chainlit_response.elements.append(cl.Pdf(name="pdf", display="inline", path=source, page=page_numbers[0])) await chainlit_response.update() await chainlit_response.stream_token(f"- '{source}' on page(s): {page_numbers}\n") - - web_sources = graph.get_last_web_sources() if len(web_sources) > 0: await chainlit_response.stream_token("\nThe following web sources were consulted:\n") for source in web_sources: await chainlit_response.stream_token(f"- {source}\n") + +async def process_response(message): + config = {"configurable": {"thread_id": cl.user_session.get("id")}} + + chainlit_response = cl.Message(content="") + + async for response in graph.stream(message.content, config=config): + await chainlit_response.stream_token(response) + + pdf_sources = graph.get_last_pdf_sources() + web_sources = graph.get_last_web_sources() + await add_sources(chainlit_response, pdf_sources, web_sources) + await chainlit_response.send() @@ -129,20 +132,7 @@ async def process_cond_response(message): for response in graph.stream(message.content, config=config): await chainlit_response.stream_token(response) - if len(graph.last_retrieved_docs) > 0: - await chainlit_response.stream_token("\nThe following PDF source were consulted:\n") - for source, page_numbers in graph.last_retrieved_docs.items(): - page_numbers = list(page_numbers) - page_numbers.sort() - # display="side" seems to be not supported by chainlit for PDF's, so we use "inline" instead. - chainlit_response.elements.append(cl.Pdf(name="pdf", display="inline", path=source, page=page_numbers[0])) - await chainlit_response.update() - await chainlit_response.stream_token(f"- '{source}' on page(s): {page_numbers}\n") - - if len(graph.last_retrieved_sources) > 0: - await chainlit_response.stream_token("\nThe following web sources were consulted:\n") - for source in graph.last_retrieved_sources: - await chainlit_response.stream_token(f"- {source}\n") + await add_sources(chainlit_response, graph.last_retrieved_docs, graph.last_retrieved_sources) await chainlit_response.send() diff --git a/generic_rag/graphs/cond_ret_gen.py b/generic_rag/graphs/cond_ret_gen.py index 9b0cdd3..c884a8e 100644 --- a/generic_rag/graphs/cond_ret_gen.py +++ b/generic_rag/graphs/cond_ret_gen.py @@ -63,7 +63,6 @@ class CondRetGenLangGraph: ) # Use re.DOTALL if dicts might span newlines for dict_str in dictionary_strings: parsed_dict = ast.literal_eval(dict_str) - print(parsed_dict) if "filetype" in parsed_dict and parsed_dict["filetype"] == "web": self.last_retrieved_sources.add(parsed_dict["source"]) elif Path(parsed_dict["source"]).suffix == ".pdf": From 67ba306d3e02ac704c41797b2726ddcf9f6f661c Mon Sep 17 00:00:00 2001 From: Ruben Lucas Date: Wed, 9 Apr 2025 16:06:15 +0200 Subject: [PATCH 3/4] =?UTF-8?q?=F0=9F=90=9B=20If=20check=20for=20pdf=20sou?= =?UTF-8?q?rce=20because=20try=20doesn't=20fail?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- generic_rag/graphs/ret_gen.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/generic_rag/graphs/ret_gen.py b/generic_rag/graphs/ret_gen.py index 9cb69bb..9a05349 100644 --- a/generic_rag/graphs/ret_gen.py +++ b/generic_rag/graphs/ret_gen.py @@ -61,12 +61,11 @@ class RetGenLangGraph: return pdf_sources for doc in self.last_retrieved_docs: - try: - Path(doc.metadata["source"]).suffix == ".pdf" - except KeyError: - continue - else: + source_candidate = doc.metadata["source"] + if "source" in doc.metadata and Path(doc.metadata["source"]).suffix.lower() == ".pdf": source = doc.metadata["source"] + else: + continue if source not in pdf_sources: pdf_sources[source] = set() From df2afd73cb67e5bead9614a317ca3042464f08a5 Mon Sep 17 00:00:00 2001 From: Ruben Lucas Date: Wed, 9 Apr 2025 16:30:39 +0200 Subject: [PATCH 4/4] =?UTF-8?q?=F0=9F=8E=A8=20Reset=20retrieved=20sources?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- generic_rag/graphs/cond_ret_gen.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/generic_rag/graphs/cond_ret_gen.py b/generic_rag/graphs/cond_ret_gen.py index c884a8e..5d3a6ed 100644 --- a/generic_rag/graphs/cond_ret_gen.py +++ b/generic_rag/graphs/cond_ret_gen.py @@ -95,6 +95,10 @@ class CondRetGenLangGraph: def _query_or_respond(self, state: MessagesState) -> dict[str, BaseMessage]: """Generate tool call for retrieval or respond.""" + # Reset last retrieved docs + self.last_retrieved_docs = {} + self.last_retrieved_sources = set() + llm_with_tools = self.chat_model.bind_tools([self._retrieve]) response = llm_with_tools.invoke(state["messages"]) return {"messages": [response]}