Fix pdf source retrieval information

This commit is contained in:
Nielson Janné 2025-03-28 15:08:32 +01:00
parent d1e9b3d8cf
commit cd14c8add2

View File

@ -1,3 +1,4 @@
from pathlib import Path
from typing import Any, Union from typing import Any, Union
from langchain import hub from langchain import hub
@ -32,7 +33,7 @@ class RetGenLangGraph:
def invoke(self, message: str, config: dict) -> Union[dict[str, Any], Any]: def invoke(self, message: str, config: dict) -> Union[dict[str, Any], Any]:
self.last_invoke = self.graph.invoke({"question": message}, config=config) self.last_invoke = self.graph.invoke({"question": message}, config=config)
return self.last_invoke["answer"] return self.last_invoke["answer"]
def _retrieve(self, state: State) -> dict: def _retrieve(self, state: State) -> dict:
retrieved_docs = self.vector_store.similarity_search(state["question"]) retrieved_docs = self.vector_store.similarity_search(state["question"])
return {"context": retrieved_docs} return {"context": retrieved_docs}
@ -49,19 +50,33 @@ class RetGenLangGraph:
""" """
if self.last_invoke is None: if self.last_invoke is None:
return [] return []
pdf_sources = {} pdf_sources = {}
for context in self.last_invoke["context"]: for context in self.last_invoke["context"]:
try: try:
if context.metadata["filetype"] == "application/pdf": Path(context.metadata["source"]).suffix == ".pdf"
source = context.metadata["source"] except KeyError:
page_number = context.metadata["page_number"] continue
if source in pdf_sources: else:
pdf_sources[source].add(page_number) source = context.metadata["source"]
else:
pdf_sources[source] = {page_number} if source not in pdf_sources:
pdf_sources[source] = set()
# The page numbers are in the `page_numer` and `page` fields.
try:
page_number = context.metadata["page_number"]
except KeyError: except KeyError:
pass pass
else:
pdf_sources[source].add(page_number)
try:
page_number = context.metadata["page"]
except KeyError:
pass
else:
pdf_sources[source].add(page_number)
return pdf_sources return pdf_sources
@ -71,13 +86,14 @@ class RetGenLangGraph:
""" """
if self.last_invoke is None: if self.last_invoke is None:
return set() return set()
web_sources = set() web_sources = set()
for context in self.last_invoke["context"]: for context in self.last_invoke["context"]:
try: try:
if context.metadata["filetype"] == "web": context.metadata["filetype"] == "web"
web_sources.add(context.metadata["source"])
except KeyError: except KeyError:
pass continue
else:
web_sources.add(context.metadata["source"])
return web_sources return web_sources