diff --git a/generic_rag/graphs/ret_gen.py b/generic_rag/graphs/ret_gen.py index f7f1fae..c6f58d3 100644 --- a/generic_rag/graphs/ret_gen.py +++ b/generic_rag/graphs/ret_gen.py @@ -1,3 +1,4 @@ +from pathlib import Path from typing import Any, Union from langchain import hub @@ -32,7 +33,7 @@ class RetGenLangGraph: def invoke(self, message: str, config: dict) -> Union[dict[str, Any], Any]: self.last_invoke = self.graph.invoke({"question": message}, config=config) return self.last_invoke["answer"] - + def _retrieve(self, state: State) -> dict: retrieved_docs = self.vector_store.similarity_search(state["question"]) return {"context": retrieved_docs} @@ -49,19 +50,33 @@ class RetGenLangGraph: """ if self.last_invoke is None: return [] - + pdf_sources = {} for context in self.last_invoke["context"]: try: - if context.metadata["filetype"] == "application/pdf": - source = context.metadata["source"] - page_number = context.metadata["page_number"] - if source in pdf_sources: - pdf_sources[source].add(page_number) - else: - pdf_sources[source] = {page_number} + Path(context.metadata["source"]).suffix == ".pdf" + except KeyError: + continue + else: + source = context.metadata["source"] + + if source not in pdf_sources: + pdf_sources[source] = set() + + # The page numbers are in the `page_numer` and `page` fields. + try: + page_number = context.metadata["page_number"] except KeyError: pass + else: + pdf_sources[source].add(page_number) + + try: + page_number = context.metadata["page"] + except KeyError: + pass + else: + pdf_sources[source].add(page_number) return pdf_sources @@ -71,13 +86,14 @@ class RetGenLangGraph: """ if self.last_invoke is None: return set() - + web_sources = set() for context in self.last_invoke["context"]: try: - if context.metadata["filetype"] == "web": - web_sources.add(context.metadata["source"]) + context.metadata["filetype"] == "web" except KeyError: - pass + continue + else: + web_sources.add(context.metadata["source"]) return web_sources