forked from AI_team/Philosophy-RAG-demo
Refactor out Retrieval/Generator LangGraph
This commit is contained in:
parent
3412dea813
commit
3fa0e31521
@ -6,14 +6,13 @@ from pathlib import Path
|
|||||||
|
|
||||||
import chainlit as cl
|
import chainlit as cl
|
||||||
from backend.models import BackendType, get_chat_model, get_embedding_model
|
from backend.models import BackendType, get_chat_model, get_embedding_model
|
||||||
|
from graphs.ret_gen import RetGenLangGraph
|
||||||
from chainlit.cli import run_chainlit
|
from chainlit.cli import run_chainlit
|
||||||
from langchain import hub
|
|
||||||
from langchain_chroma import Chroma
|
from langchain_chroma import Chroma
|
||||||
from langchain_core.documents import Document
|
|
||||||
from langgraph.graph import START, StateGraph
|
|
||||||
from langgraph.pregel.io import AddableValuesDict
|
|
||||||
from parsers.parser import add_pdf_files, add_urls
|
from parsers.parser import add_pdf_files, add_urls
|
||||||
from typing_extensions import List, TypedDict
|
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -65,62 +64,26 @@ parser.add_argument(
|
|||||||
parser.add_argument("-r", "--reset-chrome-db", action="store_true", help="Reset the Chroma DB.")
|
parser.add_argument("-r", "--reset-chrome-db", action="store_true", help="Reset the Chroma DB.")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
vector_store = Chroma(
|
||||||
|
collection_name="generic_rag",
|
||||||
|
embedding_function=get_embedding_model(args.backend),
|
||||||
|
persist_directory=str(args.chroma_db_location),
|
||||||
|
)
|
||||||
|
|
||||||
class State(TypedDict):
|
ret_gen_graph = RetGenLangGraph(
|
||||||
question: str
|
vector_store, chat_model=get_chat_model(args.backend), embedding_model=get_embedding_model(args.backend)
|
||||||
context: List[Document]
|
)
|
||||||
answer: str
|
|
||||||
|
|
||||||
|
|
||||||
def retrieve(state: State):
|
|
||||||
vector_store = cl.user_session.get("vector_store")
|
|
||||||
|
|
||||||
retrieved_docs = vector_store.similarity_search(state["question"])
|
|
||||||
|
|
||||||
return {"context": retrieved_docs}
|
|
||||||
|
|
||||||
|
|
||||||
def generate(state: State):
|
|
||||||
prompt = cl.user_session.get("prompt")
|
|
||||||
llm = cl.user_session.get("chat_model")
|
|
||||||
|
|
||||||
docs_content = "\n\n".join(doc.page_content for doc in state["context"])
|
|
||||||
messages = prompt.invoke({"question": state["question"], "context": docs_content})
|
|
||||||
response = llm.invoke(messages)
|
|
||||||
|
|
||||||
return {"answer": response.content}
|
|
||||||
|
|
||||||
|
|
||||||
@cl.on_chat_start
|
|
||||||
async def on_chat_start():
|
|
||||||
vector_store = Chroma(
|
|
||||||
collection_name="generic_rag",
|
|
||||||
embedding_function=get_embedding_model(args.back_end),
|
|
||||||
persist_directory=str(args.chroma_db_location),
|
|
||||||
)
|
|
||||||
|
|
||||||
cl.user_session.set("vector_store", vector_store)
|
|
||||||
cl.user_session.set("emb_model", get_embedding_model(args.back_end))
|
|
||||||
cl.user_session.set("chat_model", get_chat_model(args.back_end))
|
|
||||||
cl.user_session.set("prompt", hub.pull("rlm/rag-prompt"))
|
|
||||||
|
|
||||||
graph_builder = StateGraph(State).add_sequence([retrieve, generate])
|
|
||||||
graph_builder.add_edge(START, "retrieve")
|
|
||||||
graph = graph_builder.compile()
|
|
||||||
|
|
||||||
cl.user_session.set("graph", graph)
|
|
||||||
|
|
||||||
|
|
||||||
@cl.on_message
|
@cl.on_message
|
||||||
async def on_message(message: cl.Message):
|
async def on_message(message: cl.Message):
|
||||||
graph = cl.user_session.get("graph")
|
response = ret_gen_graph.invoke(message.content)
|
||||||
response = graph.invoke({"question": message.content})
|
|
||||||
|
|
||||||
answer = response["answer"]
|
answer = response["answer"]
|
||||||
answer += "\n\n"
|
answer += "\n\n"
|
||||||
|
|
||||||
pdf_sources = get_pdf_sources(response)
|
pdf_sources = ret_gen_graph.get_last_pdf_sources()
|
||||||
web_sources = get_web_sources(response)
|
web_sources = ret_gen_graph.get_last_web_sources()
|
||||||
|
|
||||||
elements = []
|
elements = []
|
||||||
if len(pdf_sources) > 0:
|
if len(pdf_sources) > 0:
|
||||||
@ -128,7 +91,7 @@ async def on_message(message: cl.Message):
|
|||||||
for source, page_numbers in pdf_sources.items():
|
for source, page_numbers in pdf_sources.items():
|
||||||
page_numbers = list(page_numbers)
|
page_numbers = list(page_numbers)
|
||||||
page_numbers.sort()
|
page_numbers.sort()
|
||||||
# display="side" seems to be not supported by chainlit for PDF's, so we use "inline" instead
|
# display="side" seems to be not supported by chainlit for PDF's, so we use "inline" instead.
|
||||||
elements.append(cl.Pdf(name="pdf", display="inline", path=source, page=page_numbers[0]))
|
elements.append(cl.Pdf(name="pdf", display="inline", path=source, page=page_numbers[0]))
|
||||||
answer += f"'{source}' on page(s): {page_numbers}\n"
|
answer += f"'{source}' on page(s): {page_numbers}\n"
|
||||||
|
|
||||||
@ -138,39 +101,6 @@ async def on_message(message: cl.Message):
|
|||||||
await cl.Message(content=answer, elements=elements).send()
|
await cl.Message(content=answer, elements=elements).send()
|
||||||
|
|
||||||
|
|
||||||
def get_pdf_sources(response: AddableValuesDict) -> dict[str, list[int]]:
|
|
||||||
"""
|
|
||||||
Function that retrieves the PDF sources with page numbers from a response.
|
|
||||||
"""
|
|
||||||
pdf_sources = {}
|
|
||||||
for context in response["context"]:
|
|
||||||
try:
|
|
||||||
if context.metadata["filetype"] == "application/pdf":
|
|
||||||
source = context.metadata["source"]
|
|
||||||
page_number = context.metadata["page_number"]
|
|
||||||
if source in pdf_sources:
|
|
||||||
pdf_sources[source].add(page_number)
|
|
||||||
else:
|
|
||||||
pdf_sources[source] = {page_number}
|
|
||||||
except KeyError:
|
|
||||||
pass
|
|
||||||
return pdf_sources
|
|
||||||
|
|
||||||
|
|
||||||
def get_web_sources(response: AddableValuesDict) -> set:
|
|
||||||
"""
|
|
||||||
Function that retrieves the web sources from a response.
|
|
||||||
"""
|
|
||||||
web_sources = set()
|
|
||||||
for context in response["context"]:
|
|
||||||
try:
|
|
||||||
if context.metadata["filetype"] == "web":
|
|
||||||
web_sources.add(context.metadata["source"])
|
|
||||||
except KeyError:
|
|
||||||
pass
|
|
||||||
return web_sources
|
|
||||||
|
|
||||||
|
|
||||||
@cl.set_starters
|
@cl.set_starters
|
||||||
async def set_starters():
|
async def set_starters():
|
||||||
chainlit_starters = os.environ["CHAINLIT_STARTERS"]
|
chainlit_starters = os.environ["CHAINLIT_STARTERS"]
|
||||||
@ -193,12 +123,6 @@ async def set_starters():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
vector_store = Chroma(
|
|
||||||
collection_name="generic_rag",
|
|
||||||
embedding_function=get_embedding_model(args.back_end),
|
|
||||||
persist_directory=str(args.chroma_db_location),
|
|
||||||
)
|
|
||||||
|
|
||||||
if args.reset_chrome_db:
|
if args.reset_chrome_db:
|
||||||
vector_store.reset_collection()
|
vector_store.reset_collection()
|
||||||
|
|
||||||
|
|||||||
80
generic_rag/graphs/ret_gen.py
Normal file
80
generic_rag/graphs/ret_gen.py
Normal file
@ -0,0 +1,80 @@
|
|||||||
|
from langgraph.graph import START, END, StateGraph
|
||||||
|
from typing_extensions import List, TypedDict
|
||||||
|
from langchain_core.documents import Document
|
||||||
|
from langchain import hub
|
||||||
|
from typing import Any, Union
|
||||||
|
|
||||||
|
|
||||||
|
class State(TypedDict):
|
||||||
|
question: str
|
||||||
|
context: List[Document]
|
||||||
|
answer: str
|
||||||
|
|
||||||
|
|
||||||
|
class RetGenLangGraph:
|
||||||
|
def __init__(self, vector_store, chat_model, embedding_model):
|
||||||
|
self.vector_store = vector_store
|
||||||
|
self.chat_model = chat_model
|
||||||
|
self.embedding_model = embedding_model
|
||||||
|
self.prompt = hub.pull("rlm/rag-prompt")
|
||||||
|
|
||||||
|
graph_builder = StateGraph(State).add_sequence([self._retrieve, self._generate])
|
||||||
|
graph_builder.add_edge(START, "_retrieve")
|
||||||
|
graph_builder.add_edge("_retrieve", "_generate")
|
||||||
|
graph_builder.add_edge("_generate", END)
|
||||||
|
|
||||||
|
self.graph = graph_builder.compile()
|
||||||
|
self.last_invoke = None
|
||||||
|
|
||||||
|
def invoke(self, message: str) -> Union[dict[str, Any], Any]:
|
||||||
|
self.last_invoke = self.graph.invoke(message)
|
||||||
|
return self.last_invoke
|
||||||
|
|
||||||
|
def _retrieve(self, state: State) -> dict:
|
||||||
|
retrieved_docs = self.vector_store.similarity_search(state["question"])
|
||||||
|
return {"context": retrieved_docs}
|
||||||
|
|
||||||
|
def _generate(self, state: State) -> dict:
|
||||||
|
docs_content = "\n\n".join(doc.page_content for doc in state["context"])
|
||||||
|
messages = self.prompt.invoke({"question": state["question"], "context": docs_content})
|
||||||
|
response = self.chat_model.invoke(messages)
|
||||||
|
return {"answer": response.content}
|
||||||
|
|
||||||
|
def get_last_pdf_sources(self) -> dict[str, list[int]]:
|
||||||
|
"""
|
||||||
|
Method that retrieves the PDF sources used during the last invoke.
|
||||||
|
"""
|
||||||
|
if self.last_invoke is None:
|
||||||
|
return []
|
||||||
|
|
||||||
|
pdf_sources = {}
|
||||||
|
for context in self.last_invoke["context"]:
|
||||||
|
try:
|
||||||
|
if context.metadata["filetype"] == "application/pdf":
|
||||||
|
source = context.metadata["source"]
|
||||||
|
page_number = context.metadata["page_number"]
|
||||||
|
if source in pdf_sources:
|
||||||
|
pdf_sources[source].add(page_number)
|
||||||
|
else:
|
||||||
|
pdf_sources[source] = {page_number}
|
||||||
|
except KeyError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return pdf_sources
|
||||||
|
|
||||||
|
def get_last_web_sources(self) -> set:
|
||||||
|
"""
|
||||||
|
Method that retrieves the web sources used during the last invoke.
|
||||||
|
"""
|
||||||
|
if self.last_invoke is None:
|
||||||
|
return set()
|
||||||
|
|
||||||
|
web_sources = set()
|
||||||
|
for context in self.last_invoke["context"]:
|
||||||
|
try:
|
||||||
|
if context.metadata["filetype"] == "web":
|
||||||
|
web_sources.add(context.metadata["source"])
|
||||||
|
except KeyError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return web_sources
|
||||||
Loading…
Reference in New Issue
Block a user