Add memory in the RetGenGraph

This commit is contained in:
Nielson Janné 2025-03-17 17:48:30 +01:00
parent ee0c731faf
commit bb1bf558f7
2 changed files with 13 additions and 10 deletions

View File

@ -64,8 +64,7 @@ parser.add_argument(
"-c", "-c",
"--use-conditional-graph", "--use-conditional-graph",
action="store_true", action="store_true",
help="Use the conditial retrieve generate graph over the regular retrieve generate graph. " help="Use the conditial retrieve generate graph over the regular retrieve generate graph.",
"The conditional version has build in (chat) memory and is capable of quering vectorstores on its own insight.",
) )
args = parser.parse_args() args = parser.parse_args()
@ -98,7 +97,8 @@ async def on_message(message: cl.Message):
await chainlit_response.send() await chainlit_response.send()
elif isinstance(graph, RetGenLangGraph): elif isinstance(graph, RetGenLangGraph):
response = graph.invoke(message.content) config = {"configurable": {"thread_id": cl.user_session.get("id")}}
response = graph.invoke(message.content, config=config)
answer = response["answer"] answer = response["answer"]
answer += "\n\n" answer += "\n\n"

View File

@ -1,9 +1,11 @@
from langgraph.graph import START, END, StateGraph
from typing_extensions import List, TypedDict
from langchain_core.documents import Document
from langchain import hub
from typing import Any, Union from typing import Any, Union
from langchain import hub
from langchain_core.documents import Document
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import END, START, StateGraph
from typing_extensions import List, TypedDict
class State(TypedDict): class State(TypedDict):
question: str question: str
@ -18,16 +20,17 @@ class RetGenLangGraph:
self.embedding_model = embedding_model self.embedding_model = embedding_model
self.prompt = hub.pull("rlm/rag-prompt") self.prompt = hub.pull("rlm/rag-prompt")
memory = MemorySaver()
graph_builder = StateGraph(State).add_sequence([self._retrieve, self._generate]) graph_builder = StateGraph(State).add_sequence([self._retrieve, self._generate])
graph_builder.add_edge(START, "_retrieve") graph_builder.add_edge(START, "_retrieve")
graph_builder.add_edge("_retrieve", "_generate") graph_builder.add_edge("_retrieve", "_generate")
graph_builder.add_edge("_generate", END) graph_builder.add_edge("_generate", END)
self.graph = graph_builder.compile() self.graph = graph_builder.compile(memory)
self.last_invoke = None self.last_invoke = None
def invoke(self, message: str) -> Union[dict[str, Any], Any]: def invoke(self, message: str, config: dict) -> Union[dict[str, Any], Any]:
self.last_invoke = self.graph.invoke({"question": message}) self.last_invoke = self.graph.invoke({"question": message}, config=config)
return self.last_invoke return self.last_invoke
def _retrieve(self, state: State) -> dict: def _retrieve(self, state: State) -> dict: