From 996f3bf7a2d656cc1146a32fedd06daf906bf2b7 Mon Sep 17 00:00:00 2001 From: Ruben Lucas Date: Tue, 15 Apr 2025 16:22:55 +0200 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20Yaml=20parser=20using=20Pydantic=20?= =?UTF-8?q?classes?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 3 + generic_rag/app.py | 112 ++++++------------ generic_rag/backend/models.py | 209 ++++++++++++++++++++++++---------- generic_rag/parsers/config.py | 177 ++++++++++++++++++++++++++++ generic_rag/parsers/parser.py | 7 +- 5 files changed, 370 insertions(+), 138 deletions(-) create mode 100644 generic_rag/parsers/config.py diff --git a/.gitignore b/.gitignore index 74598a5..6578b8f 100644 --- a/.gitignore +++ b/.gitignore @@ -167,3 +167,6 @@ chainlit.md # Chroma DB .chroma_db/ + +# Settings +config.yaml \ No newline at end of file diff --git a/generic_rag/app.py b/generic_rag/app.py index 50ee910..d17b211 100644 --- a/generic_rag/app.py +++ b/generic_rag/app.py @@ -1,14 +1,15 @@ -import argparse import json import logging import os from pathlib import Path +import sys import chainlit as cl from chainlit.cli import run_chainlit from langchain_chroma import Chroma -from generic_rag.backend.models import ChatBackend, EmbeddingBackend, get_chat_model, get_embedding_model +from generic_rag.parsers.config import AppSettings, load_settings +from generic_rag.backend.models import get_chat_model, get_embedding_model from generic_rag.graphs.cond_ret_gen import CondRetGenLangGraph from generic_rag.graphs.ret_gen import RetGenLangGraph from generic_rag.parsers.parser import add_pdf_files, add_urls @@ -23,85 +24,36 @@ system_prompt = ( "If you don't know the answer, say that you don't know." ) -parser = argparse.ArgumentParser(description="A Sogeti Nederland Generic RAG demo.") -parser.add_argument( - "-c", - "--chat-backend", - type=ChatBackend, - choices=list(ChatBackend), - default=ChatBackend.local, - help="Cloud provider or local LLM to use as backend. In the case of 'local', Ollama needs to be installed.", -) -parser.add_argument( - "-e", - "--emb-backend", - type=EmbeddingBackend, - choices=list(EmbeddingBackend), - default=EmbeddingBackend.huggingface, - help="Cloud provider or local embedding to use as backend. In the case of 'local', Ollama needs to be installed. ", -) -parser.add_argument( - "-p", - "--pdf-data", - type=Path, - nargs="+", - default=[], - help="One or multiple paths to folders or files to use for retrieval. " - "If a path is a folder, all files in the folder will be used. " - "If a path is a file, only that file will be used. " - "If the path is relative it will be relative to the current working directory.", -) -parser.add_argument( - "-u", - "--unstructured-pdf", - action="store_true", - help="Use an unstructered PDF loader. " - "An unstructured PDF loader might be usefull for PDF files " - "that contain a lot of images with text, tables or (scanned) text as images. " - "Please use '-r' when switching parsers on already indexed data.", -) -parser.add_argument("--pdf-chunk_size", type=int, default=1000, help="The size of the chunks to split the text into.") -parser.add_argument("--pdf-chunk_overlap", type=int, default=200, help="The overlap between the chunks.") -parser.add_argument( - "--pdf-add-start-index", action="store_true", help="Add the start index to the metadata of the chunks." -) -parser.add_argument( - "-w", "--web-data", type=str, nargs="*", default=[], help="One or multiple URLs to use for retrieval." -) -parser.add_argument("--web-chunk-size", type=int, default=200, help="The size of the chunks to split the text into.") -parser.add_argument( - "-d", - "--chroma-db-location", - type=Path, - default=Path(".chroma_db"), - help="File path to store or load a Chroma DB from/to.", -) -parser.add_argument("-r", "--reset-chrome-db", action="store_true", help="Reset the Chroma DB.") -parser.add_argument( - "--use-conditional-graph", - action="store_true", - help="Use the conditial retrieve generate graph over the regular retrieve generate graph.", -) -args = parser.parse_args() +CONFIG_FILE_PATH = Path("config.yaml") + +try: + settings: AppSettings = load_settings(CONFIG_FILE_PATH) +except (FileNotFoundError, Exception) as e: + logger.error(f"Failed to load configuration from {CONFIG_FILE_PATH}. Exiting.") + sys.exit(1) + +embedding_function = get_embedding_model(settings) + +chat_function = get_chat_model(settings) vector_store = Chroma( collection_name="generic_rag", - embedding_function=get_embedding_model(args.emb_backend), - persist_directory=str(args.chroma_db_location), + embedding_function=embedding_function, + persist_directory=str(settings.chroma_db.location), ) -if args.use_conditional_graph: +if settings.use_conditional_graph: graph = CondRetGenLangGraph( vector_store=vector_store, - chat_model=get_chat_model(args.chat_backend), - embedding_model=get_embedding_model(args.emb_backend), + chat_model=chat_function, + embedding_model=embedding_function, system_prompt=system_prompt, ) else: graph = RetGenLangGraph( vector_store=vector_store, - chat_model=get_chat_model(args.chat_backend), - embedding_model=get_embedding_model(args.emb_backend), + chat_model=chat_function, + embedding_model=embedding_function, system_prompt=system_prompt, ) @@ -129,7 +81,9 @@ async def add_sources(chainlit_response: cl.Message, pdf_sources: dict, web_sour for source, page_numbers in pdf_sources.items(): filename = Path(source).name await chainlit_response.stream_token(f"- {filename} on page(s): {sorted(page_numbers)}\n") - chainlit_response.elements.append(cl.Pdf(name=filename, display="side", path=source, page=sorted(page_numbers)[0])) + chainlit_response.elements.append( + cl.Pdf(name=filename, display="side", path=source, page=sorted(page_numbers)[0]) + ) if len(web_sources) > 0: await chainlit_response.stream_token("\n\nThe following web sources were consulted:\n") @@ -159,17 +113,21 @@ async def set_starters(): if __name__ == "__main__": - if args.reset_chrome_db: + if settings.chroma_db.reset: vector_store.reset_collection() add_pdf_files( vector_store, - args.pdf_data, - args.pdf_chunk_size, - args.pdf_chunk_overlap, - args.pdf_add_start_index, - args.unstructured_pdf, + settings.pdf.data, + settings.pdf.chunk_size, + settings.pdf.chunk_overlap, + settings.pdf.add_start_index, + settings.pdf.unstructured, + ) + add_urls( + vector_store, + settings.web.data, + settings.web.chunk_size, ) - add_urls(vector_store, args.web_data, args.web_chunk_size) run_chainlit(__file__) diff --git a/generic_rag/backend/models.py b/generic_rag/backend/models.py index 1e100ef..a5a261f 100644 --- a/generic_rag/backend/models.py +++ b/generic_rag/backend/models.py @@ -1,85 +1,180 @@ -import os -from enum import Enum +import logging -from langchain.chat_models import init_chat_model -from langchain_aws import BedrockEmbeddings +from generic_rag.parsers.config import AppSettings, ChatBackend, EmbeddingBackend + +# Langchain imports from langchain_core.embeddings import Embeddings from langchain_core.language_models.chat_models import BaseChatModel -from langchain_google_vertexai import VertexAIEmbeddings +from langchain_aws import BedrockEmbeddings, ChatBedrock # Import ChatBedrock +from langchain_google_vertexai import VertexAIEmbeddings, ChatVertexAI # Import ChatVertexAI from langchain_huggingface import HuggingFaceEmbeddings from langchain_ollama import ChatOllama, OllamaEmbeddings -from langchain_openai import AzureChatOpenAI, AzureOpenAIEmbeddings, OpenAIEmbeddings +from langchain_openai import AzureChatOpenAI, AzureOpenAIEmbeddings, ChatOpenAI, OpenAIEmbeddings # Import ChatOpenAI + +logger = logging.getLogger(__name__) -class ChatBackend(Enum): - azure = "azure" - openai = "openai" - google_vertex = "google_vertex" - aws = "aws" - local = "local" +def get_chat_model(settings: AppSettings) -> BaseChatModel: + """ + Initializes and returns a chat model based on the backend type and configuration. - # make the enum pretty printable for argparse - def __str__(self): - return self.value + Args: + settings: The loaded AppSettings object containing configurations. + Returns: + An instance of BaseChatModel. -class EmbeddingBackend(Enum): - azure = "azure" - openai = "openai" - google_vertex = "google_vertex" - aws = "aws" - local = "local" - huggingface = "huggingface" + Raises: + ValueError: If the backend type is unknown or required configuration is missing. + """ + logger.info(f"Initializing chat model for backend: {settings.chat_backend.value}") - # make the enum pretty printable for argparse - def __str__(self): - return self.value - - -def get_chat_model(backend_type: ChatBackend) -> BaseChatModel: - if backend_type == ChatBackend.azure: + if settings.chat_backend == ChatBackend.azure: + if not settings.azure: + raise ValueError("Azure chat backend selected, but 'azure' configuration section is missing in config.") + if ( + not settings.azure.llm_endpoint + or not settings.azure.llm_deployment_name + or not settings.azure.llm_api_version + ): + raise ValueError( + "Azure configuration requires 'llm_endpoint', 'llm_deployment_name', and 'llm_api_version'." + ) return AzureChatOpenAI( - azure_endpoint=os.environ["AZURE_LLM_ENDPOINT"], - azure_deployment=os.environ["AZURE_LLM_DEPLOYMENT_NAME"], - openai_api_version=os.environ["AZURE_LLM_API_VERSION"], + azure_endpoint=settings.azure.llm_endpoint, + azure_deployment=settings.azure.llm_deployment_name, + openai_api_version=settings.azure.llm_api_version, + openai_api_key=settings.azure.openai_api_key.get_secret_value() if settings.azure.openai_api_key else None, ) - if backend_type == ChatBackend.openai: - return init_chat_model(os.environ["OPENAI_CHAT_MODEL"], model_provider="openai") + if settings.chat_backend == ChatBackend.openai: + if not settings.openai: + raise ValueError("OpenAI chat backend selected, but 'openai' configuration section is missing.") + if not settings.openai.api_key or not settings.openai.chat_model: + raise ValueError("OpenAI configuration requires 'api_key' and 'chat_model'.") + logger.info(f"Using OpenAI model: {model_name}") + return ChatOpenAI(model=settings.openai.chat_model, openai_api_key=settings.openai.api_key.get_secret_value()) - if backend_type == ChatBackend.google_vertex: - return init_chat_model(os.environ["GOOGLE_CHAT_MODEL"], model_provider="google_vertexai") + if settings.chat_backend == ChatBackend.google_vertex: + if not settings.google: + raise ValueError("Google Vertex chat backend selected, but 'google' configuration section is missing.") + if settings.google.chat_model: + model_name = settings.google.chat_model + logger.info(f"Using Google Vertex model: {model_name}") + return ChatVertexAI( + model_name=settings.google.chat_model, + project=settings.google.project_id, + location=settings.google.location, + ) - if backend_type == ChatBackend.aws: - return init_chat_model(model=os.environ["AWS_CHAT_MODEL"], model_provider="bedrock_converse") + if settings.chat_backend == ChatBackend.aws: + if not settings.aws: + raise ValueError("AWS Bedrock chat backend selected, but 'aws' configuration section is missing.") + model_name = "anthropic.claude-v2" # Example default + if hasattr(settings.aws, "chat_model") and settings.aws.chat_model: + model_name = settings.aws.chat_model + logger.info(f"Using AWS Bedrock model: {model_name}") + return ChatBedrock( + model_id=model_name, + region_name=settings.aws.region_name, + ) - if backend_type == ChatBackend.local: - return ChatOllama(model=os.environ["LOCAL_CHAT_MODEL"]) + if settings.chat_backend == ChatBackend.local: + if not settings.local or not settings.local.chat_model: + raise ValueError("Local chat backend selected, but 'local.chat_model' is missing in config.") + logger.info(f"Using Local Ollama model: {settings.local.chat_model}") + # Base URL can also be configured, e.g., base_url=config.local.ollama_base_url + return ChatOllama(model=settings.local.chat_model) - raise ValueError(f"Unknown backend type: {backend_type}") + # This should not be reached if all Enum members are handled + raise ValueError(f"Unknown or unhandled chat backend type: {settings.chat_backend}") -def get_embedding_model(backend_type: EmbeddingBackend) -> Embeddings: - if backend_type == EmbeddingBackend.azure: +def get_embedding_model(settings: AppSettings) -> Embeddings: + """ + Initializes and returns an embedding model based on the backend type and configuration. + + Args: + settings: The loaded AppSettings object containing configurations. + + Returns: + An instance of Embeddings. + + Raises: + ValueError: If the backend type is unknown or required configuration is missing. + """ + logger.info(f"Initializing embedding model for backend: {settings.emb_backend.value}") + + if settings.emb_backend == EmbeddingBackend.azure: + if not settings.azure: + raise ValueError("Azure embedding backend selected, but 'azure' configuration section is missing.") + if ( + not settings.azure.emb_endpoint + or not settings.azure.emb_deployment_name + or not settings.azure.emb_api_version + ): + raise ValueError( + "Azure configuration requires 'emb_endpoint', 'emb_deployment_name', and 'emb_api_version'." + ) return AzureOpenAIEmbeddings( - azure_endpoint=os.environ["AZURE_EMB_ENDPOINT"], - azure_deployment=os.environ["AZURE_EMB_DEPLOYMENT_NAME"], - openai_api_version=os.environ["AZURE_EMB_API_VERSION"], + azure_endpoint=settings.azure.emb_endpoint, + azure_deployment=settings.azure.emb_deployment_name, + openai_api_version=settings.azure.emb_api_version, + openai_api_key=settings.azure.openai_api_key.get_secret_value() if settings.azure.openai_api_key else None, ) - if backend_type == EmbeddingBackend.openai: - return OpenAIEmbeddings(model=os.environ["OPENAI_EMB_MODEL"]) + if settings.emb_backend == EmbeddingBackend.openai: + if not settings.openai: + raise ValueError("OpenAI embedding backend selected, but 'openai' configuration section is missing.") + if not settings.openai.api_key: + raise ValueError("OpenAI configuration requires 'api_key'.") + model_name = "text-embedding-ada-002" # Example default + if hasattr(settings.openai, "emb_model") and settings.openai.emb_model: + model_name = settings.openai.emb_model + logger.info(f"Using OpenAI embedding model: {model_name}") + return OpenAIEmbeddings(model=model_name, openai_api_key=settings.openai.api_key.get_secret_value()) - if backend_type == EmbeddingBackend.google_vertex: - return VertexAIEmbeddings(model=os.environ["GOOGLE_EMB_MODEL"]) + if settings.emb_backend == EmbeddingBackend.google_vertex: + if not settings.google: + raise ValueError("Google Vertex embedding backend selected, but 'google' configuration section is missing.") + model_name = "textembedding-gecko@001" # Example default + if settings.google.emb_model: + model_name = settings.google.emb_model + logger.info(f"Using Google Vertex embedding model: {model_name}") + return VertexAIEmbeddings( + model_name=model_name, project=settings.google.project_id, location=settings.google.location + ) - if backend_type == EmbeddingBackend.aws: - return BedrockEmbeddings(model_id=os.environ["AWS_EMB_MODEL"]) + if settings.emb_backend == EmbeddingBackend.aws: + if not settings.aws: + raise ValueError("AWS Bedrock embedding backend selected, but 'aws' configuration section is missing.") + model_name = "amazon.titan-embed-text-v1" # Example default + if hasattr(settings.aws, "emb_model") and settings.aws.emb_model: + model_name = settings.aws.emb_model + logger.info(f"Using AWS Bedrock embedding model: {model_name}") + return BedrockEmbeddings(model_id=model_name, region_name=settings.aws.region_name) - if backend_type == EmbeddingBackend.local: - return OllamaEmbeddings(model=os.environ["LOCAL_EMB_MODEL"]) + if settings.emb_backend == EmbeddingBackend.local: + if not settings.local or not settings.local.emb_model: + raise ValueError("Local embedding backend selected, but 'local.emb_model' is missing in config.") + logger.info(f"Using Local Ollama embedding model: {settings.local.emb_model}") + return OllamaEmbeddings(model=settings.local.emb_model) - if backend_type == EmbeddingBackend.huggingface: - return HuggingFaceEmbeddings(model_name=os.environ["HUGGINGFACE_EMB_MODEL"]) + if settings.emb_backend == EmbeddingBackend.huggingface: + if not settings.huggingface or not settings.huggingface.emb_model: + if settings.local and settings.local.emb_model: + logger.warning( + "HuggingFace backend selected, but 'huggingface.emb_model' missing. Using 'local.emb_model'." + ) + model_name = settings.local.emb_model + else: + raise ValueError( + "HuggingFace embedding backend selected, but 'huggingface.emb_model' (or 'local.emb_model') is missing in config." + ) + else: + model_name = settings.huggingface.emb_model - raise ValueError(f"Unknown backend type: {backend_type}") + logger.info(f"Using HuggingFace embedding model: {model_name}") + return HuggingFaceEmbeddings(model_name=model_name) + + raise ValueError(f"Unknown or unhandled embedding backend type: {settings.emb_backend}") diff --git a/generic_rag/parsers/config.py b/generic_rag/parsers/config.py new file mode 100644 index 0000000..af5c9b0 --- /dev/null +++ b/generic_rag/parsers/config.py @@ -0,0 +1,177 @@ +import yaml +from pathlib import Path +from typing import List, Optional +from enum import Enum +from pydantic import ( + BaseModel, + Field, + ValidationError, + SecretStr, +) +import sys + + +class ChatBackend(str, Enum): + azure = "azure" + openai = "openai" + google_vertex = "google_vertex" + aws = "aws" + local = "local" + + def __str__(self): + return self.value + + +class EmbeddingBackend(str, Enum): + azure = "azure" + openai = "openai" + google_vertex = "google_vertex" + aws = "aws" + local = "local" + huggingface = "huggingface" + + def __str__(self): + return self.value + + +class AzureSettings(BaseModel): + """Azure specific settings.""" + + openai_api_key: Optional[SecretStr] = None + llm_endpoint: Optional[str] = None + llm_deployment_name: Optional[str] = None + llm_api_version: Optional[str] = None + emb_endpoint: Optional[str] = None + emb_deployment_name: Optional[str] = None + emb_api_version: Optional[str] = None + + +class OpenAISettings(BaseModel): + """OpenAI specific settings.""" + + api_key: Optional[SecretStr] = None + + +class GoogleSettings(BaseModel): + """Google specific settings (Vertex AI or GenAI).""" + + api_key: Optional[SecretStr] = None + project_id: Optional[str] = None + location: Optional[str] = None + chat_model: Optional[str] = None + emb_model: Optional[str] = None + + +class AwsSettings(BaseModel): + """AWS specific settings (e.g., for Bedrock).""" + + access_key_id: Optional[SecretStr] = None + secret_access_key: Optional[SecretStr] = None + region_name: Optional[str] = None + + +class LocalSettings(BaseModel): + """Local backend specific settings (e.g., Ollama models).""" + + chat_model: Optional[str] = None + emb_model: Optional[str] = None + + +class HuggingFaceSettings(BaseModel): + """HuggingFace specific settings (if different from local embeddings).""" + + emb_model: Optional[str] = None + api_token: Optional[SecretStr] = None + + +class PdfSettings(BaseModel): + """PDF processing settings.""" + + data: List[Path] = Field(default_factory=list) + unstructured: bool = Field(default=False) + chunk_size: int = Field(default=1000) + chunk_overlap: int = Field(default=200) + add_start_index: bool = Field(default=False) + + +class WebSettings(BaseModel): + """Web data processing settings.""" + + data: List[str] = Field(default_factory=list) + chunk_size: int = Field(default=200) + + +class ChromaDbSettings(BaseModel): + """Chroma DB settings.""" + + location: Path = Field(default=Path(".chroma_db")) + reset: bool = Field(default=False) + + +class AppSettings(BaseModel): + """ + Main application settings model. + + Loads configuration from a YAML file using the structure defined + by the nested models. + """ + + # --- Top-level settings --- + chat_backend: ChatBackend = Field(default=ChatBackend.local) + emb_backend: EmbeddingBackend = Field(default=EmbeddingBackend.huggingface) + use_conditional_graph: bool = Field(default=False) + + # --- Provider-specific settings --- + azure: Optional[AzureSettings] = None + openai: Optional[OpenAISettings] = None + google: Optional[GoogleSettings] = None + aws: Optional[AwsSettings] = None + local: Optional[LocalSettings] = None + huggingface: Optional[HuggingFaceSettings] = None # Separate HF config if needed + + # --- Data processing settings --- + pdf: PdfSettings = Field(default_factory=PdfSettings) + web: WebSettings = Field(default_factory=WebSettings) + chroma_db: ChromaDbSettings = Field(default_factory=ChromaDbSettings) + + +# --- Configuration Loading Function --- +def load_settings(config_path: Path = Path("config.yaml")) -> AppSettings: + """ + Loads settings from a YAML file and validates them using Pydantic models. + + Args: + config_path: The path to the configuration YAML file. + + Returns: + An instance of AppSettings containing the loaded configuration. + + Raises: + FileNotFoundError: If the config file does not exist. + yaml.YAMLError: If the file is not valid YAML. + ValidationError: If the data in the file doesn't match the AppSettings model. + """ + if not config_path.is_file(): + print(f"Error: Configuration file not found at '{config_path}'", file=sys.stderr) + raise FileNotFoundError(f"Configuration file not found: {config_path}") + + print(f"--- Loading settings from '{config_path}' ---") + try: + with open(config_path, "r", encoding="utf-8") as f: + config_data = yaml.safe_load(f) + if config_data is None: + config_data = {} + + settings = AppSettings(**config_data) + print("--- Settings loaded and validated successfully ---") + return settings + + except yaml.YAMLError as e: + print(f"Error parsing YAML file '{config_path}':\n {e}", file=sys.stderr) + raise + except ValidationError as e: + print(f"Error validating configuration from '{config_path}':\n{e}", file=sys.stderr) + raise + except Exception as e: + print(f"An unexpected error occurred while loading settings from '{config_path}': {e}", file=sys.stderr) + raise diff --git a/generic_rag/parsers/parser.py b/generic_rag/parsers/parser.py index ef3b374..be5bd28 100644 --- a/generic_rag/parsers/parser.py +++ b/generic_rag/parsers/parser.py @@ -32,7 +32,7 @@ def add_urls(vector_store: Chroma, urls: list[str], chunk_size: int) -> None: The URL's will be fetched and split into chunks of text with the provided chunk size. """ logger.info("Web sources to the vector store.") - + all_splits = [] for url in urls: if len(vector_store.get(where={"source": url}, limit=1)["ids"]) > 0: @@ -87,7 +87,6 @@ def add_pdf_files( The PDF file will be parsed per page and split into chunks of text with the provided chunk size and overlap. """ logger.info("Adding PDF files to the vector store.") - pdf_files = get_all_local_pdf_files(file_paths) logger.info(f"Found {len(pdf_files)} PDF files to add to the vector store.") @@ -100,8 +99,8 @@ def add_pdf_files( if len(new_pdfs) == 0: return - - logger.info(f"{len(new_pdfs)} PDF's to add to the vector store.") + + logger.info(f"{len(new_pdfs)} PDF(s) to add to the vector store.") loaded_document = [] for file in new_pdfs: