From bb1bf558f7e7b467c3993fbe8471bc8ceee31c17 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nielson=20Jann=C3=A9?= Date: Mon, 17 Mar 2025 17:48:30 +0100 Subject: [PATCH] Add memory in the RetGenGraph --- generic_rag/app.py | 6 +++--- generic_rag/graphs/ret_gen.py | 17 ++++++++++------- 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/generic_rag/app.py b/generic_rag/app.py index 94d9932..c09c5e8 100644 --- a/generic_rag/app.py +++ b/generic_rag/app.py @@ -64,8 +64,7 @@ parser.add_argument( "-c", "--use-conditional-graph", action="store_true", - help="Use the conditial retrieve generate graph over the regular retrieve generate graph. " - "The conditional version has build in (chat) memory and is capable of quering vectorstores on its own insight.", + help="Use the conditial retrieve generate graph over the regular retrieve generate graph.", ) args = parser.parse_args() @@ -98,7 +97,8 @@ async def on_message(message: cl.Message): await chainlit_response.send() elif isinstance(graph, RetGenLangGraph): - response = graph.invoke(message.content) + config = {"configurable": {"thread_id": cl.user_session.get("id")}} + response = graph.invoke(message.content, config=config) answer = response["answer"] answer += "\n\n" diff --git a/generic_rag/graphs/ret_gen.py b/generic_rag/graphs/ret_gen.py index fd51259..523ad7c 100644 --- a/generic_rag/graphs/ret_gen.py +++ b/generic_rag/graphs/ret_gen.py @@ -1,9 +1,11 @@ -from langgraph.graph import START, END, StateGraph -from typing_extensions import List, TypedDict -from langchain_core.documents import Document -from langchain import hub from typing import Any, Union +from langchain import hub +from langchain_core.documents import Document +from langgraph.checkpoint.memory import MemorySaver +from langgraph.graph import END, START, StateGraph +from typing_extensions import List, TypedDict + class State(TypedDict): question: str @@ -18,16 +20,17 @@ class RetGenLangGraph: self.embedding_model = embedding_model self.prompt = hub.pull("rlm/rag-prompt") + memory = MemorySaver() graph_builder = StateGraph(State).add_sequence([self._retrieve, self._generate]) graph_builder.add_edge(START, "_retrieve") graph_builder.add_edge("_retrieve", "_generate") graph_builder.add_edge("_generate", END) - self.graph = graph_builder.compile() + self.graph = graph_builder.compile(memory) self.last_invoke = None - def invoke(self, message: str) -> Union[dict[str, Any], Any]: - self.last_invoke = self.graph.invoke({"question": message}) + def invoke(self, message: str, config: dict) -> Union[dict[str, Any], Any]: + self.last_invoke = self.graph.invoke({"question": message}, config=config) return self.last_invoke def _retrieve(self, state: State) -> dict: