From fbc0e746787f05de0c5e2897d6aa4bb4fa89b1b3 Mon Sep 17 00:00:00 2001 From: Ruben Lucas Date: Thu, 17 Apr 2025 12:42:53 +0200 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20Add=20Reranker=20model?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- generic_rag/app.py | 15 +++++++++++++-- generic_rag/backend/models.py | 12 ++++++++++++ generic_rag/graphs/ret_gen.py | 14 ++++++++++++-- 3 files changed, 37 insertions(+), 4 deletions(-) diff --git a/generic_rag/app.py b/generic_rag/app.py index 50ee910..eada787 100644 --- a/generic_rag/app.py +++ b/generic_rag/app.py @@ -8,7 +8,13 @@ import chainlit as cl from chainlit.cli import run_chainlit from langchain_chroma import Chroma -from generic_rag.backend.models import ChatBackend, EmbeddingBackend, get_chat_model, get_embedding_model +from generic_rag.backend.models import ( + ChatBackend, + EmbeddingBackend, + get_chat_model, + get_embedding_model, + get_compression_model, +) from generic_rag.graphs.cond_ret_gen import CondRetGenLangGraph from generic_rag.graphs.ret_gen import RetGenLangGraph from generic_rag.parsers.parser import add_pdf_files, add_urls @@ -103,6 +109,9 @@ else: chat_model=get_chat_model(args.chat_backend), embedding_model=get_embedding_model(args.emb_backend), system_prompt=system_prompt, + compression_model=get_compression_model( + "BAAI/bge-reranker-base", vector_store + ), # TODO: implement in config parser ) @@ -129,7 +138,9 @@ async def add_sources(chainlit_response: cl.Message, pdf_sources: dict, web_sour for source, page_numbers in pdf_sources.items(): filename = Path(source).name await chainlit_response.stream_token(f"- {filename} on page(s): {sorted(page_numbers)}\n") - chainlit_response.elements.append(cl.Pdf(name=filename, display="side", path=source, page=sorted(page_numbers)[0])) + chainlit_response.elements.append( + cl.Pdf(name=filename, display="side", path=source, page=sorted(page_numbers)[0]) + ) if len(web_sources) > 0: await chainlit_response.stream_token("\n\nThe following web sources were consulted:\n") diff --git a/generic_rag/backend/models.py b/generic_rag/backend/models.py index 1e100ef..f2e8f2d 100644 --- a/generic_rag/backend/models.py +++ b/generic_rag/backend/models.py @@ -1,12 +1,17 @@ import os from enum import Enum +from langchain_chroma import Chroma from langchain.chat_models import init_chat_model from langchain_aws import BedrockEmbeddings from langchain_core.embeddings import Embeddings from langchain_core.language_models.chat_models import BaseChatModel +from langchain_core.retrievers import BaseRetriever from langchain_google_vertexai import VertexAIEmbeddings from langchain_huggingface import HuggingFaceEmbeddings +from langchain.retrievers import ContextualCompressionRetriever +from langchain.retrievers.document_compressors import CrossEncoderReranker +from langchain_community.cross_encoders import HuggingFaceCrossEncoder from langchain_ollama import ChatOllama, OllamaEmbeddings from langchain_openai import AzureChatOpenAI, AzureOpenAIEmbeddings, OpenAIEmbeddings @@ -83,3 +88,10 @@ def get_embedding_model(backend_type: EmbeddingBackend) -> Embeddings: return HuggingFaceEmbeddings(model_name=os.environ["HUGGINGFACE_EMB_MODEL"]) raise ValueError(f"Unknown backend type: {backend_type}") + + +def get_compression_model(model_name: str, vector_store: Chroma) -> BaseRetriever: + base_retriever = vector_store.as_retriever(search_kwargs={"k": 20}) + rerank_model = HuggingFaceCrossEncoder(model_name=model_name) + compressor = CrossEncoderReranker(model=rerank_model, top_n=4) + return ContextualCompressionRetriever(base_compressor=compressor, base_retriever=base_retriever) diff --git a/generic_rag/graphs/ret_gen.py b/generic_rag/graphs/ret_gen.py index 6cdd4f2..2ed6174 100644 --- a/generic_rag/graphs/ret_gen.py +++ b/generic_rag/graphs/ret_gen.py @@ -8,6 +8,7 @@ from langchain_core.embeddings import Embeddings from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.messages import BaseMessage, SystemMessage from langchain_core.runnables.config import RunnableConfig +from langchain_core.retrievers import BaseRetriever from langgraph.checkpoint.memory import MemorySaver from langgraph.graph import END, START, StateGraph from typing_extensions import List, TypedDict @@ -23,12 +24,18 @@ class State(TypedDict): class RetGenLangGraph: def __init__( - self, vector_store: Chroma, chat_model: BaseChatModel, embedding_model: Embeddings, system_prompt: str + self, + vector_store: Chroma, + chat_model: BaseChatModel, + embedding_model: Embeddings, + system_prompt: str, + compression_model: BaseRetriever | None = None, ): self.vector_store = vector_store self.chat_model = chat_model self.embedding_model = embedding_model self.system_prompt = system_prompt + self.compression_model = compression_model memory = MemorySaver() graph_builder = StateGraph(State).add_sequence([self._retrieve, self._generate]) @@ -45,7 +52,10 @@ class RetGenLangGraph: def _retrieve(self, state: State) -> dict[str, list]: logger.debug(f"querying VS for: {state["question"]}") - self.last_retrieved_docs = self.vector_store.similarity_search(state["question"]) + if self.compression_model: + self.last_retrieved_docs = self.compression_model.invoke(state["question"]) + else: + self.last_retrieved_docs = self.vector_store.similarity_search(state["question"]) return {"context": self.last_retrieved_docs} def _generate(self, state: State) -> dict[str, list]: