forked from AI_team/Philosophy-RAG-demo
✨ Add Reranker model
This commit is contained in:
parent
59ab43d31e
commit
fbc0e74678
@ -8,7 +8,13 @@ import chainlit as cl
|
|||||||
from chainlit.cli import run_chainlit
|
from chainlit.cli import run_chainlit
|
||||||
from langchain_chroma import Chroma
|
from langchain_chroma import Chroma
|
||||||
|
|
||||||
from generic_rag.backend.models import ChatBackend, EmbeddingBackend, get_chat_model, get_embedding_model
|
from generic_rag.backend.models import (
|
||||||
|
ChatBackend,
|
||||||
|
EmbeddingBackend,
|
||||||
|
get_chat_model,
|
||||||
|
get_embedding_model,
|
||||||
|
get_compression_model,
|
||||||
|
)
|
||||||
from generic_rag.graphs.cond_ret_gen import CondRetGenLangGraph
|
from generic_rag.graphs.cond_ret_gen import CondRetGenLangGraph
|
||||||
from generic_rag.graphs.ret_gen import RetGenLangGraph
|
from generic_rag.graphs.ret_gen import RetGenLangGraph
|
||||||
from generic_rag.parsers.parser import add_pdf_files, add_urls
|
from generic_rag.parsers.parser import add_pdf_files, add_urls
|
||||||
@ -103,6 +109,9 @@ else:
|
|||||||
chat_model=get_chat_model(args.chat_backend),
|
chat_model=get_chat_model(args.chat_backend),
|
||||||
embedding_model=get_embedding_model(args.emb_backend),
|
embedding_model=get_embedding_model(args.emb_backend),
|
||||||
system_prompt=system_prompt,
|
system_prompt=system_prompt,
|
||||||
|
compression_model=get_compression_model(
|
||||||
|
"BAAI/bge-reranker-base", vector_store
|
||||||
|
), # TODO: implement in config parser
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -129,7 +138,9 @@ async def add_sources(chainlit_response: cl.Message, pdf_sources: dict, web_sour
|
|||||||
for source, page_numbers in pdf_sources.items():
|
for source, page_numbers in pdf_sources.items():
|
||||||
filename = Path(source).name
|
filename = Path(source).name
|
||||||
await chainlit_response.stream_token(f"- {filename} on page(s): {sorted(page_numbers)}\n")
|
await chainlit_response.stream_token(f"- {filename} on page(s): {sorted(page_numbers)}\n")
|
||||||
chainlit_response.elements.append(cl.Pdf(name=filename, display="side", path=source, page=sorted(page_numbers)[0]))
|
chainlit_response.elements.append(
|
||||||
|
cl.Pdf(name=filename, display="side", path=source, page=sorted(page_numbers)[0])
|
||||||
|
)
|
||||||
|
|
||||||
if len(web_sources) > 0:
|
if len(web_sources) > 0:
|
||||||
await chainlit_response.stream_token("\n\nThe following web sources were consulted:\n")
|
await chainlit_response.stream_token("\n\nThe following web sources were consulted:\n")
|
||||||
|
|||||||
@ -1,12 +1,17 @@
|
|||||||
import os
|
import os
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
|
from langchain_chroma import Chroma
|
||||||
from langchain.chat_models import init_chat_model
|
from langchain.chat_models import init_chat_model
|
||||||
from langchain_aws import BedrockEmbeddings
|
from langchain_aws import BedrockEmbeddings
|
||||||
from langchain_core.embeddings import Embeddings
|
from langchain_core.embeddings import Embeddings
|
||||||
from langchain_core.language_models.chat_models import BaseChatModel
|
from langchain_core.language_models.chat_models import BaseChatModel
|
||||||
|
from langchain_core.retrievers import BaseRetriever
|
||||||
from langchain_google_vertexai import VertexAIEmbeddings
|
from langchain_google_vertexai import VertexAIEmbeddings
|
||||||
from langchain_huggingface import HuggingFaceEmbeddings
|
from langchain_huggingface import HuggingFaceEmbeddings
|
||||||
|
from langchain.retrievers import ContextualCompressionRetriever
|
||||||
|
from langchain.retrievers.document_compressors import CrossEncoderReranker
|
||||||
|
from langchain_community.cross_encoders import HuggingFaceCrossEncoder
|
||||||
from langchain_ollama import ChatOllama, OllamaEmbeddings
|
from langchain_ollama import ChatOllama, OllamaEmbeddings
|
||||||
from langchain_openai import AzureChatOpenAI, AzureOpenAIEmbeddings, OpenAIEmbeddings
|
from langchain_openai import AzureChatOpenAI, AzureOpenAIEmbeddings, OpenAIEmbeddings
|
||||||
|
|
||||||
@ -83,3 +88,10 @@ def get_embedding_model(backend_type: EmbeddingBackend) -> Embeddings:
|
|||||||
return HuggingFaceEmbeddings(model_name=os.environ["HUGGINGFACE_EMB_MODEL"])
|
return HuggingFaceEmbeddings(model_name=os.environ["HUGGINGFACE_EMB_MODEL"])
|
||||||
|
|
||||||
raise ValueError(f"Unknown backend type: {backend_type}")
|
raise ValueError(f"Unknown backend type: {backend_type}")
|
||||||
|
|
||||||
|
|
||||||
|
def get_compression_model(model_name: str, vector_store: Chroma) -> BaseRetriever:
|
||||||
|
base_retriever = vector_store.as_retriever(search_kwargs={"k": 20})
|
||||||
|
rerank_model = HuggingFaceCrossEncoder(model_name=model_name)
|
||||||
|
compressor = CrossEncoderReranker(model=rerank_model, top_n=4)
|
||||||
|
return ContextualCompressionRetriever(base_compressor=compressor, base_retriever=base_retriever)
|
||||||
|
|||||||
@ -8,6 +8,7 @@ from langchain_core.embeddings import Embeddings
|
|||||||
from langchain_core.language_models.chat_models import BaseChatModel
|
from langchain_core.language_models.chat_models import BaseChatModel
|
||||||
from langchain_core.messages import BaseMessage, SystemMessage
|
from langchain_core.messages import BaseMessage, SystemMessage
|
||||||
from langchain_core.runnables.config import RunnableConfig
|
from langchain_core.runnables.config import RunnableConfig
|
||||||
|
from langchain_core.retrievers import BaseRetriever
|
||||||
from langgraph.checkpoint.memory import MemorySaver
|
from langgraph.checkpoint.memory import MemorySaver
|
||||||
from langgraph.graph import END, START, StateGraph
|
from langgraph.graph import END, START, StateGraph
|
||||||
from typing_extensions import List, TypedDict
|
from typing_extensions import List, TypedDict
|
||||||
@ -23,12 +24,18 @@ class State(TypedDict):
|
|||||||
|
|
||||||
class RetGenLangGraph:
|
class RetGenLangGraph:
|
||||||
def __init__(
|
def __init__(
|
||||||
self, vector_store: Chroma, chat_model: BaseChatModel, embedding_model: Embeddings, system_prompt: str
|
self,
|
||||||
|
vector_store: Chroma,
|
||||||
|
chat_model: BaseChatModel,
|
||||||
|
embedding_model: Embeddings,
|
||||||
|
system_prompt: str,
|
||||||
|
compression_model: BaseRetriever | None = None,
|
||||||
):
|
):
|
||||||
self.vector_store = vector_store
|
self.vector_store = vector_store
|
||||||
self.chat_model = chat_model
|
self.chat_model = chat_model
|
||||||
self.embedding_model = embedding_model
|
self.embedding_model = embedding_model
|
||||||
self.system_prompt = system_prompt
|
self.system_prompt = system_prompt
|
||||||
|
self.compression_model = compression_model
|
||||||
|
|
||||||
memory = MemorySaver()
|
memory = MemorySaver()
|
||||||
graph_builder = StateGraph(State).add_sequence([self._retrieve, self._generate])
|
graph_builder = StateGraph(State).add_sequence([self._retrieve, self._generate])
|
||||||
@ -45,6 +52,9 @@ class RetGenLangGraph:
|
|||||||
|
|
||||||
def _retrieve(self, state: State) -> dict[str, list]:
|
def _retrieve(self, state: State) -> dict[str, list]:
|
||||||
logger.debug(f"querying VS for: {state["question"]}")
|
logger.debug(f"querying VS for: {state["question"]}")
|
||||||
|
if self.compression_model:
|
||||||
|
self.last_retrieved_docs = self.compression_model.invoke(state["question"])
|
||||||
|
else:
|
||||||
self.last_retrieved_docs = self.vector_store.similarity_search(state["question"])
|
self.last_retrieved_docs = self.vector_store.similarity_search(state["question"])
|
||||||
return {"context": self.last_retrieved_docs}
|
return {"context": self.last_retrieved_docs}
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user