diff --git a/generic_rag/app.py b/generic_rag/app.py index d11f996..1cf35d3 100644 --- a/generic_rag/app.py +++ b/generic_rag/app.py @@ -88,39 +88,47 @@ else: @cl.on_message async def on_message(message: cl.Message): if isinstance(graph, CondRetGenLangGraph): - config = {"configurable": {"thread_id": cl.user_session.get("id")}} - - chainlit_response = cl.Message(content="") - - for response in graph.stream(message.content, config=config): - await chainlit_response.stream_token(response) - - await chainlit_response.send() - + await process_cond_response(message) elif isinstance(graph, RetGenLangGraph): - config = {"configurable": {"thread_id": cl.user_session.get("id")}} - response = graph.invoke(message.content, config=config) + await process_response(message) - answer = response["answer"] - answer += "\n\n" - pdf_sources = graph.get_last_pdf_sources() - web_sources = graph.get_last_web_sources() +async def process_response(message): + config = {"configurable": {"thread_id": cl.user_session.get("id")}} - elements = [] - if len(pdf_sources) > 0: - answer += "The following PDF source were consulted:\n" - for source, page_numbers in pdf_sources.items(): - page_numbers = list(page_numbers) - page_numbers.sort() - # display="side" seems to be not supported by chainlit for PDF's, so we use "inline" instead. - elements.append(cl.Pdf(name="pdf", display="inline", path=source, page=page_numbers[0])) - answer += f"'{source}' on page(s): {page_numbers}\n" + chainlit_response = cl.Message(content="") - if len(web_sources) > 0: - answer += f"The following web sources were consulted: {web_sources}\n" + response = graph.invoke(message.content, config=config) - await cl.Message(content=answer, elements=elements).send() + await chainlit_response.stream_token(f"{response}\n\n") + + pdf_sources = graph.get_last_pdf_sources() + if len(pdf_sources) > 0: + await chainlit_response.stream_token("The following PDF source were consulted:\n") + for source, page_numbers in pdf_sources.items(): + page_numbers = list(page_numbers) + page_numbers.sort() + # display="side" seems to be not supported by chainlit for PDF's, so we use "inline" instead. + chainlit_response.elements.append(cl.Pdf(name="pdf", display="inline", path=source, page=page_numbers[0])) + await chainlit_response.update() + await chainlit_response.stream_token(f"'{source}' on page(s): {page_numbers}\n") + + web_sources = graph.get_last_web_sources() + if len(web_sources) > 0: + await chainlit_response.stream_token(f"The following web sources were consulted: {web_sources}\n") + + await chainlit_response.send() + + +async def process_cond_response(message): + config = {"configurable": {"thread_id": cl.user_session.get("id")}} + + chainlit_response = cl.Message(content="") + + for response in graph.stream(message.content, config=config): + await chainlit_response.stream_token(response) + + await chainlit_response.send() @cl.set_starters diff --git a/generic_rag/graphs/ret_gen.py b/generic_rag/graphs/ret_gen.py index 523ad7c..f7f1fae 100644 --- a/generic_rag/graphs/ret_gen.py +++ b/generic_rag/graphs/ret_gen.py @@ -31,7 +31,7 @@ class RetGenLangGraph: def invoke(self, message: str, config: dict) -> Union[dict[str, Any], Any]: self.last_invoke = self.graph.invoke({"question": message}, config=config) - return self.last_invoke + return self.last_invoke["answer"] def _retrieve(self, state: State) -> dict: retrieved_docs = self.vector_store.similarity_search(state["question"])