forked from AI_team/Philosophy-RAG-demo
Initial chainlit + langchain commit
This commit is contained in:
parent
378c8e6243
commit
d4ce22dc7e
1
.python-version
Normal file
1
.python-version
Normal file
@ -0,0 +1 @@
|
|||||||
|
3.12
|
||||||
117
generic_rag/app.py
Normal file
117
generic_rag/app.py
Normal file
@ -0,0 +1,117 @@
|
|||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import chainlit as cl
|
||||||
|
from chainlit.cli import run_chainlit
|
||||||
|
from langchain import hub
|
||||||
|
from langchain_core.documents import Document
|
||||||
|
from langchain_core.vectorstores import InMemoryVectorStore
|
||||||
|
from langgraph.graph import START, StateGraph
|
||||||
|
from typing_extensions import List, TypedDict
|
||||||
|
|
||||||
|
from backend.model import BackendType, get_embedding_model, get_chat_model
|
||||||
|
from parsers.parser import process_local_files, process_web_sites
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description="A Sogeti Nederland Generic RAG demo.")
|
||||||
|
parser.add_argument("-b", "--back-end", type=BackendType, choices=list(BackendType), default=BackendType.azure,
|
||||||
|
help="(Cloud) back-end to use. In the case of local, a locally installed ollama will be used.")
|
||||||
|
parser.add_argument("-p", "--pdf-data", type=Path, required=True, nargs="+",
|
||||||
|
help="One or multiple paths to folders or files to use for retrieval. "
|
||||||
|
"If a path is a folder, all files in the folder will be used. "
|
||||||
|
"If a path is a file, only that file will be used. "
|
||||||
|
"If the path is relative it will be relative to the current working directory.")
|
||||||
|
parser.add_argument("--pdf-chunk_size", type=int, default=1000,
|
||||||
|
help="The size of the chunks to split the text into.")
|
||||||
|
parser.add_argument("--pdf-chunk_overlap", type=int, default=200,
|
||||||
|
help="The overlap between the chunks.")
|
||||||
|
parser.add_argument("--pdf-add-start-index", action="store_true",
|
||||||
|
help="Add the start index to the metadata of the chunks.")
|
||||||
|
parser.add_argument("-w", "--web-data", type=str, nargs="*", default=[],
|
||||||
|
help="One or multiple URLs to use for retrieval.")
|
||||||
|
parser.add_argument("--web-chunk-size", type=int, default=200,
|
||||||
|
help="The size of the chunks to split the text into.")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
class State(TypedDict):
|
||||||
|
question: str
|
||||||
|
context: List[Document]
|
||||||
|
answer: str
|
||||||
|
|
||||||
|
|
||||||
|
def retrieve(state: State):
|
||||||
|
vector_store = cl.user_session.get("vector_store")
|
||||||
|
retrieved_docs = vector_store.similarity_search(state["question"])
|
||||||
|
return {"context": retrieved_docs}
|
||||||
|
|
||||||
|
|
||||||
|
def generate(state: State):
|
||||||
|
prompt = cl.user_session.get("prompt")
|
||||||
|
llm = cl.user_session.get("chat_model")
|
||||||
|
|
||||||
|
docs_content = "\n\n".join(doc.page_content for doc in state["context"])
|
||||||
|
messages = prompt.invoke({"question": state["question"], "context": docs_content})
|
||||||
|
response = llm.invoke(messages)
|
||||||
|
|
||||||
|
return {"answer": response.content}
|
||||||
|
|
||||||
|
|
||||||
|
@cl.on_chat_start
|
||||||
|
async def on_chat_start():
|
||||||
|
await cl.Message(author="System", content="Starting up application").send()
|
||||||
|
|
||||||
|
embedding = get_embedding_model(args.back_end)
|
||||||
|
vector_store = InMemoryVectorStore(embedding)
|
||||||
|
|
||||||
|
await cl.Message(author="System", content="Processing PDF files.").send()
|
||||||
|
pdf_splits = await cl.make_async(process_local_files)(args.pdf_data, args.pdf_chunk_size,
|
||||||
|
args.pdf_chunk_overlap, args.pdf_add_start_index)
|
||||||
|
await cl.Message(author="System", content="Processing web sites.").send()
|
||||||
|
web_splits = await cl.make_async(process_web_sites)(args.web_data, args.web_chunk_size)
|
||||||
|
|
||||||
|
_ = vector_store.add_documents(documents=pdf_splits + web_splits)
|
||||||
|
|
||||||
|
cl.user_session.set("emb_model", embedding)
|
||||||
|
cl.user_session.set("vector_store", vector_store)
|
||||||
|
cl.user_session.set("chat_model", get_chat_model(args.back_end))
|
||||||
|
cl.user_session.set("prompt", hub.pull("rlm/rag-prompt"))
|
||||||
|
|
||||||
|
graph_builder = StateGraph(State).add_sequence([retrieve, generate])
|
||||||
|
graph_builder.add_edge(START, "retrieve")
|
||||||
|
graph = graph_builder.compile()
|
||||||
|
|
||||||
|
cl.user_session.set("graph", graph)
|
||||||
|
|
||||||
|
await cl.Message(content="Ready for chatting!").send()
|
||||||
|
|
||||||
|
|
||||||
|
@cl.on_message
|
||||||
|
async def on_message(message: cl.Message):
|
||||||
|
graph = cl.user_session.get("graph")
|
||||||
|
response = graph.invoke({"question": message.content})
|
||||||
|
# Send the final answer.
|
||||||
|
await cl.Message(content=response).send()
|
||||||
|
|
||||||
|
|
||||||
|
@cl.set_starters
|
||||||
|
async def set_starters():
|
||||||
|
return [cl.Starter(label="Morning routine ideation",
|
||||||
|
message="Can you help me create a personalized morning routine that would help increase my "
|
||||||
|
"productivity throughout the day? Start by asking me about my current habits and what "
|
||||||
|
"activities energize me in the morning.", ),
|
||||||
|
cl.Starter(label="Explain superconductors",
|
||||||
|
message="Explain superconductors like I'm five years old.", ),
|
||||||
|
cl.Starter(label="Python script for daily email reports",
|
||||||
|
message="Write a script to automate sending daily email reports in Python, and walk me through "
|
||||||
|
"how I would set it up.", ),
|
||||||
|
cl.Starter(label="Text inviting friend to wedding",
|
||||||
|
message="Write a text asking a friend to be my plus-one at a wedding next month. I want to keep "
|
||||||
|
"it super short and casual, and offer an out.", )]
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
run_chainlit(__file__)
|
||||||
64
generic_rag/backend/model.py
Normal file
64
generic_rag/backend/model.py
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
import os
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
from langchain.chat_models import init_chat_model
|
||||||
|
from langchain_aws import BedrockEmbeddings
|
||||||
|
from langchain_core.embeddings import Embeddings
|
||||||
|
from langchain_core.language_models.chat_models import BaseChatModel
|
||||||
|
from langchain_google_vertexai import VertexAIEmbeddings
|
||||||
|
from langchain_huggingface import HuggingFaceEmbeddings
|
||||||
|
from langchain_ollama import OllamaLLM
|
||||||
|
from langchain_openai import AzureChatOpenAI, AzureOpenAIEmbeddings
|
||||||
|
from langchain_openai import OpenAIEmbeddings
|
||||||
|
|
||||||
|
|
||||||
|
class BackendType(Enum):
|
||||||
|
azure = "azure"
|
||||||
|
openai = "openai"
|
||||||
|
google = "google"
|
||||||
|
aws = "aws"
|
||||||
|
local = "local"
|
||||||
|
|
||||||
|
|
||||||
|
def get_chat_model(backend_type: BackendType) -> BaseChatModel:
|
||||||
|
if backend_type == BackendType.azure:
|
||||||
|
return AzureChatOpenAI(
|
||||||
|
azure_endpoint=os.environ["AZURE_LLM_ENDPOINT"],
|
||||||
|
azure_deployment=os.environ["AZURE_LLM_DEPLOYMENT_NAME"],
|
||||||
|
openai_api_version=os.environ["AZURE_LLM_API_VERSION"])
|
||||||
|
|
||||||
|
if backend_type == BackendType.openai:
|
||||||
|
return init_chat_model(os.environ["OPENAI_CHAT_MODEL"], model_provider="openai")
|
||||||
|
|
||||||
|
if backend_type == BackendType.google:
|
||||||
|
return init_chat_model(os.environ["GOOGLE_CHAT_MODEL"], model_provider="google_vertexai")
|
||||||
|
|
||||||
|
if backend_type == BackendType.aws:
|
||||||
|
return init_chat_model(model=os.environ["AWS_CHAT_MODEL"], model_provider="bedrock_converse")
|
||||||
|
|
||||||
|
if backend_type == BackendType.local:
|
||||||
|
return OllamaLLM(model=os.environ["LOCAL_CHAT_MODEL"])
|
||||||
|
|
||||||
|
raise ValueError(f"Unknown backend type: {backend_type}")
|
||||||
|
|
||||||
|
|
||||||
|
def get_embedding_model(backend_type: BackendType) -> Embeddings:
|
||||||
|
if backend_type == BackendType.azure:
|
||||||
|
return AzureOpenAIEmbeddings(
|
||||||
|
azure_endpoint=os.environ["AZURE_EMB_ENDPOINT"],
|
||||||
|
azure_deployment=os.environ["AZURE_EMB_DEPLOYMENT_NAME"],
|
||||||
|
openai_api_version=os.environ["AZURE_EMB_API_VERSION"])
|
||||||
|
|
||||||
|
if backend_type == BackendType.openai:
|
||||||
|
return OpenAIEmbeddings(model=os.environ["OPENAI_EMB_MODEL"])
|
||||||
|
|
||||||
|
if backend_type == BackendType.google:
|
||||||
|
return VertexAIEmbeddings(model=os.environ["GOOGLE_EMB_MODEL"])
|
||||||
|
|
||||||
|
if backend_type == BackendType.aws:
|
||||||
|
return BedrockEmbeddings(model_id=os.environ["AWS_EMB_MODEL"])
|
||||||
|
|
||||||
|
if backend_type == BackendType.local:
|
||||||
|
return HuggingFaceEmbeddings(model_name=os.environ["LOCAL_EMB_MODEL"])
|
||||||
|
|
||||||
|
raise ValueError(f"Unknown backend type: {backend_type}")
|
||||||
87
generic_rag/parsers/parser.py
Normal file
87
generic_rag/parsers/parser.py
Normal file
@ -0,0 +1,87 @@
|
|||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from bs4 import BeautifulSoup
|
||||||
|
from langchain_core.documents import Document
|
||||||
|
from langchain_text_splitters import HTMLSemanticPreservingSplitter
|
||||||
|
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||||||
|
from langchain_unstructured import UnstructuredLoader
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
from bs4 import Tag
|
||||||
|
|
||||||
|
headers_to_split_on = [
|
||||||
|
("h1", "Header 1"),
|
||||||
|
("h2", "Header 2"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def code_handler(element: Tag) -> str:
|
||||||
|
"""
|
||||||
|
Custom handler for code elements.
|
||||||
|
"""
|
||||||
|
data_lang = element.get("data-lang")
|
||||||
|
code_format = f"<code:{data_lang}>{element.get_text()}</code>"
|
||||||
|
|
||||||
|
return code_format
|
||||||
|
|
||||||
|
|
||||||
|
def process_web_sites(websites: list[str], chunk_size: int) -> list[Document]:
|
||||||
|
"""
|
||||||
|
Process one or more websites and returns a list of langchain Document's.
|
||||||
|
"""
|
||||||
|
if len(websites) == 0:
|
||||||
|
return []
|
||||||
|
|
||||||
|
splits = []
|
||||||
|
for url in websites:
|
||||||
|
# Fetch the webpage
|
||||||
|
response = requests.get(url)
|
||||||
|
html_content = response.text
|
||||||
|
|
||||||
|
# Parse the HTML
|
||||||
|
soup = BeautifulSoup(html_content, "html.parser")
|
||||||
|
|
||||||
|
# split documents
|
||||||
|
web_splitter = HTMLSemanticPreservingSplitter(
|
||||||
|
headers_to_split_on=headers_to_split_on,
|
||||||
|
separators=["\n\n", "\n", ". ", "! ", "? "],
|
||||||
|
max_chunk_size=chunk_size,
|
||||||
|
preserve_images=True,
|
||||||
|
preserve_videos=True,
|
||||||
|
elements_to_preserve=["table", "ul", "ol", "code"],
|
||||||
|
denylist_tags=["script", "style", "head"],
|
||||||
|
custom_handlers={"code": code_handler})
|
||||||
|
|
||||||
|
splits.extend(web_splitter.split_text(str(soup)))
|
||||||
|
|
||||||
|
return splits
|
||||||
|
|
||||||
|
|
||||||
|
def process_local_files(
|
||||||
|
local_paths: list[Path], chunk_size: int, chunk_overlap: int, add_start_index: bool
|
||||||
|
) -> list[Document]:
|
||||||
|
# get all files
|
||||||
|
file_paths = []
|
||||||
|
for path in local_paths:
|
||||||
|
if path.is_dir():
|
||||||
|
file_paths.extend(list(path.glob("*.pdf")))
|
||||||
|
if path.suffix == ".pdf":
|
||||||
|
file_paths.append(path)
|
||||||
|
else:
|
||||||
|
logging.warning(f"Ignoring path {path} as it is not a pdf file.")
|
||||||
|
|
||||||
|
# parse pdf's
|
||||||
|
documents = []
|
||||||
|
for file_path in file_paths:
|
||||||
|
loader = UnstructuredLoader(file_path=file_path, strategy="hi_res")
|
||||||
|
for doc in loader.lazy_load():
|
||||||
|
documents.append(doc)
|
||||||
|
|
||||||
|
# split documents
|
||||||
|
text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size,
|
||||||
|
chunk_overlap=chunk_overlap,
|
||||||
|
add_start_index=add_start_index)
|
||||||
|
return text_splitter.split_documents(documents)
|
||||||
56
generic_rag/rendering/render_pdf_page.py
Normal file
56
generic_rag/rendering/render_pdf_page.py
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import fitz
|
||||||
|
import matplotlib.patches as patches
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from PIL import Image
|
||||||
|
from langchain_core.documents import Document
|
||||||
|
|
||||||
|
|
||||||
|
def render_pdf_bound_box(file_path: str | Path, doc_list: list[Document], page_number: int) -> None:
|
||||||
|
"""
|
||||||
|
Function that renders the bounding boxes of the segments on a PDF page.
|
||||||
|
"""
|
||||||
|
pdf_page = fitz.open(file_path).load_page(page_number - 1)
|
||||||
|
page_docs = [
|
||||||
|
doc for doc in doc_list if doc.metadata.get("page_number") == page_number
|
||||||
|
]
|
||||||
|
segments = [doc.metadata for doc in page_docs]
|
||||||
|
|
||||||
|
pix = pdf_page.get_pixmap()
|
||||||
|
pil_image = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
|
||||||
|
|
||||||
|
fig, ax = plt.subplots(1, figsize=(10, 10))
|
||||||
|
ax.imshow(pil_image)
|
||||||
|
categories = set()
|
||||||
|
category_to_color = {
|
||||||
|
"Title": "orchid",
|
||||||
|
"Image": "forestgreen",
|
||||||
|
"Table": "tomato",
|
||||||
|
}
|
||||||
|
for segment in segments:
|
||||||
|
points = segment["coordinates"]["points"]
|
||||||
|
layout_width = segment["coordinates"]["layout_width"]
|
||||||
|
layout_height = segment["coordinates"]["layout_height"]
|
||||||
|
scaled_points = [
|
||||||
|
(x * pix.width / layout_width, y * pix.height / layout_height)
|
||||||
|
for x, y in points
|
||||||
|
]
|
||||||
|
box_color = category_to_color.get(segment["category"], "deepskyblue")
|
||||||
|
categories.add(segment["category"])
|
||||||
|
rect = patches.Polygon(
|
||||||
|
scaled_points, linewidth=1, edgecolor=box_color, facecolor="none"
|
||||||
|
)
|
||||||
|
ax.add_patch(rect)
|
||||||
|
|
||||||
|
# Make legend
|
||||||
|
legend_handles = [patches.Patch(color="deepskyblue", label="Text")]
|
||||||
|
for category in ["Title", "Image", "Table"]:
|
||||||
|
if category in categories:
|
||||||
|
legend_handles.append(
|
||||||
|
patches.Patch(color=category_to_color[category], label=category)
|
||||||
|
)
|
||||||
|
ax.axis("off")
|
||||||
|
ax.legend(handles=legend_handles, loc="upper right")
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.savefig(f"test_{page_number}.png")
|
||||||
29
pyproject.toml
Normal file
29
pyproject.toml
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
[project]
|
||||||
|
name = "Sogeti-generic-RAG-demo"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = "A Sogeti generic RAG demo"
|
||||||
|
readme = "README.md"
|
||||||
|
requires-python = ">=3.12"
|
||||||
|
dependencies = [
|
||||||
|
"beautifulsoup4>=4.13.3",
|
||||||
|
"chainlit>=2.3.0",
|
||||||
|
"dotenv>=0.9.9",
|
||||||
|
"langchain>=0.3.20",
|
||||||
|
"langchain-aws>=0.2.15",
|
||||||
|
"langchain-community>=0.3.19",
|
||||||
|
"langchain-google-vertexai>=2.0.15",
|
||||||
|
"langchain-huggingface>=0.1.2",
|
||||||
|
"langchain-ollama>=0.2.3",
|
||||||
|
"langchain-openai>=0.3.7",
|
||||||
|
"langchain-text-splitters>=0.3.6",
|
||||||
|
"langchain-unstructured>=0.1.6",
|
||||||
|
"langgraph>=0.3.5",
|
||||||
|
"matplotlib>=3.10.1",
|
||||||
|
"pillow>=11.1.0",
|
||||||
|
"pymupdf>=1.25.3",
|
||||||
|
"unstructured[pdf]>=0.16.23",
|
||||||
|
]
|
||||||
|
|
||||||
|
[tool.setuptools]
|
||||||
|
packages = ["."]
|
||||||
|
exclude = ["data"]
|
||||||
Loading…
Reference in New Issue
Block a user