forked from AI_team/Philosophy-RAG-demo
✨ Yaml parser using Pydantic classes
This commit is contained in:
parent
59ab43d31e
commit
996f3bf7a2
3
.gitignore
vendored
3
.gitignore
vendored
@ -167,3 +167,6 @@ chainlit.md
|
||||
|
||||
# Chroma DB
|
||||
.chroma_db/
|
||||
|
||||
# Settings
|
||||
config.yaml
|
||||
@ -1,14 +1,15 @@
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
import sys
|
||||
|
||||
import chainlit as cl
|
||||
from chainlit.cli import run_chainlit
|
||||
from langchain_chroma import Chroma
|
||||
|
||||
from generic_rag.backend.models import ChatBackend, EmbeddingBackend, get_chat_model, get_embedding_model
|
||||
from generic_rag.parsers.config import AppSettings, load_settings
|
||||
from generic_rag.backend.models import get_chat_model, get_embedding_model
|
||||
from generic_rag.graphs.cond_ret_gen import CondRetGenLangGraph
|
||||
from generic_rag.graphs.ret_gen import RetGenLangGraph
|
||||
from generic_rag.parsers.parser import add_pdf_files, add_urls
|
||||
@ -23,85 +24,36 @@ system_prompt = (
|
||||
"If you don't know the answer, say that you don't know."
|
||||
)
|
||||
|
||||
parser = argparse.ArgumentParser(description="A Sogeti Nederland Generic RAG demo.")
|
||||
parser.add_argument(
|
||||
"-c",
|
||||
"--chat-backend",
|
||||
type=ChatBackend,
|
||||
choices=list(ChatBackend),
|
||||
default=ChatBackend.local,
|
||||
help="Cloud provider or local LLM to use as backend. In the case of 'local', Ollama needs to be installed.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-e",
|
||||
"--emb-backend",
|
||||
type=EmbeddingBackend,
|
||||
choices=list(EmbeddingBackend),
|
||||
default=EmbeddingBackend.huggingface,
|
||||
help="Cloud provider or local embedding to use as backend. In the case of 'local', Ollama needs to be installed. ",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-p",
|
||||
"--pdf-data",
|
||||
type=Path,
|
||||
nargs="+",
|
||||
default=[],
|
||||
help="One or multiple paths to folders or files to use for retrieval. "
|
||||
"If a path is a folder, all files in the folder will be used. "
|
||||
"If a path is a file, only that file will be used. "
|
||||
"If the path is relative it will be relative to the current working directory.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-u",
|
||||
"--unstructured-pdf",
|
||||
action="store_true",
|
||||
help="Use an unstructered PDF loader. "
|
||||
"An unstructured PDF loader might be usefull for PDF files "
|
||||
"that contain a lot of images with text, tables or (scanned) text as images. "
|
||||
"Please use '-r' when switching parsers on already indexed data.",
|
||||
)
|
||||
parser.add_argument("--pdf-chunk_size", type=int, default=1000, help="The size of the chunks to split the text into.")
|
||||
parser.add_argument("--pdf-chunk_overlap", type=int, default=200, help="The overlap between the chunks.")
|
||||
parser.add_argument(
|
||||
"--pdf-add-start-index", action="store_true", help="Add the start index to the metadata of the chunks."
|
||||
)
|
||||
parser.add_argument(
|
||||
"-w", "--web-data", type=str, nargs="*", default=[], help="One or multiple URLs to use for retrieval."
|
||||
)
|
||||
parser.add_argument("--web-chunk-size", type=int, default=200, help="The size of the chunks to split the text into.")
|
||||
parser.add_argument(
|
||||
"-d",
|
||||
"--chroma-db-location",
|
||||
type=Path,
|
||||
default=Path(".chroma_db"),
|
||||
help="File path to store or load a Chroma DB from/to.",
|
||||
)
|
||||
parser.add_argument("-r", "--reset-chrome-db", action="store_true", help="Reset the Chroma DB.")
|
||||
parser.add_argument(
|
||||
"--use-conditional-graph",
|
||||
action="store_true",
|
||||
help="Use the conditial retrieve generate graph over the regular retrieve generate graph.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
CONFIG_FILE_PATH = Path("config.yaml")
|
||||
|
||||
try:
|
||||
settings: AppSettings = load_settings(CONFIG_FILE_PATH)
|
||||
except (FileNotFoundError, Exception) as e:
|
||||
logger.error(f"Failed to load configuration from {CONFIG_FILE_PATH}. Exiting.")
|
||||
sys.exit(1)
|
||||
|
||||
embedding_function = get_embedding_model(settings)
|
||||
|
||||
chat_function = get_chat_model(settings)
|
||||
|
||||
vector_store = Chroma(
|
||||
collection_name="generic_rag",
|
||||
embedding_function=get_embedding_model(args.emb_backend),
|
||||
persist_directory=str(args.chroma_db_location),
|
||||
embedding_function=embedding_function,
|
||||
persist_directory=str(settings.chroma_db.location),
|
||||
)
|
||||
|
||||
if args.use_conditional_graph:
|
||||
if settings.use_conditional_graph:
|
||||
graph = CondRetGenLangGraph(
|
||||
vector_store=vector_store,
|
||||
chat_model=get_chat_model(args.chat_backend),
|
||||
embedding_model=get_embedding_model(args.emb_backend),
|
||||
chat_model=chat_function,
|
||||
embedding_model=embedding_function,
|
||||
system_prompt=system_prompt,
|
||||
)
|
||||
else:
|
||||
graph = RetGenLangGraph(
|
||||
vector_store=vector_store,
|
||||
chat_model=get_chat_model(args.chat_backend),
|
||||
embedding_model=get_embedding_model(args.emb_backend),
|
||||
chat_model=chat_function,
|
||||
embedding_model=embedding_function,
|
||||
system_prompt=system_prompt,
|
||||
)
|
||||
|
||||
@ -129,7 +81,9 @@ async def add_sources(chainlit_response: cl.Message, pdf_sources: dict, web_sour
|
||||
for source, page_numbers in pdf_sources.items():
|
||||
filename = Path(source).name
|
||||
await chainlit_response.stream_token(f"- {filename} on page(s): {sorted(page_numbers)}\n")
|
||||
chainlit_response.elements.append(cl.Pdf(name=filename, display="side", path=source, page=sorted(page_numbers)[0]))
|
||||
chainlit_response.elements.append(
|
||||
cl.Pdf(name=filename, display="side", path=source, page=sorted(page_numbers)[0])
|
||||
)
|
||||
|
||||
if len(web_sources) > 0:
|
||||
await chainlit_response.stream_token("\n\nThe following web sources were consulted:\n")
|
||||
@ -159,17 +113,21 @@ async def set_starters():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if args.reset_chrome_db:
|
||||
if settings.chroma_db.reset:
|
||||
vector_store.reset_collection()
|
||||
|
||||
add_pdf_files(
|
||||
vector_store,
|
||||
args.pdf_data,
|
||||
args.pdf_chunk_size,
|
||||
args.pdf_chunk_overlap,
|
||||
args.pdf_add_start_index,
|
||||
args.unstructured_pdf,
|
||||
settings.pdf.data,
|
||||
settings.pdf.chunk_size,
|
||||
settings.pdf.chunk_overlap,
|
||||
settings.pdf.add_start_index,
|
||||
settings.pdf.unstructured,
|
||||
)
|
||||
add_urls(
|
||||
vector_store,
|
||||
settings.web.data,
|
||||
settings.web.chunk_size,
|
||||
)
|
||||
add_urls(vector_store, args.web_data, args.web_chunk_size)
|
||||
|
||||
run_chainlit(__file__)
|
||||
|
||||
@ -1,85 +1,180 @@
|
||||
import os
|
||||
from enum import Enum
|
||||
import logging
|
||||
|
||||
from langchain.chat_models import init_chat_model
|
||||
from langchain_aws import BedrockEmbeddings
|
||||
from generic_rag.parsers.config import AppSettings, ChatBackend, EmbeddingBackend
|
||||
|
||||
# Langchain imports
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_google_vertexai import VertexAIEmbeddings
|
||||
from langchain_aws import BedrockEmbeddings, ChatBedrock # Import ChatBedrock
|
||||
from langchain_google_vertexai import VertexAIEmbeddings, ChatVertexAI # Import ChatVertexAI
|
||||
from langchain_huggingface import HuggingFaceEmbeddings
|
||||
from langchain_ollama import ChatOllama, OllamaEmbeddings
|
||||
from langchain_openai import AzureChatOpenAI, AzureOpenAIEmbeddings, OpenAIEmbeddings
|
||||
from langchain_openai import AzureChatOpenAI, AzureOpenAIEmbeddings, ChatOpenAI, OpenAIEmbeddings # Import ChatOpenAI
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ChatBackend(Enum):
|
||||
azure = "azure"
|
||||
openai = "openai"
|
||||
google_vertex = "google_vertex"
|
||||
aws = "aws"
|
||||
local = "local"
|
||||
def get_chat_model(settings: AppSettings) -> BaseChatModel:
|
||||
"""
|
||||
Initializes and returns a chat model based on the backend type and configuration.
|
||||
|
||||
# make the enum pretty printable for argparse
|
||||
def __str__(self):
|
||||
return self.value
|
||||
Args:
|
||||
settings: The loaded AppSettings object containing configurations.
|
||||
|
||||
Returns:
|
||||
An instance of BaseChatModel.
|
||||
|
||||
class EmbeddingBackend(Enum):
|
||||
azure = "azure"
|
||||
openai = "openai"
|
||||
google_vertex = "google_vertex"
|
||||
aws = "aws"
|
||||
local = "local"
|
||||
huggingface = "huggingface"
|
||||
Raises:
|
||||
ValueError: If the backend type is unknown or required configuration is missing.
|
||||
"""
|
||||
logger.info(f"Initializing chat model for backend: {settings.chat_backend.value}")
|
||||
|
||||
# make the enum pretty printable for argparse
|
||||
def __str__(self):
|
||||
return self.value
|
||||
|
||||
|
||||
def get_chat_model(backend_type: ChatBackend) -> BaseChatModel:
|
||||
if backend_type == ChatBackend.azure:
|
||||
if settings.chat_backend == ChatBackend.azure:
|
||||
if not settings.azure:
|
||||
raise ValueError("Azure chat backend selected, but 'azure' configuration section is missing in config.")
|
||||
if (
|
||||
not settings.azure.llm_endpoint
|
||||
or not settings.azure.llm_deployment_name
|
||||
or not settings.azure.llm_api_version
|
||||
):
|
||||
raise ValueError(
|
||||
"Azure configuration requires 'llm_endpoint', 'llm_deployment_name', and 'llm_api_version'."
|
||||
)
|
||||
return AzureChatOpenAI(
|
||||
azure_endpoint=os.environ["AZURE_LLM_ENDPOINT"],
|
||||
azure_deployment=os.environ["AZURE_LLM_DEPLOYMENT_NAME"],
|
||||
openai_api_version=os.environ["AZURE_LLM_API_VERSION"],
|
||||
azure_endpoint=settings.azure.llm_endpoint,
|
||||
azure_deployment=settings.azure.llm_deployment_name,
|
||||
openai_api_version=settings.azure.llm_api_version,
|
||||
openai_api_key=settings.azure.openai_api_key.get_secret_value() if settings.azure.openai_api_key else None,
|
||||
)
|
||||
|
||||
if backend_type == ChatBackend.openai:
|
||||
return init_chat_model(os.environ["OPENAI_CHAT_MODEL"], model_provider="openai")
|
||||
if settings.chat_backend == ChatBackend.openai:
|
||||
if not settings.openai:
|
||||
raise ValueError("OpenAI chat backend selected, but 'openai' configuration section is missing.")
|
||||
if not settings.openai.api_key or not settings.openai.chat_model:
|
||||
raise ValueError("OpenAI configuration requires 'api_key' and 'chat_model'.")
|
||||
logger.info(f"Using OpenAI model: {model_name}")
|
||||
return ChatOpenAI(model=settings.openai.chat_model, openai_api_key=settings.openai.api_key.get_secret_value())
|
||||
|
||||
if backend_type == ChatBackend.google_vertex:
|
||||
return init_chat_model(os.environ["GOOGLE_CHAT_MODEL"], model_provider="google_vertexai")
|
||||
if settings.chat_backend == ChatBackend.google_vertex:
|
||||
if not settings.google:
|
||||
raise ValueError("Google Vertex chat backend selected, but 'google' configuration section is missing.")
|
||||
if settings.google.chat_model:
|
||||
model_name = settings.google.chat_model
|
||||
logger.info(f"Using Google Vertex model: {model_name}")
|
||||
return ChatVertexAI(
|
||||
model_name=settings.google.chat_model,
|
||||
project=settings.google.project_id,
|
||||
location=settings.google.location,
|
||||
)
|
||||
|
||||
if backend_type == ChatBackend.aws:
|
||||
return init_chat_model(model=os.environ["AWS_CHAT_MODEL"], model_provider="bedrock_converse")
|
||||
if settings.chat_backend == ChatBackend.aws:
|
||||
if not settings.aws:
|
||||
raise ValueError("AWS Bedrock chat backend selected, but 'aws' configuration section is missing.")
|
||||
model_name = "anthropic.claude-v2" # Example default
|
||||
if hasattr(settings.aws, "chat_model") and settings.aws.chat_model:
|
||||
model_name = settings.aws.chat_model
|
||||
logger.info(f"Using AWS Bedrock model: {model_name}")
|
||||
return ChatBedrock(
|
||||
model_id=model_name,
|
||||
region_name=settings.aws.region_name,
|
||||
)
|
||||
|
||||
if backend_type == ChatBackend.local:
|
||||
return ChatOllama(model=os.environ["LOCAL_CHAT_MODEL"])
|
||||
if settings.chat_backend == ChatBackend.local:
|
||||
if not settings.local or not settings.local.chat_model:
|
||||
raise ValueError("Local chat backend selected, but 'local.chat_model' is missing in config.")
|
||||
logger.info(f"Using Local Ollama model: {settings.local.chat_model}")
|
||||
# Base URL can also be configured, e.g., base_url=config.local.ollama_base_url
|
||||
return ChatOllama(model=settings.local.chat_model)
|
||||
|
||||
raise ValueError(f"Unknown backend type: {backend_type}")
|
||||
# This should not be reached if all Enum members are handled
|
||||
raise ValueError(f"Unknown or unhandled chat backend type: {settings.chat_backend}")
|
||||
|
||||
|
||||
def get_embedding_model(backend_type: EmbeddingBackend) -> Embeddings:
|
||||
if backend_type == EmbeddingBackend.azure:
|
||||
def get_embedding_model(settings: AppSettings) -> Embeddings:
|
||||
"""
|
||||
Initializes and returns an embedding model based on the backend type and configuration.
|
||||
|
||||
Args:
|
||||
settings: The loaded AppSettings object containing configurations.
|
||||
|
||||
Returns:
|
||||
An instance of Embeddings.
|
||||
|
||||
Raises:
|
||||
ValueError: If the backend type is unknown or required configuration is missing.
|
||||
"""
|
||||
logger.info(f"Initializing embedding model for backend: {settings.emb_backend.value}")
|
||||
|
||||
if settings.emb_backend == EmbeddingBackend.azure:
|
||||
if not settings.azure:
|
||||
raise ValueError("Azure embedding backend selected, but 'azure' configuration section is missing.")
|
||||
if (
|
||||
not settings.azure.emb_endpoint
|
||||
or not settings.azure.emb_deployment_name
|
||||
or not settings.azure.emb_api_version
|
||||
):
|
||||
raise ValueError(
|
||||
"Azure configuration requires 'emb_endpoint', 'emb_deployment_name', and 'emb_api_version'."
|
||||
)
|
||||
return AzureOpenAIEmbeddings(
|
||||
azure_endpoint=os.environ["AZURE_EMB_ENDPOINT"],
|
||||
azure_deployment=os.environ["AZURE_EMB_DEPLOYMENT_NAME"],
|
||||
openai_api_version=os.environ["AZURE_EMB_API_VERSION"],
|
||||
azure_endpoint=settings.azure.emb_endpoint,
|
||||
azure_deployment=settings.azure.emb_deployment_name,
|
||||
openai_api_version=settings.azure.emb_api_version,
|
||||
openai_api_key=settings.azure.openai_api_key.get_secret_value() if settings.azure.openai_api_key else None,
|
||||
)
|
||||
|
||||
if backend_type == EmbeddingBackend.openai:
|
||||
return OpenAIEmbeddings(model=os.environ["OPENAI_EMB_MODEL"])
|
||||
if settings.emb_backend == EmbeddingBackend.openai:
|
||||
if not settings.openai:
|
||||
raise ValueError("OpenAI embedding backend selected, but 'openai' configuration section is missing.")
|
||||
if not settings.openai.api_key:
|
||||
raise ValueError("OpenAI configuration requires 'api_key'.")
|
||||
model_name = "text-embedding-ada-002" # Example default
|
||||
if hasattr(settings.openai, "emb_model") and settings.openai.emb_model:
|
||||
model_name = settings.openai.emb_model
|
||||
logger.info(f"Using OpenAI embedding model: {model_name}")
|
||||
return OpenAIEmbeddings(model=model_name, openai_api_key=settings.openai.api_key.get_secret_value())
|
||||
|
||||
if backend_type == EmbeddingBackend.google_vertex:
|
||||
return VertexAIEmbeddings(model=os.environ["GOOGLE_EMB_MODEL"])
|
||||
if settings.emb_backend == EmbeddingBackend.google_vertex:
|
||||
if not settings.google:
|
||||
raise ValueError("Google Vertex embedding backend selected, but 'google' configuration section is missing.")
|
||||
model_name = "textembedding-gecko@001" # Example default
|
||||
if settings.google.emb_model:
|
||||
model_name = settings.google.emb_model
|
||||
logger.info(f"Using Google Vertex embedding model: {model_name}")
|
||||
return VertexAIEmbeddings(
|
||||
model_name=model_name, project=settings.google.project_id, location=settings.google.location
|
||||
)
|
||||
|
||||
if backend_type == EmbeddingBackend.aws:
|
||||
return BedrockEmbeddings(model_id=os.environ["AWS_EMB_MODEL"])
|
||||
if settings.emb_backend == EmbeddingBackend.aws:
|
||||
if not settings.aws:
|
||||
raise ValueError("AWS Bedrock embedding backend selected, but 'aws' configuration section is missing.")
|
||||
model_name = "amazon.titan-embed-text-v1" # Example default
|
||||
if hasattr(settings.aws, "emb_model") and settings.aws.emb_model:
|
||||
model_name = settings.aws.emb_model
|
||||
logger.info(f"Using AWS Bedrock embedding model: {model_name}")
|
||||
return BedrockEmbeddings(model_id=model_name, region_name=settings.aws.region_name)
|
||||
|
||||
if backend_type == EmbeddingBackend.local:
|
||||
return OllamaEmbeddings(model=os.environ["LOCAL_EMB_MODEL"])
|
||||
if settings.emb_backend == EmbeddingBackend.local:
|
||||
if not settings.local or not settings.local.emb_model:
|
||||
raise ValueError("Local embedding backend selected, but 'local.emb_model' is missing in config.")
|
||||
logger.info(f"Using Local Ollama embedding model: {settings.local.emb_model}")
|
||||
return OllamaEmbeddings(model=settings.local.emb_model)
|
||||
|
||||
if backend_type == EmbeddingBackend.huggingface:
|
||||
return HuggingFaceEmbeddings(model_name=os.environ["HUGGINGFACE_EMB_MODEL"])
|
||||
if settings.emb_backend == EmbeddingBackend.huggingface:
|
||||
if not settings.huggingface or not settings.huggingface.emb_model:
|
||||
if settings.local and settings.local.emb_model:
|
||||
logger.warning(
|
||||
"HuggingFace backend selected, but 'huggingface.emb_model' missing. Using 'local.emb_model'."
|
||||
)
|
||||
model_name = settings.local.emb_model
|
||||
else:
|
||||
raise ValueError(
|
||||
"HuggingFace embedding backend selected, but 'huggingface.emb_model' (or 'local.emb_model') is missing in config."
|
||||
)
|
||||
else:
|
||||
model_name = settings.huggingface.emb_model
|
||||
|
||||
raise ValueError(f"Unknown backend type: {backend_type}")
|
||||
logger.info(f"Using HuggingFace embedding model: {model_name}")
|
||||
return HuggingFaceEmbeddings(model_name=model_name)
|
||||
|
||||
raise ValueError(f"Unknown or unhandled embedding backend type: {settings.emb_backend}")
|
||||
|
||||
177
generic_rag/parsers/config.py
Normal file
177
generic_rag/parsers/config.py
Normal file
@ -0,0 +1,177 @@
|
||||
import yaml
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
from enum import Enum
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
Field,
|
||||
ValidationError,
|
||||
SecretStr,
|
||||
)
|
||||
import sys
|
||||
|
||||
|
||||
class ChatBackend(str, Enum):
|
||||
azure = "azure"
|
||||
openai = "openai"
|
||||
google_vertex = "google_vertex"
|
||||
aws = "aws"
|
||||
local = "local"
|
||||
|
||||
def __str__(self):
|
||||
return self.value
|
||||
|
||||
|
||||
class EmbeddingBackend(str, Enum):
|
||||
azure = "azure"
|
||||
openai = "openai"
|
||||
google_vertex = "google_vertex"
|
||||
aws = "aws"
|
||||
local = "local"
|
||||
huggingface = "huggingface"
|
||||
|
||||
def __str__(self):
|
||||
return self.value
|
||||
|
||||
|
||||
class AzureSettings(BaseModel):
|
||||
"""Azure specific settings."""
|
||||
|
||||
openai_api_key: Optional[SecretStr] = None
|
||||
llm_endpoint: Optional[str] = None
|
||||
llm_deployment_name: Optional[str] = None
|
||||
llm_api_version: Optional[str] = None
|
||||
emb_endpoint: Optional[str] = None
|
||||
emb_deployment_name: Optional[str] = None
|
||||
emb_api_version: Optional[str] = None
|
||||
|
||||
|
||||
class OpenAISettings(BaseModel):
|
||||
"""OpenAI specific settings."""
|
||||
|
||||
api_key: Optional[SecretStr] = None
|
||||
|
||||
|
||||
class GoogleSettings(BaseModel):
|
||||
"""Google specific settings (Vertex AI or GenAI)."""
|
||||
|
||||
api_key: Optional[SecretStr] = None
|
||||
project_id: Optional[str] = None
|
||||
location: Optional[str] = None
|
||||
chat_model: Optional[str] = None
|
||||
emb_model: Optional[str] = None
|
||||
|
||||
|
||||
class AwsSettings(BaseModel):
|
||||
"""AWS specific settings (e.g., for Bedrock)."""
|
||||
|
||||
access_key_id: Optional[SecretStr] = None
|
||||
secret_access_key: Optional[SecretStr] = None
|
||||
region_name: Optional[str] = None
|
||||
|
||||
|
||||
class LocalSettings(BaseModel):
|
||||
"""Local backend specific settings (e.g., Ollama models)."""
|
||||
|
||||
chat_model: Optional[str] = None
|
||||
emb_model: Optional[str] = None
|
||||
|
||||
|
||||
class HuggingFaceSettings(BaseModel):
|
||||
"""HuggingFace specific settings (if different from local embeddings)."""
|
||||
|
||||
emb_model: Optional[str] = None
|
||||
api_token: Optional[SecretStr] = None
|
||||
|
||||
|
||||
class PdfSettings(BaseModel):
|
||||
"""PDF processing settings."""
|
||||
|
||||
data: List[Path] = Field(default_factory=list)
|
||||
unstructured: bool = Field(default=False)
|
||||
chunk_size: int = Field(default=1000)
|
||||
chunk_overlap: int = Field(default=200)
|
||||
add_start_index: bool = Field(default=False)
|
||||
|
||||
|
||||
class WebSettings(BaseModel):
|
||||
"""Web data processing settings."""
|
||||
|
||||
data: List[str] = Field(default_factory=list)
|
||||
chunk_size: int = Field(default=200)
|
||||
|
||||
|
||||
class ChromaDbSettings(BaseModel):
|
||||
"""Chroma DB settings."""
|
||||
|
||||
location: Path = Field(default=Path(".chroma_db"))
|
||||
reset: bool = Field(default=False)
|
||||
|
||||
|
||||
class AppSettings(BaseModel):
|
||||
"""
|
||||
Main application settings model.
|
||||
|
||||
Loads configuration from a YAML file using the structure defined
|
||||
by the nested models.
|
||||
"""
|
||||
|
||||
# --- Top-level settings ---
|
||||
chat_backend: ChatBackend = Field(default=ChatBackend.local)
|
||||
emb_backend: EmbeddingBackend = Field(default=EmbeddingBackend.huggingface)
|
||||
use_conditional_graph: bool = Field(default=False)
|
||||
|
||||
# --- Provider-specific settings ---
|
||||
azure: Optional[AzureSettings] = None
|
||||
openai: Optional[OpenAISettings] = None
|
||||
google: Optional[GoogleSettings] = None
|
||||
aws: Optional[AwsSettings] = None
|
||||
local: Optional[LocalSettings] = None
|
||||
huggingface: Optional[HuggingFaceSettings] = None # Separate HF config if needed
|
||||
|
||||
# --- Data processing settings ---
|
||||
pdf: PdfSettings = Field(default_factory=PdfSettings)
|
||||
web: WebSettings = Field(default_factory=WebSettings)
|
||||
chroma_db: ChromaDbSettings = Field(default_factory=ChromaDbSettings)
|
||||
|
||||
|
||||
# --- Configuration Loading Function ---
|
||||
def load_settings(config_path: Path = Path("config.yaml")) -> AppSettings:
|
||||
"""
|
||||
Loads settings from a YAML file and validates them using Pydantic models.
|
||||
|
||||
Args:
|
||||
config_path: The path to the configuration YAML file.
|
||||
|
||||
Returns:
|
||||
An instance of AppSettings containing the loaded configuration.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If the config file does not exist.
|
||||
yaml.YAMLError: If the file is not valid YAML.
|
||||
ValidationError: If the data in the file doesn't match the AppSettings model.
|
||||
"""
|
||||
if not config_path.is_file():
|
||||
print(f"Error: Configuration file not found at '{config_path}'", file=sys.stderr)
|
||||
raise FileNotFoundError(f"Configuration file not found: {config_path}")
|
||||
|
||||
print(f"--- Loading settings from '{config_path}' ---")
|
||||
try:
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
config_data = yaml.safe_load(f)
|
||||
if config_data is None:
|
||||
config_data = {}
|
||||
|
||||
settings = AppSettings(**config_data)
|
||||
print("--- Settings loaded and validated successfully ---")
|
||||
return settings
|
||||
|
||||
except yaml.YAMLError as e:
|
||||
print(f"Error parsing YAML file '{config_path}':\n {e}", file=sys.stderr)
|
||||
raise
|
||||
except ValidationError as e:
|
||||
print(f"Error validating configuration from '{config_path}':\n{e}", file=sys.stderr)
|
||||
raise
|
||||
except Exception as e:
|
||||
print(f"An unexpected error occurred while loading settings from '{config_path}': {e}", file=sys.stderr)
|
||||
raise
|
||||
@ -32,7 +32,7 @@ def add_urls(vector_store: Chroma, urls: list[str], chunk_size: int) -> None:
|
||||
The URL's will be fetched and split into chunks of text with the provided chunk size.
|
||||
"""
|
||||
logger.info("Web sources to the vector store.")
|
||||
|
||||
|
||||
all_splits = []
|
||||
for url in urls:
|
||||
if len(vector_store.get(where={"source": url}, limit=1)["ids"]) > 0:
|
||||
@ -87,7 +87,6 @@ def add_pdf_files(
|
||||
The PDF file will be parsed per page and split into chunks of text with the provided chunk size and overlap.
|
||||
"""
|
||||
logger.info("Adding PDF files to the vector store.")
|
||||
|
||||
pdf_files = get_all_local_pdf_files(file_paths)
|
||||
logger.info(f"Found {len(pdf_files)} PDF files to add to the vector store.")
|
||||
|
||||
@ -100,8 +99,8 @@ def add_pdf_files(
|
||||
|
||||
if len(new_pdfs) == 0:
|
||||
return
|
||||
|
||||
logger.info(f"{len(new_pdfs)} PDF's to add to the vector store.")
|
||||
|
||||
logger.info(f"{len(new_pdfs)} PDF(s) to add to the vector store.")
|
||||
|
||||
loaded_document = []
|
||||
for file in new_pdfs:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user