diff --git a/generic_rag/app.py b/generic_rag/app.py index 274a7ac..c76a253 100644 --- a/generic_rag/app.py +++ b/generic_rag/app.py @@ -98,9 +98,8 @@ async def process_response(message): chainlit_response = cl.Message(content="") - response = graph.invoke(message.content, config=config) - - await chainlit_response.stream_token(f"{response}\n") + for response in graph.stream(message.content, config=config): + await chainlit_response.stream_token(response) pdf_sources = graph.get_last_pdf_sources() if len(pdf_sources) > 0: diff --git a/generic_rag/graphs/ret_gen.py b/generic_rag/graphs/ret_gen.py index c30d6e2..3c332ed 100644 --- a/generic_rag/graphs/ret_gen.py +++ b/generic_rag/graphs/ret_gen.py @@ -3,7 +3,10 @@ from pathlib import Path from typing import Any, Union from langchain import hub +from langchain_chroma import Chroma from langchain_core.documents import Document +from langchain_core.embeddings import Embeddings +from langchain_core.language_models.chat_models import BaseChatModel from langgraph.checkpoint.memory import MemorySaver from langgraph.graph import END, START, StateGraph from typing_extensions import List, TypedDict @@ -19,7 +22,7 @@ class State(TypedDict): class RetGenLangGraph: - def __init__(self, vector_store, chat_model, embedding_model): + def __init__(self, vector_store: Chroma, chat_model: BaseChatModel, embedding_model: Embeddings): self.vector_store = vector_store self.chat_model = chat_model self.embedding_model = embedding_model @@ -32,15 +35,15 @@ class RetGenLangGraph: graph_builder.add_edge("_generate", END) self.graph = graph_builder.compile(memory) - self.last_invoke = None + self.last_retrieved_docs = [] - def invoke(self, message: str, config: dict) -> Union[dict[str, Any], Any]: - self.last_invoke = self.graph.invoke({"question": message}, config=config) - return self.last_invoke["answer"] + def stream(self, message: str, config: dict) -> Union[dict[str, Any], Any]: + for response, _ in self.graph.stream({"question": message}, stream_mode="messages", config=config): + yield response.content def _retrieve(self, state: State) -> dict: - retrieved_docs = self.vector_store.similarity_search(state["question"]) - return {"context": retrieved_docs} + self.last_retrieved_docs = self.vector_store.similarity_search(state["question"]) + return {"context": self.last_retrieved_docs} def _generate(self, state: State) -> dict: docs_content = "\n\n".join(doc.page_content for doc in state["context"]) @@ -54,30 +57,30 @@ class RetGenLangGraph: """ pdf_sources = {} - if self.last_invoke is None: + if not self.last_retrieved_docs: return pdf_sources - for context in self.last_invoke["context"]: + for doc in self.last_retrieved_docs: try: - Path(context.metadata["source"]).suffix == ".pdf" + Path(doc.metadata["source"]).suffix == ".pdf" except KeyError: continue else: - source = context.metadata["source"] + source = doc.metadata["source"] if source not in pdf_sources: pdf_sources[source] = set() # The page numbers are in the `page_numer` and `page` fields. try: - page_number = context.metadata["page_number"] + page_number = doc.metadata["page_number"] except KeyError: pass else: pdf_sources[source].add(page_number) try: - page_number = context.metadata["page"] + page_number = doc.metadata["page"] except KeyError: pass else: @@ -94,15 +97,15 @@ class RetGenLangGraph: """ web_sources = set() - if self.last_invoke is None: + if not self.last_retrieved_docs: return web_sources - for context in self.last_invoke["context"]: + for doc in self.last_retrieved_docs: try: - context.metadata["filetype"] == "web" + doc.metadata["filetype"] == "web" except KeyError: continue else: - web_sources.add(context.metadata["source"]) + web_sources.add(doc.metadata["source"]) return web_sources