Add support for both lang Graphs

This commit is contained in:
Nielson Janné 2025-03-17 17:40:54 +01:00
parent f25770e3ce
commit ee0c731faf

View File

@ -6,14 +6,12 @@ from pathlib import Path
import chainlit as cl
from backend.models import BackendType, get_chat_model, get_embedding_model
from graphs.ret_gen import RetGenLangGraph
from chainlit.cli import run_chainlit
from graphs.cond_ret_gen import CondRetGenLangGraph
from graphs.ret_gen import RetGenLangGraph
from langchain_chroma import Chroma
from parsers.parser import add_pdf_files, add_urls
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@ -62,6 +60,13 @@ parser.add_argument(
help="File path to store or load a Chroma DB from/to.",
)
parser.add_argument("-r", "--reset-chrome-db", action="store_true", help="Reset the Chroma DB.")
parser.add_argument(
"-c",
"--use-conditional-graph",
action="store_true",
help="Use the conditial retrieve generate graph over the regular retrieve generate graph. "
"The conditional version has build in (chat) memory and is capable of quering vectorstores on its own insight.",
)
args = parser.parse_args()
vector_store = Chroma(
@ -70,35 +75,51 @@ vector_store = Chroma(
persist_directory=str(args.chroma_db_location),
)
ret_gen_graph = RetGenLangGraph(
vector_store, chat_model=get_chat_model(args.backend), embedding_model=get_embedding_model(args.backend)
)
if args.use_conditional_graph:
graph = CondRetGenLangGraph(
vector_store, chat_model=get_chat_model(args.backend), embedding_model=get_embedding_model(args.backend)
)
else:
graph = RetGenLangGraph(
vector_store, chat_model=get_chat_model(args.backend), embedding_model=get_embedding_model(args.backend)
)
@cl.on_message
async def on_message(message: cl.Message):
response = ret_gen_graph.invoke(message.content)
if isinstance(graph, CondRetGenLangGraph):
config = {"configurable": {"thread_id": cl.user_session.get("id")}}
answer = response["answer"]
answer += "\n\n"
chainlit_response = cl.Message(content="")
pdf_sources = ret_gen_graph.get_last_pdf_sources()
web_sources = ret_gen_graph.get_last_web_sources()
for response in graph.stream(message.content, config=config):
await chainlit_response.stream_token(response)
elements = []
if len(pdf_sources) > 0:
answer += "The following PDF source were consulted:\n"
for source, page_numbers in pdf_sources.items():
page_numbers = list(page_numbers)
page_numbers.sort()
# display="side" seems to be not supported by chainlit for PDF's, so we use "inline" instead.
elements.append(cl.Pdf(name="pdf", display="inline", path=source, page=page_numbers[0]))
answer += f"'{source}' on page(s): {page_numbers}\n"
await chainlit_response.send()
if len(web_sources) > 0:
answer += f"The following web sources were consulted: {web_sources}\n"
elif isinstance(graph, RetGenLangGraph):
response = graph.invoke(message.content)
await cl.Message(content=answer, elements=elements).send()
answer = response["answer"]
answer += "\n\n"
pdf_sources = graph.get_last_pdf_sources()
web_sources = graph.get_last_web_sources()
elements = []
if len(pdf_sources) > 0:
answer += "The following PDF source were consulted:\n"
for source, page_numbers in pdf_sources.items():
page_numbers = list(page_numbers)
page_numbers.sort()
# display="side" seems to be not supported by chainlit for PDF's, so we use "inline" instead.
elements.append(cl.Pdf(name="pdf", display="inline", path=source, page=page_numbers[0]))
answer += f"'{source}' on page(s): {page_numbers}\n"
if len(web_sources) > 0:
answer += f"The following web sources were consulted: {web_sources}\n"
await cl.Message(content=answer, elements=elements).send()
@cl.set_starters