Add support for both lang Graphs

This commit is contained in:
Nielson Janné 2025-03-17 17:40:54 +01:00
parent f25770e3ce
commit ee0c731faf

View File

@ -6,14 +6,12 @@ from pathlib import Path
import chainlit as cl
from backend.models import BackendType, get_chat_model, get_embedding_model
from graphs.ret_gen import RetGenLangGraph
from chainlit.cli import run_chainlit
from graphs.cond_ret_gen import CondRetGenLangGraph
from graphs.ret_gen import RetGenLangGraph
from langchain_chroma import Chroma
from parsers.parser import add_pdf_files, add_urls
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@ -62,6 +60,13 @@ parser.add_argument(
help="File path to store or load a Chroma DB from/to.",
)
parser.add_argument("-r", "--reset-chrome-db", action="store_true", help="Reset the Chroma DB.")
parser.add_argument(
"-c",
"--use-conditional-graph",
action="store_true",
help="Use the conditial retrieve generate graph over the regular retrieve generate graph. "
"The conditional version has build in (chat) memory and is capable of quering vectorstores on its own insight.",
)
args = parser.parse_args()
vector_store = Chroma(
@ -70,20 +75,36 @@ vector_store = Chroma(
persist_directory=str(args.chroma_db_location),
)
ret_gen_graph = RetGenLangGraph(
if args.use_conditional_graph:
graph = CondRetGenLangGraph(
vector_store, chat_model=get_chat_model(args.backend), embedding_model=get_embedding_model(args.backend)
)
else:
graph = RetGenLangGraph(
vector_store, chat_model=get_chat_model(args.backend), embedding_model=get_embedding_model(args.backend)
)
@cl.on_message
async def on_message(message: cl.Message):
response = ret_gen_graph.invoke(message.content)
if isinstance(graph, CondRetGenLangGraph):
config = {"configurable": {"thread_id": cl.user_session.get("id")}}
chainlit_response = cl.Message(content="")
for response in graph.stream(message.content, config=config):
await chainlit_response.stream_token(response)
await chainlit_response.send()
elif isinstance(graph, RetGenLangGraph):
response = graph.invoke(message.content)
answer = response["answer"]
answer += "\n\n"
pdf_sources = ret_gen_graph.get_last_pdf_sources()
web_sources = ret_gen_graph.get_last_web_sources()
pdf_sources = graph.get_last_pdf_sources()
web_sources = graph.get_last_web_sources()
elements = []
if len(pdf_sources) > 0: