Merge pull request 'Uses a single prompt for both graph's' (#23) from prompt_fix into main

Reviewed-on: AI_team/generic-RAG-demo#23
This commit is contained in:
nielsonj 2025-04-13 14:28:03 +02:00
commit 779d6d4ca6
3 changed files with 51 additions and 57 deletions

View File

@ -15,6 +15,13 @@ from parsers.parser import add_pdf_files, add_urls
logging.basicConfig(level=logging.DEBUG) logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
system_prompt = (
"You are an assistant for question-answering tasks. "
"If the question is in Dutch, answer in Dutch. If the question is in English, answer in English."
"Use the following pieces of retrieved context to answer the question. "
"If you don't know the answer, say that you don't know."
)
parser = argparse.ArgumentParser(description="A Sogeti Nederland Generic RAG demo.") parser = argparse.ArgumentParser(description="A Sogeti Nederland Generic RAG demo.")
parser.add_argument( parser.add_argument(
"-c", "-c",
@ -84,27 +91,38 @@ vector_store = Chroma(
if args.use_conditional_graph: if args.use_conditional_graph:
graph = CondRetGenLangGraph( graph = CondRetGenLangGraph(
vector_store, vector_store=vector_store,
chat_model=get_chat_model(args.chat_backend), chat_model=get_chat_model(args.chat_backend),
embedding_model=get_embedding_model(args.emb_backend), embedding_model=get_embedding_model(args.emb_backend),
system_prompt=system_prompt,
) )
else: else:
graph = RetGenLangGraph( graph = RetGenLangGraph(
vector_store, vector_store=vector_store,
chat_model=get_chat_model(args.chat_backend), chat_model=get_chat_model(args.chat_backend),
embedding_model=get_embedding_model(args.emb_backend), embedding_model=get_embedding_model(args.emb_backend),
system_prompt=system_prompt,
) )
@cl.on_message @cl.on_message
async def on_message(message: cl.Message): async def on_message(message: cl.Message):
config = {"configurable": {"thread_id": cl.user_session.get("id")}}
chainlit_response = cl.Message(content="")
async for response in graph.stream(message.content, config=config):
await chainlit_response.stream_token(response)
if isinstance(graph, RetGenLangGraph):
await add_sources(chainlit_response, graph.get_last_pdf_sources(), graph.get_last_web_sources())
if isinstance(graph, CondRetGenLangGraph): if isinstance(graph, CondRetGenLangGraph):
await process_cond_response(message) await add_sources(chainlit_response, graph.last_retrieved_docs, graph.last_retrieved_sources)
elif isinstance(graph, RetGenLangGraph):
await process_response(message) await chainlit_response.send()
async def add_sources(chainlit_response: cl.Message, pdf_sources: dict, web_sources: set | list): async def add_sources(chainlit_response: cl.Message, pdf_sources: dict, web_sources: set | list) -> None:
if len(pdf_sources) > 0: if len(pdf_sources) > 0:
await chainlit_response.stream_token("\n\nThe following PDF source were consulted:\n") await chainlit_response.stream_token("\n\nThe following PDF source were consulted:\n")
for source, page_numbers in pdf_sources.items(): for source, page_numbers in pdf_sources.items():
@ -114,40 +132,13 @@ async def add_sources(chainlit_response: cl.Message, pdf_sources: dict, web_sour
chainlit_response.elements.append(cl.Pdf(name="pdf", display="inline", path=source, page=page_numbers[0])) chainlit_response.elements.append(cl.Pdf(name="pdf", display="inline", path=source, page=page_numbers[0]))
await chainlit_response.update() await chainlit_response.update()
await chainlit_response.stream_token(f"- '{source}' on page(s): {page_numbers}\n") await chainlit_response.stream_token(f"- '{source}' on page(s): {page_numbers}\n")
if len(web_sources) > 0: if len(web_sources) > 0:
await chainlit_response.stream_token("\n\nThe following web sources were consulted:\n") await chainlit_response.stream_token("\n\nThe following web sources were consulted:\n")
for source in web_sources: for source in web_sources:
await chainlit_response.stream_token(f"- {source}\n") await chainlit_response.stream_token(f"- {source}\n")
async def process_response(message):
config = {"configurable": {"thread_id": cl.user_session.get("id")}}
chainlit_response = cl.Message(content="")
async for response in graph.stream(message.content, config=config):
await chainlit_response.stream_token(response)
pdf_sources = graph.get_last_pdf_sources()
web_sources = graph.get_last_web_sources()
await add_sources(chainlit_response, pdf_sources, web_sources)
await chainlit_response.send()
async def process_cond_response(message):
config = {"configurable": {"thread_id": cl.user_session.get("id")}}
chainlit_response = cl.Message(content="")
for response in graph.stream(message.content, config=config):
await chainlit_response.stream_token(response)
await add_sources(chainlit_response, graph.last_retrieved_docs, graph.last_retrieved_sources)
await chainlit_response.send()
@cl.set_starters @cl.set_starters
async def set_starters(): async def set_starters():
chainlit_starters = os.environ.get("CHAINLIT_STARTERS", None) chainlit_starters = os.environ.get("CHAINLIT_STARTERS", None)

View File

@ -1,16 +1,16 @@
import logging
from typing import Any, Iterator
import re
import ast import ast
import logging
import re
from pathlib import Path from pathlib import Path
from typing import Any, AsyncGenerator
from langchain_chroma import Chroma from langchain_chroma import Chroma
from langchain_core.documents import Document from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings from langchain_core.embeddings import Embeddings
from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import BaseMessage, SystemMessage from langchain_core.messages import BaseMessage, SystemMessage
from langchain_core.tools import tool
from langchain_core.runnables.config import RunnableConfig from langchain_core.runnables.config import RunnableConfig
from langchain_core.tools import tool
from langgraph.checkpoint.memory import MemorySaver from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import END, MessagesState, StateGraph from langgraph.graph import END, MessagesState, StateGraph
from langgraph.prebuilt import InjectedStore, ToolNode, tools_condition from langgraph.prebuilt import InjectedStore, ToolNode, tools_condition
@ -19,15 +19,12 @@ from typing_extensions import Annotated
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class CondRetGenLangGraph: class CondRetGenLangGraph:
def __init__(self, vector_store: Chroma, chat_model: BaseChatModel, embedding_model: Embeddings): def __init__(
self, vector_store: Chroma, chat_model: BaseChatModel, embedding_model: Embeddings, system_prompt: str
):
self.chat_model = chat_model self.chat_model = chat_model
self.embedding_model = embedding_model self.embedding_model = embedding_model
self.system_prompt = ( self.system_prompt = system_prompt
"You are an assistant for question-answering tasks. "
"If the question is in Dutch, answer in Dutch. If the question is in English, answer in English."
"Use the following pieces of retrieved context to answer the question. "
"If you don't know the answer, say that you don't know."
)
memory = MemorySaver() memory = MemorySaver()
tools = ToolNode([self._retrieve]) tools = ToolNode([self._retrieve])
@ -52,8 +49,8 @@ class CondRetGenLangGraph:
self.last_retrieved_docs = {} self.last_retrieved_docs = {}
self.last_retrieved_sources = set() self.last_retrieved_sources = set()
def stream(self, message: str, config: RunnableConfig | None = None) -> Iterator[str]: async def stream(self, message: str, config: RunnableConfig | None = None) -> AsyncGenerator[Any, Any]:
for llm_response, metadata in self.graph.stream( async for llm_response, metadata in self.graph.astream(
{"messages": [{"role": "user", "content": message}]}, stream_mode="messages", config=config {"messages": [{"role": "user", "content": message}]}, stream_mode="messages", config=config
): ):
if llm_response.content and metadata["langgraph_node"] == "_generate": if llm_response.content and metadata["langgraph_node"] == "_generate":

View File

@ -2,11 +2,12 @@ import logging
from pathlib import Path from pathlib import Path
from typing import Any, AsyncGenerator from typing import Any, AsyncGenerator
from langchain import hub
from langchain_chroma import Chroma from langchain_chroma import Chroma
from langchain_core.documents import Document from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings from langchain_core.embeddings import Embeddings
from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import BaseMessage, SystemMessage
from langchain_core.runnables.config import RunnableConfig
from langgraph.checkpoint.memory import MemorySaver from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import END, START, StateGraph from langgraph.graph import END, START, StateGraph
from typing_extensions import List, TypedDict from typing_extensions import List, TypedDict
@ -17,15 +18,17 @@ logger = logging.getLogger(__name__)
class State(TypedDict): class State(TypedDict):
question: str question: str
context: List[Document] context: List[Document]
answer: str answer: BaseMessage
class RetGenLangGraph: class RetGenLangGraph:
def __init__(self, vector_store: Chroma, chat_model: BaseChatModel, embedding_model: Embeddings): def __init__(
self, vector_store: Chroma, chat_model: BaseChatModel, embedding_model: Embeddings, system_prompt: str
):
self.vector_store = vector_store self.vector_store = vector_store
self.chat_model = chat_model self.chat_model = chat_model
self.embedding_model = embedding_model self.embedding_model = embedding_model
self.prompt = hub.pull("rlm/rag-prompt") self.system_prompt = system_prompt
memory = MemorySaver() memory = MemorySaver()
graph_builder = StateGraph(State).add_sequence([self._retrieve, self._generate]) graph_builder = StateGraph(State).add_sequence([self._retrieve, self._generate])
@ -36,20 +39,23 @@ class RetGenLangGraph:
self.graph = graph_builder.compile(memory) self.graph = graph_builder.compile(memory)
self.last_retrieved_docs = [] self.last_retrieved_docs = []
async def stream(self, message: str, config: dict) -> AsyncGenerator[Any, Any]: async def stream(self, message: str, config: RunnableConfig | None = None) -> AsyncGenerator[Any, Any]:
async for response, _ in self.graph.astream({"question": message}, stream_mode="messages", config=config): async for response, _ in self.graph.astream({"question": message}, stream_mode="messages", config=config):
yield response.content yield response.content
def _retrieve(self, state: State) -> dict: def _retrieve(self, state: State) -> dict[str, list]:
logger.debug(f"querying VS for: {state["question"]}") logger.debug(f"querying VS for: {state["question"]}")
self.last_retrieved_docs = self.vector_store.similarity_search(state["question"]) self.last_retrieved_docs = self.vector_store.similarity_search(state["question"])
return {"context": self.last_retrieved_docs} return {"context": self.last_retrieved_docs}
async def _generate(self, state: State) -> AsyncGenerator[Any, Any]: def _generate(self, state: State) -> dict[str, list]:
docs_content = "\n\n".join(doc.page_content for doc in state["context"]) docs_content = "\n\n".join(doc.page_content for doc in state["context"])
messages = await self.prompt.ainvoke({"question": state["question"], "context": docs_content}) system_message_content = self.system_prompt + f"\n\n{docs_content}"
async for response in self.chat_model.astream(messages):
yield {"answer": response.content} prompt = [SystemMessage(system_message_content)] + [state["question"]]
response = self.chat_model.invoke(prompt)
return {"answer": [response]}
def get_last_pdf_sources(self) -> dict[str, list[int]]: def get_last_pdf_sources(self) -> dict[str, list[int]]:
""" """