forked from AI_team/Philosophy-RAG-demo
127 lines
4.9 KiB
Python
127 lines
4.9 KiB
Python
import argparse
|
|
import json
|
|
import logging
|
|
import os
|
|
from pathlib import Path
|
|
|
|
import chainlit as cl
|
|
from chainlit.cli import run_chainlit
|
|
from langchain import hub
|
|
from langchain_chroma import Chroma
|
|
from langchain_core.documents import Document
|
|
from langgraph.graph import START, StateGraph
|
|
from typing_extensions import List, TypedDict
|
|
|
|
from backend.models import BackendType, get_embedding_model, get_chat_model
|
|
from parsers.parser import process_local_files, process_web_sites
|
|
from langchain_community.vectorstores.utils import filter_complex_metadata
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
parser = argparse.ArgumentParser(description="A Sogeti Nederland Generic RAG demo.")
|
|
parser.add_argument("-b", "--back-end", type=BackendType, choices=list(BackendType), default=BackendType.azure,
|
|
help="(Cloud) back-end to use. In the case of local, a locally installed ollama will be used.")
|
|
parser.add_argument("-p", "--pdf-data", type=Path, required=True, nargs="+",
|
|
help="One or multiple paths to folders or files to use for retrieval. "
|
|
"If a path is a folder, all files in the folder will be used. "
|
|
"If a path is a file, only that file will be used. "
|
|
"If the path is relative it will be relative to the current working directory.")
|
|
parser.add_argument("--pdf-chunk_size", type=int, default=1000,
|
|
help="The size of the chunks to split the text into.")
|
|
parser.add_argument("--pdf-chunk_overlap", type=int, default=200,
|
|
help="The overlap between the chunks.")
|
|
parser.add_argument("--pdf-add-start-index", action="store_true",
|
|
help="Add the start index to the metadata of the chunks.")
|
|
parser.add_argument("-w", "--web-data", type=str, nargs="*", default=[],
|
|
help="One or multiple URLs to use for retrieval.")
|
|
parser.add_argument("--web-chunk-size", type=int, default=200,
|
|
help="The size of the chunks to split the text into.")
|
|
parser.add_argument("-c", "--chroma-db-location", type=Path, default=Path(".chroma_db"),
|
|
help="file path to store or load a Chroma DB from/to.")
|
|
args = parser.parse_args()
|
|
|
|
|
|
class State(TypedDict):
|
|
question: str
|
|
context: List[Document]
|
|
answer: str
|
|
|
|
|
|
def retrieve(state: State):
|
|
vector_store = cl.user_session.get("vector_store")
|
|
retrieved_docs = vector_store.similarity_search(state["question"])
|
|
return {"context": retrieved_docs}
|
|
|
|
|
|
def generate(state: State):
|
|
prompt = cl.user_session.get("prompt")
|
|
llm = cl.user_session.get("chat_model")
|
|
|
|
docs_content = "\n\n".join(doc.page_content for doc in state["context"])
|
|
messages = prompt.invoke({"question": state["question"], "context": docs_content})
|
|
response = llm.invoke(messages)
|
|
|
|
return {"answer": response.content}
|
|
|
|
|
|
@cl.on_chat_start
|
|
async def on_chat_start():
|
|
vector_store = Chroma(collection_name="generic_rag",
|
|
embedding_function=get_embedding_model(args.back_end),
|
|
persist_directory=str(args.chroma_db_location))
|
|
|
|
cl.user_session.set("vector_store", vector_store)
|
|
cl.user_session.set("emb_model", get_embedding_model(args.back_end))
|
|
cl.user_session.set("chat_model", get_chat_model(args.back_end))
|
|
cl.user_session.set("prompt", hub.pull("rlm/rag-prompt"))
|
|
|
|
graph_builder = StateGraph(State).add_sequence([retrieve, generate])
|
|
graph_builder.add_edge(START, "retrieve")
|
|
graph = graph_builder.compile()
|
|
|
|
cl.user_session.set("graph", graph)
|
|
|
|
|
|
@cl.on_message
|
|
async def on_message(message: cl.Message):
|
|
graph = cl.user_session.get("graph")
|
|
response = graph.invoke({"question": message.content})
|
|
await cl.Message(content=response).send()
|
|
|
|
|
|
@cl.set_starters
|
|
async def set_starters():
|
|
chainlit_starters = os.environ["CHAINLIT_STARTERS"]
|
|
|
|
if chainlit_starters is None:
|
|
return
|
|
|
|
dict_list = json.loads(chainlit_starters)
|
|
|
|
starters = []
|
|
for starter in dict_list:
|
|
try:
|
|
starters.append(cl.Starter(label=starter["label"], message=starter["message"]))
|
|
except KeyError:
|
|
logging.warning("CHAINLIT_STARTERS environment is not a list with "
|
|
"dictionaries containing 'label' and 'message' keys.")
|
|
|
|
return starters
|
|
|
|
|
|
if __name__ == "__main__":
|
|
pdf_splits = process_local_files(args.pdf_data, args.pdf_chunk_size, args.pdf_chunk_overlap,
|
|
args.pdf_add_start_index)
|
|
web_splits = process_web_sites(args.web_data, args.web_chunk_size)
|
|
|
|
filtered_splits = filter_complex_metadata(pdf_splits + web_splits)
|
|
|
|
vector_store = Chroma(collection_name="generic_rag",
|
|
embedding_function=get_embedding_model(args.back_end),
|
|
persist_directory=str(args.chroma_db_location))
|
|
_ = vector_store.add_documents(documents=filtered_splits)
|
|
del vector_store
|
|
|
|
run_chainlit(__file__)
|