From ec4edf9c92f9de2c73861524e74689efb2120aa4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nielson=20Jann=C3=A9?= Date: Fri, 11 Apr 2025 17:13:14 +0200 Subject: [PATCH] Improve type hinting --- generic_rag/graphs/ret_gen.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/generic_rag/graphs/ret_gen.py b/generic_rag/graphs/ret_gen.py index e93cc89..b293e03 100644 --- a/generic_rag/graphs/ret_gen.py +++ b/generic_rag/graphs/ret_gen.py @@ -7,6 +7,7 @@ from langchain_core.documents import Document from langchain_core.embeddings import Embeddings from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.messages import BaseMessage, SystemMessage +from langchain_core.runnables.config import RunnableConfig from langgraph.checkpoint.memory import MemorySaver from langgraph.graph import END, START, StateGraph from typing_extensions import List, TypedDict @@ -38,11 +39,11 @@ class RetGenLangGraph: self.graph = graph_builder.compile(memory) self.last_retrieved_docs = [] - async def stream(self, message: str, config: dict) -> AsyncGenerator[Any, Any]: + async def stream(self, message: str, config: RunnableConfig | None = None) -> AsyncGenerator[Any, Any]: async for response, _ in self.graph.astream({"question": message}, stream_mode="messages", config=config): yield response.content - def _retrieve(self, state: State) -> dict: + def _retrieve(self, state: State) -> dict[str, list]: logger.debug(f"querying VS for: {state["question"]}") self.last_retrieved_docs = self.vector_store.similarity_search(state["question"]) return {"context": self.last_retrieved_docs}