From ee0c731faf7d0d6ae2e1eb641f9f8b1a2002ab3a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nielson=20Jann=C3=A9?= Date: Mon, 17 Mar 2025 17:40:54 +0100 Subject: [PATCH] Add support for both lang Graphs --- generic_rag/app.py | 69 ++++++++++++++++++++++++++++++---------------- 1 file changed, 45 insertions(+), 24 deletions(-) diff --git a/generic_rag/app.py b/generic_rag/app.py index 0448de2..94d9932 100644 --- a/generic_rag/app.py +++ b/generic_rag/app.py @@ -6,14 +6,12 @@ from pathlib import Path import chainlit as cl from backend.models import BackendType, get_chat_model, get_embedding_model -from graphs.ret_gen import RetGenLangGraph from chainlit.cli import run_chainlit - +from graphs.cond_ret_gen import CondRetGenLangGraph +from graphs.ret_gen import RetGenLangGraph from langchain_chroma import Chroma - from parsers.parser import add_pdf_files, add_urls - logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @@ -62,6 +60,13 @@ parser.add_argument( help="File path to store or load a Chroma DB from/to.", ) parser.add_argument("-r", "--reset-chrome-db", action="store_true", help="Reset the Chroma DB.") +parser.add_argument( + "-c", + "--use-conditional-graph", + action="store_true", + help="Use the conditial retrieve generate graph over the regular retrieve generate graph. " + "The conditional version has build in (chat) memory and is capable of quering vectorstores on its own insight.", +) args = parser.parse_args() vector_store = Chroma( @@ -70,35 +75,51 @@ vector_store = Chroma( persist_directory=str(args.chroma_db_location), ) -ret_gen_graph = RetGenLangGraph( - vector_store, chat_model=get_chat_model(args.backend), embedding_model=get_embedding_model(args.backend) -) +if args.use_conditional_graph: + graph = CondRetGenLangGraph( + vector_store, chat_model=get_chat_model(args.backend), embedding_model=get_embedding_model(args.backend) + ) +else: + graph = RetGenLangGraph( + vector_store, chat_model=get_chat_model(args.backend), embedding_model=get_embedding_model(args.backend) + ) @cl.on_message async def on_message(message: cl.Message): - response = ret_gen_graph.invoke(message.content) + if isinstance(graph, CondRetGenLangGraph): + config = {"configurable": {"thread_id": cl.user_session.get("id")}} - answer = response["answer"] - answer += "\n\n" + chainlit_response = cl.Message(content="") - pdf_sources = ret_gen_graph.get_last_pdf_sources() - web_sources = ret_gen_graph.get_last_web_sources() + for response in graph.stream(message.content, config=config): + await chainlit_response.stream_token(response) - elements = [] - if len(pdf_sources) > 0: - answer += "The following PDF source were consulted:\n" - for source, page_numbers in pdf_sources.items(): - page_numbers = list(page_numbers) - page_numbers.sort() - # display="side" seems to be not supported by chainlit for PDF's, so we use "inline" instead. - elements.append(cl.Pdf(name="pdf", display="inline", path=source, page=page_numbers[0])) - answer += f"'{source}' on page(s): {page_numbers}\n" + await chainlit_response.send() - if len(web_sources) > 0: - answer += f"The following web sources were consulted: {web_sources}\n" + elif isinstance(graph, RetGenLangGraph): + response = graph.invoke(message.content) - await cl.Message(content=answer, elements=elements).send() + answer = response["answer"] + answer += "\n\n" + + pdf_sources = graph.get_last_pdf_sources() + web_sources = graph.get_last_web_sources() + + elements = [] + if len(pdf_sources) > 0: + answer += "The following PDF source were consulted:\n" + for source, page_numbers in pdf_sources.items(): + page_numbers = list(page_numbers) + page_numbers.sort() + # display="side" seems to be not supported by chainlit for PDF's, so we use "inline" instead. + elements.append(cl.Pdf(name="pdf", display="inline", path=source, page=page_numbers[0])) + answer += f"'{source}' on page(s): {page_numbers}\n" + + if len(web_sources) > 0: + answer += f"The following web sources were consulted: {web_sources}\n" + + await cl.Message(content=answer, elements=elements).send() @cl.set_starters