diff --git a/generic_rag/app.py b/generic_rag/app.py index 73bb3f5..999b910 100644 --- a/generic_rag/app.py +++ b/generic_rag/app.py @@ -15,6 +15,13 @@ from parsers.parser import add_pdf_files, add_urls logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger(__name__) +system_prompt = ( + "You are an assistant for question-answering tasks. " + "If the question is in Dutch, answer in Dutch. If the question is in English, answer in English." + "Use the following pieces of retrieved context to answer the question. " + "If you don't know the answer, say that you don't know." +) + parser = argparse.ArgumentParser(description="A Sogeti Nederland Generic RAG demo.") parser.add_argument( "-c", @@ -84,27 +91,38 @@ vector_store = Chroma( if args.use_conditional_graph: graph = CondRetGenLangGraph( - vector_store, + vector_store=vector_store, chat_model=get_chat_model(args.chat_backend), embedding_model=get_embedding_model(args.emb_backend), + system_prompt=system_prompt, ) else: graph = RetGenLangGraph( - vector_store, + vector_store=vector_store, chat_model=get_chat_model(args.chat_backend), embedding_model=get_embedding_model(args.emb_backend), + system_prompt=system_prompt, ) @cl.on_message async def on_message(message: cl.Message): + config = {"configurable": {"thread_id": cl.user_session.get("id")}} + + chainlit_response = cl.Message(content="") + + async for response in graph.stream(message.content, config=config): + await chainlit_response.stream_token(response) + + if isinstance(graph, RetGenLangGraph): + await add_sources(chainlit_response, graph.get_last_pdf_sources(), graph.get_last_web_sources()) if isinstance(graph, CondRetGenLangGraph): - await process_cond_response(message) - elif isinstance(graph, RetGenLangGraph): - await process_response(message) + await add_sources(chainlit_response, graph.last_retrieved_docs, graph.last_retrieved_sources) + + await chainlit_response.send() -async def add_sources(chainlit_response: cl.Message, pdf_sources: dict, web_sources: set | list): +async def add_sources(chainlit_response: cl.Message, pdf_sources: dict, web_sources: set | list) -> None: if len(pdf_sources) > 0: await chainlit_response.stream_token("\n\nThe following PDF source were consulted:\n") for source, page_numbers in pdf_sources.items(): @@ -114,40 +132,13 @@ async def add_sources(chainlit_response: cl.Message, pdf_sources: dict, web_sour chainlit_response.elements.append(cl.Pdf(name="pdf", display="inline", path=source, page=page_numbers[0])) await chainlit_response.update() await chainlit_response.stream_token(f"- '{source}' on page(s): {page_numbers}\n") + if len(web_sources) > 0: await chainlit_response.stream_token("\n\nThe following web sources were consulted:\n") for source in web_sources: await chainlit_response.stream_token(f"- {source}\n") -async def process_response(message): - config = {"configurable": {"thread_id": cl.user_session.get("id")}} - - chainlit_response = cl.Message(content="") - - async for response in graph.stream(message.content, config=config): - await chainlit_response.stream_token(response) - - pdf_sources = graph.get_last_pdf_sources() - web_sources = graph.get_last_web_sources() - await add_sources(chainlit_response, pdf_sources, web_sources) - - await chainlit_response.send() - - -async def process_cond_response(message): - config = {"configurable": {"thread_id": cl.user_session.get("id")}} - - chainlit_response = cl.Message(content="") - - for response in graph.stream(message.content, config=config): - await chainlit_response.stream_token(response) - - await add_sources(chainlit_response, graph.last_retrieved_docs, graph.last_retrieved_sources) - - await chainlit_response.send() - - @cl.set_starters async def set_starters(): chainlit_starters = os.environ.get("CHAINLIT_STARTERS", None) diff --git a/generic_rag/graphs/cond_ret_gen.py b/generic_rag/graphs/cond_ret_gen.py index 8fb554c..b2057be 100644 --- a/generic_rag/graphs/cond_ret_gen.py +++ b/generic_rag/graphs/cond_ret_gen.py @@ -1,16 +1,16 @@ -import logging -from typing import Any, Iterator -import re import ast +import logging +import re from pathlib import Path +from typing import Any, AsyncGenerator from langchain_chroma import Chroma from langchain_core.documents import Document from langchain_core.embeddings import Embeddings from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.messages import BaseMessage, SystemMessage -from langchain_core.tools import tool from langchain_core.runnables.config import RunnableConfig +from langchain_core.tools import tool from langgraph.checkpoint.memory import MemorySaver from langgraph.graph import END, MessagesState, StateGraph from langgraph.prebuilt import InjectedStore, ToolNode, tools_condition @@ -19,15 +19,12 @@ from typing_extensions import Annotated logger = logging.getLogger(__name__) class CondRetGenLangGraph: - def __init__(self, vector_store: Chroma, chat_model: BaseChatModel, embedding_model: Embeddings): + def __init__( + self, vector_store: Chroma, chat_model: BaseChatModel, embedding_model: Embeddings, system_prompt: str + ): self.chat_model = chat_model self.embedding_model = embedding_model - self.system_prompt = ( - "You are an assistant for question-answering tasks. " - "If the question is in Dutch, answer in Dutch. If the question is in English, answer in English." - "Use the following pieces of retrieved context to answer the question. " - "If you don't know the answer, say that you don't know." - ) + self.system_prompt = system_prompt memory = MemorySaver() tools = ToolNode([self._retrieve]) @@ -52,8 +49,8 @@ class CondRetGenLangGraph: self.last_retrieved_docs = {} self.last_retrieved_sources = set() - def stream(self, message: str, config: RunnableConfig | None = None) -> Iterator[str]: - for llm_response, metadata in self.graph.stream( + async def stream(self, message: str, config: RunnableConfig | None = None) -> AsyncGenerator[Any, Any]: + async for llm_response, metadata in self.graph.astream( {"messages": [{"role": "user", "content": message}]}, stream_mode="messages", config=config ): if llm_response.content and metadata["langgraph_node"] == "_generate": diff --git a/generic_rag/graphs/ret_gen.py b/generic_rag/graphs/ret_gen.py index c4262c3..b293e03 100644 --- a/generic_rag/graphs/ret_gen.py +++ b/generic_rag/graphs/ret_gen.py @@ -2,11 +2,12 @@ import logging from pathlib import Path from typing import Any, AsyncGenerator -from langchain import hub from langchain_chroma import Chroma from langchain_core.documents import Document from langchain_core.embeddings import Embeddings from langchain_core.language_models.chat_models import BaseChatModel +from langchain_core.messages import BaseMessage, SystemMessage +from langchain_core.runnables.config import RunnableConfig from langgraph.checkpoint.memory import MemorySaver from langgraph.graph import END, START, StateGraph from typing_extensions import List, TypedDict @@ -17,15 +18,17 @@ logger = logging.getLogger(__name__) class State(TypedDict): question: str context: List[Document] - answer: str + answer: BaseMessage class RetGenLangGraph: - def __init__(self, vector_store: Chroma, chat_model: BaseChatModel, embedding_model: Embeddings): + def __init__( + self, vector_store: Chroma, chat_model: BaseChatModel, embedding_model: Embeddings, system_prompt: str + ): self.vector_store = vector_store self.chat_model = chat_model self.embedding_model = embedding_model - self.prompt = hub.pull("rlm/rag-prompt") + self.system_prompt = system_prompt memory = MemorySaver() graph_builder = StateGraph(State).add_sequence([self._retrieve, self._generate]) @@ -36,20 +39,23 @@ class RetGenLangGraph: self.graph = graph_builder.compile(memory) self.last_retrieved_docs = [] - async def stream(self, message: str, config: dict) -> AsyncGenerator[Any, Any]: + async def stream(self, message: str, config: RunnableConfig | None = None) -> AsyncGenerator[Any, Any]: async for response, _ in self.graph.astream({"question": message}, stream_mode="messages", config=config): yield response.content - def _retrieve(self, state: State) -> dict: + def _retrieve(self, state: State) -> dict[str, list]: logger.debug(f"querying VS for: {state["question"]}") self.last_retrieved_docs = self.vector_store.similarity_search(state["question"]) return {"context": self.last_retrieved_docs} - async def _generate(self, state: State) -> AsyncGenerator[Any, Any]: + def _generate(self, state: State) -> dict[str, list]: docs_content = "\n\n".join(doc.page_content for doc in state["context"]) - messages = await self.prompt.ainvoke({"question": state["question"], "context": docs_content}) - async for response in self.chat_model.astream(messages): - yield {"answer": response.content} + system_message_content = self.system_prompt + f"\n\n{docs_content}" + + prompt = [SystemMessage(system_message_content)] + [state["question"]] + + response = self.chat_model.invoke(prompt) + return {"answer": [response]} def get_last_pdf_sources(self) -> dict[str, list[int]]: """