forked from AI_team/Philosophy-RAG-demo
Merge pull request 'CondRetGenLangGraph() returns web and pdf sources by levering the stream defintion' (#19) from cond-sources into main
Reviewed-on: AI_team/generic-RAG-demo#19 Reviewed-by: nielsonj <nielson.janne@sogeti.com>
This commit is contained in:
commit
ab78fdc0c7
@ -93,15 +93,7 @@ async def on_message(message: cl.Message):
|
|||||||
await process_response(message)
|
await process_response(message)
|
||||||
|
|
||||||
|
|
||||||
async def process_response(message):
|
async def add_sources(chainlit_response: cl.Message, pdf_sources: dict, web_sources: set | list):
|
||||||
config = {"configurable": {"thread_id": cl.user_session.get("id")}}
|
|
||||||
|
|
||||||
chainlit_response = cl.Message(content="")
|
|
||||||
|
|
||||||
async for response in graph.stream(message.content, config=config):
|
|
||||||
await chainlit_response.stream_token(response)
|
|
||||||
|
|
||||||
pdf_sources = graph.get_last_pdf_sources()
|
|
||||||
if len(pdf_sources) > 0:
|
if len(pdf_sources) > 0:
|
||||||
await chainlit_response.stream_token("\nThe following PDF source were consulted:\n")
|
await chainlit_response.stream_token("\nThe following PDF source were consulted:\n")
|
||||||
for source, page_numbers in pdf_sources.items():
|
for source, page_numbers in pdf_sources.items():
|
||||||
@ -111,13 +103,24 @@ async def process_response(message):
|
|||||||
chainlit_response.elements.append(cl.Pdf(name="pdf", display="inline", path=source, page=page_numbers[0]))
|
chainlit_response.elements.append(cl.Pdf(name="pdf", display="inline", path=source, page=page_numbers[0]))
|
||||||
await chainlit_response.update()
|
await chainlit_response.update()
|
||||||
await chainlit_response.stream_token(f"- '{source}' on page(s): {page_numbers}\n")
|
await chainlit_response.stream_token(f"- '{source}' on page(s): {page_numbers}\n")
|
||||||
|
|
||||||
web_sources = graph.get_last_web_sources()
|
|
||||||
if len(web_sources) > 0:
|
if len(web_sources) > 0:
|
||||||
await chainlit_response.stream_token("\nThe following web sources were consulted:\n")
|
await chainlit_response.stream_token("\nThe following web sources were consulted:\n")
|
||||||
for source in web_sources:
|
for source in web_sources:
|
||||||
await chainlit_response.stream_token(f"- {source}\n")
|
await chainlit_response.stream_token(f"- {source}\n")
|
||||||
|
|
||||||
|
|
||||||
|
async def process_response(message):
|
||||||
|
config = {"configurable": {"thread_id": cl.user_session.get("id")}}
|
||||||
|
|
||||||
|
chainlit_response = cl.Message(content="")
|
||||||
|
|
||||||
|
async for response in graph.stream(message.content, config=config):
|
||||||
|
await chainlit_response.stream_token(response)
|
||||||
|
|
||||||
|
pdf_sources = graph.get_last_pdf_sources()
|
||||||
|
web_sources = graph.get_last_web_sources()
|
||||||
|
await add_sources(chainlit_response, pdf_sources, web_sources)
|
||||||
|
|
||||||
await chainlit_response.send()
|
await chainlit_response.send()
|
||||||
|
|
||||||
|
|
||||||
@ -129,6 +132,8 @@ async def process_cond_response(message):
|
|||||||
for response in graph.stream(message.content, config=config):
|
for response in graph.stream(message.content, config=config):
|
||||||
await chainlit_response.stream_token(response)
|
await chainlit_response.stream_token(response)
|
||||||
|
|
||||||
|
await add_sources(chainlit_response, graph.last_retrieved_docs, graph.last_retrieved_sources)
|
||||||
|
|
||||||
await chainlit_response.send()
|
await chainlit_response.send()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,5 +1,8 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Any, Iterator, List
|
from typing import Any, Iterator
|
||||||
|
import re
|
||||||
|
import ast
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
from langchain_chroma import Chroma
|
from langchain_chroma import Chroma
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
@ -7,6 +10,7 @@ from langchain_core.embeddings import Embeddings
|
|||||||
from langchain_core.language_models.chat_models import BaseChatModel
|
from langchain_core.language_models.chat_models import BaseChatModel
|
||||||
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
|
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
|
||||||
from langchain_core.tools import tool
|
from langchain_core.tools import tool
|
||||||
|
from langchain_core.runnables.config import RunnableConfig
|
||||||
from langgraph.checkpoint.memory import MemorySaver
|
from langgraph.checkpoint.memory import MemorySaver
|
||||||
from langgraph.graph import END, MessagesState, StateGraph
|
from langgraph.graph import END, MessagesState, StateGraph
|
||||||
from langgraph.prebuilt import InjectedStore, ToolNode, tools_condition
|
from langgraph.prebuilt import InjectedStore, ToolNode, tools_condition
|
||||||
@ -39,28 +43,43 @@ class CondRetGenLangGraph:
|
|||||||
|
|
||||||
self.graph = graph_builder.compile(checkpointer=memory, store=vector_store)
|
self.graph = graph_builder.compile(checkpointer=memory, store=vector_store)
|
||||||
|
|
||||||
def stream(self, message: str, config=None) -> Iterator[str]:
|
self.file_path_pattern = r"'file_path'\s*:\s*'((?:[^'\\]|\\.)*)'"
|
||||||
|
self.source_pattern = r"'source'\s*:\s*'((?:[^'\\]|\\.)*)'"
|
||||||
|
self.page_pattern = r"'page'\s*:\s*(\d+)"
|
||||||
|
self.pattern = r"Source:\s*(\{.*?\})"
|
||||||
|
|
||||||
|
self.last_retrieved_docs = {}
|
||||||
|
self.last_retrieved_sources = set()
|
||||||
|
|
||||||
|
def stream(self, message: str, config: RunnableConfig | None = None) -> Iterator[str]:
|
||||||
for llm_response, metadata in self.graph.stream(
|
for llm_response, metadata in self.graph.stream(
|
||||||
{"messages": [{"role": "user", "content": message}]}, stream_mode="messages", config=config
|
{"messages": [{"role": "user", "content": message}]}, stream_mode="messages", config=config
|
||||||
):
|
):
|
||||||
if (
|
if llm_response.content and metadata["langgraph_node"] == "_generate":
|
||||||
llm_response.content
|
|
||||||
and not isinstance(llm_response, HumanMessage)
|
|
||||||
and metadata["langgraph_node"] == "_generate"
|
|
||||||
):
|
|
||||||
yield llm_response.content
|
yield llm_response.content
|
||||||
|
elif llm_response.name == "_retrieve":
|
||||||
# TODO: read souces used in AIMessages and set internal value sources used in last received stream.
|
dictionary_strings = re.findall(
|
||||||
|
self.pattern, llm_response.content, re.DOTALL
|
||||||
|
) # Use re.DOTALL if dicts might span newlines
|
||||||
|
for dict_str in dictionary_strings:
|
||||||
|
parsed_dict = ast.literal_eval(dict_str)
|
||||||
|
if "filetype" in parsed_dict and parsed_dict["filetype"] == "web":
|
||||||
|
self.last_retrieved_sources.add(parsed_dict["source"])
|
||||||
|
elif Path(parsed_dict["source"]).suffix == ".pdf":
|
||||||
|
if parsed_dict["source"] in self.last_retrieved_docs:
|
||||||
|
self.last_retrieved_docs[parsed_dict["source"]].add(parsed_dict["page"])
|
||||||
|
else:
|
||||||
|
self.last_retrieved_docs[parsed_dict["source"]] = {parsed_dict["page"]}
|
||||||
|
|
||||||
@tool(response_format="content_and_artifact")
|
@tool(response_format="content_and_artifact")
|
||||||
def _retrieve(
|
def _retrieve(
|
||||||
query: str, full_user_content: str, vector_store: Annotated[Any, InjectedStore()]
|
query: str, full_user_content: str, vector_store: Annotated[Any, InjectedStore()]
|
||||||
) -> tuple[str, List[Document]]:
|
) -> tuple[str, list[Document]]:
|
||||||
"""
|
"""
|
||||||
Retrieve information related to a query and user content.
|
Retrieve information related to a query and user content.
|
||||||
"""
|
"""
|
||||||
# This method is used as a tool in the graph.
|
# This method is used as a tool in the graph.
|
||||||
# It's doc-string is used for the pydentic model, please consider doc-string text carefully.
|
# It's doc-string is used for the pydantic model, please consider doc-string text carefully.
|
||||||
# Furthermore, it can not and should not have the `self` parameter.
|
# Furthermore, it can not and should not have the `self` parameter.
|
||||||
# If you want to pass on state, please refer to:
|
# If you want to pass on state, please refer to:
|
||||||
# https://python.langchain.com/docs/concepts/tools/#special-type-annotations
|
# https://python.langchain.com/docs/concepts/tools/#special-type-annotations
|
||||||
@ -76,6 +95,10 @@ class CondRetGenLangGraph:
|
|||||||
|
|
||||||
def _query_or_respond(self, state: MessagesState) -> dict[str, BaseMessage]:
|
def _query_or_respond(self, state: MessagesState) -> dict[str, BaseMessage]:
|
||||||
"""Generate tool call for retrieval or respond."""
|
"""Generate tool call for retrieval or respond."""
|
||||||
|
# Reset last retrieved docs
|
||||||
|
self.last_retrieved_docs = {}
|
||||||
|
self.last_retrieved_sources = set()
|
||||||
|
|
||||||
llm_with_tools = self.chat_model.bind_tools([self._retrieve])
|
llm_with_tools = self.chat_model.bind_tools([self._retrieve])
|
||||||
response = llm_with_tools.invoke(state["messages"])
|
response = llm_with_tools.invoke(state["messages"])
|
||||||
return {"messages": [response]}
|
return {"messages": [response]}
|
||||||
|
|||||||
@ -61,12 +61,11 @@ class RetGenLangGraph:
|
|||||||
return pdf_sources
|
return pdf_sources
|
||||||
|
|
||||||
for doc in self.last_retrieved_docs:
|
for doc in self.last_retrieved_docs:
|
||||||
try:
|
source_candidate = doc.metadata["source"]
|
||||||
Path(doc.metadata["source"]).suffix == ".pdf"
|
if "source" in doc.metadata and Path(doc.metadata["source"]).suffix.lower() == ".pdf":
|
||||||
except KeyError:
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
source = doc.metadata["source"]
|
source = doc.metadata["source"]
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
|
||||||
if source not in pdf_sources:
|
if source not in pdf_sources:
|
||||||
pdf_sources[source] = set()
|
pdf_sources[source] = set()
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user