Improve type hinting

This commit is contained in:
Nielson Janné 2025-04-11 17:13:14 +02:00
parent 92b57224fa
commit ec4edf9c92

View File

@ -7,6 +7,7 @@ from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings from langchain_core.embeddings import Embeddings
from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import BaseMessage, SystemMessage from langchain_core.messages import BaseMessage, SystemMessage
from langchain_core.runnables.config import RunnableConfig
from langgraph.checkpoint.memory import MemorySaver from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import END, START, StateGraph from langgraph.graph import END, START, StateGraph
from typing_extensions import List, TypedDict from typing_extensions import List, TypedDict
@ -38,11 +39,11 @@ class RetGenLangGraph:
self.graph = graph_builder.compile(memory) self.graph = graph_builder.compile(memory)
self.last_retrieved_docs = [] self.last_retrieved_docs = []
async def stream(self, message: str, config: dict) -> AsyncGenerator[Any, Any]: async def stream(self, message: str, config: RunnableConfig | None = None) -> AsyncGenerator[Any, Any]:
async for response, _ in self.graph.astream({"question": message}, stream_mode="messages", config=config): async for response, _ in self.graph.astream({"question": message}, stream_mode="messages", config=config):
yield response.content yield response.content
def _retrieve(self, state: State) -> dict: def _retrieve(self, state: State) -> dict[str, list]:
logger.debug(f"querying VS for: {state["question"]}") logger.debug(f"querying VS for: {state["question"]}")
self.last_retrieved_docs = self.vector_store.similarity_search(state["question"]) self.last_retrieved_docs = self.vector_store.similarity_search(state["question"])
return {"context": self.last_retrieved_docs} return {"context": self.last_retrieved_docs}