diff --git a/generic_rag/backend/models.py b/generic_rag/backend/models.py index a5a261f..2065bd5 100644 --- a/generic_rag/backend/models.py +++ b/generic_rag/backend/models.py @@ -56,15 +56,17 @@ def get_chat_model(settings: AppSettings) -> BaseChatModel: return ChatOpenAI(model=settings.openai.chat_model, openai_api_key=settings.openai.api_key.get_secret_value()) if settings.chat_backend == ChatBackend.google_vertex: - if not settings.google: - raise ValueError("Google Vertex chat backend selected, but 'google' configuration section is missing.") - if settings.google.chat_model: + if not settings.google_vertex: + raise ValueError( + "Google Vertex chat backend selected, but 'google_vertex' configuration section is missing." + ) + if settings.google_vertex.chat_model: model_name = settings.google.chat_model logger.info(f"Using Google Vertex model: {model_name}") return ChatVertexAI( - model_name=settings.google.chat_model, - project=settings.google.project_id, - location=settings.google.location, + model_name=settings.google_vertex.chat_model, + project=settings.google_vertex.project_id, + location=settings.google_vertex.location, ) if settings.chat_backend == ChatBackend.aws: @@ -135,14 +137,22 @@ def get_embedding_model(settings: AppSettings) -> Embeddings: return OpenAIEmbeddings(model=model_name, openai_api_key=settings.openai.api_key.get_secret_value()) if settings.emb_backend == EmbeddingBackend.google_vertex: - if not settings.google: - raise ValueError("Google Vertex embedding backend selected, but 'google' configuration section is missing.") + if not settings.google_vertex: + raise ValueError( + "Google Vertex embedding backend selected, but 'google_vertex' configuration section is missing." + ) model_name = "textembedding-gecko@001" # Example default - if settings.google.emb_model: - model_name = settings.google.emb_model + if ( + not settings.google_vertex.emb_model + or not settings.google_vertex.project_id + or not settings.google_vertex.location + ): + raise ValueError("Google Vertex configuration requires 'emb_model', 'project_id', and 'location'.") logger.info(f"Using Google Vertex embedding model: {model_name}") return VertexAIEmbeddings( - model_name=model_name, project=settings.google.project_id, location=settings.google.location + model_name=settings.google_vertex.emb_model, + project=settings.google_vertex.project_id, + location=settings.google_vertex.location, ) if settings.emb_backend == EmbeddingBackend.aws: diff --git a/generic_rag/parsers/config.py b/generic_rag/parsers/config.py index af5c9b0..1815937 100644 --- a/generic_rag/parsers/config.py +++ b/generic_rag/parsers/config.py @@ -52,10 +52,9 @@ class OpenAISettings(BaseModel): api_key: Optional[SecretStr] = None -class GoogleSettings(BaseModel): - """Google specific settings (Vertex AI or GenAI).""" +class GoogleVertexSettings(BaseModel): + """Google Vertex specific settings.""" - api_key: Optional[SecretStr] = None project_id: Optional[str] = None location: Optional[str] = None chat_model: Optional[str] = None @@ -124,7 +123,7 @@ class AppSettings(BaseModel): # --- Provider-specific settings --- azure: Optional[AzureSettings] = None openai: Optional[OpenAISettings] = None - google: Optional[GoogleSettings] = None + google_vertex: Optional[GoogleVertexSettings] = None aws: Optional[AwsSettings] = None local: Optional[LocalSettings] = None huggingface: Optional[HuggingFaceSettings] = None # Separate HF config if needed