Yaml parser using Pydantic classes

This commit is contained in:
Ruben Lucas 2025-04-15 16:22:55 +02:00
parent 59ab43d31e
commit 996f3bf7a2
5 changed files with 370 additions and 138 deletions

3
.gitignore vendored
View File

@ -167,3 +167,6 @@ chainlit.md
# Chroma DB # Chroma DB
.chroma_db/ .chroma_db/
# Settings
config.yaml

View File

@ -1,14 +1,15 @@
import argparse
import json import json
import logging import logging
import os import os
from pathlib import Path from pathlib import Path
import sys
import chainlit as cl import chainlit as cl
from chainlit.cli import run_chainlit from chainlit.cli import run_chainlit
from langchain_chroma import Chroma from langchain_chroma import Chroma
from generic_rag.backend.models import ChatBackend, EmbeddingBackend, get_chat_model, get_embedding_model from generic_rag.parsers.config import AppSettings, load_settings
from generic_rag.backend.models import get_chat_model, get_embedding_model
from generic_rag.graphs.cond_ret_gen import CondRetGenLangGraph from generic_rag.graphs.cond_ret_gen import CondRetGenLangGraph
from generic_rag.graphs.ret_gen import RetGenLangGraph from generic_rag.graphs.ret_gen import RetGenLangGraph
from generic_rag.parsers.parser import add_pdf_files, add_urls from generic_rag.parsers.parser import add_pdf_files, add_urls
@ -23,85 +24,36 @@ system_prompt = (
"If you don't know the answer, say that you don't know." "If you don't know the answer, say that you don't know."
) )
parser = argparse.ArgumentParser(description="A Sogeti Nederland Generic RAG demo.") CONFIG_FILE_PATH = Path("config.yaml")
parser.add_argument(
"-c", try:
"--chat-backend", settings: AppSettings = load_settings(CONFIG_FILE_PATH)
type=ChatBackend, except (FileNotFoundError, Exception) as e:
choices=list(ChatBackend), logger.error(f"Failed to load configuration from {CONFIG_FILE_PATH}. Exiting.")
default=ChatBackend.local, sys.exit(1)
help="Cloud provider or local LLM to use as backend. In the case of 'local', Ollama needs to be installed.",
) embedding_function = get_embedding_model(settings)
parser.add_argument(
"-e", chat_function = get_chat_model(settings)
"--emb-backend",
type=EmbeddingBackend,
choices=list(EmbeddingBackend),
default=EmbeddingBackend.huggingface,
help="Cloud provider or local embedding to use as backend. In the case of 'local', Ollama needs to be installed. ",
)
parser.add_argument(
"-p",
"--pdf-data",
type=Path,
nargs="+",
default=[],
help="One or multiple paths to folders or files to use for retrieval. "
"If a path is a folder, all files in the folder will be used. "
"If a path is a file, only that file will be used. "
"If the path is relative it will be relative to the current working directory.",
)
parser.add_argument(
"-u",
"--unstructured-pdf",
action="store_true",
help="Use an unstructered PDF loader. "
"An unstructured PDF loader might be usefull for PDF files "
"that contain a lot of images with text, tables or (scanned) text as images. "
"Please use '-r' when switching parsers on already indexed data.",
)
parser.add_argument("--pdf-chunk_size", type=int, default=1000, help="The size of the chunks to split the text into.")
parser.add_argument("--pdf-chunk_overlap", type=int, default=200, help="The overlap between the chunks.")
parser.add_argument(
"--pdf-add-start-index", action="store_true", help="Add the start index to the metadata of the chunks."
)
parser.add_argument(
"-w", "--web-data", type=str, nargs="*", default=[], help="One or multiple URLs to use for retrieval."
)
parser.add_argument("--web-chunk-size", type=int, default=200, help="The size of the chunks to split the text into.")
parser.add_argument(
"-d",
"--chroma-db-location",
type=Path,
default=Path(".chroma_db"),
help="File path to store or load a Chroma DB from/to.",
)
parser.add_argument("-r", "--reset-chrome-db", action="store_true", help="Reset the Chroma DB.")
parser.add_argument(
"--use-conditional-graph",
action="store_true",
help="Use the conditial retrieve generate graph over the regular retrieve generate graph.",
)
args = parser.parse_args()
vector_store = Chroma( vector_store = Chroma(
collection_name="generic_rag", collection_name="generic_rag",
embedding_function=get_embedding_model(args.emb_backend), embedding_function=embedding_function,
persist_directory=str(args.chroma_db_location), persist_directory=str(settings.chroma_db.location),
) )
if args.use_conditional_graph: if settings.use_conditional_graph:
graph = CondRetGenLangGraph( graph = CondRetGenLangGraph(
vector_store=vector_store, vector_store=vector_store,
chat_model=get_chat_model(args.chat_backend), chat_model=chat_function,
embedding_model=get_embedding_model(args.emb_backend), embedding_model=embedding_function,
system_prompt=system_prompt, system_prompt=system_prompt,
) )
else: else:
graph = RetGenLangGraph( graph = RetGenLangGraph(
vector_store=vector_store, vector_store=vector_store,
chat_model=get_chat_model(args.chat_backend), chat_model=chat_function,
embedding_model=get_embedding_model(args.emb_backend), embedding_model=embedding_function,
system_prompt=system_prompt, system_prompt=system_prompt,
) )
@ -129,7 +81,9 @@ async def add_sources(chainlit_response: cl.Message, pdf_sources: dict, web_sour
for source, page_numbers in pdf_sources.items(): for source, page_numbers in pdf_sources.items():
filename = Path(source).name filename = Path(source).name
await chainlit_response.stream_token(f"- {filename} on page(s): {sorted(page_numbers)}\n") await chainlit_response.stream_token(f"- {filename} on page(s): {sorted(page_numbers)}\n")
chainlit_response.elements.append(cl.Pdf(name=filename, display="side", path=source, page=sorted(page_numbers)[0])) chainlit_response.elements.append(
cl.Pdf(name=filename, display="side", path=source, page=sorted(page_numbers)[0])
)
if len(web_sources) > 0: if len(web_sources) > 0:
await chainlit_response.stream_token("\n\nThe following web sources were consulted:\n") await chainlit_response.stream_token("\n\nThe following web sources were consulted:\n")
@ -159,17 +113,21 @@ async def set_starters():
if __name__ == "__main__": if __name__ == "__main__":
if args.reset_chrome_db: if settings.chroma_db.reset:
vector_store.reset_collection() vector_store.reset_collection()
add_pdf_files( add_pdf_files(
vector_store, vector_store,
args.pdf_data, settings.pdf.data,
args.pdf_chunk_size, settings.pdf.chunk_size,
args.pdf_chunk_overlap, settings.pdf.chunk_overlap,
args.pdf_add_start_index, settings.pdf.add_start_index,
args.unstructured_pdf, settings.pdf.unstructured,
)
add_urls(
vector_store,
settings.web.data,
settings.web.chunk_size,
) )
add_urls(vector_store, args.web_data, args.web_chunk_size)
run_chainlit(__file__) run_chainlit(__file__)

View File

@ -1,85 +1,180 @@
import os import logging
from enum import Enum
from langchain.chat_models import init_chat_model from generic_rag.parsers.config import AppSettings, ChatBackend, EmbeddingBackend
from langchain_aws import BedrockEmbeddings
# Langchain imports
from langchain_core.embeddings import Embeddings from langchain_core.embeddings import Embeddings
from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.language_models.chat_models import BaseChatModel
from langchain_google_vertexai import VertexAIEmbeddings from langchain_aws import BedrockEmbeddings, ChatBedrock # Import ChatBedrock
from langchain_google_vertexai import VertexAIEmbeddings, ChatVertexAI # Import ChatVertexAI
from langchain_huggingface import HuggingFaceEmbeddings from langchain_huggingface import HuggingFaceEmbeddings
from langchain_ollama import ChatOllama, OllamaEmbeddings from langchain_ollama import ChatOllama, OllamaEmbeddings
from langchain_openai import AzureChatOpenAI, AzureOpenAIEmbeddings, OpenAIEmbeddings from langchain_openai import AzureChatOpenAI, AzureOpenAIEmbeddings, ChatOpenAI, OpenAIEmbeddings # Import ChatOpenAI
logger = logging.getLogger(__name__)
class ChatBackend(Enum): def get_chat_model(settings: AppSettings) -> BaseChatModel:
azure = "azure" """
openai = "openai" Initializes and returns a chat model based on the backend type and configuration.
google_vertex = "google_vertex"
aws = "aws"
local = "local"
# make the enum pretty printable for argparse Args:
def __str__(self): settings: The loaded AppSettings object containing configurations.
return self.value
Returns:
An instance of BaseChatModel.
class EmbeddingBackend(Enum): Raises:
azure = "azure" ValueError: If the backend type is unknown or required configuration is missing.
openai = "openai" """
google_vertex = "google_vertex" logger.info(f"Initializing chat model for backend: {settings.chat_backend.value}")
aws = "aws"
local = "local"
huggingface = "huggingface"
# make the enum pretty printable for argparse if settings.chat_backend == ChatBackend.azure:
def __str__(self): if not settings.azure:
return self.value raise ValueError("Azure chat backend selected, but 'azure' configuration section is missing in config.")
if (
not settings.azure.llm_endpoint
def get_chat_model(backend_type: ChatBackend) -> BaseChatModel: or not settings.azure.llm_deployment_name
if backend_type == ChatBackend.azure: or not settings.azure.llm_api_version
):
raise ValueError(
"Azure configuration requires 'llm_endpoint', 'llm_deployment_name', and 'llm_api_version'."
)
return AzureChatOpenAI( return AzureChatOpenAI(
azure_endpoint=os.environ["AZURE_LLM_ENDPOINT"], azure_endpoint=settings.azure.llm_endpoint,
azure_deployment=os.environ["AZURE_LLM_DEPLOYMENT_NAME"], azure_deployment=settings.azure.llm_deployment_name,
openai_api_version=os.environ["AZURE_LLM_API_VERSION"], openai_api_version=settings.azure.llm_api_version,
openai_api_key=settings.azure.openai_api_key.get_secret_value() if settings.azure.openai_api_key else None,
) )
if backend_type == ChatBackend.openai: if settings.chat_backend == ChatBackend.openai:
return init_chat_model(os.environ["OPENAI_CHAT_MODEL"], model_provider="openai") if not settings.openai:
raise ValueError("OpenAI chat backend selected, but 'openai' configuration section is missing.")
if not settings.openai.api_key or not settings.openai.chat_model:
raise ValueError("OpenAI configuration requires 'api_key' and 'chat_model'.")
logger.info(f"Using OpenAI model: {model_name}")
return ChatOpenAI(model=settings.openai.chat_model, openai_api_key=settings.openai.api_key.get_secret_value())
if backend_type == ChatBackend.google_vertex: if settings.chat_backend == ChatBackend.google_vertex:
return init_chat_model(os.environ["GOOGLE_CHAT_MODEL"], model_provider="google_vertexai") if not settings.google:
raise ValueError("Google Vertex chat backend selected, but 'google' configuration section is missing.")
if settings.google.chat_model:
model_name = settings.google.chat_model
logger.info(f"Using Google Vertex model: {model_name}")
return ChatVertexAI(
model_name=settings.google.chat_model,
project=settings.google.project_id,
location=settings.google.location,
)
if backend_type == ChatBackend.aws: if settings.chat_backend == ChatBackend.aws:
return init_chat_model(model=os.environ["AWS_CHAT_MODEL"], model_provider="bedrock_converse") if not settings.aws:
raise ValueError("AWS Bedrock chat backend selected, but 'aws' configuration section is missing.")
model_name = "anthropic.claude-v2" # Example default
if hasattr(settings.aws, "chat_model") and settings.aws.chat_model:
model_name = settings.aws.chat_model
logger.info(f"Using AWS Bedrock model: {model_name}")
return ChatBedrock(
model_id=model_name,
region_name=settings.aws.region_name,
)
if backend_type == ChatBackend.local: if settings.chat_backend == ChatBackend.local:
return ChatOllama(model=os.environ["LOCAL_CHAT_MODEL"]) if not settings.local or not settings.local.chat_model:
raise ValueError("Local chat backend selected, but 'local.chat_model' is missing in config.")
logger.info(f"Using Local Ollama model: {settings.local.chat_model}")
# Base URL can also be configured, e.g., base_url=config.local.ollama_base_url
return ChatOllama(model=settings.local.chat_model)
raise ValueError(f"Unknown backend type: {backend_type}") # This should not be reached if all Enum members are handled
raise ValueError(f"Unknown or unhandled chat backend type: {settings.chat_backend}")
def get_embedding_model(backend_type: EmbeddingBackend) -> Embeddings: def get_embedding_model(settings: AppSettings) -> Embeddings:
if backend_type == EmbeddingBackend.azure: """
Initializes and returns an embedding model based on the backend type and configuration.
Args:
settings: The loaded AppSettings object containing configurations.
Returns:
An instance of Embeddings.
Raises:
ValueError: If the backend type is unknown or required configuration is missing.
"""
logger.info(f"Initializing embedding model for backend: {settings.emb_backend.value}")
if settings.emb_backend == EmbeddingBackend.azure:
if not settings.azure:
raise ValueError("Azure embedding backend selected, but 'azure' configuration section is missing.")
if (
not settings.azure.emb_endpoint
or not settings.azure.emb_deployment_name
or not settings.azure.emb_api_version
):
raise ValueError(
"Azure configuration requires 'emb_endpoint', 'emb_deployment_name', and 'emb_api_version'."
)
return AzureOpenAIEmbeddings( return AzureOpenAIEmbeddings(
azure_endpoint=os.environ["AZURE_EMB_ENDPOINT"], azure_endpoint=settings.azure.emb_endpoint,
azure_deployment=os.environ["AZURE_EMB_DEPLOYMENT_NAME"], azure_deployment=settings.azure.emb_deployment_name,
openai_api_version=os.environ["AZURE_EMB_API_VERSION"], openai_api_version=settings.azure.emb_api_version,
openai_api_key=settings.azure.openai_api_key.get_secret_value() if settings.azure.openai_api_key else None,
) )
if backend_type == EmbeddingBackend.openai: if settings.emb_backend == EmbeddingBackend.openai:
return OpenAIEmbeddings(model=os.environ["OPENAI_EMB_MODEL"]) if not settings.openai:
raise ValueError("OpenAI embedding backend selected, but 'openai' configuration section is missing.")
if not settings.openai.api_key:
raise ValueError("OpenAI configuration requires 'api_key'.")
model_name = "text-embedding-ada-002" # Example default
if hasattr(settings.openai, "emb_model") and settings.openai.emb_model:
model_name = settings.openai.emb_model
logger.info(f"Using OpenAI embedding model: {model_name}")
return OpenAIEmbeddings(model=model_name, openai_api_key=settings.openai.api_key.get_secret_value())
if backend_type == EmbeddingBackend.google_vertex: if settings.emb_backend == EmbeddingBackend.google_vertex:
return VertexAIEmbeddings(model=os.environ["GOOGLE_EMB_MODEL"]) if not settings.google:
raise ValueError("Google Vertex embedding backend selected, but 'google' configuration section is missing.")
model_name = "textembedding-gecko@001" # Example default
if settings.google.emb_model:
model_name = settings.google.emb_model
logger.info(f"Using Google Vertex embedding model: {model_name}")
return VertexAIEmbeddings(
model_name=model_name, project=settings.google.project_id, location=settings.google.location
)
if backend_type == EmbeddingBackend.aws: if settings.emb_backend == EmbeddingBackend.aws:
return BedrockEmbeddings(model_id=os.environ["AWS_EMB_MODEL"]) if not settings.aws:
raise ValueError("AWS Bedrock embedding backend selected, but 'aws' configuration section is missing.")
model_name = "amazon.titan-embed-text-v1" # Example default
if hasattr(settings.aws, "emb_model") and settings.aws.emb_model:
model_name = settings.aws.emb_model
logger.info(f"Using AWS Bedrock embedding model: {model_name}")
return BedrockEmbeddings(model_id=model_name, region_name=settings.aws.region_name)
if backend_type == EmbeddingBackend.local: if settings.emb_backend == EmbeddingBackend.local:
return OllamaEmbeddings(model=os.environ["LOCAL_EMB_MODEL"]) if not settings.local or not settings.local.emb_model:
raise ValueError("Local embedding backend selected, but 'local.emb_model' is missing in config.")
logger.info(f"Using Local Ollama embedding model: {settings.local.emb_model}")
return OllamaEmbeddings(model=settings.local.emb_model)
if backend_type == EmbeddingBackend.huggingface: if settings.emb_backend == EmbeddingBackend.huggingface:
return HuggingFaceEmbeddings(model_name=os.environ["HUGGINGFACE_EMB_MODEL"]) if not settings.huggingface or not settings.huggingface.emb_model:
if settings.local and settings.local.emb_model:
logger.warning(
"HuggingFace backend selected, but 'huggingface.emb_model' missing. Using 'local.emb_model'."
)
model_name = settings.local.emb_model
else:
raise ValueError(
"HuggingFace embedding backend selected, but 'huggingface.emb_model' (or 'local.emb_model') is missing in config."
)
else:
model_name = settings.huggingface.emb_model
raise ValueError(f"Unknown backend type: {backend_type}") logger.info(f"Using HuggingFace embedding model: {model_name}")
return HuggingFaceEmbeddings(model_name=model_name)
raise ValueError(f"Unknown or unhandled embedding backend type: {settings.emb_backend}")

View File

@ -0,0 +1,177 @@
import yaml
from pathlib import Path
from typing import List, Optional
from enum import Enum
from pydantic import (
BaseModel,
Field,
ValidationError,
SecretStr,
)
import sys
class ChatBackend(str, Enum):
azure = "azure"
openai = "openai"
google_vertex = "google_vertex"
aws = "aws"
local = "local"
def __str__(self):
return self.value
class EmbeddingBackend(str, Enum):
azure = "azure"
openai = "openai"
google_vertex = "google_vertex"
aws = "aws"
local = "local"
huggingface = "huggingface"
def __str__(self):
return self.value
class AzureSettings(BaseModel):
"""Azure specific settings."""
openai_api_key: Optional[SecretStr] = None
llm_endpoint: Optional[str] = None
llm_deployment_name: Optional[str] = None
llm_api_version: Optional[str] = None
emb_endpoint: Optional[str] = None
emb_deployment_name: Optional[str] = None
emb_api_version: Optional[str] = None
class OpenAISettings(BaseModel):
"""OpenAI specific settings."""
api_key: Optional[SecretStr] = None
class GoogleSettings(BaseModel):
"""Google specific settings (Vertex AI or GenAI)."""
api_key: Optional[SecretStr] = None
project_id: Optional[str] = None
location: Optional[str] = None
chat_model: Optional[str] = None
emb_model: Optional[str] = None
class AwsSettings(BaseModel):
"""AWS specific settings (e.g., for Bedrock)."""
access_key_id: Optional[SecretStr] = None
secret_access_key: Optional[SecretStr] = None
region_name: Optional[str] = None
class LocalSettings(BaseModel):
"""Local backend specific settings (e.g., Ollama models)."""
chat_model: Optional[str] = None
emb_model: Optional[str] = None
class HuggingFaceSettings(BaseModel):
"""HuggingFace specific settings (if different from local embeddings)."""
emb_model: Optional[str] = None
api_token: Optional[SecretStr] = None
class PdfSettings(BaseModel):
"""PDF processing settings."""
data: List[Path] = Field(default_factory=list)
unstructured: bool = Field(default=False)
chunk_size: int = Field(default=1000)
chunk_overlap: int = Field(default=200)
add_start_index: bool = Field(default=False)
class WebSettings(BaseModel):
"""Web data processing settings."""
data: List[str] = Field(default_factory=list)
chunk_size: int = Field(default=200)
class ChromaDbSettings(BaseModel):
"""Chroma DB settings."""
location: Path = Field(default=Path(".chroma_db"))
reset: bool = Field(default=False)
class AppSettings(BaseModel):
"""
Main application settings model.
Loads configuration from a YAML file using the structure defined
by the nested models.
"""
# --- Top-level settings ---
chat_backend: ChatBackend = Field(default=ChatBackend.local)
emb_backend: EmbeddingBackend = Field(default=EmbeddingBackend.huggingface)
use_conditional_graph: bool = Field(default=False)
# --- Provider-specific settings ---
azure: Optional[AzureSettings] = None
openai: Optional[OpenAISettings] = None
google: Optional[GoogleSettings] = None
aws: Optional[AwsSettings] = None
local: Optional[LocalSettings] = None
huggingface: Optional[HuggingFaceSettings] = None # Separate HF config if needed
# --- Data processing settings ---
pdf: PdfSettings = Field(default_factory=PdfSettings)
web: WebSettings = Field(default_factory=WebSettings)
chroma_db: ChromaDbSettings = Field(default_factory=ChromaDbSettings)
# --- Configuration Loading Function ---
def load_settings(config_path: Path = Path("config.yaml")) -> AppSettings:
"""
Loads settings from a YAML file and validates them using Pydantic models.
Args:
config_path: The path to the configuration YAML file.
Returns:
An instance of AppSettings containing the loaded configuration.
Raises:
FileNotFoundError: If the config file does not exist.
yaml.YAMLError: If the file is not valid YAML.
ValidationError: If the data in the file doesn't match the AppSettings model.
"""
if not config_path.is_file():
print(f"Error: Configuration file not found at '{config_path}'", file=sys.stderr)
raise FileNotFoundError(f"Configuration file not found: {config_path}")
print(f"--- Loading settings from '{config_path}' ---")
try:
with open(config_path, "r", encoding="utf-8") as f:
config_data = yaml.safe_load(f)
if config_data is None:
config_data = {}
settings = AppSettings(**config_data)
print("--- Settings loaded and validated successfully ---")
return settings
except yaml.YAMLError as e:
print(f"Error parsing YAML file '{config_path}':\n {e}", file=sys.stderr)
raise
except ValidationError as e:
print(f"Error validating configuration from '{config_path}':\n{e}", file=sys.stderr)
raise
except Exception as e:
print(f"An unexpected error occurred while loading settings from '{config_path}': {e}", file=sys.stderr)
raise

View File

@ -87,7 +87,6 @@ def add_pdf_files(
The PDF file will be parsed per page and split into chunks of text with the provided chunk size and overlap. The PDF file will be parsed per page and split into chunks of text with the provided chunk size and overlap.
""" """
logger.info("Adding PDF files to the vector store.") logger.info("Adding PDF files to the vector store.")
pdf_files = get_all_local_pdf_files(file_paths) pdf_files = get_all_local_pdf_files(file_paths)
logger.info(f"Found {len(pdf_files)} PDF files to add to the vector store.") logger.info(f"Found {len(pdf_files)} PDF files to add to the vector store.")
@ -101,7 +100,7 @@ def add_pdf_files(
if len(new_pdfs) == 0: if len(new_pdfs) == 0:
return return
logger.info(f"{len(new_pdfs)} PDF's to add to the vector store.") logger.info(f"{len(new_pdfs)} PDF(s) to add to the vector store.")
loaded_document = [] loaded_document = []
for file in new_pdfs: for file in new_pdfs: