Merge pull request 'Change RetGenLangGraph to use streaming instead of invoking on the LLM' (#17) from reg_gen_stream into main

Reviewed-on: AI_team/generic-RAG-demo#17
This commit is contained in:
rubenl 2025-04-09 11:21:13 +02:00
commit 6ad6ac4a34
2 changed files with 27 additions and 26 deletions

View File

@ -98,9 +98,8 @@ async def process_response(message):
chainlit_response = cl.Message(content="") chainlit_response = cl.Message(content="")
response = graph.invoke(message.content, config=config) async for response in graph.stream(message.content, config=config):
await chainlit_response.stream_token(response)
await chainlit_response.stream_token(f"{response}\n")
pdf_sources = graph.get_last_pdf_sources() pdf_sources = graph.get_last_pdf_sources()
if len(pdf_sources) > 0: if len(pdf_sources) > 0:

View File

@ -1,9 +1,12 @@
import logging import logging
from pathlib import Path from pathlib import Path
from typing import Any, Union from typing import Any, AsyncGenerator
from langchain import hub from langchain import hub
from langchain_chroma import Chroma
from langchain_core.documents import Document from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.language_models.chat_models import BaseChatModel
from langgraph.checkpoint.memory import MemorySaver from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import END, START, StateGraph from langgraph.graph import END, START, StateGraph
from typing_extensions import List, TypedDict from typing_extensions import List, TypedDict
@ -19,7 +22,7 @@ class State(TypedDict):
class RetGenLangGraph: class RetGenLangGraph:
def __init__(self, vector_store, chat_model, embedding_model): def __init__(self, vector_store: Chroma, chat_model: BaseChatModel, embedding_model: Embeddings):
self.vector_store = vector_store self.vector_store = vector_store
self.chat_model = chat_model self.chat_model = chat_model
self.embedding_model = embedding_model self.embedding_model = embedding_model
@ -32,21 +35,21 @@ class RetGenLangGraph:
graph_builder.add_edge("_generate", END) graph_builder.add_edge("_generate", END)
self.graph = graph_builder.compile(memory) self.graph = graph_builder.compile(memory)
self.last_invoke = None self.last_retrieved_docs = []
def invoke(self, message: str, config: dict) -> Union[dict[str, Any], Any]: async def stream(self, message: str, config: dict) -> AsyncGenerator[Any, Any]:
self.last_invoke = self.graph.invoke({"question": message}, config=config) async for response, _ in self.graph.astream({"question": message}, stream_mode="messages", config=config):
return self.last_invoke["answer"] yield response.content
def _retrieve(self, state: State) -> dict: def _retrieve(self, state: State) -> dict:
retrieved_docs = self.vector_store.similarity_search(state["question"]) self.last_retrieved_docs = self.vector_store.similarity_search(state["question"])
return {"context": retrieved_docs} return {"context": self.last_retrieved_docs}
def _generate(self, state: State) -> dict: async def _generate(self, state: State) -> AsyncGenerator[Any, Any]:
docs_content = "\n\n".join(doc.page_content for doc in state["context"]) docs_content = "\n\n".join(doc.page_content for doc in state["context"])
messages = self.prompt.invoke({"question": state["question"], "context": docs_content}) messages = await self.prompt.ainvoke({"question": state["question"], "context": docs_content})
response = self.chat_model.invoke(messages) async for response in self.chat_model.astream(messages):
return {"answer": response.content} yield {"answer": response.content}
def get_last_pdf_sources(self) -> dict[str, list[int]]: def get_last_pdf_sources(self) -> dict[str, list[int]]:
""" """
@ -54,30 +57,30 @@ class RetGenLangGraph:
""" """
pdf_sources = {} pdf_sources = {}
if self.last_invoke is None: if not self.last_retrieved_docs:
return pdf_sources return pdf_sources
for context in self.last_invoke["context"]: for doc in self.last_retrieved_docs:
try: try:
Path(context.metadata["source"]).suffix == ".pdf" Path(doc.metadata["source"]).suffix == ".pdf"
except KeyError: except KeyError:
continue continue
else: else:
source = context.metadata["source"] source = doc.metadata["source"]
if source not in pdf_sources: if source not in pdf_sources:
pdf_sources[source] = set() pdf_sources[source] = set()
# The page numbers are in the `page_numer` and `page` fields. # The page numbers are in the `page_numer` and `page` fields.
try: try:
page_number = context.metadata["page_number"] page_number = doc.metadata["page_number"]
except KeyError: except KeyError:
pass pass
else: else:
pdf_sources[source].add(page_number) pdf_sources[source].add(page_number)
try: try:
page_number = context.metadata["page"] page_number = doc.metadata["page"]
except KeyError: except KeyError:
pass pass
else: else:
@ -94,15 +97,14 @@ class RetGenLangGraph:
""" """
web_sources = set() web_sources = set()
if self.last_invoke is None: if not self.last_retrieved_docs:
return web_sources return web_sources
for context in self.last_invoke["context"]: for doc in self.last_retrieved_docs:
try: try:
context.metadata["filetype"] == "web" if doc.metadata["filetype"] == "web":
web_sources.add(doc.metadata["source"])
except KeyError: except KeyError:
continue continue
else:
web_sources.add(context.metadata["source"])
return web_sources return web_sources