forked from AI_team/Philosophy-RAG-demo
81 lines
2.8 KiB
Python
81 lines
2.8 KiB
Python
from langgraph.graph import START, END, StateGraph
|
|
from typing_extensions import List, TypedDict
|
|
from langchain_core.documents import Document
|
|
from langchain import hub
|
|
from typing import Any, Union
|
|
|
|
|
|
class State(TypedDict):
|
|
question: str
|
|
context: List[Document]
|
|
answer: str
|
|
|
|
|
|
class RetGenLangGraph:
|
|
def __init__(self, vector_store, chat_model, embedding_model):
|
|
self.vector_store = vector_store
|
|
self.chat_model = chat_model
|
|
self.embedding_model = embedding_model
|
|
self.prompt = hub.pull("rlm/rag-prompt")
|
|
|
|
graph_builder = StateGraph(State).add_sequence([self._retrieve, self._generate])
|
|
graph_builder.add_edge(START, "_retrieve")
|
|
graph_builder.add_edge("_retrieve", "_generate")
|
|
graph_builder.add_edge("_generate", END)
|
|
|
|
self.graph = graph_builder.compile()
|
|
self.last_invoke = None
|
|
|
|
def invoke(self, message: str) -> Union[dict[str, Any], Any]:
|
|
self.last_invoke = self.graph.invoke(message)
|
|
return self.last_invoke
|
|
|
|
def _retrieve(self, state: State) -> dict:
|
|
retrieved_docs = self.vector_store.similarity_search(state["question"])
|
|
return {"context": retrieved_docs}
|
|
|
|
def _generate(self, state: State) -> dict:
|
|
docs_content = "\n\n".join(doc.page_content for doc in state["context"])
|
|
messages = self.prompt.invoke({"question": state["question"], "context": docs_content})
|
|
response = self.chat_model.invoke(messages)
|
|
return {"answer": response.content}
|
|
|
|
def get_last_pdf_sources(self) -> dict[str, list[int]]:
|
|
"""
|
|
Method that retrieves the PDF sources used during the last invoke.
|
|
"""
|
|
if self.last_invoke is None:
|
|
return []
|
|
|
|
pdf_sources = {}
|
|
for context in self.last_invoke["context"]:
|
|
try:
|
|
if context.metadata["filetype"] == "application/pdf":
|
|
source = context.metadata["source"]
|
|
page_number = context.metadata["page_number"]
|
|
if source in pdf_sources:
|
|
pdf_sources[source].add(page_number)
|
|
else:
|
|
pdf_sources[source] = {page_number}
|
|
except KeyError:
|
|
pass
|
|
|
|
return pdf_sources
|
|
|
|
def get_last_web_sources(self) -> set:
|
|
"""
|
|
Method that retrieves the web sources used during the last invoke.
|
|
"""
|
|
if self.last_invoke is None:
|
|
return set()
|
|
|
|
web_sources = set()
|
|
for context in self.last_invoke["context"]:
|
|
try:
|
|
if context.metadata["filetype"] == "web":
|
|
web_sources.add(context.metadata["source"])
|
|
except KeyError:
|
|
pass
|
|
|
|
return web_sources
|