from pathlib import Path from typing import Any, Union from langchain import hub from langchain_core.documents import Document from langgraph.checkpoint.memory import MemorySaver from langgraph.graph import END, START, StateGraph from typing_extensions import List, TypedDict class State(TypedDict): question: str context: List[Document] answer: str class RetGenLangGraph: def __init__(self, vector_store, chat_model, embedding_model): self.vector_store = vector_store self.chat_model = chat_model self.embedding_model = embedding_model self.prompt = hub.pull("rlm/rag-prompt") memory = MemorySaver() graph_builder = StateGraph(State).add_sequence([self._retrieve, self._generate]) graph_builder.add_edge(START, "_retrieve") graph_builder.add_edge("_retrieve", "_generate") graph_builder.add_edge("_generate", END) self.graph = graph_builder.compile(memory) self.last_invoke = None def invoke(self, message: str, config: dict) -> Union[dict[str, Any], Any]: self.last_invoke = self.graph.invoke({"question": message}, config=config) return self.last_invoke["answer"] def _retrieve(self, state: State) -> dict: retrieved_docs = self.vector_store.similarity_search(state["question"]) return {"context": retrieved_docs} def _generate(self, state: State) -> dict: docs_content = "\n\n".join(doc.page_content for doc in state["context"]) messages = self.prompt.invoke({"question": state["question"], "context": docs_content}) response = self.chat_model.invoke(messages) return {"answer": response.content} def get_last_pdf_sources(self) -> dict[str, list[int]]: """ Method that retrieves the PDF sources used during the last invoke. """ pdf_sources = {} if self.last_invoke is None: return pdf_sources for context in self.last_invoke["context"]: try: Path(context.metadata["source"]).suffix == ".pdf" except KeyError: continue else: source = context.metadata["source"] if source not in pdf_sources: pdf_sources[source] = set() # The page numbers are in the `page_numer` and `page` fields. try: page_number = context.metadata["page_number"] except KeyError: pass else: pdf_sources[source].add(page_number) try: page_number = context.metadata["page"] except KeyError: pass else: pdf_sources[source].add(page_number) return pdf_sources def get_last_web_sources(self) -> set: """ Method that retrieves the web sources used during the last invoke. """ web_sources = set() if self.last_invoke is None: return web_sources for context in self.last_invoke["context"]: try: context.metadata["filetype"] == "web" except KeyError: continue else: web_sources.add(context.metadata["source"]) return web_sources