forked from AI_team/Philosophy-RAG-demo
Add memory in the RetGenGraph
This commit is contained in:
parent
ee0c731faf
commit
bb1bf558f7
@ -64,8 +64,7 @@ parser.add_argument(
|
|||||||
"-c",
|
"-c",
|
||||||
"--use-conditional-graph",
|
"--use-conditional-graph",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="Use the conditial retrieve generate graph over the regular retrieve generate graph. "
|
help="Use the conditial retrieve generate graph over the regular retrieve generate graph.",
|
||||||
"The conditional version has build in (chat) memory and is capable of quering vectorstores on its own insight.",
|
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
@ -98,7 +97,8 @@ async def on_message(message: cl.Message):
|
|||||||
await chainlit_response.send()
|
await chainlit_response.send()
|
||||||
|
|
||||||
elif isinstance(graph, RetGenLangGraph):
|
elif isinstance(graph, RetGenLangGraph):
|
||||||
response = graph.invoke(message.content)
|
config = {"configurable": {"thread_id": cl.user_session.get("id")}}
|
||||||
|
response = graph.invoke(message.content, config=config)
|
||||||
|
|
||||||
answer = response["answer"]
|
answer = response["answer"]
|
||||||
answer += "\n\n"
|
answer += "\n\n"
|
||||||
|
|||||||
@ -1,9 +1,11 @@
|
|||||||
from langgraph.graph import START, END, StateGraph
|
|
||||||
from typing_extensions import List, TypedDict
|
|
||||||
from langchain_core.documents import Document
|
|
||||||
from langchain import hub
|
|
||||||
from typing import Any, Union
|
from typing import Any, Union
|
||||||
|
|
||||||
|
from langchain import hub
|
||||||
|
from langchain_core.documents import Document
|
||||||
|
from langgraph.checkpoint.memory import MemorySaver
|
||||||
|
from langgraph.graph import END, START, StateGraph
|
||||||
|
from typing_extensions import List, TypedDict
|
||||||
|
|
||||||
|
|
||||||
class State(TypedDict):
|
class State(TypedDict):
|
||||||
question: str
|
question: str
|
||||||
@ -18,16 +20,17 @@ class RetGenLangGraph:
|
|||||||
self.embedding_model = embedding_model
|
self.embedding_model = embedding_model
|
||||||
self.prompt = hub.pull("rlm/rag-prompt")
|
self.prompt = hub.pull("rlm/rag-prompt")
|
||||||
|
|
||||||
|
memory = MemorySaver()
|
||||||
graph_builder = StateGraph(State).add_sequence([self._retrieve, self._generate])
|
graph_builder = StateGraph(State).add_sequence([self._retrieve, self._generate])
|
||||||
graph_builder.add_edge(START, "_retrieve")
|
graph_builder.add_edge(START, "_retrieve")
|
||||||
graph_builder.add_edge("_retrieve", "_generate")
|
graph_builder.add_edge("_retrieve", "_generate")
|
||||||
graph_builder.add_edge("_generate", END)
|
graph_builder.add_edge("_generate", END)
|
||||||
|
|
||||||
self.graph = graph_builder.compile()
|
self.graph = graph_builder.compile(memory)
|
||||||
self.last_invoke = None
|
self.last_invoke = None
|
||||||
|
|
||||||
def invoke(self, message: str) -> Union[dict[str, Any], Any]:
|
def invoke(self, message: str, config: dict) -> Union[dict[str, Any], Any]:
|
||||||
self.last_invoke = self.graph.invoke({"question": message})
|
self.last_invoke = self.graph.invoke({"question": message}, config=config)
|
||||||
return self.last_invoke
|
return self.last_invoke
|
||||||
|
|
||||||
def _retrieve(self, state: State) -> dict:
|
def _retrieve(self, state: State) -> dict:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user