forked from AI_team/Philosophy-RAG-demo
Add memory in the RetGenGraph
This commit is contained in:
parent
ee0c731faf
commit
bb1bf558f7
@ -64,8 +64,7 @@ parser.add_argument(
|
||||
"-c",
|
||||
"--use-conditional-graph",
|
||||
action="store_true",
|
||||
help="Use the conditial retrieve generate graph over the regular retrieve generate graph. "
|
||||
"The conditional version has build in (chat) memory and is capable of quering vectorstores on its own insight.",
|
||||
help="Use the conditial retrieve generate graph over the regular retrieve generate graph.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
@ -98,7 +97,8 @@ async def on_message(message: cl.Message):
|
||||
await chainlit_response.send()
|
||||
|
||||
elif isinstance(graph, RetGenLangGraph):
|
||||
response = graph.invoke(message.content)
|
||||
config = {"configurable": {"thread_id": cl.user_session.get("id")}}
|
||||
response = graph.invoke(message.content, config=config)
|
||||
|
||||
answer = response["answer"]
|
||||
answer += "\n\n"
|
||||
|
||||
@ -1,9 +1,11 @@
|
||||
from langgraph.graph import START, END, StateGraph
|
||||
from typing_extensions import List, TypedDict
|
||||
from langchain_core.documents import Document
|
||||
from langchain import hub
|
||||
from typing import Any, Union
|
||||
|
||||
from langchain import hub
|
||||
from langchain_core.documents import Document
|
||||
from langgraph.checkpoint.memory import MemorySaver
|
||||
from langgraph.graph import END, START, StateGraph
|
||||
from typing_extensions import List, TypedDict
|
||||
|
||||
|
||||
class State(TypedDict):
|
||||
question: str
|
||||
@ -18,16 +20,17 @@ class RetGenLangGraph:
|
||||
self.embedding_model = embedding_model
|
||||
self.prompt = hub.pull("rlm/rag-prompt")
|
||||
|
||||
memory = MemorySaver()
|
||||
graph_builder = StateGraph(State).add_sequence([self._retrieve, self._generate])
|
||||
graph_builder.add_edge(START, "_retrieve")
|
||||
graph_builder.add_edge("_retrieve", "_generate")
|
||||
graph_builder.add_edge("_generate", END)
|
||||
|
||||
self.graph = graph_builder.compile()
|
||||
self.graph = graph_builder.compile(memory)
|
||||
self.last_invoke = None
|
||||
|
||||
def invoke(self, message: str) -> Union[dict[str, Any], Any]:
|
||||
self.last_invoke = self.graph.invoke({"question": message})
|
||||
def invoke(self, message: str, config: dict) -> Union[dict[str, Any], Any]:
|
||||
self.last_invoke = self.graph.invoke({"question": message}, config=config)
|
||||
return self.last_invoke
|
||||
|
||||
def _retrieve(self, state: State) -> dict:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user