forked from AI_team/Philosophy-RAG-demo
Change RetGenLangGraph to use streaming instead of invoking on the LLM
This commit is contained in:
parent
5d86ad6961
commit
db3d1cfa20
@ -98,9 +98,8 @@ async def process_response(message):
|
|||||||
|
|
||||||
chainlit_response = cl.Message(content="")
|
chainlit_response = cl.Message(content="")
|
||||||
|
|
||||||
response = graph.invoke(message.content, config=config)
|
for response in graph.stream(message.content, config=config):
|
||||||
|
await chainlit_response.stream_token(response)
|
||||||
await chainlit_response.stream_token(f"{response}\n")
|
|
||||||
|
|
||||||
pdf_sources = graph.get_last_pdf_sources()
|
pdf_sources = graph.get_last_pdf_sources()
|
||||||
if len(pdf_sources) > 0:
|
if len(pdf_sources) > 0:
|
||||||
|
|||||||
@ -3,7 +3,10 @@ from pathlib import Path
|
|||||||
from typing import Any, Union
|
from typing import Any, Union
|
||||||
|
|
||||||
from langchain import hub
|
from langchain import hub
|
||||||
|
from langchain_chroma import Chroma
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
|
from langchain_core.embeddings import Embeddings
|
||||||
|
from langchain_core.language_models.chat_models import BaseChatModel
|
||||||
from langgraph.checkpoint.memory import MemorySaver
|
from langgraph.checkpoint.memory import MemorySaver
|
||||||
from langgraph.graph import END, START, StateGraph
|
from langgraph.graph import END, START, StateGraph
|
||||||
from typing_extensions import List, TypedDict
|
from typing_extensions import List, TypedDict
|
||||||
@ -19,7 +22,7 @@ class State(TypedDict):
|
|||||||
|
|
||||||
|
|
||||||
class RetGenLangGraph:
|
class RetGenLangGraph:
|
||||||
def __init__(self, vector_store, chat_model, embedding_model):
|
def __init__(self, vector_store: Chroma, chat_model: BaseChatModel, embedding_model: Embeddings):
|
||||||
self.vector_store = vector_store
|
self.vector_store = vector_store
|
||||||
self.chat_model = chat_model
|
self.chat_model = chat_model
|
||||||
self.embedding_model = embedding_model
|
self.embedding_model = embedding_model
|
||||||
@ -32,15 +35,15 @@ class RetGenLangGraph:
|
|||||||
graph_builder.add_edge("_generate", END)
|
graph_builder.add_edge("_generate", END)
|
||||||
|
|
||||||
self.graph = graph_builder.compile(memory)
|
self.graph = graph_builder.compile(memory)
|
||||||
self.last_invoke = None
|
self.last_retrieved_docs = []
|
||||||
|
|
||||||
def invoke(self, message: str, config: dict) -> Union[dict[str, Any], Any]:
|
def stream(self, message: str, config: dict) -> Union[dict[str, Any], Any]:
|
||||||
self.last_invoke = self.graph.invoke({"question": message}, config=config)
|
for response, _ in self.graph.stream({"question": message}, stream_mode="messages", config=config):
|
||||||
return self.last_invoke["answer"]
|
yield response.content
|
||||||
|
|
||||||
def _retrieve(self, state: State) -> dict:
|
def _retrieve(self, state: State) -> dict:
|
||||||
retrieved_docs = self.vector_store.similarity_search(state["question"])
|
self.last_retrieved_docs = self.vector_store.similarity_search(state["question"])
|
||||||
return {"context": retrieved_docs}
|
return {"context": self.last_retrieved_docs}
|
||||||
|
|
||||||
def _generate(self, state: State) -> dict:
|
def _generate(self, state: State) -> dict:
|
||||||
docs_content = "\n\n".join(doc.page_content for doc in state["context"])
|
docs_content = "\n\n".join(doc.page_content for doc in state["context"])
|
||||||
@ -54,30 +57,30 @@ class RetGenLangGraph:
|
|||||||
"""
|
"""
|
||||||
pdf_sources = {}
|
pdf_sources = {}
|
||||||
|
|
||||||
if self.last_invoke is None:
|
if not self.last_retrieved_docs:
|
||||||
return pdf_sources
|
return pdf_sources
|
||||||
|
|
||||||
for context in self.last_invoke["context"]:
|
for doc in self.last_retrieved_docs:
|
||||||
try:
|
try:
|
||||||
Path(context.metadata["source"]).suffix == ".pdf"
|
Path(doc.metadata["source"]).suffix == ".pdf"
|
||||||
except KeyError:
|
except KeyError:
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
source = context.metadata["source"]
|
source = doc.metadata["source"]
|
||||||
|
|
||||||
if source not in pdf_sources:
|
if source not in pdf_sources:
|
||||||
pdf_sources[source] = set()
|
pdf_sources[source] = set()
|
||||||
|
|
||||||
# The page numbers are in the `page_numer` and `page` fields.
|
# The page numbers are in the `page_numer` and `page` fields.
|
||||||
try:
|
try:
|
||||||
page_number = context.metadata["page_number"]
|
page_number = doc.metadata["page_number"]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
pdf_sources[source].add(page_number)
|
pdf_sources[source].add(page_number)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
page_number = context.metadata["page"]
|
page_number = doc.metadata["page"]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
@ -94,15 +97,15 @@ class RetGenLangGraph:
|
|||||||
"""
|
"""
|
||||||
web_sources = set()
|
web_sources = set()
|
||||||
|
|
||||||
if self.last_invoke is None:
|
if not self.last_retrieved_docs:
|
||||||
return web_sources
|
return web_sources
|
||||||
|
|
||||||
for context in self.last_invoke["context"]:
|
for doc in self.last_retrieved_docs:
|
||||||
try:
|
try:
|
||||||
context.metadata["filetype"] == "web"
|
doc.metadata["filetype"] == "web"
|
||||||
except KeyError:
|
except KeyError:
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
web_sources.add(context.metadata["source"])
|
web_sources.add(doc.metadata["source"])
|
||||||
|
|
||||||
return web_sources
|
return web_sources
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user