forked from AI_team/Philosophy-RAG-demo
Only process files and websites if not already in Chroma DB.
This commit is contained in:
parent
1f75264e96
commit
b07eca8f9b
@ -5,17 +5,15 @@ import os
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import chainlit as cl
|
import chainlit as cl
|
||||||
|
from backend.models import BackendType, get_chat_model, get_embedding_model
|
||||||
from chainlit.cli import run_chainlit
|
from chainlit.cli import run_chainlit
|
||||||
from langchain import hub
|
from langchain import hub
|
||||||
from langchain_chroma import Chroma
|
from langchain_chroma import Chroma
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
from langgraph.graph import START, StateGraph
|
from langgraph.graph import START, StateGraph
|
||||||
|
from parsers.parser import add_pdf_files, add_urls
|
||||||
from typing_extensions import List, TypedDict
|
from typing_extensions import List, TypedDict
|
||||||
|
|
||||||
from backend.models import BackendType, get_embedding_model, get_chat_model
|
|
||||||
from parsers.parser import process_local_files, process_web_sites
|
|
||||||
from langchain_community.vectorstores.utils import filter_complex_metadata
|
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -39,6 +37,7 @@ parser.add_argument("--web-chunk-size", type=int, default=200,
|
|||||||
help="The size of the chunks to split the text into.")
|
help="The size of the chunks to split the text into.")
|
||||||
parser.add_argument("-c", "--chroma-db-location", type=Path, default=Path(".chroma_db"),
|
parser.add_argument("-c", "--chroma-db-location", type=Path, default=Path(".chroma_db"),
|
||||||
help="file path to store or load a Chroma DB from/to.")
|
help="file path to store or load a Chroma DB from/to.")
|
||||||
|
parser.add_argument("-r", "--reset-chrome-db", action="store_true", help="Reset the Chroma DB.")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
@ -111,16 +110,16 @@ async def set_starters():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
pdf_splits = process_local_files(args.pdf_data, args.pdf_chunk_size, args.pdf_chunk_overlap,
|
vector_store = Chroma(
|
||||||
args.pdf_add_start_index)
|
collection_name="generic_rag",
|
||||||
web_splits = process_web_sites(args.web_data, args.web_chunk_size)
|
|
||||||
|
|
||||||
filtered_splits = filter_complex_metadata(pdf_splits + web_splits)
|
|
||||||
|
|
||||||
vector_store = Chroma(collection_name="generic_rag",
|
|
||||||
embedding_function=get_embedding_model(args.back_end),
|
embedding_function=get_embedding_model(args.back_end),
|
||||||
persist_directory=str(args.chroma_db_location))
|
persist_directory=str(args.chroma_db_location),
|
||||||
_ = vector_store.add_documents(documents=filtered_splits)
|
)
|
||||||
del vector_store
|
|
||||||
|
if args.reset_chrome_db:
|
||||||
|
vector_store.reset_collection()
|
||||||
|
|
||||||
|
add_pdf_files(vector_store, args.pdf_data, args.pdf_chunk_size, args.pdf_chunk_overlap, args.pdf_add_start_index)
|
||||||
|
add_urls(vector_store, args.web_data, args.web_chunk_size)
|
||||||
|
|
||||||
run_chainlit(__file__)
|
run_chainlit(__file__)
|
||||||
|
|||||||
@ -2,20 +2,17 @@ import logging
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from bs4 import BeautifulSoup
|
from bs4 import BeautifulSoup, Tag
|
||||||
from langchain_core.documents import Document
|
from langchain_chroma import Chroma
|
||||||
from langchain_text_splitters import HTMLSemanticPreservingSplitter
|
from langchain_community.vectorstores.utils import filter_complex_metadata
|
||||||
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
from langchain_text_splitters import HTMLSemanticPreservingSplitter, RecursiveCharacterTextSplitter
|
||||||
from langchain_unstructured import UnstructuredLoader
|
from langchain_unstructured import UnstructuredLoader
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
from bs4 import Tag
|
|
||||||
|
|
||||||
headers_to_split_on = [
|
|
||||||
("h1", "Header 1"),
|
headers_to_split_on = [("h1", "Header 1"), ("h2", "Header 2")]
|
||||||
("h2", "Header 2"),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def code_handler(element: Tag) -> str:
|
def code_handler(element: Tag) -> str:
|
||||||
@ -28,23 +25,17 @@ def code_handler(element: Tag) -> str:
|
|||||||
return code_format
|
return code_format
|
||||||
|
|
||||||
|
|
||||||
def process_web_sites(websites: list[str], chunk_size: int) -> list[Document]:
|
def add_urls(vector_store: Chroma, urls: list[str], chunk_size: int) -> None:
|
||||||
"""
|
all_splits = []
|
||||||
Process one or more websites and returns a list of langchain Document's.
|
for url in urls:
|
||||||
"""
|
if len(vector_store.get(where={"source": url}, limit=1)["ids"]) > 0:
|
||||||
if len(websites) == 0:
|
continue
|
||||||
return []
|
|
||||||
|
|
||||||
splits = []
|
|
||||||
for url in websites:
|
|
||||||
# Fetch the webpage
|
|
||||||
response = requests.get(url)
|
response = requests.get(url)
|
||||||
html_content = response.text
|
html_content = response.text
|
||||||
|
|
||||||
# Parse the HTML
|
|
||||||
soup = BeautifulSoup(html_content, "html.parser")
|
soup = BeautifulSoup(html_content, "html.parser")
|
||||||
|
|
||||||
# split documents
|
|
||||||
web_splitter = HTMLSemanticPreservingSplitter(
|
web_splitter = HTMLSemanticPreservingSplitter(
|
||||||
headers_to_split_on=headers_to_split_on,
|
headers_to_split_on=headers_to_split_on,
|
||||||
separators=["\n\n", "\n", ". ", "! ", "? "],
|
separators=["\n\n", "\n", ". ", "! ", "? "],
|
||||||
@ -53,32 +44,65 @@ def process_web_sites(websites: list[str], chunk_size: int) -> list[Document]:
|
|||||||
preserve_videos=True,
|
preserve_videos=True,
|
||||||
elements_to_preserve=["table", "ul", "ol", "code"],
|
elements_to_preserve=["table", "ul", "ol", "code"],
|
||||||
denylist_tags=["script", "style", "head"],
|
denylist_tags=["script", "style", "head"],
|
||||||
custom_handlers={"code": code_handler})
|
custom_handlers={"code": code_handler},
|
||||||
|
)
|
||||||
|
|
||||||
splits.extend(web_splitter.split_text(str(soup)))
|
splits = web_splitter.split_text(str(soup))
|
||||||
|
|
||||||
return splits
|
for split in splits:
|
||||||
|
split.metadata["source"] = url
|
||||||
|
|
||||||
|
all_splits.extend(splits)
|
||||||
|
|
||||||
|
if len(all_splits) == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
filtered_splits = filter_complex_metadata(all_splits)
|
||||||
|
vector_store.add_documents(documents=filtered_splits)
|
||||||
|
|
||||||
|
|
||||||
def process_local_files(
|
def add_pdf_files(
|
||||||
local_paths: list[Path], chunk_size: int, chunk_overlap: int, add_start_index: bool
|
vector_store: Chroma, file_paths: list[Path], chunk_size: int, chunk_overlap: int, add_start_index: bool
|
||||||
) -> list[Document]:
|
) -> None:
|
||||||
process_files = []
|
pdf_files = get_all_local_pdf_files(file_paths)
|
||||||
for path in local_paths:
|
|
||||||
if path.is_dir():
|
new_pdfs = []
|
||||||
process_files.extend(list(path.glob("*.pdf")))
|
for pdf_file in pdf_files:
|
||||||
elif path.suffix == ".pdf":
|
if len(vector_store.get(where={"source": str(pdf_file)}, limit=1)["ids"]) == 0:
|
||||||
process_files.append(path)
|
new_pdfs.append(pdf_file)
|
||||||
else:
|
|
||||||
logging.warning(f"Ignoring path {path} as it is not a folder or pdf file.")
|
if len(new_pdfs) == 0:
|
||||||
|
return
|
||||||
|
|
||||||
loaded_document = []
|
loaded_document = []
|
||||||
for file in process_files:
|
for file in new_pdfs:
|
||||||
loader = UnstructuredLoader(file_path=file, strategy="hi_res")
|
loader = UnstructuredLoader(file_path=file, strategy="hi_res")
|
||||||
for document in loader.lazy_load():
|
for document in loader.lazy_load():
|
||||||
loaded_document.append(document)
|
loaded_document.append(document)
|
||||||
|
|
||||||
text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size,
|
text_splitter = RecursiveCharacterTextSplitter(
|
||||||
chunk_overlap=chunk_overlap,
|
chunk_size=chunk_size, chunk_overlap=chunk_overlap, add_start_index=add_start_index
|
||||||
add_start_index=add_start_index)
|
)
|
||||||
return text_splitter.split_documents(loaded_document)
|
|
||||||
|
pdf_splits = text_splitter.split_documents(loaded_document)
|
||||||
|
|
||||||
|
vector_store.add_documents(documents=filter_complex_metadata(pdf_splits))
|
||||||
|
|
||||||
|
|
||||||
|
def get_all_local_pdf_files(local_paths: list[Path]) -> list[Path]:
|
||||||
|
"""
|
||||||
|
Function that takes a list of local paths,
|
||||||
|
that might contain directories paths and/or direct file paths,
|
||||||
|
and returns a list with all file paths that are a PDF file or any PDF files found in the directory file paths.
|
||||||
|
This fucntion does not scan directories recursively.
|
||||||
|
"""
|
||||||
|
all_pdf_files = []
|
||||||
|
for path in local_paths:
|
||||||
|
if path.is_dir():
|
||||||
|
all_pdf_files.extend(list(path.glob("*.pdf")))
|
||||||
|
elif path.suffix == ".pdf":
|
||||||
|
all_pdf_files.append(path)
|
||||||
|
else:
|
||||||
|
logging.warning(f"Ignoring path {path} as it is not a folder or pdf file.")
|
||||||
|
|
||||||
|
return all_pdf_files
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user