forked from AI_team/Philosophy-RAG-demo
Use a single system prompt everywhere
This commit is contained in:
parent
b1e8f19f00
commit
92b57224fa
@ -15,6 +15,13 @@ from parsers.parser import add_pdf_files, add_urls
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
system_prompt = (
|
||||
"You are an assistant for question-answering tasks. "
|
||||
"If the question is in Dutch, answer in Dutch. If the question is in English, answer in English."
|
||||
"Use the following pieces of retrieved context to answer the question. "
|
||||
"If you don't know the answer, say that you don't know."
|
||||
)
|
||||
|
||||
parser = argparse.ArgumentParser(description="A Sogeti Nederland Generic RAG demo.")
|
||||
parser.add_argument(
|
||||
"-c",
|
||||
@ -87,12 +94,14 @@ if args.use_conditional_graph:
|
||||
vector_store,
|
||||
chat_model=get_chat_model(args.chat_backend),
|
||||
embedding_model=get_embedding_model(args.emb_backend),
|
||||
system_prompt=system_prompt,
|
||||
)
|
||||
else:
|
||||
graph = RetGenLangGraph(
|
||||
vector_store,
|
||||
chat_model=get_chat_model(args.chat_backend),
|
||||
embedding_model=get_embedding_model(args.emb_backend),
|
||||
system_prompt=system_prompt,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -19,15 +19,12 @@ from typing_extensions import Annotated
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class CondRetGenLangGraph:
|
||||
def __init__(self, vector_store: Chroma, chat_model: BaseChatModel, embedding_model: Embeddings):
|
||||
def __init__(
|
||||
self, vector_store: Chroma, chat_model: BaseChatModel, embedding_model: Embeddings, system_prompt: str
|
||||
):
|
||||
self.chat_model = chat_model
|
||||
self.embedding_model = embedding_model
|
||||
self.system_prompt = (
|
||||
"You are an assistant for question-answering tasks. "
|
||||
"If the question is in Dutch, answer in Dutch. If the question is in English, answer in English."
|
||||
"Use the following pieces of retrieved context to answer the question. "
|
||||
"If you don't know the answer, say that you don't know."
|
||||
)
|
||||
self.system_prompt = system_prompt
|
||||
|
||||
memory = MemorySaver()
|
||||
tools = ToolNode([self._retrieve])
|
||||
|
||||
@ -2,11 +2,11 @@ import logging
|
||||
from pathlib import Path
|
||||
from typing import Any, AsyncGenerator
|
||||
|
||||
from langchain import hub
|
||||
from langchain_chroma import Chroma
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_core.messages import BaseMessage, SystemMessage
|
||||
from langgraph.checkpoint.memory import MemorySaver
|
||||
from langgraph.graph import END, START, StateGraph
|
||||
from typing_extensions import List, TypedDict
|
||||
@ -17,15 +17,17 @@ logger = logging.getLogger(__name__)
|
||||
class State(TypedDict):
|
||||
question: str
|
||||
context: List[Document]
|
||||
answer: str
|
||||
answer: BaseMessage
|
||||
|
||||
|
||||
class RetGenLangGraph:
|
||||
def __init__(self, vector_store: Chroma, chat_model: BaseChatModel, embedding_model: Embeddings):
|
||||
def __init__(
|
||||
self, vector_store: Chroma, chat_model: BaseChatModel, embedding_model: Embeddings, system_prompt: str
|
||||
):
|
||||
self.vector_store = vector_store
|
||||
self.chat_model = chat_model
|
||||
self.embedding_model = embedding_model
|
||||
self.prompt = hub.pull("rlm/rag-prompt")
|
||||
self.system_prompt = system_prompt
|
||||
|
||||
memory = MemorySaver()
|
||||
graph_builder = StateGraph(State).add_sequence([self._retrieve, self._generate])
|
||||
@ -45,11 +47,14 @@ class RetGenLangGraph:
|
||||
self.last_retrieved_docs = self.vector_store.similarity_search(state["question"])
|
||||
return {"context": self.last_retrieved_docs}
|
||||
|
||||
async def _generate(self, state: State) -> AsyncGenerator[Any, Any]:
|
||||
def _generate(self, state: State) -> dict[str, list]:
|
||||
docs_content = "\n\n".join(doc.page_content for doc in state["context"])
|
||||
messages = await self.prompt.ainvoke({"question": state["question"], "context": docs_content})
|
||||
async for response in self.chat_model.astream(messages):
|
||||
yield {"answer": response.content}
|
||||
system_message_content = self.system_prompt + f"\n\n{docs_content}"
|
||||
|
||||
prompt = [SystemMessage(system_message_content)] + [state["question"]]
|
||||
|
||||
response = self.chat_model.invoke(prompt)
|
||||
return {"answer": [response]}
|
||||
|
||||
def get_last_pdf_sources(self) -> dict[str, list[int]]:
|
||||
"""
|
||||
|
||||
Loading…
Reference in New Issue
Block a user