From f25770e3cef90b8932030a6cecc96e36b6b82195 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nielson=20Jann=C3=A9?= Date: Mon, 17 Mar 2025 16:51:15 +0100 Subject: [PATCH] Add a Conditional Retrieve/Generator LangGraph --- generic_rag/graphs/cond_ret_gen.py | 106 +++++++++++++++++++++++++++++ 1 file changed, 106 insertions(+) create mode 100644 generic_rag/graphs/cond_ret_gen.py diff --git a/generic_rag/graphs/cond_ret_gen.py b/generic_rag/graphs/cond_ret_gen.py new file mode 100644 index 0000000..fbbf534 --- /dev/null +++ b/generic_rag/graphs/cond_ret_gen.py @@ -0,0 +1,106 @@ +import logging +from typing import Any, Iterator, List + +from langchain_chroma import Chroma +from langchain_core.documents import Document +from langchain_core.embeddings import Embeddings +from langchain_core.language_models.chat_models import BaseChatModel +from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage +from langchain_core.tools import tool +from langgraph.checkpoint.memory import MemorySaver +from langgraph.graph import END, MessagesState, StateGraph +from langgraph.prebuilt import InjectedStore, ToolNode, tools_condition +from typing_extensions import Annotated + + +class CondRetGenLangGraph: + def __init__(self, vector_store: Chroma, chat_model: BaseChatModel, embedding_model: Embeddings): + self.chat_model = chat_model + self.embedding_model = embedding_model + self.system_prompt = ( + "You are an assistant for question-answering tasks. " + "If the question is in Dutch, answer in Dutch. If the question is in English, answer in English." + "Use the following pieces of retrieved context to answer the question. " + "If you don't know the answer, say that you don't know." + ) + + memory = MemorySaver() + tools = ToolNode([self._retrieve]) + + graph_builder = StateGraph(MessagesState) + graph_builder.add_node(self._query_or_respond) + graph_builder.add_node(tools) + graph_builder.add_node(self._generate) + + graph_builder.set_entry_point("_query_or_respond") + graph_builder.add_conditional_edges("_query_or_respond", tools_condition, {END: END, "tools": "tools"}) + graph_builder.add_edge("tools", "_generate") + graph_builder.add_edge("_generate", END) + + self.graph = graph_builder.compile(checkpointer=memory, store=vector_store) + + def stream(self, message: str, config=None) -> Iterator[str]: + for llm_response, metadata in self.graph.stream( + {"messages": [{"role": "user", "content": message}]}, stream_mode="messages", config=config + ): + if ( + llm_response.content + and not isinstance(llm_response, HumanMessage) + and metadata["langgraph_node"] == "_generate" + ): + yield llm_response.content + + # TODO: read souces used in AIMessages and set internal value sources used in last received stream. + + @tool(response_format="content_and_artifact") + def _retrieve( + query: str, full_user_content: str, vector_store: Annotated[Any, InjectedStore()] + ) -> tuple[str, List[Document]]: + """ + Retrieve information related to a query and user content. + """ + # This method is used as a tool in the graph. + # It's doc-string is used for the pydentic model, please consider doc-string text carefully. + # Furthermore, it can not and should not have the `self` parameter. + # If you want to pass on state, please refer to: + # https://python.langchain.com/docs/concepts/tools/#special-type-annotations + logging.info(f"Query: {query}") + logging.info(f"user content: {full_user_content}") + + retrieved_docs = [] + retrieved_docs = vector_store.similarity_search(query, k=4) + retrieved_docs = vector_store.similarity_search(full_user_content, k=4) + serialized = "\n\n".join((f"Source: {doc.metadata}\nContent: {doc.page_content}") for doc in retrieved_docs) + + return serialized, retrieved_docs + + def _query_or_respond(self, state: MessagesState) -> dict[str, BaseMessage]: + """Generate tool call for retrieval or respond.""" + llm_with_tools = self.chat_model.bind_tools([self._retrieve]) + response = llm_with_tools.invoke(state["messages"]) + return {"messages": [response]} + + def _generate(self, state: MessagesState) -> dict[str, BaseMessage]: + """Generate answer.""" + # get generated ToolMessages + recent_tool_messages = [] + for message in reversed(state["messages"]): + if message.type == "tool": + recent_tool_messages.append(message) + else: + break + tool_messages = recent_tool_messages[::-1] + + # format into prompt + docs_content = "\n\n".join(doc.content for doc in tool_messages) + system_message_content = self.system_prompt + f"\n\n{docs_content}" + conversation_messages = [ + message + for message in state["messages"] + if message.type in ("human", "system") or (message.type == "ai" and not message.tool_calls) + ] + prompt = [SystemMessage(system_message_content)] + conversation_messages + + # run + response = self.chat_model.invoke(prompt) + return {"messages": [response]}