From 92b57224fa189e1414de2ad728ea6fb5c3cb2df3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nielson=20Jann=C3=A9?= Date: Fri, 11 Apr 2025 17:08:26 +0200 Subject: [PATCH 1/5] Use a single system prompt everywhere --- generic_rag/app.py | 9 +++++++++ generic_rag/graphs/cond_ret_gen.py | 11 ++++------- generic_rag/graphs/ret_gen.py | 21 +++++++++++++-------- 3 files changed, 26 insertions(+), 15 deletions(-) diff --git a/generic_rag/app.py b/generic_rag/app.py index 73bb3f5..dfe5fa4 100644 --- a/generic_rag/app.py +++ b/generic_rag/app.py @@ -15,6 +15,13 @@ from parsers.parser import add_pdf_files, add_urls logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger(__name__) +system_prompt = ( + "You are an assistant for question-answering tasks. " + "If the question is in Dutch, answer in Dutch. If the question is in English, answer in English." + "Use the following pieces of retrieved context to answer the question. " + "If you don't know the answer, say that you don't know." +) + parser = argparse.ArgumentParser(description="A Sogeti Nederland Generic RAG demo.") parser.add_argument( "-c", @@ -87,12 +94,14 @@ if args.use_conditional_graph: vector_store, chat_model=get_chat_model(args.chat_backend), embedding_model=get_embedding_model(args.emb_backend), + system_prompt=system_prompt, ) else: graph = RetGenLangGraph( vector_store, chat_model=get_chat_model(args.chat_backend), embedding_model=get_embedding_model(args.emb_backend), + system_prompt=system_prompt, ) diff --git a/generic_rag/graphs/cond_ret_gen.py b/generic_rag/graphs/cond_ret_gen.py index 8fb554c..027a21b 100644 --- a/generic_rag/graphs/cond_ret_gen.py +++ b/generic_rag/graphs/cond_ret_gen.py @@ -19,15 +19,12 @@ from typing_extensions import Annotated logger = logging.getLogger(__name__) class CondRetGenLangGraph: - def __init__(self, vector_store: Chroma, chat_model: BaseChatModel, embedding_model: Embeddings): + def __init__( + self, vector_store: Chroma, chat_model: BaseChatModel, embedding_model: Embeddings, system_prompt: str + ): self.chat_model = chat_model self.embedding_model = embedding_model - self.system_prompt = ( - "You are an assistant for question-answering tasks. " - "If the question is in Dutch, answer in Dutch. If the question is in English, answer in English." - "Use the following pieces of retrieved context to answer the question. " - "If you don't know the answer, say that you don't know." - ) + self.system_prompt = system_prompt memory = MemorySaver() tools = ToolNode([self._retrieve]) diff --git a/generic_rag/graphs/ret_gen.py b/generic_rag/graphs/ret_gen.py index c4262c3..e93cc89 100644 --- a/generic_rag/graphs/ret_gen.py +++ b/generic_rag/graphs/ret_gen.py @@ -2,11 +2,11 @@ import logging from pathlib import Path from typing import Any, AsyncGenerator -from langchain import hub from langchain_chroma import Chroma from langchain_core.documents import Document from langchain_core.embeddings import Embeddings from langchain_core.language_models.chat_models import BaseChatModel +from langchain_core.messages import BaseMessage, SystemMessage from langgraph.checkpoint.memory import MemorySaver from langgraph.graph import END, START, StateGraph from typing_extensions import List, TypedDict @@ -17,15 +17,17 @@ logger = logging.getLogger(__name__) class State(TypedDict): question: str context: List[Document] - answer: str + answer: BaseMessage class RetGenLangGraph: - def __init__(self, vector_store: Chroma, chat_model: BaseChatModel, embedding_model: Embeddings): + def __init__( + self, vector_store: Chroma, chat_model: BaseChatModel, embedding_model: Embeddings, system_prompt: str + ): self.vector_store = vector_store self.chat_model = chat_model self.embedding_model = embedding_model - self.prompt = hub.pull("rlm/rag-prompt") + self.system_prompt = system_prompt memory = MemorySaver() graph_builder = StateGraph(State).add_sequence([self._retrieve, self._generate]) @@ -45,11 +47,14 @@ class RetGenLangGraph: self.last_retrieved_docs = self.vector_store.similarity_search(state["question"]) return {"context": self.last_retrieved_docs} - async def _generate(self, state: State) -> AsyncGenerator[Any, Any]: + def _generate(self, state: State) -> dict[str, list]: docs_content = "\n\n".join(doc.page_content for doc in state["context"]) - messages = await self.prompt.ainvoke({"question": state["question"], "context": docs_content}) - async for response in self.chat_model.astream(messages): - yield {"answer": response.content} + system_message_content = self.system_prompt + f"\n\n{docs_content}" + + prompt = [SystemMessage(system_message_content)] + [state["question"]] + + response = self.chat_model.invoke(prompt) + return {"answer": [response]} def get_last_pdf_sources(self) -> dict[str, list[int]]: """ From ec4edf9c92f9de2c73861524e74689efb2120aa4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nielson=20Jann=C3=A9?= Date: Fri, 11 Apr 2025 17:13:14 +0200 Subject: [PATCH 2/5] Improve type hinting --- generic_rag/graphs/ret_gen.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/generic_rag/graphs/ret_gen.py b/generic_rag/graphs/ret_gen.py index e93cc89..b293e03 100644 --- a/generic_rag/graphs/ret_gen.py +++ b/generic_rag/graphs/ret_gen.py @@ -7,6 +7,7 @@ from langchain_core.documents import Document from langchain_core.embeddings import Embeddings from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.messages import BaseMessage, SystemMessage +from langchain_core.runnables.config import RunnableConfig from langgraph.checkpoint.memory import MemorySaver from langgraph.graph import END, START, StateGraph from typing_extensions import List, TypedDict @@ -38,11 +39,11 @@ class RetGenLangGraph: self.graph = graph_builder.compile(memory) self.last_retrieved_docs = [] - async def stream(self, message: str, config: dict) -> AsyncGenerator[Any, Any]: + async def stream(self, message: str, config: RunnableConfig | None = None) -> AsyncGenerator[Any, Any]: async for response, _ in self.graph.astream({"question": message}, stream_mode="messages", config=config): yield response.content - def _retrieve(self, state: State) -> dict: + def _retrieve(self, state: State) -> dict[str, list]: logger.debug(f"querying VS for: {state["question"]}") self.last_retrieved_docs = self.vector_store.similarity_search(state["question"]) return {"context": self.last_retrieved_docs} From 7277cd6ff9b5b930e7c24cfc46144acf722a19d2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nielson=20Jann=C3=A9?= Date: Fri, 11 Apr 2025 17:13:44 +0200 Subject: [PATCH 3/5] Factor out some duplication --- generic_rag/app.py | 46 +++++++++--------------------- generic_rag/graphs/cond_ret_gen.py | 4 +-- 2 files changed, 16 insertions(+), 34 deletions(-) diff --git a/generic_rag/app.py b/generic_rag/app.py index dfe5fa4..aabbee1 100644 --- a/generic_rag/app.py +++ b/generic_rag/app.py @@ -107,13 +107,22 @@ else: @cl.on_message async def on_message(message: cl.Message): + config = {"configurable": {"thread_id": cl.user_session.get("id")}} + + chainlit_response = cl.Message(content="") + + async for response in graph.stream(message.content, config=config): + await chainlit_response.stream_token(response) + + if isinstance(graph, RetGenLangGraph): + await add_sources(chainlit_response, graph.get_last_pdf_sources(), graph.get_last_web_sources()) if isinstance(graph, CondRetGenLangGraph): - await process_cond_response(message) - elif isinstance(graph, RetGenLangGraph): - await process_response(message) + await add_sources(chainlit_response, graph.last_retrieved_docs, graph.last_retrieved_sources) + + await chainlit_response.send() -async def add_sources(chainlit_response: cl.Message, pdf_sources: dict, web_sources: set | list): +async def add_sources(chainlit_response: cl.Message, pdf_sources: dict, web_sources: set | list) -> None: if len(pdf_sources) > 0: await chainlit_response.stream_token("\n\nThe following PDF source were consulted:\n") for source, page_numbers in pdf_sources.items(): @@ -123,40 +132,13 @@ async def add_sources(chainlit_response: cl.Message, pdf_sources: dict, web_sour chainlit_response.elements.append(cl.Pdf(name="pdf", display="inline", path=source, page=page_numbers[0])) await chainlit_response.update() await chainlit_response.stream_token(f"- '{source}' on page(s): {page_numbers}\n") + if len(web_sources) > 0: await chainlit_response.stream_token("\n\nThe following web sources were consulted:\n") for source in web_sources: await chainlit_response.stream_token(f"- {source}\n") -async def process_response(message): - config = {"configurable": {"thread_id": cl.user_session.get("id")}} - - chainlit_response = cl.Message(content="") - - async for response in graph.stream(message.content, config=config): - await chainlit_response.stream_token(response) - - pdf_sources = graph.get_last_pdf_sources() - web_sources = graph.get_last_web_sources() - await add_sources(chainlit_response, pdf_sources, web_sources) - - await chainlit_response.send() - - -async def process_cond_response(message): - config = {"configurable": {"thread_id": cl.user_session.get("id")}} - - chainlit_response = cl.Message(content="") - - for response in graph.stream(message.content, config=config): - await chainlit_response.stream_token(response) - - await add_sources(chainlit_response, graph.last_retrieved_docs, graph.last_retrieved_sources) - - await chainlit_response.send() - - @cl.set_starters async def set_starters(): chainlit_starters = os.environ.get("CHAINLIT_STARTERS", None) diff --git a/generic_rag/graphs/cond_ret_gen.py b/generic_rag/graphs/cond_ret_gen.py index 027a21b..65845b2 100644 --- a/generic_rag/graphs/cond_ret_gen.py +++ b/generic_rag/graphs/cond_ret_gen.py @@ -49,8 +49,8 @@ class CondRetGenLangGraph: self.last_retrieved_docs = {} self.last_retrieved_sources = set() - def stream(self, message: str, config: RunnableConfig | None = None) -> Iterator[str]: - for llm_response, metadata in self.graph.stream( + async def stream(self, message: str, config: RunnableConfig | None = None) -> AsyncGenerator[Any, Any]: + async for llm_response, metadata in self.graph.astream( {"messages": [{"role": "user", "content": message}]}, stream_mode="messages", config=config ): if llm_response.content and metadata["langgraph_node"] == "_generate": From 939a7044b53329d3c6401e87f369845cd1d253dc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nielson=20Jann=C3=A9?= Date: Fri, 11 Apr 2025 17:13:55 +0200 Subject: [PATCH 4/5] Organize imports --- generic_rag/graphs/cond_ret_gen.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/generic_rag/graphs/cond_ret_gen.py b/generic_rag/graphs/cond_ret_gen.py index 65845b2..b2057be 100644 --- a/generic_rag/graphs/cond_ret_gen.py +++ b/generic_rag/graphs/cond_ret_gen.py @@ -1,16 +1,16 @@ -import logging -from typing import Any, Iterator -import re import ast +import logging +import re from pathlib import Path +from typing import Any, AsyncGenerator from langchain_chroma import Chroma from langchain_core.documents import Document from langchain_core.embeddings import Embeddings from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.messages import BaseMessage, SystemMessage -from langchain_core.tools import tool from langchain_core.runnables.config import RunnableConfig +from langchain_core.tools import tool from langgraph.checkpoint.memory import MemorySaver from langgraph.graph import END, MessagesState, StateGraph from langgraph.prebuilt import InjectedStore, ToolNode, tools_condition From 5a15f1bd161080d74866f63bbeccc114d76883c8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nielson=20Jann=C3=A9?= Date: Fri, 11 Apr 2025 17:14:16 +0200 Subject: [PATCH 5/5] define param explicit --- generic_rag/app.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/generic_rag/app.py b/generic_rag/app.py index aabbee1..999b910 100644 --- a/generic_rag/app.py +++ b/generic_rag/app.py @@ -91,14 +91,14 @@ vector_store = Chroma( if args.use_conditional_graph: graph = CondRetGenLangGraph( - vector_store, + vector_store=vector_store, chat_model=get_chat_model(args.chat_backend), embedding_model=get_embedding_model(args.emb_backend), system_prompt=system_prompt, ) else: graph = RetGenLangGraph( - vector_store, + vector_store=vector_store, chat_model=get_chat_model(args.chat_backend), embedding_model=get_embedding_model(args.emb_backend), system_prompt=system_prompt,