diff --git a/.python-version b/.python-version new file mode 100644 index 0000000..e4fba21 --- /dev/null +++ b/.python-version @@ -0,0 +1 @@ +3.12 diff --git a/generic_rag/app.py b/generic_rag/app.py new file mode 100644 index 0000000..feb3218 --- /dev/null +++ b/generic_rag/app.py @@ -0,0 +1,117 @@ +import argparse +import logging +from pathlib import Path + +import chainlit as cl +from chainlit.cli import run_chainlit +from langchain import hub +from langchain_core.documents import Document +from langchain_core.vectorstores import InMemoryVectorStore +from langgraph.graph import START, StateGraph +from typing_extensions import List, TypedDict + +from backend.model import BackendType, get_embedding_model, get_chat_model +from parsers.parser import process_local_files, process_web_sites + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +parser = argparse.ArgumentParser(description="A Sogeti Nederland Generic RAG demo.") +parser.add_argument("-b", "--back-end", type=BackendType, choices=list(BackendType), default=BackendType.azure, + help="(Cloud) back-end to use. In the case of local, a locally installed ollama will be used.") +parser.add_argument("-p", "--pdf-data", type=Path, required=True, nargs="+", + help="One or multiple paths to folders or files to use for retrieval. " + "If a path is a folder, all files in the folder will be used. " + "If a path is a file, only that file will be used. " + "If the path is relative it will be relative to the current working directory.") +parser.add_argument("--pdf-chunk_size", type=int, default=1000, + help="The size of the chunks to split the text into.") +parser.add_argument("--pdf-chunk_overlap", type=int, default=200, + help="The overlap between the chunks.") +parser.add_argument("--pdf-add-start-index", action="store_true", + help="Add the start index to the metadata of the chunks.") +parser.add_argument("-w", "--web-data", type=str, nargs="*", default=[], + help="One or multiple URLs to use for retrieval.") +parser.add_argument("--web-chunk-size", type=int, default=200, + help="The size of the chunks to split the text into.") +args = parser.parse_args() + + +class State(TypedDict): + question: str + context: List[Document] + answer: str + + +def retrieve(state: State): + vector_store = cl.user_session.get("vector_store") + retrieved_docs = vector_store.similarity_search(state["question"]) + return {"context": retrieved_docs} + + +def generate(state: State): + prompt = cl.user_session.get("prompt") + llm = cl.user_session.get("chat_model") + + docs_content = "\n\n".join(doc.page_content for doc in state["context"]) + messages = prompt.invoke({"question": state["question"], "context": docs_content}) + response = llm.invoke(messages) + + return {"answer": response.content} + + +@cl.on_chat_start +async def on_chat_start(): + await cl.Message(author="System", content="Starting up application").send() + + embedding = get_embedding_model(args.back_end) + vector_store = InMemoryVectorStore(embedding) + + await cl.Message(author="System", content="Processing PDF files.").send() + pdf_splits = await cl.make_async(process_local_files)(args.pdf_data, args.pdf_chunk_size, + args.pdf_chunk_overlap, args.pdf_add_start_index) + await cl.Message(author="System", content="Processing web sites.").send() + web_splits = await cl.make_async(process_web_sites)(args.web_data, args.web_chunk_size) + + _ = vector_store.add_documents(documents=pdf_splits + web_splits) + + cl.user_session.set("emb_model", embedding) + cl.user_session.set("vector_store", vector_store) + cl.user_session.set("chat_model", get_chat_model(args.back_end)) + cl.user_session.set("prompt", hub.pull("rlm/rag-prompt")) + + graph_builder = StateGraph(State).add_sequence([retrieve, generate]) + graph_builder.add_edge(START, "retrieve") + graph = graph_builder.compile() + + cl.user_session.set("graph", graph) + + await cl.Message(content="Ready for chatting!").send() + + +@cl.on_message +async def on_message(message: cl.Message): + graph = cl.user_session.get("graph") + response = graph.invoke({"question": message.content}) + # Send the final answer. + await cl.Message(content=response).send() + + +@cl.set_starters +async def set_starters(): + return [cl.Starter(label="Morning routine ideation", + message="Can you help me create a personalized morning routine that would help increase my " + "productivity throughout the day? Start by asking me about my current habits and what " + "activities energize me in the morning.", ), + cl.Starter(label="Explain superconductors", + message="Explain superconductors like I'm five years old.", ), + cl.Starter(label="Python script for daily email reports", + message="Write a script to automate sending daily email reports in Python, and walk me through " + "how I would set it up.", ), + cl.Starter(label="Text inviting friend to wedding", + message="Write a text asking a friend to be my plus-one at a wedding next month. I want to keep " + "it super short and casual, and offer an out.", )] + + +if __name__ == "__main__": + run_chainlit(__file__) diff --git a/generic_rag/backend/model.py b/generic_rag/backend/model.py new file mode 100644 index 0000000..ebcf45c --- /dev/null +++ b/generic_rag/backend/model.py @@ -0,0 +1,64 @@ +import os +from enum import Enum + +from langchain.chat_models import init_chat_model +from langchain_aws import BedrockEmbeddings +from langchain_core.embeddings import Embeddings +from langchain_core.language_models.chat_models import BaseChatModel +from langchain_google_vertexai import VertexAIEmbeddings +from langchain_huggingface import HuggingFaceEmbeddings +from langchain_ollama import OllamaLLM +from langchain_openai import AzureChatOpenAI, AzureOpenAIEmbeddings +from langchain_openai import OpenAIEmbeddings + + +class BackendType(Enum): + azure = "azure" + openai = "openai" + google = "google" + aws = "aws" + local = "local" + + +def get_chat_model(backend_type: BackendType) -> BaseChatModel: + if backend_type == BackendType.azure: + return AzureChatOpenAI( + azure_endpoint=os.environ["AZURE_LLM_ENDPOINT"], + azure_deployment=os.environ["AZURE_LLM_DEPLOYMENT_NAME"], + openai_api_version=os.environ["AZURE_LLM_API_VERSION"]) + + if backend_type == BackendType.openai: + return init_chat_model(os.environ["OPENAI_CHAT_MODEL"], model_provider="openai") + + if backend_type == BackendType.google: + return init_chat_model(os.environ["GOOGLE_CHAT_MODEL"], model_provider="google_vertexai") + + if backend_type == BackendType.aws: + return init_chat_model(model=os.environ["AWS_CHAT_MODEL"], model_provider="bedrock_converse") + + if backend_type == BackendType.local: + return OllamaLLM(model=os.environ["LOCAL_CHAT_MODEL"]) + + raise ValueError(f"Unknown backend type: {backend_type}") + + +def get_embedding_model(backend_type: BackendType) -> Embeddings: + if backend_type == BackendType.azure: + return AzureOpenAIEmbeddings( + azure_endpoint=os.environ["AZURE_EMB_ENDPOINT"], + azure_deployment=os.environ["AZURE_EMB_DEPLOYMENT_NAME"], + openai_api_version=os.environ["AZURE_EMB_API_VERSION"]) + + if backend_type == BackendType.openai: + return OpenAIEmbeddings(model=os.environ["OPENAI_EMB_MODEL"]) + + if backend_type == BackendType.google: + return VertexAIEmbeddings(model=os.environ["GOOGLE_EMB_MODEL"]) + + if backend_type == BackendType.aws: + return BedrockEmbeddings(model_id=os.environ["AWS_EMB_MODEL"]) + + if backend_type == BackendType.local: + return HuggingFaceEmbeddings(model_name=os.environ["LOCAL_EMB_MODEL"]) + + raise ValueError(f"Unknown backend type: {backend_type}") diff --git a/generic_rag/parsers/parser.py b/generic_rag/parsers/parser.py new file mode 100644 index 0000000..b591ce6 --- /dev/null +++ b/generic_rag/parsers/parser.py @@ -0,0 +1,87 @@ +import logging +from pathlib import Path + +import requests +from bs4 import BeautifulSoup +from langchain_core.documents import Document +from langchain_text_splitters import HTMLSemanticPreservingSplitter +from langchain_text_splitters import RecursiveCharacterTextSplitter +from langchain_unstructured import UnstructuredLoader + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) +from bs4 import Tag + +headers_to_split_on = [ + ("h1", "Header 1"), + ("h2", "Header 2"), +] + + +def code_handler(element: Tag) -> str: + """ + Custom handler for code elements. + """ + data_lang = element.get("data-lang") + code_format = f"{element.get_text()}" + + return code_format + + +def process_web_sites(websites: list[str], chunk_size: int) -> list[Document]: + """ + Process one or more websites and returns a list of langchain Document's. + """ + if len(websites) == 0: + return [] + + splits = [] + for url in websites: + # Fetch the webpage + response = requests.get(url) + html_content = response.text + + # Parse the HTML + soup = BeautifulSoup(html_content, "html.parser") + + # split documents + web_splitter = HTMLSemanticPreservingSplitter( + headers_to_split_on=headers_to_split_on, + separators=["\n\n", "\n", ". ", "! ", "? "], + max_chunk_size=chunk_size, + preserve_images=True, + preserve_videos=True, + elements_to_preserve=["table", "ul", "ol", "code"], + denylist_tags=["script", "style", "head"], + custom_handlers={"code": code_handler}) + + splits.extend(web_splitter.split_text(str(soup))) + + return splits + + +def process_local_files( + local_paths: list[Path], chunk_size: int, chunk_overlap: int, add_start_index: bool +) -> list[Document]: + # get all files + file_paths = [] + for path in local_paths: + if path.is_dir(): + file_paths.extend(list(path.glob("*.pdf"))) + if path.suffix == ".pdf": + file_paths.append(path) + else: + logging.warning(f"Ignoring path {path} as it is not a pdf file.") + + # parse pdf's + documents = [] + for file_path in file_paths: + loader = UnstructuredLoader(file_path=file_path, strategy="hi_res") + for doc in loader.lazy_load(): + documents.append(doc) + + # split documents + text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + add_start_index=add_start_index) + return text_splitter.split_documents(documents) diff --git a/generic_rag/rendering/render_pdf_page.py b/generic_rag/rendering/render_pdf_page.py new file mode 100644 index 0000000..7c365a9 --- /dev/null +++ b/generic_rag/rendering/render_pdf_page.py @@ -0,0 +1,56 @@ +from pathlib import Path + +import fitz +import matplotlib.patches as patches +import matplotlib.pyplot as plt +from PIL import Image +from langchain_core.documents import Document + + +def render_pdf_bound_box(file_path: str | Path, doc_list: list[Document], page_number: int) -> None: + """ + Function that renders the bounding boxes of the segments on a PDF page. + """ + pdf_page = fitz.open(file_path).load_page(page_number - 1) + page_docs = [ + doc for doc in doc_list if doc.metadata.get("page_number") == page_number + ] + segments = [doc.metadata for doc in page_docs] + + pix = pdf_page.get_pixmap() + pil_image = Image.frombytes("RGB", [pix.width, pix.height], pix.samples) + + fig, ax = plt.subplots(1, figsize=(10, 10)) + ax.imshow(pil_image) + categories = set() + category_to_color = { + "Title": "orchid", + "Image": "forestgreen", + "Table": "tomato", + } + for segment in segments: + points = segment["coordinates"]["points"] + layout_width = segment["coordinates"]["layout_width"] + layout_height = segment["coordinates"]["layout_height"] + scaled_points = [ + (x * pix.width / layout_width, y * pix.height / layout_height) + for x, y in points + ] + box_color = category_to_color.get(segment["category"], "deepskyblue") + categories.add(segment["category"]) + rect = patches.Polygon( + scaled_points, linewidth=1, edgecolor=box_color, facecolor="none" + ) + ax.add_patch(rect) + + # Make legend + legend_handles = [patches.Patch(color="deepskyblue", label="Text")] + for category in ["Title", "Image", "Table"]: + if category in categories: + legend_handles.append( + patches.Patch(color=category_to_color[category], label=category) + ) + ax.axis("off") + ax.legend(handles=legend_handles, loc="upper right") + plt.tight_layout() + plt.savefig(f"test_{page_number}.png") diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..fedb856 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,29 @@ +[project] +name = "Sogeti-generic-RAG-demo" +version = "0.1.0" +description = "A Sogeti generic RAG demo" +readme = "README.md" +requires-python = ">=3.12" +dependencies = [ + "beautifulsoup4>=4.13.3", + "chainlit>=2.3.0", + "dotenv>=0.9.9", + "langchain>=0.3.20", + "langchain-aws>=0.2.15", + "langchain-community>=0.3.19", + "langchain-google-vertexai>=2.0.15", + "langchain-huggingface>=0.1.2", + "langchain-ollama>=0.2.3", + "langchain-openai>=0.3.7", + "langchain-text-splitters>=0.3.6", + "langchain-unstructured>=0.1.6", + "langgraph>=0.3.5", + "matplotlib>=3.10.1", + "pillow>=11.1.0", + "pymupdf>=1.25.3", + "unstructured[pdf]>=0.16.23", +] + +[tool.setuptools] +packages = ["."] +exclude = ["data"]