From db3d1cfa20da13313d54a50a1dba8c25f30864e3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nielson=20Jann=C3=A9?= Date: Tue, 8 Apr 2025 23:37:04 +0200 Subject: [PATCH 1/3] Change RetGenLangGraph to use streaming instead of invoking on the LLM --- generic_rag/app.py | 5 ++--- generic_rag/graphs/ret_gen.py | 37 +++++++++++++++++++---------------- 2 files changed, 22 insertions(+), 20 deletions(-) diff --git a/generic_rag/app.py b/generic_rag/app.py index 274a7ac..c76a253 100644 --- a/generic_rag/app.py +++ b/generic_rag/app.py @@ -98,9 +98,8 @@ async def process_response(message): chainlit_response = cl.Message(content="") - response = graph.invoke(message.content, config=config) - - await chainlit_response.stream_token(f"{response}\n") + for response in graph.stream(message.content, config=config): + await chainlit_response.stream_token(response) pdf_sources = graph.get_last_pdf_sources() if len(pdf_sources) > 0: diff --git a/generic_rag/graphs/ret_gen.py b/generic_rag/graphs/ret_gen.py index c30d6e2..3c332ed 100644 --- a/generic_rag/graphs/ret_gen.py +++ b/generic_rag/graphs/ret_gen.py @@ -3,7 +3,10 @@ from pathlib import Path from typing import Any, Union from langchain import hub +from langchain_chroma import Chroma from langchain_core.documents import Document +from langchain_core.embeddings import Embeddings +from langchain_core.language_models.chat_models import BaseChatModel from langgraph.checkpoint.memory import MemorySaver from langgraph.graph import END, START, StateGraph from typing_extensions import List, TypedDict @@ -19,7 +22,7 @@ class State(TypedDict): class RetGenLangGraph: - def __init__(self, vector_store, chat_model, embedding_model): + def __init__(self, vector_store: Chroma, chat_model: BaseChatModel, embedding_model: Embeddings): self.vector_store = vector_store self.chat_model = chat_model self.embedding_model = embedding_model @@ -32,15 +35,15 @@ class RetGenLangGraph: graph_builder.add_edge("_generate", END) self.graph = graph_builder.compile(memory) - self.last_invoke = None + self.last_retrieved_docs = [] - def invoke(self, message: str, config: dict) -> Union[dict[str, Any], Any]: - self.last_invoke = self.graph.invoke({"question": message}, config=config) - return self.last_invoke["answer"] + def stream(self, message: str, config: dict) -> Union[dict[str, Any], Any]: + for response, _ in self.graph.stream({"question": message}, stream_mode="messages", config=config): + yield response.content def _retrieve(self, state: State) -> dict: - retrieved_docs = self.vector_store.similarity_search(state["question"]) - return {"context": retrieved_docs} + self.last_retrieved_docs = self.vector_store.similarity_search(state["question"]) + return {"context": self.last_retrieved_docs} def _generate(self, state: State) -> dict: docs_content = "\n\n".join(doc.page_content for doc in state["context"]) @@ -54,30 +57,30 @@ class RetGenLangGraph: """ pdf_sources = {} - if self.last_invoke is None: + if not self.last_retrieved_docs: return pdf_sources - for context in self.last_invoke["context"]: + for doc in self.last_retrieved_docs: try: - Path(context.metadata["source"]).suffix == ".pdf" + Path(doc.metadata["source"]).suffix == ".pdf" except KeyError: continue else: - source = context.metadata["source"] + source = doc.metadata["source"] if source not in pdf_sources: pdf_sources[source] = set() # The page numbers are in the `page_numer` and `page` fields. try: - page_number = context.metadata["page_number"] + page_number = doc.metadata["page_number"] except KeyError: pass else: pdf_sources[source].add(page_number) try: - page_number = context.metadata["page"] + page_number = doc.metadata["page"] except KeyError: pass else: @@ -94,15 +97,15 @@ class RetGenLangGraph: """ web_sources = set() - if self.last_invoke is None: + if not self.last_retrieved_docs: return web_sources - for context in self.last_invoke["context"]: + for doc in self.last_retrieved_docs: try: - context.metadata["filetype"] == "web" + doc.metadata["filetype"] == "web" except KeyError: continue else: - web_sources.add(context.metadata["source"]) + web_sources.add(doc.metadata["source"]) return web_sources From 9bb9f0ea226a84ed51e34be1ecc39baedcdd1595 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nielson=20Jann=C3=A9?= Date: Wed, 9 Apr 2025 11:16:55 +0200 Subject: [PATCH 2/3] Make RetGenGrap async for smoother user experience --- generic_rag/app.py | 2 +- generic_rag/graphs/ret_gen.py | 14 +++++++------- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/generic_rag/app.py b/generic_rag/app.py index c76a253..56b7162 100644 --- a/generic_rag/app.py +++ b/generic_rag/app.py @@ -98,7 +98,7 @@ async def process_response(message): chainlit_response = cl.Message(content="") - for response in graph.stream(message.content, config=config): + async for response in graph.stream(message.content, config=config): await chainlit_response.stream_token(response) pdf_sources = graph.get_last_pdf_sources() diff --git a/generic_rag/graphs/ret_gen.py b/generic_rag/graphs/ret_gen.py index 3c332ed..b12f7e1 100644 --- a/generic_rag/graphs/ret_gen.py +++ b/generic_rag/graphs/ret_gen.py @@ -1,6 +1,6 @@ import logging from pathlib import Path -from typing import Any, Union +from typing import Any, AsyncGenerator from langchain import hub from langchain_chroma import Chroma @@ -37,19 +37,19 @@ class RetGenLangGraph: self.graph = graph_builder.compile(memory) self.last_retrieved_docs = [] - def stream(self, message: str, config: dict) -> Union[dict[str, Any], Any]: - for response, _ in self.graph.stream({"question": message}, stream_mode="messages", config=config): + async def stream(self, message: str, config: dict) -> AsyncGenerator[Any, Any]: + async for response, _ in self.graph.astream({"question": message}, stream_mode="messages", config=config): yield response.content def _retrieve(self, state: State) -> dict: self.last_retrieved_docs = self.vector_store.similarity_search(state["question"]) return {"context": self.last_retrieved_docs} - def _generate(self, state: State) -> dict: + async def _generate(self, state: State) -> AsyncGenerator[Any, Any]: docs_content = "\n\n".join(doc.page_content for doc in state["context"]) - messages = self.prompt.invoke({"question": state["question"], "context": docs_content}) - response = self.chat_model.invoke(messages) - return {"answer": response.content} + messages = await self.prompt.ainvoke({"question": state["question"], "context": docs_content}) + async for response in self.chat_model.astream(messages): + yield {"answer": response.content} def get_last_pdf_sources(self) -> dict[str, list[int]]: """ From 9baa7b0ef660e1d1e1dee25ce828221d1316bace Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nielson=20Jann=C3=A9?= Date: Wed, 9 Apr 2025 11:17:09 +0200 Subject: [PATCH 3/3] Fix a small bug in report web page source --- generic_rag/graphs/ret_gen.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/generic_rag/graphs/ret_gen.py b/generic_rag/graphs/ret_gen.py index b12f7e1..9cb69bb 100644 --- a/generic_rag/graphs/ret_gen.py +++ b/generic_rag/graphs/ret_gen.py @@ -102,10 +102,9 @@ class RetGenLangGraph: for doc in self.last_retrieved_docs: try: - doc.metadata["filetype"] == "web" + if doc.metadata["filetype"] == "web": + web_sources.add(doc.metadata["source"]) except KeyError: continue - else: - web_sources.add(doc.metadata["source"]) return web_sources