forked from AI_team/Philosophy-RAG-demo
Add a Conditional Retrieve/Generator LangGraph
This commit is contained in:
parent
3965ce0fb2
commit
f25770e3ce
106
generic_rag/graphs/cond_ret_gen.py
Normal file
106
generic_rag/graphs/cond_ret_gen.py
Normal file
@ -0,0 +1,106 @@
|
|||||||
|
import logging
|
||||||
|
from typing import Any, Iterator, List
|
||||||
|
|
||||||
|
from langchain_chroma import Chroma
|
||||||
|
from langchain_core.documents import Document
|
||||||
|
from langchain_core.embeddings import Embeddings
|
||||||
|
from langchain_core.language_models.chat_models import BaseChatModel
|
||||||
|
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
from langgraph.checkpoint.memory import MemorySaver
|
||||||
|
from langgraph.graph import END, MessagesState, StateGraph
|
||||||
|
from langgraph.prebuilt import InjectedStore, ToolNode, tools_condition
|
||||||
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
|
|
||||||
|
class CondRetGenLangGraph:
|
||||||
|
def __init__(self, vector_store: Chroma, chat_model: BaseChatModel, embedding_model: Embeddings):
|
||||||
|
self.chat_model = chat_model
|
||||||
|
self.embedding_model = embedding_model
|
||||||
|
self.system_prompt = (
|
||||||
|
"You are an assistant for question-answering tasks. "
|
||||||
|
"If the question is in Dutch, answer in Dutch. If the question is in English, answer in English."
|
||||||
|
"Use the following pieces of retrieved context to answer the question. "
|
||||||
|
"If you don't know the answer, say that you don't know."
|
||||||
|
)
|
||||||
|
|
||||||
|
memory = MemorySaver()
|
||||||
|
tools = ToolNode([self._retrieve])
|
||||||
|
|
||||||
|
graph_builder = StateGraph(MessagesState)
|
||||||
|
graph_builder.add_node(self._query_or_respond)
|
||||||
|
graph_builder.add_node(tools)
|
||||||
|
graph_builder.add_node(self._generate)
|
||||||
|
|
||||||
|
graph_builder.set_entry_point("_query_or_respond")
|
||||||
|
graph_builder.add_conditional_edges("_query_or_respond", tools_condition, {END: END, "tools": "tools"})
|
||||||
|
graph_builder.add_edge("tools", "_generate")
|
||||||
|
graph_builder.add_edge("_generate", END)
|
||||||
|
|
||||||
|
self.graph = graph_builder.compile(checkpointer=memory, store=vector_store)
|
||||||
|
|
||||||
|
def stream(self, message: str, config=None) -> Iterator[str]:
|
||||||
|
for llm_response, metadata in self.graph.stream(
|
||||||
|
{"messages": [{"role": "user", "content": message}]}, stream_mode="messages", config=config
|
||||||
|
):
|
||||||
|
if (
|
||||||
|
llm_response.content
|
||||||
|
and not isinstance(llm_response, HumanMessage)
|
||||||
|
and metadata["langgraph_node"] == "_generate"
|
||||||
|
):
|
||||||
|
yield llm_response.content
|
||||||
|
|
||||||
|
# TODO: read souces used in AIMessages and set internal value sources used in last received stream.
|
||||||
|
|
||||||
|
@tool(response_format="content_and_artifact")
|
||||||
|
def _retrieve(
|
||||||
|
query: str, full_user_content: str, vector_store: Annotated[Any, InjectedStore()]
|
||||||
|
) -> tuple[str, List[Document]]:
|
||||||
|
"""
|
||||||
|
Retrieve information related to a query and user content.
|
||||||
|
"""
|
||||||
|
# This method is used as a tool in the graph.
|
||||||
|
# It's doc-string is used for the pydentic model, please consider doc-string text carefully.
|
||||||
|
# Furthermore, it can not and should not have the `self` parameter.
|
||||||
|
# If you want to pass on state, please refer to:
|
||||||
|
# https://python.langchain.com/docs/concepts/tools/#special-type-annotations
|
||||||
|
logging.info(f"Query: {query}")
|
||||||
|
logging.info(f"user content: {full_user_content}")
|
||||||
|
|
||||||
|
retrieved_docs = []
|
||||||
|
retrieved_docs = vector_store.similarity_search(query, k=4)
|
||||||
|
retrieved_docs = vector_store.similarity_search(full_user_content, k=4)
|
||||||
|
serialized = "\n\n".join((f"Source: {doc.metadata}\nContent: {doc.page_content}") for doc in retrieved_docs)
|
||||||
|
|
||||||
|
return serialized, retrieved_docs
|
||||||
|
|
||||||
|
def _query_or_respond(self, state: MessagesState) -> dict[str, BaseMessage]:
|
||||||
|
"""Generate tool call for retrieval or respond."""
|
||||||
|
llm_with_tools = self.chat_model.bind_tools([self._retrieve])
|
||||||
|
response = llm_with_tools.invoke(state["messages"])
|
||||||
|
return {"messages": [response]}
|
||||||
|
|
||||||
|
def _generate(self, state: MessagesState) -> dict[str, BaseMessage]:
|
||||||
|
"""Generate answer."""
|
||||||
|
# get generated ToolMessages
|
||||||
|
recent_tool_messages = []
|
||||||
|
for message in reversed(state["messages"]):
|
||||||
|
if message.type == "tool":
|
||||||
|
recent_tool_messages.append(message)
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
tool_messages = recent_tool_messages[::-1]
|
||||||
|
|
||||||
|
# format into prompt
|
||||||
|
docs_content = "\n\n".join(doc.content for doc in tool_messages)
|
||||||
|
system_message_content = self.system_prompt + f"\n\n{docs_content}"
|
||||||
|
conversation_messages = [
|
||||||
|
message
|
||||||
|
for message in state["messages"]
|
||||||
|
if message.type in ("human", "system") or (message.type == "ai" and not message.tool_calls)
|
||||||
|
]
|
||||||
|
prompt = [SystemMessage(system_message_content)] + conversation_messages
|
||||||
|
|
||||||
|
# run
|
||||||
|
response = self.chat_model.invoke(prompt)
|
||||||
|
return {"messages": [response]}
|
||||||
Loading…
Reference in New Issue
Block a user