Only process files and websites if not already in Chroma DB.

This commit is contained in:
Nielson Janné 2025-03-14 23:19:51 +01:00
parent 1f75264e96
commit b07eca8f9b
2 changed files with 75 additions and 52 deletions

View File

@ -5,17 +5,15 @@ import os
from pathlib import Path
import chainlit as cl
from backend.models import BackendType, get_chat_model, get_embedding_model
from chainlit.cli import run_chainlit
from langchain import hub
from langchain_chroma import Chroma
from langchain_core.documents import Document
from langgraph.graph import START, StateGraph
from parsers.parser import add_pdf_files, add_urls
from typing_extensions import List, TypedDict
from backend.models import BackendType, get_embedding_model, get_chat_model
from parsers.parser import process_local_files, process_web_sites
from langchain_community.vectorstores.utils import filter_complex_metadata
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@ -39,6 +37,7 @@ parser.add_argument("--web-chunk-size", type=int, default=200,
help="The size of the chunks to split the text into.")
parser.add_argument("-c", "--chroma-db-location", type=Path, default=Path(".chroma_db"),
help="file path to store or load a Chroma DB from/to.")
parser.add_argument("-r", "--reset-chrome-db", action="store_true", help="Reset the Chroma DB.")
args = parser.parse_args()
@ -111,16 +110,16 @@ async def set_starters():
if __name__ == "__main__":
pdf_splits = process_local_files(args.pdf_data, args.pdf_chunk_size, args.pdf_chunk_overlap,
args.pdf_add_start_index)
web_splits = process_web_sites(args.web_data, args.web_chunk_size)
vector_store = Chroma(
collection_name="generic_rag",
embedding_function=get_embedding_model(args.back_end),
persist_directory=str(args.chroma_db_location),
)
filtered_splits = filter_complex_metadata(pdf_splits + web_splits)
if args.reset_chrome_db:
vector_store.reset_collection()
vector_store = Chroma(collection_name="generic_rag",
embedding_function=get_embedding_model(args.back_end),
persist_directory=str(args.chroma_db_location))
_ = vector_store.add_documents(documents=filtered_splits)
del vector_store
add_pdf_files(vector_store, args.pdf_data, args.pdf_chunk_size, args.pdf_chunk_overlap, args.pdf_add_start_index)
add_urls(vector_store, args.web_data, args.web_chunk_size)
run_chainlit(__file__)

View File

@ -2,20 +2,17 @@ import logging
from pathlib import Path
import requests
from bs4 import BeautifulSoup
from langchain_core.documents import Document
from langchain_text_splitters import HTMLSemanticPreservingSplitter
from langchain_text_splitters import RecursiveCharacterTextSplitter
from bs4 import BeautifulSoup, Tag
from langchain_chroma import Chroma
from langchain_community.vectorstores.utils import filter_complex_metadata
from langchain_text_splitters import HTMLSemanticPreservingSplitter, RecursiveCharacterTextSplitter
from langchain_unstructured import UnstructuredLoader
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
from bs4 import Tag
headers_to_split_on = [
("h1", "Header 1"),
("h2", "Header 2"),
]
headers_to_split_on = [("h1", "Header 1"), ("h2", "Header 2")]
def code_handler(element: Tag) -> str:
@ -28,23 +25,17 @@ def code_handler(element: Tag) -> str:
return code_format
def process_web_sites(websites: list[str], chunk_size: int) -> list[Document]:
"""
Process one or more websites and returns a list of langchain Document's.
"""
if len(websites) == 0:
return []
def add_urls(vector_store: Chroma, urls: list[str], chunk_size: int) -> None:
all_splits = []
for url in urls:
if len(vector_store.get(where={"source": url}, limit=1)["ids"]) > 0:
continue
splits = []
for url in websites:
# Fetch the webpage
response = requests.get(url)
html_content = response.text
# Parse the HTML
soup = BeautifulSoup(html_content, "html.parser")
# split documents
web_splitter = HTMLSemanticPreservingSplitter(
headers_to_split_on=headers_to_split_on,
separators=["\n\n", "\n", ". ", "! ", "? "],
@ -53,32 +44,65 @@ def process_web_sites(websites: list[str], chunk_size: int) -> list[Document]:
preserve_videos=True,
elements_to_preserve=["table", "ul", "ol", "code"],
denylist_tags=["script", "style", "head"],
custom_handlers={"code": code_handler})
custom_handlers={"code": code_handler},
)
splits.extend(web_splitter.split_text(str(soup)))
splits = web_splitter.split_text(str(soup))
return splits
for split in splits:
split.metadata["source"] = url
all_splits.extend(splits)
if len(all_splits) == 0:
return
filtered_splits = filter_complex_metadata(all_splits)
vector_store.add_documents(documents=filtered_splits)
def process_local_files(
local_paths: list[Path], chunk_size: int, chunk_overlap: int, add_start_index: bool
) -> list[Document]:
process_files = []
for path in local_paths:
if path.is_dir():
process_files.extend(list(path.glob("*.pdf")))
elif path.suffix == ".pdf":
process_files.append(path)
else:
logging.warning(f"Ignoring path {path} as it is not a folder or pdf file.")
def add_pdf_files(
vector_store: Chroma, file_paths: list[Path], chunk_size: int, chunk_overlap: int, add_start_index: bool
) -> None:
pdf_files = get_all_local_pdf_files(file_paths)
new_pdfs = []
for pdf_file in pdf_files:
if len(vector_store.get(where={"source": str(pdf_file)}, limit=1)["ids"]) == 0:
new_pdfs.append(pdf_file)
if len(new_pdfs) == 0:
return
loaded_document = []
for file in process_files:
for file in new_pdfs:
loader = UnstructuredLoader(file_path=file, strategy="hi_res")
for document in loader.lazy_load():
loaded_document.append(document)
text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
add_start_index=add_start_index)
return text_splitter.split_documents(loaded_document)
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size, chunk_overlap=chunk_overlap, add_start_index=add_start_index
)
pdf_splits = text_splitter.split_documents(loaded_document)
vector_store.add_documents(documents=filter_complex_metadata(pdf_splits))
def get_all_local_pdf_files(local_paths: list[Path]) -> list[Path]:
"""
Function that takes a list of local paths,
that might contain directories paths and/or direct file paths,
and returns a list with all file paths that are a PDF file or any PDF files found in the directory file paths.
This fucntion does not scan directories recursively.
"""
all_pdf_files = []
for path in local_paths:
if path.is_dir():
all_pdf_files.extend(list(path.glob("*.pdf")))
elif path.suffix == ".pdf":
all_pdf_files.append(path)
else:
logging.warning(f"Ignoring path {path} as it is not a folder or pdf file.")
return all_pdf_files