Make RetGenGrap async for smoother user experience

This commit is contained in:
Nielson Janné 2025-04-09 11:16:55 +02:00
parent db3d1cfa20
commit 9bb9f0ea22
2 changed files with 8 additions and 8 deletions

View File

@ -98,7 +98,7 @@ async def process_response(message):
chainlit_response = cl.Message(content="") chainlit_response = cl.Message(content="")
for response in graph.stream(message.content, config=config): async for response in graph.stream(message.content, config=config):
await chainlit_response.stream_token(response) await chainlit_response.stream_token(response)
pdf_sources = graph.get_last_pdf_sources() pdf_sources = graph.get_last_pdf_sources()

View File

@ -1,6 +1,6 @@
import logging import logging
from pathlib import Path from pathlib import Path
from typing import Any, Union from typing import Any, AsyncGenerator
from langchain import hub from langchain import hub
from langchain_chroma import Chroma from langchain_chroma import Chroma
@ -37,19 +37,19 @@ class RetGenLangGraph:
self.graph = graph_builder.compile(memory) self.graph = graph_builder.compile(memory)
self.last_retrieved_docs = [] self.last_retrieved_docs = []
def stream(self, message: str, config: dict) -> Union[dict[str, Any], Any]: async def stream(self, message: str, config: dict) -> AsyncGenerator[Any, Any]:
for response, _ in self.graph.stream({"question": message}, stream_mode="messages", config=config): async for response, _ in self.graph.astream({"question": message}, stream_mode="messages", config=config):
yield response.content yield response.content
def _retrieve(self, state: State) -> dict: def _retrieve(self, state: State) -> dict:
self.last_retrieved_docs = self.vector_store.similarity_search(state["question"]) self.last_retrieved_docs = self.vector_store.similarity_search(state["question"])
return {"context": self.last_retrieved_docs} return {"context": self.last_retrieved_docs}
def _generate(self, state: State) -> dict: async def _generate(self, state: State) -> AsyncGenerator[Any, Any]:
docs_content = "\n\n".join(doc.page_content for doc in state["context"]) docs_content = "\n\n".join(doc.page_content for doc in state["context"])
messages = self.prompt.invoke({"question": state["question"], "context": docs_content}) messages = await self.prompt.ainvoke({"question": state["question"], "context": docs_content})
response = self.chat_model.invoke(messages) async for response in self.chat_model.astream(messages):
return {"answer": response.content} yield {"answer": response.content}
def get_last_pdf_sources(self) -> dict[str, list[int]]: def get_last_pdf_sources(self) -> dict[str, list[int]]:
""" """