diff --git a/generic_rag/graphs/ret_gen.py b/generic_rag/graphs/ret_gen.py index e93cc89..b293e03 100644 --- a/generic_rag/graphs/ret_gen.py +++ b/generic_rag/graphs/ret_gen.py @@ -7,6 +7,7 @@ from langchain_core.documents import Document from langchain_core.embeddings import Embeddings from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.messages import BaseMessage, SystemMessage +from langchain_core.runnables.config import RunnableConfig from langgraph.checkpoint.memory import MemorySaver from langgraph.graph import END, START, StateGraph from typing_extensions import List, TypedDict @@ -38,11 +39,11 @@ class RetGenLangGraph: self.graph = graph_builder.compile(memory) self.last_retrieved_docs = [] - async def stream(self, message: str, config: dict) -> AsyncGenerator[Any, Any]: + async def stream(self, message: str, config: RunnableConfig | None = None) -> AsyncGenerator[Any, Any]: async for response, _ in self.graph.astream({"question": message}, stream_mode="messages", config=config): yield response.content - def _retrieve(self, state: State) -> dict: + def _retrieve(self, state: State) -> dict[str, list]: logger.debug(f"querying VS for: {state["question"]}") self.last_retrieved_docs = self.vector_store.similarity_search(state["question"]) return {"context": self.last_retrieved_docs}