forked from AI_team/Philosophy-RAG-demo
Fix pdf source retrieval information
This commit is contained in:
parent
d1e9b3d8cf
commit
cd14c8add2
@ -1,3 +1,4 @@
|
|||||||
|
from pathlib import Path
|
||||||
from typing import Any, Union
|
from typing import Any, Union
|
||||||
|
|
||||||
from langchain import hub
|
from langchain import hub
|
||||||
@ -32,7 +33,7 @@ class RetGenLangGraph:
|
|||||||
def invoke(self, message: str, config: dict) -> Union[dict[str, Any], Any]:
|
def invoke(self, message: str, config: dict) -> Union[dict[str, Any], Any]:
|
||||||
self.last_invoke = self.graph.invoke({"question": message}, config=config)
|
self.last_invoke = self.graph.invoke({"question": message}, config=config)
|
||||||
return self.last_invoke["answer"]
|
return self.last_invoke["answer"]
|
||||||
|
|
||||||
def _retrieve(self, state: State) -> dict:
|
def _retrieve(self, state: State) -> dict:
|
||||||
retrieved_docs = self.vector_store.similarity_search(state["question"])
|
retrieved_docs = self.vector_store.similarity_search(state["question"])
|
||||||
return {"context": retrieved_docs}
|
return {"context": retrieved_docs}
|
||||||
@ -49,19 +50,33 @@ class RetGenLangGraph:
|
|||||||
"""
|
"""
|
||||||
if self.last_invoke is None:
|
if self.last_invoke is None:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
pdf_sources = {}
|
pdf_sources = {}
|
||||||
for context in self.last_invoke["context"]:
|
for context in self.last_invoke["context"]:
|
||||||
try:
|
try:
|
||||||
if context.metadata["filetype"] == "application/pdf":
|
Path(context.metadata["source"]).suffix == ".pdf"
|
||||||
source = context.metadata["source"]
|
except KeyError:
|
||||||
page_number = context.metadata["page_number"]
|
continue
|
||||||
if source in pdf_sources:
|
else:
|
||||||
pdf_sources[source].add(page_number)
|
source = context.metadata["source"]
|
||||||
else:
|
|
||||||
pdf_sources[source] = {page_number}
|
if source not in pdf_sources:
|
||||||
|
pdf_sources[source] = set()
|
||||||
|
|
||||||
|
# The page numbers are in the `page_numer` and `page` fields.
|
||||||
|
try:
|
||||||
|
page_number = context.metadata["page_number"]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
pass
|
pass
|
||||||
|
else:
|
||||||
|
pdf_sources[source].add(page_number)
|
||||||
|
|
||||||
|
try:
|
||||||
|
page_number = context.metadata["page"]
|
||||||
|
except KeyError:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
pdf_sources[source].add(page_number)
|
||||||
|
|
||||||
return pdf_sources
|
return pdf_sources
|
||||||
|
|
||||||
@ -71,13 +86,14 @@ class RetGenLangGraph:
|
|||||||
"""
|
"""
|
||||||
if self.last_invoke is None:
|
if self.last_invoke is None:
|
||||||
return set()
|
return set()
|
||||||
|
|
||||||
web_sources = set()
|
web_sources = set()
|
||||||
for context in self.last_invoke["context"]:
|
for context in self.last_invoke["context"]:
|
||||||
try:
|
try:
|
||||||
if context.metadata["filetype"] == "web":
|
context.metadata["filetype"] == "web"
|
||||||
web_sources.add(context.metadata["source"])
|
|
||||||
except KeyError:
|
except KeyError:
|
||||||
pass
|
continue
|
||||||
|
else:
|
||||||
|
web_sources.add(context.metadata["source"])
|
||||||
|
|
||||||
return web_sources
|
return web_sources
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user