Use a single system prompt everywhere

This commit is contained in:
Nielson Janné 2025-04-11 17:08:26 +02:00
parent b1e8f19f00
commit 92b57224fa
3 changed files with 26 additions and 15 deletions

View File

@ -15,6 +15,13 @@ from parsers.parser import add_pdf_files, add_urls
logging.basicConfig(level=logging.DEBUG) logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
system_prompt = (
"You are an assistant for question-answering tasks. "
"If the question is in Dutch, answer in Dutch. If the question is in English, answer in English."
"Use the following pieces of retrieved context to answer the question. "
"If you don't know the answer, say that you don't know."
)
parser = argparse.ArgumentParser(description="A Sogeti Nederland Generic RAG demo.") parser = argparse.ArgumentParser(description="A Sogeti Nederland Generic RAG demo.")
parser.add_argument( parser.add_argument(
"-c", "-c",
@ -87,12 +94,14 @@ if args.use_conditional_graph:
vector_store, vector_store,
chat_model=get_chat_model(args.chat_backend), chat_model=get_chat_model(args.chat_backend),
embedding_model=get_embedding_model(args.emb_backend), embedding_model=get_embedding_model(args.emb_backend),
system_prompt=system_prompt,
) )
else: else:
graph = RetGenLangGraph( graph = RetGenLangGraph(
vector_store, vector_store,
chat_model=get_chat_model(args.chat_backend), chat_model=get_chat_model(args.chat_backend),
embedding_model=get_embedding_model(args.emb_backend), embedding_model=get_embedding_model(args.emb_backend),
system_prompt=system_prompt,
) )

View File

@ -19,15 +19,12 @@ from typing_extensions import Annotated
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class CondRetGenLangGraph: class CondRetGenLangGraph:
def __init__(self, vector_store: Chroma, chat_model: BaseChatModel, embedding_model: Embeddings): def __init__(
self, vector_store: Chroma, chat_model: BaseChatModel, embedding_model: Embeddings, system_prompt: str
):
self.chat_model = chat_model self.chat_model = chat_model
self.embedding_model = embedding_model self.embedding_model = embedding_model
self.system_prompt = ( self.system_prompt = system_prompt
"You are an assistant for question-answering tasks. "
"If the question is in Dutch, answer in Dutch. If the question is in English, answer in English."
"Use the following pieces of retrieved context to answer the question. "
"If you don't know the answer, say that you don't know."
)
memory = MemorySaver() memory = MemorySaver()
tools = ToolNode([self._retrieve]) tools = ToolNode([self._retrieve])

View File

@ -2,11 +2,11 @@ import logging
from pathlib import Path from pathlib import Path
from typing import Any, AsyncGenerator from typing import Any, AsyncGenerator
from langchain import hub
from langchain_chroma import Chroma from langchain_chroma import Chroma
from langchain_core.documents import Document from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings from langchain_core.embeddings import Embeddings
from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import BaseMessage, SystemMessage
from langgraph.checkpoint.memory import MemorySaver from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import END, START, StateGraph from langgraph.graph import END, START, StateGraph
from typing_extensions import List, TypedDict from typing_extensions import List, TypedDict
@ -17,15 +17,17 @@ logger = logging.getLogger(__name__)
class State(TypedDict): class State(TypedDict):
question: str question: str
context: List[Document] context: List[Document]
answer: str answer: BaseMessage
class RetGenLangGraph: class RetGenLangGraph:
def __init__(self, vector_store: Chroma, chat_model: BaseChatModel, embedding_model: Embeddings): def __init__(
self, vector_store: Chroma, chat_model: BaseChatModel, embedding_model: Embeddings, system_prompt: str
):
self.vector_store = vector_store self.vector_store = vector_store
self.chat_model = chat_model self.chat_model = chat_model
self.embedding_model = embedding_model self.embedding_model = embedding_model
self.prompt = hub.pull("rlm/rag-prompt") self.system_prompt = system_prompt
memory = MemorySaver() memory = MemorySaver()
graph_builder = StateGraph(State).add_sequence([self._retrieve, self._generate]) graph_builder = StateGraph(State).add_sequence([self._retrieve, self._generate])
@ -45,11 +47,14 @@ class RetGenLangGraph:
self.last_retrieved_docs = self.vector_store.similarity_search(state["question"]) self.last_retrieved_docs = self.vector_store.similarity_search(state["question"])
return {"context": self.last_retrieved_docs} return {"context": self.last_retrieved_docs}
async def _generate(self, state: State) -> AsyncGenerator[Any, Any]: def _generate(self, state: State) -> dict[str, list]:
docs_content = "\n\n".join(doc.page_content for doc in state["context"]) docs_content = "\n\n".join(doc.page_content for doc in state["context"])
messages = await self.prompt.ainvoke({"question": state["question"], "context": docs_content}) system_message_content = self.system_prompt + f"\n\n{docs_content}"
async for response in self.chat_model.astream(messages):
yield {"answer": response.content} prompt = [SystemMessage(system_message_content)] + [state["question"]]
response = self.chat_model.invoke(prompt)
return {"answer": [response]}
def get_last_pdf_sources(self) -> dict[str, list[int]]: def get_last_pdf_sources(self) -> dict[str, list[int]]:
""" """