Change RetGenLangGraph to use streaming instead of invoking on the LLM

This commit is contained in:
Nielson Janné 2025-04-08 23:37:04 +02:00
parent 5d86ad6961
commit db3d1cfa20
2 changed files with 22 additions and 20 deletions

View File

@ -98,9 +98,8 @@ async def process_response(message):
chainlit_response = cl.Message(content="") chainlit_response = cl.Message(content="")
response = graph.invoke(message.content, config=config) for response in graph.stream(message.content, config=config):
await chainlit_response.stream_token(response)
await chainlit_response.stream_token(f"{response}\n")
pdf_sources = graph.get_last_pdf_sources() pdf_sources = graph.get_last_pdf_sources()
if len(pdf_sources) > 0: if len(pdf_sources) > 0:

View File

@ -3,7 +3,10 @@ from pathlib import Path
from typing import Any, Union from typing import Any, Union
from langchain import hub from langchain import hub
from langchain_chroma import Chroma
from langchain_core.documents import Document from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.language_models.chat_models import BaseChatModel
from langgraph.checkpoint.memory import MemorySaver from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import END, START, StateGraph from langgraph.graph import END, START, StateGraph
from typing_extensions import List, TypedDict from typing_extensions import List, TypedDict
@ -19,7 +22,7 @@ class State(TypedDict):
class RetGenLangGraph: class RetGenLangGraph:
def __init__(self, vector_store, chat_model, embedding_model): def __init__(self, vector_store: Chroma, chat_model: BaseChatModel, embedding_model: Embeddings):
self.vector_store = vector_store self.vector_store = vector_store
self.chat_model = chat_model self.chat_model = chat_model
self.embedding_model = embedding_model self.embedding_model = embedding_model
@ -32,15 +35,15 @@ class RetGenLangGraph:
graph_builder.add_edge("_generate", END) graph_builder.add_edge("_generate", END)
self.graph = graph_builder.compile(memory) self.graph = graph_builder.compile(memory)
self.last_invoke = None self.last_retrieved_docs = []
def invoke(self, message: str, config: dict) -> Union[dict[str, Any], Any]: def stream(self, message: str, config: dict) -> Union[dict[str, Any], Any]:
self.last_invoke = self.graph.invoke({"question": message}, config=config) for response, _ in self.graph.stream({"question": message}, stream_mode="messages", config=config):
return self.last_invoke["answer"] yield response.content
def _retrieve(self, state: State) -> dict: def _retrieve(self, state: State) -> dict:
retrieved_docs = self.vector_store.similarity_search(state["question"]) self.last_retrieved_docs = self.vector_store.similarity_search(state["question"])
return {"context": retrieved_docs} return {"context": self.last_retrieved_docs}
def _generate(self, state: State) -> dict: def _generate(self, state: State) -> dict:
docs_content = "\n\n".join(doc.page_content for doc in state["context"]) docs_content = "\n\n".join(doc.page_content for doc in state["context"])
@ -54,30 +57,30 @@ class RetGenLangGraph:
""" """
pdf_sources = {} pdf_sources = {}
if self.last_invoke is None: if not self.last_retrieved_docs:
return pdf_sources return pdf_sources
for context in self.last_invoke["context"]: for doc in self.last_retrieved_docs:
try: try:
Path(context.metadata["source"]).suffix == ".pdf" Path(doc.metadata["source"]).suffix == ".pdf"
except KeyError: except KeyError:
continue continue
else: else:
source = context.metadata["source"] source = doc.metadata["source"]
if source not in pdf_sources: if source not in pdf_sources:
pdf_sources[source] = set() pdf_sources[source] = set()
# The page numbers are in the `page_numer` and `page` fields. # The page numbers are in the `page_numer` and `page` fields.
try: try:
page_number = context.metadata["page_number"] page_number = doc.metadata["page_number"]
except KeyError: except KeyError:
pass pass
else: else:
pdf_sources[source].add(page_number) pdf_sources[source].add(page_number)
try: try:
page_number = context.metadata["page"] page_number = doc.metadata["page"]
except KeyError: except KeyError:
pass pass
else: else:
@ -94,15 +97,15 @@ class RetGenLangGraph:
""" """
web_sources = set() web_sources = set()
if self.last_invoke is None: if not self.last_retrieved_docs:
return web_sources return web_sources
for context in self.last_invoke["context"]: for doc in self.last_retrieved_docs:
try: try:
context.metadata["filetype"] == "web" doc.metadata["filetype"] == "web"
except KeyError: except KeyError:
continue continue
else: else:
web_sources.add(context.metadata["source"]) web_sources.add(doc.metadata["source"])
return web_sources return web_sources