forked from AI_team/Philosophy-RAG-demo
Add support for both lang Graphs
This commit is contained in:
parent
f25770e3ce
commit
ee0c731faf
@ -6,14 +6,12 @@ from pathlib import Path
|
|||||||
|
|
||||||
import chainlit as cl
|
import chainlit as cl
|
||||||
from backend.models import BackendType, get_chat_model, get_embedding_model
|
from backend.models import BackendType, get_chat_model, get_embedding_model
|
||||||
from graphs.ret_gen import RetGenLangGraph
|
|
||||||
from chainlit.cli import run_chainlit
|
from chainlit.cli import run_chainlit
|
||||||
|
from graphs.cond_ret_gen import CondRetGenLangGraph
|
||||||
|
from graphs.ret_gen import RetGenLangGraph
|
||||||
from langchain_chroma import Chroma
|
from langchain_chroma import Chroma
|
||||||
|
|
||||||
from parsers.parser import add_pdf_files, add_urls
|
from parsers.parser import add_pdf_files, add_urls
|
||||||
|
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -62,6 +60,13 @@ parser.add_argument(
|
|||||||
help="File path to store or load a Chroma DB from/to.",
|
help="File path to store or load a Chroma DB from/to.",
|
||||||
)
|
)
|
||||||
parser.add_argument("-r", "--reset-chrome-db", action="store_true", help="Reset the Chroma DB.")
|
parser.add_argument("-r", "--reset-chrome-db", action="store_true", help="Reset the Chroma DB.")
|
||||||
|
parser.add_argument(
|
||||||
|
"-c",
|
||||||
|
"--use-conditional-graph",
|
||||||
|
action="store_true",
|
||||||
|
help="Use the conditial retrieve generate graph over the regular retrieve generate graph. "
|
||||||
|
"The conditional version has build in (chat) memory and is capable of quering vectorstores on its own insight.",
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
vector_store = Chroma(
|
vector_store = Chroma(
|
||||||
@ -70,35 +75,51 @@ vector_store = Chroma(
|
|||||||
persist_directory=str(args.chroma_db_location),
|
persist_directory=str(args.chroma_db_location),
|
||||||
)
|
)
|
||||||
|
|
||||||
ret_gen_graph = RetGenLangGraph(
|
if args.use_conditional_graph:
|
||||||
vector_store, chat_model=get_chat_model(args.backend), embedding_model=get_embedding_model(args.backend)
|
graph = CondRetGenLangGraph(
|
||||||
)
|
vector_store, chat_model=get_chat_model(args.backend), embedding_model=get_embedding_model(args.backend)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
graph = RetGenLangGraph(
|
||||||
|
vector_store, chat_model=get_chat_model(args.backend), embedding_model=get_embedding_model(args.backend)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@cl.on_message
|
@cl.on_message
|
||||||
async def on_message(message: cl.Message):
|
async def on_message(message: cl.Message):
|
||||||
response = ret_gen_graph.invoke(message.content)
|
if isinstance(graph, CondRetGenLangGraph):
|
||||||
|
config = {"configurable": {"thread_id": cl.user_session.get("id")}}
|
||||||
|
|
||||||
answer = response["answer"]
|
chainlit_response = cl.Message(content="")
|
||||||
answer += "\n\n"
|
|
||||||
|
|
||||||
pdf_sources = ret_gen_graph.get_last_pdf_sources()
|
for response in graph.stream(message.content, config=config):
|
||||||
web_sources = ret_gen_graph.get_last_web_sources()
|
await chainlit_response.stream_token(response)
|
||||||
|
|
||||||
elements = []
|
await chainlit_response.send()
|
||||||
if len(pdf_sources) > 0:
|
|
||||||
answer += "The following PDF source were consulted:\n"
|
|
||||||
for source, page_numbers in pdf_sources.items():
|
|
||||||
page_numbers = list(page_numbers)
|
|
||||||
page_numbers.sort()
|
|
||||||
# display="side" seems to be not supported by chainlit for PDF's, so we use "inline" instead.
|
|
||||||
elements.append(cl.Pdf(name="pdf", display="inline", path=source, page=page_numbers[0]))
|
|
||||||
answer += f"'{source}' on page(s): {page_numbers}\n"
|
|
||||||
|
|
||||||
if len(web_sources) > 0:
|
elif isinstance(graph, RetGenLangGraph):
|
||||||
answer += f"The following web sources were consulted: {web_sources}\n"
|
response = graph.invoke(message.content)
|
||||||
|
|
||||||
await cl.Message(content=answer, elements=elements).send()
|
answer = response["answer"]
|
||||||
|
answer += "\n\n"
|
||||||
|
|
||||||
|
pdf_sources = graph.get_last_pdf_sources()
|
||||||
|
web_sources = graph.get_last_web_sources()
|
||||||
|
|
||||||
|
elements = []
|
||||||
|
if len(pdf_sources) > 0:
|
||||||
|
answer += "The following PDF source were consulted:\n"
|
||||||
|
for source, page_numbers in pdf_sources.items():
|
||||||
|
page_numbers = list(page_numbers)
|
||||||
|
page_numbers.sort()
|
||||||
|
# display="side" seems to be not supported by chainlit for PDF's, so we use "inline" instead.
|
||||||
|
elements.append(cl.Pdf(name="pdf", display="inline", path=source, page=page_numbers[0]))
|
||||||
|
answer += f"'{source}' on page(s): {page_numbers}\n"
|
||||||
|
|
||||||
|
if len(web_sources) > 0:
|
||||||
|
answer += f"The following web sources were consulted: {web_sources}\n"
|
||||||
|
|
||||||
|
await cl.Message(content=answer, elements=elements).send()
|
||||||
|
|
||||||
|
|
||||||
@cl.set_starters
|
@cl.set_starters
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user