diff --git a/generic_rag/app.py b/generic_rag/app.py index dfe5fa4..aabbee1 100644 --- a/generic_rag/app.py +++ b/generic_rag/app.py @@ -107,13 +107,22 @@ else: @cl.on_message async def on_message(message: cl.Message): + config = {"configurable": {"thread_id": cl.user_session.get("id")}} + + chainlit_response = cl.Message(content="") + + async for response in graph.stream(message.content, config=config): + await chainlit_response.stream_token(response) + + if isinstance(graph, RetGenLangGraph): + await add_sources(chainlit_response, graph.get_last_pdf_sources(), graph.get_last_web_sources()) if isinstance(graph, CondRetGenLangGraph): - await process_cond_response(message) - elif isinstance(graph, RetGenLangGraph): - await process_response(message) + await add_sources(chainlit_response, graph.last_retrieved_docs, graph.last_retrieved_sources) + + await chainlit_response.send() -async def add_sources(chainlit_response: cl.Message, pdf_sources: dict, web_sources: set | list): +async def add_sources(chainlit_response: cl.Message, pdf_sources: dict, web_sources: set | list) -> None: if len(pdf_sources) > 0: await chainlit_response.stream_token("\n\nThe following PDF source were consulted:\n") for source, page_numbers in pdf_sources.items(): @@ -123,40 +132,13 @@ async def add_sources(chainlit_response: cl.Message, pdf_sources: dict, web_sour chainlit_response.elements.append(cl.Pdf(name="pdf", display="inline", path=source, page=page_numbers[0])) await chainlit_response.update() await chainlit_response.stream_token(f"- '{source}' on page(s): {page_numbers}\n") + if len(web_sources) > 0: await chainlit_response.stream_token("\n\nThe following web sources were consulted:\n") for source in web_sources: await chainlit_response.stream_token(f"- {source}\n") -async def process_response(message): - config = {"configurable": {"thread_id": cl.user_session.get("id")}} - - chainlit_response = cl.Message(content="") - - async for response in graph.stream(message.content, config=config): - await chainlit_response.stream_token(response) - - pdf_sources = graph.get_last_pdf_sources() - web_sources = graph.get_last_web_sources() - await add_sources(chainlit_response, pdf_sources, web_sources) - - await chainlit_response.send() - - -async def process_cond_response(message): - config = {"configurable": {"thread_id": cl.user_session.get("id")}} - - chainlit_response = cl.Message(content="") - - for response in graph.stream(message.content, config=config): - await chainlit_response.stream_token(response) - - await add_sources(chainlit_response, graph.last_retrieved_docs, graph.last_retrieved_sources) - - await chainlit_response.send() - - @cl.set_starters async def set_starters(): chainlit_starters = os.environ.get("CHAINLIT_STARTERS", None) diff --git a/generic_rag/graphs/cond_ret_gen.py b/generic_rag/graphs/cond_ret_gen.py index 027a21b..65845b2 100644 --- a/generic_rag/graphs/cond_ret_gen.py +++ b/generic_rag/graphs/cond_ret_gen.py @@ -49,8 +49,8 @@ class CondRetGenLangGraph: self.last_retrieved_docs = {} self.last_retrieved_sources = set() - def stream(self, message: str, config: RunnableConfig | None = None) -> Iterator[str]: - for llm_response, metadata in self.graph.stream( + async def stream(self, message: str, config: RunnableConfig | None = None) -> AsyncGenerator[Any, Any]: + async for llm_response, metadata in self.graph.astream( {"messages": [{"role": "user", "content": message}]}, stream_mode="messages", config=config ): if llm_response.content and metadata["langgraph_node"] == "_generate":