forked from AI_team/Philosophy-RAG-demo
Merge pull request 'Uses a single prompt for both graph's' (#23) from prompt_fix into main
Reviewed-on: AI_team/generic-RAG-demo#23
This commit is contained in:
commit
779d6d4ca6
@ -15,6 +15,13 @@ from parsers.parser import add_pdf_files, add_urls
|
|||||||
logging.basicConfig(level=logging.DEBUG)
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
system_prompt = (
|
||||||
|
"You are an assistant for question-answering tasks. "
|
||||||
|
"If the question is in Dutch, answer in Dutch. If the question is in English, answer in English."
|
||||||
|
"Use the following pieces of retrieved context to answer the question. "
|
||||||
|
"If you don't know the answer, say that you don't know."
|
||||||
|
)
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description="A Sogeti Nederland Generic RAG demo.")
|
parser = argparse.ArgumentParser(description="A Sogeti Nederland Generic RAG demo.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"-c",
|
"-c",
|
||||||
@ -84,27 +91,38 @@ vector_store = Chroma(
|
|||||||
|
|
||||||
if args.use_conditional_graph:
|
if args.use_conditional_graph:
|
||||||
graph = CondRetGenLangGraph(
|
graph = CondRetGenLangGraph(
|
||||||
vector_store,
|
vector_store=vector_store,
|
||||||
chat_model=get_chat_model(args.chat_backend),
|
chat_model=get_chat_model(args.chat_backend),
|
||||||
embedding_model=get_embedding_model(args.emb_backend),
|
embedding_model=get_embedding_model(args.emb_backend),
|
||||||
|
system_prompt=system_prompt,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
graph = RetGenLangGraph(
|
graph = RetGenLangGraph(
|
||||||
vector_store,
|
vector_store=vector_store,
|
||||||
chat_model=get_chat_model(args.chat_backend),
|
chat_model=get_chat_model(args.chat_backend),
|
||||||
embedding_model=get_embedding_model(args.emb_backend),
|
embedding_model=get_embedding_model(args.emb_backend),
|
||||||
|
system_prompt=system_prompt,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@cl.on_message
|
@cl.on_message
|
||||||
async def on_message(message: cl.Message):
|
async def on_message(message: cl.Message):
|
||||||
|
config = {"configurable": {"thread_id": cl.user_session.get("id")}}
|
||||||
|
|
||||||
|
chainlit_response = cl.Message(content="")
|
||||||
|
|
||||||
|
async for response in graph.stream(message.content, config=config):
|
||||||
|
await chainlit_response.stream_token(response)
|
||||||
|
|
||||||
|
if isinstance(graph, RetGenLangGraph):
|
||||||
|
await add_sources(chainlit_response, graph.get_last_pdf_sources(), graph.get_last_web_sources())
|
||||||
if isinstance(graph, CondRetGenLangGraph):
|
if isinstance(graph, CondRetGenLangGraph):
|
||||||
await process_cond_response(message)
|
await add_sources(chainlit_response, graph.last_retrieved_docs, graph.last_retrieved_sources)
|
||||||
elif isinstance(graph, RetGenLangGraph):
|
|
||||||
await process_response(message)
|
await chainlit_response.send()
|
||||||
|
|
||||||
|
|
||||||
async def add_sources(chainlit_response: cl.Message, pdf_sources: dict, web_sources: set | list):
|
async def add_sources(chainlit_response: cl.Message, pdf_sources: dict, web_sources: set | list) -> None:
|
||||||
if len(pdf_sources) > 0:
|
if len(pdf_sources) > 0:
|
||||||
await chainlit_response.stream_token("\n\nThe following PDF source were consulted:\n")
|
await chainlit_response.stream_token("\n\nThe following PDF source were consulted:\n")
|
||||||
for source, page_numbers in pdf_sources.items():
|
for source, page_numbers in pdf_sources.items():
|
||||||
@ -114,40 +132,13 @@ async def add_sources(chainlit_response: cl.Message, pdf_sources: dict, web_sour
|
|||||||
chainlit_response.elements.append(cl.Pdf(name="pdf", display="inline", path=source, page=page_numbers[0]))
|
chainlit_response.elements.append(cl.Pdf(name="pdf", display="inline", path=source, page=page_numbers[0]))
|
||||||
await chainlit_response.update()
|
await chainlit_response.update()
|
||||||
await chainlit_response.stream_token(f"- '{source}' on page(s): {page_numbers}\n")
|
await chainlit_response.stream_token(f"- '{source}' on page(s): {page_numbers}\n")
|
||||||
|
|
||||||
if len(web_sources) > 0:
|
if len(web_sources) > 0:
|
||||||
await chainlit_response.stream_token("\n\nThe following web sources were consulted:\n")
|
await chainlit_response.stream_token("\n\nThe following web sources were consulted:\n")
|
||||||
for source in web_sources:
|
for source in web_sources:
|
||||||
await chainlit_response.stream_token(f"- {source}\n")
|
await chainlit_response.stream_token(f"- {source}\n")
|
||||||
|
|
||||||
|
|
||||||
async def process_response(message):
|
|
||||||
config = {"configurable": {"thread_id": cl.user_session.get("id")}}
|
|
||||||
|
|
||||||
chainlit_response = cl.Message(content="")
|
|
||||||
|
|
||||||
async for response in graph.stream(message.content, config=config):
|
|
||||||
await chainlit_response.stream_token(response)
|
|
||||||
|
|
||||||
pdf_sources = graph.get_last_pdf_sources()
|
|
||||||
web_sources = graph.get_last_web_sources()
|
|
||||||
await add_sources(chainlit_response, pdf_sources, web_sources)
|
|
||||||
|
|
||||||
await chainlit_response.send()
|
|
||||||
|
|
||||||
|
|
||||||
async def process_cond_response(message):
|
|
||||||
config = {"configurable": {"thread_id": cl.user_session.get("id")}}
|
|
||||||
|
|
||||||
chainlit_response = cl.Message(content="")
|
|
||||||
|
|
||||||
for response in graph.stream(message.content, config=config):
|
|
||||||
await chainlit_response.stream_token(response)
|
|
||||||
|
|
||||||
await add_sources(chainlit_response, graph.last_retrieved_docs, graph.last_retrieved_sources)
|
|
||||||
|
|
||||||
await chainlit_response.send()
|
|
||||||
|
|
||||||
|
|
||||||
@cl.set_starters
|
@cl.set_starters
|
||||||
async def set_starters():
|
async def set_starters():
|
||||||
chainlit_starters = os.environ.get("CHAINLIT_STARTERS", None)
|
chainlit_starters = os.environ.get("CHAINLIT_STARTERS", None)
|
||||||
|
|||||||
@ -1,16 +1,16 @@
|
|||||||
import logging
|
|
||||||
from typing import Any, Iterator
|
|
||||||
import re
|
|
||||||
import ast
|
import ast
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Any, AsyncGenerator
|
||||||
|
|
||||||
from langchain_chroma import Chroma
|
from langchain_chroma import Chroma
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
from langchain_core.embeddings import Embeddings
|
from langchain_core.embeddings import Embeddings
|
||||||
from langchain_core.language_models.chat_models import BaseChatModel
|
from langchain_core.language_models.chat_models import BaseChatModel
|
||||||
from langchain_core.messages import BaseMessage, SystemMessage
|
from langchain_core.messages import BaseMessage, SystemMessage
|
||||||
from langchain_core.tools import tool
|
|
||||||
from langchain_core.runnables.config import RunnableConfig
|
from langchain_core.runnables.config import RunnableConfig
|
||||||
|
from langchain_core.tools import tool
|
||||||
from langgraph.checkpoint.memory import MemorySaver
|
from langgraph.checkpoint.memory import MemorySaver
|
||||||
from langgraph.graph import END, MessagesState, StateGraph
|
from langgraph.graph import END, MessagesState, StateGraph
|
||||||
from langgraph.prebuilt import InjectedStore, ToolNode, tools_condition
|
from langgraph.prebuilt import InjectedStore, ToolNode, tools_condition
|
||||||
@ -19,15 +19,12 @@ from typing_extensions import Annotated
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class CondRetGenLangGraph:
|
class CondRetGenLangGraph:
|
||||||
def __init__(self, vector_store: Chroma, chat_model: BaseChatModel, embedding_model: Embeddings):
|
def __init__(
|
||||||
|
self, vector_store: Chroma, chat_model: BaseChatModel, embedding_model: Embeddings, system_prompt: str
|
||||||
|
):
|
||||||
self.chat_model = chat_model
|
self.chat_model = chat_model
|
||||||
self.embedding_model = embedding_model
|
self.embedding_model = embedding_model
|
||||||
self.system_prompt = (
|
self.system_prompt = system_prompt
|
||||||
"You are an assistant for question-answering tasks. "
|
|
||||||
"If the question is in Dutch, answer in Dutch. If the question is in English, answer in English."
|
|
||||||
"Use the following pieces of retrieved context to answer the question. "
|
|
||||||
"If you don't know the answer, say that you don't know."
|
|
||||||
)
|
|
||||||
|
|
||||||
memory = MemorySaver()
|
memory = MemorySaver()
|
||||||
tools = ToolNode([self._retrieve])
|
tools = ToolNode([self._retrieve])
|
||||||
@ -52,8 +49,8 @@ class CondRetGenLangGraph:
|
|||||||
self.last_retrieved_docs = {}
|
self.last_retrieved_docs = {}
|
||||||
self.last_retrieved_sources = set()
|
self.last_retrieved_sources = set()
|
||||||
|
|
||||||
def stream(self, message: str, config: RunnableConfig | None = None) -> Iterator[str]:
|
async def stream(self, message: str, config: RunnableConfig | None = None) -> AsyncGenerator[Any, Any]:
|
||||||
for llm_response, metadata in self.graph.stream(
|
async for llm_response, metadata in self.graph.astream(
|
||||||
{"messages": [{"role": "user", "content": message}]}, stream_mode="messages", config=config
|
{"messages": [{"role": "user", "content": message}]}, stream_mode="messages", config=config
|
||||||
):
|
):
|
||||||
if llm_response.content and metadata["langgraph_node"] == "_generate":
|
if llm_response.content and metadata["langgraph_node"] == "_generate":
|
||||||
|
|||||||
@ -2,11 +2,12 @@ import logging
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, AsyncGenerator
|
from typing import Any, AsyncGenerator
|
||||||
|
|
||||||
from langchain import hub
|
|
||||||
from langchain_chroma import Chroma
|
from langchain_chroma import Chroma
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
from langchain_core.embeddings import Embeddings
|
from langchain_core.embeddings import Embeddings
|
||||||
from langchain_core.language_models.chat_models import BaseChatModel
|
from langchain_core.language_models.chat_models import BaseChatModel
|
||||||
|
from langchain_core.messages import BaseMessage, SystemMessage
|
||||||
|
from langchain_core.runnables.config import RunnableConfig
|
||||||
from langgraph.checkpoint.memory import MemorySaver
|
from langgraph.checkpoint.memory import MemorySaver
|
||||||
from langgraph.graph import END, START, StateGraph
|
from langgraph.graph import END, START, StateGraph
|
||||||
from typing_extensions import List, TypedDict
|
from typing_extensions import List, TypedDict
|
||||||
@ -17,15 +18,17 @@ logger = logging.getLogger(__name__)
|
|||||||
class State(TypedDict):
|
class State(TypedDict):
|
||||||
question: str
|
question: str
|
||||||
context: List[Document]
|
context: List[Document]
|
||||||
answer: str
|
answer: BaseMessage
|
||||||
|
|
||||||
|
|
||||||
class RetGenLangGraph:
|
class RetGenLangGraph:
|
||||||
def __init__(self, vector_store: Chroma, chat_model: BaseChatModel, embedding_model: Embeddings):
|
def __init__(
|
||||||
|
self, vector_store: Chroma, chat_model: BaseChatModel, embedding_model: Embeddings, system_prompt: str
|
||||||
|
):
|
||||||
self.vector_store = vector_store
|
self.vector_store = vector_store
|
||||||
self.chat_model = chat_model
|
self.chat_model = chat_model
|
||||||
self.embedding_model = embedding_model
|
self.embedding_model = embedding_model
|
||||||
self.prompt = hub.pull("rlm/rag-prompt")
|
self.system_prompt = system_prompt
|
||||||
|
|
||||||
memory = MemorySaver()
|
memory = MemorySaver()
|
||||||
graph_builder = StateGraph(State).add_sequence([self._retrieve, self._generate])
|
graph_builder = StateGraph(State).add_sequence([self._retrieve, self._generate])
|
||||||
@ -36,20 +39,23 @@ class RetGenLangGraph:
|
|||||||
self.graph = graph_builder.compile(memory)
|
self.graph = graph_builder.compile(memory)
|
||||||
self.last_retrieved_docs = []
|
self.last_retrieved_docs = []
|
||||||
|
|
||||||
async def stream(self, message: str, config: dict) -> AsyncGenerator[Any, Any]:
|
async def stream(self, message: str, config: RunnableConfig | None = None) -> AsyncGenerator[Any, Any]:
|
||||||
async for response, _ in self.graph.astream({"question": message}, stream_mode="messages", config=config):
|
async for response, _ in self.graph.astream({"question": message}, stream_mode="messages", config=config):
|
||||||
yield response.content
|
yield response.content
|
||||||
|
|
||||||
def _retrieve(self, state: State) -> dict:
|
def _retrieve(self, state: State) -> dict[str, list]:
|
||||||
logger.debug(f"querying VS for: {state["question"]}")
|
logger.debug(f"querying VS for: {state["question"]}")
|
||||||
self.last_retrieved_docs = self.vector_store.similarity_search(state["question"])
|
self.last_retrieved_docs = self.vector_store.similarity_search(state["question"])
|
||||||
return {"context": self.last_retrieved_docs}
|
return {"context": self.last_retrieved_docs}
|
||||||
|
|
||||||
async def _generate(self, state: State) -> AsyncGenerator[Any, Any]:
|
def _generate(self, state: State) -> dict[str, list]:
|
||||||
docs_content = "\n\n".join(doc.page_content for doc in state["context"])
|
docs_content = "\n\n".join(doc.page_content for doc in state["context"])
|
||||||
messages = await self.prompt.ainvoke({"question": state["question"], "context": docs_content})
|
system_message_content = self.system_prompt + f"\n\n{docs_content}"
|
||||||
async for response in self.chat_model.astream(messages):
|
|
||||||
yield {"answer": response.content}
|
prompt = [SystemMessage(system_message_content)] + [state["question"]]
|
||||||
|
|
||||||
|
response = self.chat_model.invoke(prompt)
|
||||||
|
return {"answer": [response]}
|
||||||
|
|
||||||
def get_last_pdf_sources(self) -> dict[str, list[int]]:
|
def get_last_pdf_sources(self) -> dict[str, list[int]]:
|
||||||
"""
|
"""
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user