Philosophy-RAG-demo/generic_rag/graphs/ret_gen.py
2025-03-17 17:48:30 +01:00

84 lines
2.9 KiB
Python

from typing import Any, Union
from langchain import hub
from langchain_core.documents import Document
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import END, START, StateGraph
from typing_extensions import List, TypedDict
class State(TypedDict):
question: str
context: List[Document]
answer: str
class RetGenLangGraph:
def __init__(self, vector_store, chat_model, embedding_model):
self.vector_store = vector_store
self.chat_model = chat_model
self.embedding_model = embedding_model
self.prompt = hub.pull("rlm/rag-prompt")
memory = MemorySaver()
graph_builder = StateGraph(State).add_sequence([self._retrieve, self._generate])
graph_builder.add_edge(START, "_retrieve")
graph_builder.add_edge("_retrieve", "_generate")
graph_builder.add_edge("_generate", END)
self.graph = graph_builder.compile(memory)
self.last_invoke = None
def invoke(self, message: str, config: dict) -> Union[dict[str, Any], Any]:
self.last_invoke = self.graph.invoke({"question": message}, config=config)
return self.last_invoke
def _retrieve(self, state: State) -> dict:
retrieved_docs = self.vector_store.similarity_search(state["question"])
return {"context": retrieved_docs}
def _generate(self, state: State) -> dict:
docs_content = "\n\n".join(doc.page_content for doc in state["context"])
messages = self.prompt.invoke({"question": state["question"], "context": docs_content})
response = self.chat_model.invoke(messages)
return {"answer": response.content}
def get_last_pdf_sources(self) -> dict[str, list[int]]:
"""
Method that retrieves the PDF sources used during the last invoke.
"""
if self.last_invoke is None:
return []
pdf_sources = {}
for context in self.last_invoke["context"]:
try:
if context.metadata["filetype"] == "application/pdf":
source = context.metadata["source"]
page_number = context.metadata["page_number"]
if source in pdf_sources:
pdf_sources[source].add(page_number)
else:
pdf_sources[source] = {page_number}
except KeyError:
pass
return pdf_sources
def get_last_web_sources(self) -> set:
"""
Method that retrieves the web sources used during the last invoke.
"""
if self.last_invoke is None:
return set()
web_sources = set()
for context in self.last_invoke["context"]:
try:
if context.metadata["filetype"] == "web":
web_sources.add(context.metadata["source"])
except KeyError:
pass
return web_sources