diff --git a/generic_rag/app.py b/generic_rag/app.py index 69938b1..76272d8 100644 --- a/generic_rag/app.py +++ b/generic_rag/app.py @@ -65,7 +65,9 @@ class State(TypedDict): def retrieve(state: State): vector_store = cl.user_session.get("vector_store") + retrieved_docs = vector_store.similarity_search(state["question"]) + return {"context": retrieved_docs} @@ -76,6 +78,7 @@ def generate(state: State): docs_content = "\n\n".join(doc.page_content for doc in state["context"]) messages = prompt.invoke({"question": state["question"], "context": docs_content}) response = llm.invoke(messages) + return {"answer": response.content}