import logging from pathlib import Path from typing import Any, AsyncGenerator from langchain import hub from langchain_chroma import Chroma from langchain_core.documents import Document from langchain_core.embeddings import Embeddings from langchain_core.language_models.chat_models import BaseChatModel from langgraph.checkpoint.memory import MemorySaver from langgraph.graph import END, START, StateGraph from typing_extensions import List, TypedDict logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class State(TypedDict): question: str context: List[Document] answer: str class RetGenLangGraph: def __init__(self, vector_store: Chroma, chat_model: BaseChatModel, embedding_model: Embeddings): self.vector_store = vector_store self.chat_model = chat_model self.embedding_model = embedding_model self.prompt = hub.pull("rlm/rag-prompt") memory = MemorySaver() graph_builder = StateGraph(State).add_sequence([self._retrieve, self._generate]) graph_builder.add_edge(START, "_retrieve") graph_builder.add_edge("_retrieve", "_generate") graph_builder.add_edge("_generate", END) self.graph = graph_builder.compile(memory) self.last_retrieved_docs = [] async def stream(self, message: str, config: dict) -> AsyncGenerator[Any, Any]: async for response, _ in self.graph.astream({"question": message}, stream_mode="messages", config=config): yield response.content def _retrieve(self, state: State) -> dict: self.last_retrieved_docs = self.vector_store.similarity_search(state["question"]) return {"context": self.last_retrieved_docs} async def _generate(self, state: State) -> AsyncGenerator[Any, Any]: docs_content = "\n\n".join(doc.page_content for doc in state["context"]) messages = await self.prompt.ainvoke({"question": state["question"], "context": docs_content}) async for response in self.chat_model.astream(messages): yield {"answer": response.content} def get_last_pdf_sources(self) -> dict[str, list[int]]: """ Method that retrieves the PDF sources used during the last invoke. """ pdf_sources = {} if not self.last_retrieved_docs: return pdf_sources for doc in self.last_retrieved_docs: try: Path(doc.metadata["source"]).suffix == ".pdf" except KeyError: continue else: source = doc.metadata["source"] if source not in pdf_sources: pdf_sources[source] = set() # The page numbers are in the `page_numer` and `page` fields. try: page_number = doc.metadata["page_number"] except KeyError: pass else: pdf_sources[source].add(page_number) try: page_number = doc.metadata["page"] except KeyError: pass else: pdf_sources[source].add(page_number) if len(pdf_sources[source]) == 0: logging.warning(f"PDF source {source} has no page number. Please check the metadata of the document.") return pdf_sources def get_last_web_sources(self) -> set: """ Method that retrieves the web sources used during the last invoke. """ web_sources = set() if not self.last_retrieved_docs: return web_sources for doc in self.last_retrieved_docs: try: doc.metadata["filetype"] == "web" except KeyError: continue else: web_sources.add(doc.metadata["source"]) return web_sources