From 9bb9f0ea226a84ed51e34be1ecc39baedcdd1595 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nielson=20Jann=C3=A9?= Date: Wed, 9 Apr 2025 11:16:55 +0200 Subject: [PATCH] Make RetGenGrap async for smoother user experience --- generic_rag/app.py | 2 +- generic_rag/graphs/ret_gen.py | 14 +++++++------- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/generic_rag/app.py b/generic_rag/app.py index c76a253..56b7162 100644 --- a/generic_rag/app.py +++ b/generic_rag/app.py @@ -98,7 +98,7 @@ async def process_response(message): chainlit_response = cl.Message(content="") - for response in graph.stream(message.content, config=config): + async for response in graph.stream(message.content, config=config): await chainlit_response.stream_token(response) pdf_sources = graph.get_last_pdf_sources() diff --git a/generic_rag/graphs/ret_gen.py b/generic_rag/graphs/ret_gen.py index 3c332ed..b12f7e1 100644 --- a/generic_rag/graphs/ret_gen.py +++ b/generic_rag/graphs/ret_gen.py @@ -1,6 +1,6 @@ import logging from pathlib import Path -from typing import Any, Union +from typing import Any, AsyncGenerator from langchain import hub from langchain_chroma import Chroma @@ -37,19 +37,19 @@ class RetGenLangGraph: self.graph = graph_builder.compile(memory) self.last_retrieved_docs = [] - def stream(self, message: str, config: dict) -> Union[dict[str, Any], Any]: - for response, _ in self.graph.stream({"question": message}, stream_mode="messages", config=config): + async def stream(self, message: str, config: dict) -> AsyncGenerator[Any, Any]: + async for response, _ in self.graph.astream({"question": message}, stream_mode="messages", config=config): yield response.content def _retrieve(self, state: State) -> dict: self.last_retrieved_docs = self.vector_store.similarity_search(state["question"]) return {"context": self.last_retrieved_docs} - def _generate(self, state: State) -> dict: + async def _generate(self, state: State) -> AsyncGenerator[Any, Any]: docs_content = "\n\n".join(doc.page_content for doc in state["context"]) - messages = self.prompt.invoke({"question": state["question"], "context": docs_content}) - response = self.chat_model.invoke(messages) - return {"answer": response.content} + messages = await self.prompt.ainvoke({"question": state["question"], "context": docs_content}) + async for response in self.chat_model.astream(messages): + yield {"answer": response.content} def get_last_pdf_sources(self) -> dict[str, list[int]]: """