Philosophy-RAG-demo/generic_rag/graphs/cond_ret_gen.py
2025-04-13 14:27:38 +02:00

128 lines
5.6 KiB
Python

import logging
from typing import Any, Iterator
import re
import ast
from pathlib import Path
from langchain_chroma import Chroma
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import BaseMessage, SystemMessage
from langchain_core.tools import tool
from langchain_core.runnables.config import RunnableConfig
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import END, MessagesState, StateGraph
from langgraph.prebuilt import InjectedStore, ToolNode, tools_condition
from typing_extensions import Annotated
logger = logging.getLogger(__name__)
class CondRetGenLangGraph:
def __init__(
self, vector_store: Chroma, chat_model: BaseChatModel, embedding_model: Embeddings, system_prompt: str
):
self.chat_model = chat_model
self.embedding_model = embedding_model
self.system_prompt = system_prompt
memory = MemorySaver()
tools = ToolNode([self._retrieve])
graph_builder = StateGraph(MessagesState)
graph_builder.add_node(self._query_or_respond)
graph_builder.add_node(tools)
graph_builder.add_node(self._generate)
graph_builder.set_entry_point("_query_or_respond")
graph_builder.add_conditional_edges("_query_or_respond", tools_condition, {END: END, "tools": "tools"})
graph_builder.add_edge("tools", "_generate")
graph_builder.add_edge("_generate", END)
self.graph = graph_builder.compile(checkpointer=memory, store=vector_store)
self.file_path_pattern = r"'file_path'\s*:\s*'((?:[^'\\]|\\.)*)'"
self.source_pattern = r"'source'\s*:\s*'((?:[^'\\]|\\.)*)'"
self.page_pattern = r"'page'\s*:\s*(\d+)"
self.pattern = r"Source:\s*(\{.*?\})"
self.last_retrieved_docs = {}
self.last_retrieved_sources = set()
async def stream(self, message: str, config: RunnableConfig | None = None) -> AsyncGenerator[Any, Any]:
async for llm_response, metadata in self.graph.astream(
{"messages": [{"role": "user", "content": message}]}, stream_mode="messages", config=config
):
if llm_response.content and metadata["langgraph_node"] == "_generate":
yield llm_response.content
elif llm_response.name == "_retrieve":
dictionary_strings = re.findall(
self.pattern, llm_response.content, re.DOTALL
) # Use re.DOTALL if dicts might span newlines
for dict_str in dictionary_strings:
parsed_dict = ast.literal_eval(dict_str)
if "filetype" in parsed_dict and parsed_dict["filetype"] == "web":
self.last_retrieved_sources.add(parsed_dict["source"])
elif Path(parsed_dict["source"]).suffix == ".pdf":
if parsed_dict["source"] in self.last_retrieved_docs:
self.last_retrieved_docs[parsed_dict["source"]].add(parsed_dict["page"])
else:
self.last_retrieved_docs[parsed_dict["source"]] = {parsed_dict["page"]}
@tool(response_format="content_and_artifact")
def _retrieve(
query: str, full_user_content: str, vector_store: Annotated[Any, InjectedStore()]
) -> tuple[str, list[Document]]:
"""
Retrieve information related to a query and user content.
"""
# This method is used as a tool in the graph.
# It's doc-string is used for the pydantic model, please consider doc-string text carefully.
# Furthermore, it can not and should not have the `self` parameter.
# If you want to pass on state, please refer to:
# https://python.langchain.com/docs/concepts/tools/#special-type-annotations
logger.debug(f"query: {query}")
logger.debug(f"user content: {full_user_content}")
retrieved_docs = []
retrieved_docs = vector_store.similarity_search(query, k=4)
retrieved_docs = vector_store.similarity_search(full_user_content, k=4)
serialized = "\n\n".join((f"Source: {doc.metadata}\nContent: {doc.page_content}") for doc in retrieved_docs)
return serialized, retrieved_docs
def _query_or_respond(self, state: MessagesState) -> dict[str, BaseMessage]:
"""Generate tool call for retrieval or respond."""
# Reset last retrieved docs
self.last_retrieved_docs = {}
self.last_retrieved_sources = set()
llm_with_tools = self.chat_model.bind_tools([self._retrieve])
response = llm_with_tools.invoke(state["messages"])
return {"messages": [response]}
def _generate(self, state: MessagesState) -> dict[str, BaseMessage]:
"""Generate answer."""
# get generated ToolMessages
recent_tool_messages = []
for message in reversed(state["messages"]):
if message.type == "tool":
recent_tool_messages.append(message)
else:
break
tool_messages = recent_tool_messages[::-1]
# format into prompt
docs_content = "\n\n".join(doc.content for doc in tool_messages)
system_message_content = self.system_prompt + f"\n\n{docs_content}"
conversation_messages = [
message
for message in state["messages"]
if message.type in ("human", "system") or (message.type == "ai" and not message.tool_calls)
]
prompt = [SystemMessage(system_message_content)] + conversation_messages
# run
response = self.chat_model.invoke(prompt)
return {"messages": [response]}