From b07eca8f9b88616b64531112478ef7ce9312f6a9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nielson=20Jann=C3=A9?= Date: Fri, 14 Mar 2025 23:19:51 +0100 Subject: [PATCH] Only process files and websites if not already in Chroma DB. --- generic_rag/app.py | 25 ++++----- generic_rag/parsers/parser.py | 102 +++++++++++++++++++++------------- 2 files changed, 75 insertions(+), 52 deletions(-) diff --git a/generic_rag/app.py b/generic_rag/app.py index 4c16e50..9432ebd 100644 --- a/generic_rag/app.py +++ b/generic_rag/app.py @@ -5,17 +5,15 @@ import os from pathlib import Path import chainlit as cl +from backend.models import BackendType, get_chat_model, get_embedding_model from chainlit.cli import run_chainlit from langchain import hub from langchain_chroma import Chroma from langchain_core.documents import Document from langgraph.graph import START, StateGraph +from parsers.parser import add_pdf_files, add_urls from typing_extensions import List, TypedDict -from backend.models import BackendType, get_embedding_model, get_chat_model -from parsers.parser import process_local_files, process_web_sites -from langchain_community.vectorstores.utils import filter_complex_metadata - logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @@ -39,6 +37,7 @@ parser.add_argument("--web-chunk-size", type=int, default=200, help="The size of the chunks to split the text into.") parser.add_argument("-c", "--chroma-db-location", type=Path, default=Path(".chroma_db"), help="file path to store or load a Chroma DB from/to.") +parser.add_argument("-r", "--reset-chrome-db", action="store_true", help="Reset the Chroma DB.") args = parser.parse_args() @@ -111,16 +110,16 @@ async def set_starters(): if __name__ == "__main__": - pdf_splits = process_local_files(args.pdf_data, args.pdf_chunk_size, args.pdf_chunk_overlap, - args.pdf_add_start_index) - web_splits = process_web_sites(args.web_data, args.web_chunk_size) + vector_store = Chroma( + collection_name="generic_rag", + embedding_function=get_embedding_model(args.back_end), + persist_directory=str(args.chroma_db_location), + ) - filtered_splits = filter_complex_metadata(pdf_splits + web_splits) + if args.reset_chrome_db: + vector_store.reset_collection() - vector_store = Chroma(collection_name="generic_rag", - embedding_function=get_embedding_model(args.back_end), - persist_directory=str(args.chroma_db_location)) - _ = vector_store.add_documents(documents=filtered_splits) - del vector_store + add_pdf_files(vector_store, args.pdf_data, args.pdf_chunk_size, args.pdf_chunk_overlap, args.pdf_add_start_index) + add_urls(vector_store, args.web_data, args.web_chunk_size) run_chainlit(__file__) diff --git a/generic_rag/parsers/parser.py b/generic_rag/parsers/parser.py index 92f2fa1..ceca313 100644 --- a/generic_rag/parsers/parser.py +++ b/generic_rag/parsers/parser.py @@ -2,20 +2,17 @@ import logging from pathlib import Path import requests -from bs4 import BeautifulSoup -from langchain_core.documents import Document -from langchain_text_splitters import HTMLSemanticPreservingSplitter -from langchain_text_splitters import RecursiveCharacterTextSplitter +from bs4 import BeautifulSoup, Tag +from langchain_chroma import Chroma +from langchain_community.vectorstores.utils import filter_complex_metadata +from langchain_text_splitters import HTMLSemanticPreservingSplitter, RecursiveCharacterTextSplitter from langchain_unstructured import UnstructuredLoader logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) -from bs4 import Tag -headers_to_split_on = [ - ("h1", "Header 1"), - ("h2", "Header 2"), -] + +headers_to_split_on = [("h1", "Header 1"), ("h2", "Header 2")] def code_handler(element: Tag) -> str: @@ -28,23 +25,17 @@ def code_handler(element: Tag) -> str: return code_format -def process_web_sites(websites: list[str], chunk_size: int) -> list[Document]: - """ - Process one or more websites and returns a list of langchain Document's. - """ - if len(websites) == 0: - return [] +def add_urls(vector_store: Chroma, urls: list[str], chunk_size: int) -> None: + all_splits = [] + for url in urls: + if len(vector_store.get(where={"source": url}, limit=1)["ids"]) > 0: + continue - splits = [] - for url in websites: - # Fetch the webpage response = requests.get(url) html_content = response.text - # Parse the HTML soup = BeautifulSoup(html_content, "html.parser") - # split documents web_splitter = HTMLSemanticPreservingSplitter( headers_to_split_on=headers_to_split_on, separators=["\n\n", "\n", ". ", "! ", "? "], @@ -53,32 +44,65 @@ def process_web_sites(websites: list[str], chunk_size: int) -> list[Document]: preserve_videos=True, elements_to_preserve=["table", "ul", "ol", "code"], denylist_tags=["script", "style", "head"], - custom_handlers={"code": code_handler}) + custom_handlers={"code": code_handler}, + ) - splits.extend(web_splitter.split_text(str(soup))) + splits = web_splitter.split_text(str(soup)) - return splits + for split in splits: + split.metadata["source"] = url + + all_splits.extend(splits) + + if len(all_splits) == 0: + return + + filtered_splits = filter_complex_metadata(all_splits) + vector_store.add_documents(documents=filtered_splits) -def process_local_files( - local_paths: list[Path], chunk_size: int, chunk_overlap: int, add_start_index: bool -) -> list[Document]: - process_files = [] - for path in local_paths: - if path.is_dir(): - process_files.extend(list(path.glob("*.pdf"))) - elif path.suffix == ".pdf": - process_files.append(path) - else: - logging.warning(f"Ignoring path {path} as it is not a folder or pdf file.") +def add_pdf_files( + vector_store: Chroma, file_paths: list[Path], chunk_size: int, chunk_overlap: int, add_start_index: bool +) -> None: + pdf_files = get_all_local_pdf_files(file_paths) + + new_pdfs = [] + for pdf_file in pdf_files: + if len(vector_store.get(where={"source": str(pdf_file)}, limit=1)["ids"]) == 0: + new_pdfs.append(pdf_file) + + if len(new_pdfs) == 0: + return loaded_document = [] - for file in process_files: + for file in new_pdfs: loader = UnstructuredLoader(file_path=file, strategy="hi_res") for document in loader.lazy_load(): loaded_document.append(document) - text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, - chunk_overlap=chunk_overlap, - add_start_index=add_start_index) - return text_splitter.split_documents(loaded_document) + text_splitter = RecursiveCharacterTextSplitter( + chunk_size=chunk_size, chunk_overlap=chunk_overlap, add_start_index=add_start_index + ) + + pdf_splits = text_splitter.split_documents(loaded_document) + + vector_store.add_documents(documents=filter_complex_metadata(pdf_splits)) + + +def get_all_local_pdf_files(local_paths: list[Path]) -> list[Path]: + """ + Function that takes a list of local paths, + that might contain directories paths and/or direct file paths, + and returns a list with all file paths that are a PDF file or any PDF files found in the directory file paths. + This fucntion does not scan directories recursively. + """ + all_pdf_files = [] + for path in local_paths: + if path.is_dir(): + all_pdf_files.extend(list(path.glob("*.pdf"))) + elif path.suffix == ".pdf": + all_pdf_files.append(path) + else: + logging.warning(f"Ignoring path {path} as it is not a folder or pdf file.") + + return all_pdf_files