forked from AI_team/Philosophy-RAG-demo
🔀 Merge remote-tracking branch 'origin/main' into setting-parser
This commit is contained in:
commit
af5cbcacc3
@ -10,7 +10,7 @@ from chainlit.cli import run_chainlit
|
||||
from langchain_chroma import Chroma
|
||||
|
||||
from generic_rag.parsers.config import AppSettings, load_settings
|
||||
from generic_rag.backend.models import get_chat_model, get_embedding_model
|
||||
from generic_rag.backend.models import get_chat_model, get_embedding_model, get_compression_model
|
||||
from generic_rag.graphs.cond_ret_gen import CondRetGenLangGraph
|
||||
from generic_rag.graphs.ret_gen import RetGenLangGraph
|
||||
from generic_rag.parsers.parser import add_pdf_files, add_urls
|
||||
@ -67,6 +67,9 @@ else:
|
||||
chat_model=chat_function,
|
||||
embedding_model=embedding_function,
|
||||
system_prompt=system_prompt,
|
||||
compression_model=get_compression_model(
|
||||
"BAAI/bge-reranker-base", vector_store
|
||||
), # TODO: implement in config parser
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -1,11 +1,17 @@
|
||||
import logging
|
||||
import os
|
||||
|
||||
from langchain_chroma import Chroma
|
||||
from langchain_aws import BedrockEmbeddings
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
from langchain_aws import BedrockEmbeddings, ChatBedrock
|
||||
from langchain_google_vertexai import VertexAIEmbeddings, ChatVertexAI
|
||||
from langchain_huggingface import HuggingFaceEmbeddings, ChatHuggingFace, HuggingFacePipeline
|
||||
from langchain.retrievers import ContextualCompressionRetriever
|
||||
from langchain.retrievers.document_compressors import CrossEncoderReranker
|
||||
from langchain_community.cross_encoders import HuggingFaceCrossEncoder
|
||||
from langchain_ollama import ChatOllama, OllamaEmbeddings
|
||||
from langchain_openai import AzureChatOpenAI, AzureOpenAIEmbeddings, ChatOpenAI, OpenAIEmbeddings
|
||||
|
||||
@ -201,4 +207,11 @@ def get_embedding_model(settings: AppSettings) -> Embeddings:
|
||||
raise ValueError("HuggingFace configuration requires 'emb_model'.")
|
||||
return HuggingFaceEmbeddings(model_name=settings.huggingface.emb_model)
|
||||
|
||||
raise ValueError(f"Unknown or unhandled embedding backend type: {settings.emb_backend}")
|
||||
raise ValueError(f"Unknown backend type: {settings.backend_type}")
|
||||
|
||||
|
||||
def get_compression_model(model_name: str, vector_store: Chroma) -> BaseRetriever:
|
||||
base_retriever = vector_store.as_retriever(search_kwargs={"k": 20})
|
||||
rerank_model = HuggingFaceCrossEncoder(model_name=model_name)
|
||||
compressor = CrossEncoderReranker(model=rerank_model, top_n=4)
|
||||
return ContextualCompressionRetriever(base_compressor=compressor, base_retriever=base_retriever)
|
||||
|
||||
@ -8,6 +8,7 @@ from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_core.messages import BaseMessage, SystemMessage
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
from langgraph.checkpoint.memory import MemorySaver
|
||||
from langgraph.graph import END, START, StateGraph
|
||||
from typing_extensions import List, TypedDict
|
||||
@ -23,12 +24,18 @@ class State(TypedDict):
|
||||
|
||||
class RetGenLangGraph:
|
||||
def __init__(
|
||||
self, vector_store: Chroma, chat_model: BaseChatModel, embedding_model: Embeddings, system_prompt: str
|
||||
self,
|
||||
vector_store: Chroma,
|
||||
chat_model: BaseChatModel,
|
||||
embedding_model: Embeddings,
|
||||
system_prompt: str,
|
||||
compression_model: BaseRetriever | None = None,
|
||||
):
|
||||
self.vector_store = vector_store
|
||||
self.chat_model = chat_model
|
||||
self.embedding_model = embedding_model
|
||||
self.system_prompt = system_prompt
|
||||
self.compression_model = compression_model
|
||||
|
||||
memory = MemorySaver()
|
||||
graph_builder = StateGraph(State).add_sequence([self._retrieve, self._generate])
|
||||
@ -45,7 +52,10 @@ class RetGenLangGraph:
|
||||
|
||||
def _retrieve(self, state: State) -> dict[str, list]:
|
||||
logger.debug(f"querying VS for: {state["question"]}")
|
||||
self.last_retrieved_docs = self.vector_store.similarity_search(state["question"])
|
||||
if self.compression_model:
|
||||
self.last_retrieved_docs = self.compression_model.invoke(state["question"])
|
||||
else:
|
||||
self.last_retrieved_docs = self.vector_store.similarity_search(state["question"])
|
||||
return {"context": self.last_retrieved_docs}
|
||||
|
||||
def _generate(self, state: State) -> dict[str, list]:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user