forked from AI_team/Philosophy-RAG-demo
Factor out some duplication
This commit is contained in:
parent
ec4edf9c92
commit
7277cd6ff9
@ -107,13 +107,22 @@ else:
|
|||||||
|
|
||||||
@cl.on_message
|
@cl.on_message
|
||||||
async def on_message(message: cl.Message):
|
async def on_message(message: cl.Message):
|
||||||
|
config = {"configurable": {"thread_id": cl.user_session.get("id")}}
|
||||||
|
|
||||||
|
chainlit_response = cl.Message(content="")
|
||||||
|
|
||||||
|
async for response in graph.stream(message.content, config=config):
|
||||||
|
await chainlit_response.stream_token(response)
|
||||||
|
|
||||||
|
if isinstance(graph, RetGenLangGraph):
|
||||||
|
await add_sources(chainlit_response, graph.get_last_pdf_sources(), graph.get_last_web_sources())
|
||||||
if isinstance(graph, CondRetGenLangGraph):
|
if isinstance(graph, CondRetGenLangGraph):
|
||||||
await process_cond_response(message)
|
await add_sources(chainlit_response, graph.last_retrieved_docs, graph.last_retrieved_sources)
|
||||||
elif isinstance(graph, RetGenLangGraph):
|
|
||||||
await process_response(message)
|
await chainlit_response.send()
|
||||||
|
|
||||||
|
|
||||||
async def add_sources(chainlit_response: cl.Message, pdf_sources: dict, web_sources: set | list):
|
async def add_sources(chainlit_response: cl.Message, pdf_sources: dict, web_sources: set | list) -> None:
|
||||||
if len(pdf_sources) > 0:
|
if len(pdf_sources) > 0:
|
||||||
await chainlit_response.stream_token("\n\nThe following PDF source were consulted:\n")
|
await chainlit_response.stream_token("\n\nThe following PDF source were consulted:\n")
|
||||||
for source, page_numbers in pdf_sources.items():
|
for source, page_numbers in pdf_sources.items():
|
||||||
@ -123,40 +132,13 @@ async def add_sources(chainlit_response: cl.Message, pdf_sources: dict, web_sour
|
|||||||
chainlit_response.elements.append(cl.Pdf(name="pdf", display="inline", path=source, page=page_numbers[0]))
|
chainlit_response.elements.append(cl.Pdf(name="pdf", display="inline", path=source, page=page_numbers[0]))
|
||||||
await chainlit_response.update()
|
await chainlit_response.update()
|
||||||
await chainlit_response.stream_token(f"- '{source}' on page(s): {page_numbers}\n")
|
await chainlit_response.stream_token(f"- '{source}' on page(s): {page_numbers}\n")
|
||||||
|
|
||||||
if len(web_sources) > 0:
|
if len(web_sources) > 0:
|
||||||
await chainlit_response.stream_token("\n\nThe following web sources were consulted:\n")
|
await chainlit_response.stream_token("\n\nThe following web sources were consulted:\n")
|
||||||
for source in web_sources:
|
for source in web_sources:
|
||||||
await chainlit_response.stream_token(f"- {source}\n")
|
await chainlit_response.stream_token(f"- {source}\n")
|
||||||
|
|
||||||
|
|
||||||
async def process_response(message):
|
|
||||||
config = {"configurable": {"thread_id": cl.user_session.get("id")}}
|
|
||||||
|
|
||||||
chainlit_response = cl.Message(content="")
|
|
||||||
|
|
||||||
async for response in graph.stream(message.content, config=config):
|
|
||||||
await chainlit_response.stream_token(response)
|
|
||||||
|
|
||||||
pdf_sources = graph.get_last_pdf_sources()
|
|
||||||
web_sources = graph.get_last_web_sources()
|
|
||||||
await add_sources(chainlit_response, pdf_sources, web_sources)
|
|
||||||
|
|
||||||
await chainlit_response.send()
|
|
||||||
|
|
||||||
|
|
||||||
async def process_cond_response(message):
|
|
||||||
config = {"configurable": {"thread_id": cl.user_session.get("id")}}
|
|
||||||
|
|
||||||
chainlit_response = cl.Message(content="")
|
|
||||||
|
|
||||||
for response in graph.stream(message.content, config=config):
|
|
||||||
await chainlit_response.stream_token(response)
|
|
||||||
|
|
||||||
await add_sources(chainlit_response, graph.last_retrieved_docs, graph.last_retrieved_sources)
|
|
||||||
|
|
||||||
await chainlit_response.send()
|
|
||||||
|
|
||||||
|
|
||||||
@cl.set_starters
|
@cl.set_starters
|
||||||
async def set_starters():
|
async def set_starters():
|
||||||
chainlit_starters = os.environ.get("CHAINLIT_STARTERS", None)
|
chainlit_starters = os.environ.get("CHAINLIT_STARTERS", None)
|
||||||
|
|||||||
@ -49,8 +49,8 @@ class CondRetGenLangGraph:
|
|||||||
self.last_retrieved_docs = {}
|
self.last_retrieved_docs = {}
|
||||||
self.last_retrieved_sources = set()
|
self.last_retrieved_sources = set()
|
||||||
|
|
||||||
def stream(self, message: str, config: RunnableConfig | None = None) -> Iterator[str]:
|
async def stream(self, message: str, config: RunnableConfig | None = None) -> AsyncGenerator[Any, Any]:
|
||||||
for llm_response, metadata in self.graph.stream(
|
async for llm_response, metadata in self.graph.astream(
|
||||||
{"messages": [{"role": "user", "content": message}]}, stream_mode="messages", config=config
|
{"messages": [{"role": "user", "content": message}]}, stream_mode="messages", config=config
|
||||||
):
|
):
|
||||||
if llm_response.content and metadata["langgraph_node"] == "_generate":
|
if llm_response.content and metadata["langgraph_node"] == "_generate":
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user