From 92b57224fa189e1414de2ad728ea6fb5c3cb2df3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nielson=20Jann=C3=A9?= Date: Fri, 11 Apr 2025 17:08:26 +0200 Subject: [PATCH] Use a single system prompt everywhere --- generic_rag/app.py | 9 +++++++++ generic_rag/graphs/cond_ret_gen.py | 11 ++++------- generic_rag/graphs/ret_gen.py | 21 +++++++++++++-------- 3 files changed, 26 insertions(+), 15 deletions(-) diff --git a/generic_rag/app.py b/generic_rag/app.py index 73bb3f5..dfe5fa4 100644 --- a/generic_rag/app.py +++ b/generic_rag/app.py @@ -15,6 +15,13 @@ from parsers.parser import add_pdf_files, add_urls logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger(__name__) +system_prompt = ( + "You are an assistant for question-answering tasks. " + "If the question is in Dutch, answer in Dutch. If the question is in English, answer in English." + "Use the following pieces of retrieved context to answer the question. " + "If you don't know the answer, say that you don't know." +) + parser = argparse.ArgumentParser(description="A Sogeti Nederland Generic RAG demo.") parser.add_argument( "-c", @@ -87,12 +94,14 @@ if args.use_conditional_graph: vector_store, chat_model=get_chat_model(args.chat_backend), embedding_model=get_embedding_model(args.emb_backend), + system_prompt=system_prompt, ) else: graph = RetGenLangGraph( vector_store, chat_model=get_chat_model(args.chat_backend), embedding_model=get_embedding_model(args.emb_backend), + system_prompt=system_prompt, ) diff --git a/generic_rag/graphs/cond_ret_gen.py b/generic_rag/graphs/cond_ret_gen.py index 8fb554c..027a21b 100644 --- a/generic_rag/graphs/cond_ret_gen.py +++ b/generic_rag/graphs/cond_ret_gen.py @@ -19,15 +19,12 @@ from typing_extensions import Annotated logger = logging.getLogger(__name__) class CondRetGenLangGraph: - def __init__(self, vector_store: Chroma, chat_model: BaseChatModel, embedding_model: Embeddings): + def __init__( + self, vector_store: Chroma, chat_model: BaseChatModel, embedding_model: Embeddings, system_prompt: str + ): self.chat_model = chat_model self.embedding_model = embedding_model - self.system_prompt = ( - "You are an assistant for question-answering tasks. " - "If the question is in Dutch, answer in Dutch. If the question is in English, answer in English." - "Use the following pieces of retrieved context to answer the question. " - "If you don't know the answer, say that you don't know." - ) + self.system_prompt = system_prompt memory = MemorySaver() tools = ToolNode([self._retrieve]) diff --git a/generic_rag/graphs/ret_gen.py b/generic_rag/graphs/ret_gen.py index c4262c3..e93cc89 100644 --- a/generic_rag/graphs/ret_gen.py +++ b/generic_rag/graphs/ret_gen.py @@ -2,11 +2,11 @@ import logging from pathlib import Path from typing import Any, AsyncGenerator -from langchain import hub from langchain_chroma import Chroma from langchain_core.documents import Document from langchain_core.embeddings import Embeddings from langchain_core.language_models.chat_models import BaseChatModel +from langchain_core.messages import BaseMessage, SystemMessage from langgraph.checkpoint.memory import MemorySaver from langgraph.graph import END, START, StateGraph from typing_extensions import List, TypedDict @@ -17,15 +17,17 @@ logger = logging.getLogger(__name__) class State(TypedDict): question: str context: List[Document] - answer: str + answer: BaseMessage class RetGenLangGraph: - def __init__(self, vector_store: Chroma, chat_model: BaseChatModel, embedding_model: Embeddings): + def __init__( + self, vector_store: Chroma, chat_model: BaseChatModel, embedding_model: Embeddings, system_prompt: str + ): self.vector_store = vector_store self.chat_model = chat_model self.embedding_model = embedding_model - self.prompt = hub.pull("rlm/rag-prompt") + self.system_prompt = system_prompt memory = MemorySaver() graph_builder = StateGraph(State).add_sequence([self._retrieve, self._generate]) @@ -45,11 +47,14 @@ class RetGenLangGraph: self.last_retrieved_docs = self.vector_store.similarity_search(state["question"]) return {"context": self.last_retrieved_docs} - async def _generate(self, state: State) -> AsyncGenerator[Any, Any]: + def _generate(self, state: State) -> dict[str, list]: docs_content = "\n\n".join(doc.page_content for doc in state["context"]) - messages = await self.prompt.ainvoke({"question": state["question"], "context": docs_content}) - async for response in self.chat_model.astream(messages): - yield {"answer": response.content} + system_message_content = self.system_prompt + f"\n\n{docs_content}" + + prompt = [SystemMessage(system_message_content)] + [state["question"]] + + response = self.chat_model.invoke(prompt) + return {"answer": [response]} def get_last_pdf_sources(self) -> dict[str, list[int]]: """