Philosophy-RAG-demo/generic_rag/graphs/ret_gen.py
2025-04-10 15:35:26 +02:00

109 lines
3.7 KiB
Python

import logging
from pathlib import Path
from typing import Any, AsyncGenerator
from langchain import hub
from langchain_chroma import Chroma
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.language_models.chat_models import BaseChatModel
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import END, START, StateGraph
from typing_extensions import List, TypedDict
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class State(TypedDict):
question: str
context: List[Document]
answer: str
class RetGenLangGraph:
def __init__(self, vector_store: Chroma, chat_model: BaseChatModel, embedding_model: Embeddings):
self.vector_store = vector_store
self.chat_model = chat_model
self.embedding_model = embedding_model
self.prompt = hub.pull("rlm/rag-prompt")
memory = MemorySaver()
graph_builder = StateGraph(State).add_sequence([self._retrieve, self._generate])
graph_builder.add_edge(START, "_retrieve")
graph_builder.add_edge("_retrieve", "_generate")
graph_builder.add_edge("_generate", END)
self.graph = graph_builder.compile(memory)
self.last_retrieved_docs = []
async def stream(self, message: str, config: dict) -> AsyncGenerator[Any, Any]:
async for response, _ in self.graph.astream({"question": message}, stream_mode="messages", config=config):
yield response.content
def _retrieve(self, state: State) -> dict:
self.last_retrieved_docs = self.vector_store.similarity_search(state["question"])
return {"context": self.last_retrieved_docs}
async def _generate(self, state: State) -> AsyncGenerator[Any, Any]:
docs_content = "\n\n".join(doc.page_content for doc in state["context"])
messages = await self.prompt.ainvoke({"question": state["question"], "context": docs_content})
async for response in self.chat_model.astream(messages):
yield {"answer": response.content}
def get_last_pdf_sources(self) -> dict[str, list[int]]:
"""
Method that retrieves the PDF sources used during the last invoke.
"""
pdf_sources = {}
if not self.last_retrieved_docs:
return pdf_sources
for doc in self.last_retrieved_docs:
if "source" in doc.metadata and Path(doc.metadata["source"]).suffix.lower() == ".pdf":
source = doc.metadata["source"]
else:
continue
if source not in pdf_sources:
pdf_sources[source] = set()
# The page numbers are in the `page_numer` and `page` fields.
try:
page_number = doc.metadata["page_number"]
except KeyError:
pass
else:
pdf_sources[source].add(page_number)
try:
page_number = doc.metadata["page"]
except KeyError:
pass
else:
pdf_sources[source].add(page_number)
if len(pdf_sources[source]) == 0:
logging.warning(f"PDF source {source} has no page number. Please check the metadata of the document.")
return pdf_sources
def get_last_web_sources(self) -> set:
"""
Method that retrieves the web sources used during the last invoke.
"""
web_sources = set()
if not self.last_retrieved_docs:
return web_sources
for doc in self.last_retrieved_docs:
try:
if doc.metadata["filetype"] == "web":
web_sources.add(doc.metadata["source"])
except KeyError:
continue
return web_sources