diff --git a/config.example.yaml b/config.example.yaml index 2b778ce..12bd131 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -28,6 +28,8 @@ google_vertex: emb_model: "textembedding-gecko@001" aws: + chat_model: "amazon.titan-llm-v1" + emb_model: "amazon.titan-embed-text-v1" region: "us-east-1" credentials: "PATH_TO_YOUR_CREDENTIALS_FILE.json" @@ -36,6 +38,7 @@ local: # Settings for local models (e.g., Ollama) emb_model: "llama3.1:8b" huggingface: # Settings specific to HuggingFace embedding backend + chat_model: "meta-llama/Llama-2-7b-chat-hf" emb_model: "sentence-transformers/paraphrase-MiniLM-L12-v2" # --- Data Processing Settings --- diff --git a/generic_rag/backend/models.py b/generic_rag/backend/models.py index 2065bd5..e8a3063 100644 --- a/generic_rag/backend/models.py +++ b/generic_rag/backend/models.py @@ -2,14 +2,13 @@ import logging from generic_rag.parsers.config import AppSettings, ChatBackend, EmbeddingBackend -# Langchain imports from langchain_core.embeddings import Embeddings from langchain_core.language_models.chat_models import BaseChatModel -from langchain_aws import BedrockEmbeddings, ChatBedrock # Import ChatBedrock -from langchain_google_vertexai import VertexAIEmbeddings, ChatVertexAI # Import ChatVertexAI -from langchain_huggingface import HuggingFaceEmbeddings +from langchain_aws import BedrockEmbeddings, ChatBedrock +from langchain_google_vertexai import VertexAIEmbeddings, ChatVertexAI +from langchain_huggingface import HuggingFaceEmbeddings, ChatHuggingFace, HuggingFacePipeline from langchain_ollama import ChatOllama, OllamaEmbeddings -from langchain_openai import AzureChatOpenAI, AzureOpenAIEmbeddings, ChatOpenAI, OpenAIEmbeddings # Import ChatOpenAI +from langchain_openai import AzureChatOpenAI, AzureOpenAIEmbeddings, ChatOpenAI, OpenAIEmbeddings logger = logging.getLogger(__name__) @@ -52,7 +51,6 @@ def get_chat_model(settings: AppSettings) -> BaseChatModel: raise ValueError("OpenAI chat backend selected, but 'openai' configuration section is missing.") if not settings.openai.api_key or not settings.openai.chat_model: raise ValueError("OpenAI configuration requires 'api_key' and 'chat_model'.") - logger.info(f"Using OpenAI model: {model_name}") return ChatOpenAI(model=settings.openai.chat_model, openai_api_key=settings.openai.api_key.get_secret_value()) if settings.chat_backend == ChatBackend.google_vertex: @@ -60,9 +58,12 @@ def get_chat_model(settings: AppSettings) -> BaseChatModel: raise ValueError( "Google Vertex chat backend selected, but 'google_vertex' configuration section is missing." ) - if settings.google_vertex.chat_model: - model_name = settings.google.chat_model - logger.info(f"Using Google Vertex model: {model_name}") + if ( + not settings.google_vertex.chat_model + or not settings.google_vertex.project_id + or not settings.google_vertex.location + ): + raise ValueError("Google Vertex configuration requires 'chat_model' and 'project_id'.") return ChatVertexAI( model_name=settings.google_vertex.chat_model, project=settings.google_vertex.project_id, @@ -72,22 +73,35 @@ def get_chat_model(settings: AppSettings) -> BaseChatModel: if settings.chat_backend == ChatBackend.aws: if not settings.aws: raise ValueError("AWS Bedrock chat backend selected, but 'aws' configuration section is missing.") - model_name = "anthropic.claude-v2" # Example default - if hasattr(settings.aws, "chat_model") and settings.aws.chat_model: - model_name = settings.aws.chat_model - logger.info(f"Using AWS Bedrock model: {model_name}") + if not settings.aws.chat_model or not settings.aws.region_name: + raise ValueError("AWS Bedrock configuration requires 'chat_model' and 'region_name'") return ChatBedrock( - model_id=model_name, + model_id=settings.aws.chat_model, region_name=settings.aws.region_name, ) if settings.chat_backend == ChatBackend.local: - if not settings.local or not settings.local.chat_model: - raise ValueError("Local chat backend selected, but 'local.chat_model' is missing in config.") - logger.info(f"Using Local Ollama model: {settings.local.chat_model}") - # Base URL can also be configured, e.g., base_url=config.local.ollama_base_url + if not settings.local: + raise ValueError("Local chat backend selected, but 'local' configuration section is missing.") + if not settings.local.chat_model: + raise ValueError("Local configuration requires 'chat_model'") return ChatOllama(model=settings.local.chat_model) + if settings.chat_backend == ChatBackend.huggingface: + if not settings.huggingface: + raise ValueError("Huggingface chat backend selected, but 'huggingface' configuration section is missing.") + if not settings.huggingface.chat_model: + raise ValueError("Huggingface configuration requires 'chat_model'") + llm = HuggingFacePipeline.from_model_id( + model_id=settings.huggingface.chat_model, + task="text-generation", + pipeline_kwargs=dict( + max_new_tokens=512, + do_sample=False, + repetition_penalty=1.03, + ), + ) + return ChatHuggingFace(llm=llm) # This should not be reached if all Enum members are handled raise ValueError(f"Unknown or unhandled chat backend type: {settings.chat_backend}") @@ -130,25 +144,21 @@ def get_embedding_model(settings: AppSettings) -> Embeddings: raise ValueError("OpenAI embedding backend selected, but 'openai' configuration section is missing.") if not settings.openai.api_key: raise ValueError("OpenAI configuration requires 'api_key'.") - model_name = "text-embedding-ada-002" # Example default - if hasattr(settings.openai, "emb_model") and settings.openai.emb_model: - model_name = settings.openai.emb_model - logger.info(f"Using OpenAI embedding model: {model_name}") - return OpenAIEmbeddings(model=model_name, openai_api_key=settings.openai.api_key.get_secret_value()) + return OpenAIEmbeddings( + model=settings.openai.emb_model, openai_api_key=settings.openai.api_key.get_secret_value() + ) if settings.emb_backend == EmbeddingBackend.google_vertex: if not settings.google_vertex: raise ValueError( "Google Vertex embedding backend selected, but 'google_vertex' configuration section is missing." ) - model_name = "textembedding-gecko@001" # Example default if ( not settings.google_vertex.emb_model or not settings.google_vertex.project_id or not settings.google_vertex.location ): raise ValueError("Google Vertex configuration requires 'emb_model', 'project_id', and 'location'.") - logger.info(f"Using Google Vertex embedding model: {model_name}") return VertexAIEmbeddings( model_name=settings.google_vertex.emb_model, project=settings.google_vertex.project_id, @@ -158,33 +168,24 @@ def get_embedding_model(settings: AppSettings) -> Embeddings: if settings.emb_backend == EmbeddingBackend.aws: if not settings.aws: raise ValueError("AWS Bedrock embedding backend selected, but 'aws' configuration section is missing.") - model_name = "amazon.titan-embed-text-v1" # Example default - if hasattr(settings.aws, "emb_model") and settings.aws.emb_model: - model_name = settings.aws.emb_model - logger.info(f"Using AWS Bedrock embedding model: {model_name}") - return BedrockEmbeddings(model_id=model_name, region_name=settings.aws.region_name) + if not settings.aws.emb_model or not settings.aws.region_name: + raise ValueError("AWS Bedrock configuration requires 'emb_model' and 'region_name'") + return BedrockEmbeddings(model_id=settings.aws.emb_model, region_name=settings.aws.region_name) if settings.emb_backend == EmbeddingBackend.local: - if not settings.local or not settings.local.emb_model: - raise ValueError("Local embedding backend selected, but 'local.emb_model' is missing in config.") - logger.info(f"Using Local Ollama embedding model: {settings.local.emb_model}") + if not settings.local: + raise ValueError("Local embedding backend selected, but 'local' configuration section is missing.") + if not settings.local.emb_model: + raise ValueError("Local configuration requires 'emb_model'") return OllamaEmbeddings(model=settings.local.emb_model) if settings.emb_backend == EmbeddingBackend.huggingface: - if not settings.huggingface or not settings.huggingface.emb_model: - if settings.local and settings.local.emb_model: - logger.warning( - "HuggingFace backend selected, but 'huggingface.emb_model' missing. Using 'local.emb_model'." - ) - model_name = settings.local.emb_model - else: - raise ValueError( - "HuggingFace embedding backend selected, but 'huggingface.emb_model' (or 'local.emb_model') is missing in config." - ) - else: - model_name = settings.huggingface.emb_model - - logger.info(f"Using HuggingFace embedding model: {model_name}") - return HuggingFaceEmbeddings(model_name=model_name) + if not settings.huggingface: + raise ValueError( + "HuggingFace embedding backend selected, but 'huggingface' configuration section is missing." + ) + if not settings.huggingface.emb_model: + raise ValueError("HuggingFace configuration requires 'emb_model'.") + return HuggingFaceEmbeddings(model_name=settings.huggingface.emb_model) raise ValueError(f"Unknown or unhandled embedding backend type: {settings.emb_backend}") diff --git a/generic_rag/parsers/config.py b/generic_rag/parsers/config.py index 1815937..6ce2402 100644 --- a/generic_rag/parsers/config.py +++ b/generic_rag/parsers/config.py @@ -17,6 +17,7 @@ class ChatBackend(str, Enum): google_vertex = "google_vertex" aws = "aws" local = "local" + huggingface = "huggingface" def __str__(self): return self.value @@ -50,6 +51,8 @@ class OpenAISettings(BaseModel): """OpenAI specific settings.""" api_key: Optional[SecretStr] = None + chat_model: Optional[str] = None + emb_model: Optional[str] = None class GoogleVertexSettings(BaseModel): @@ -64,9 +67,9 @@ class GoogleVertexSettings(BaseModel): class AwsSettings(BaseModel): """AWS specific settings (e.g., for Bedrock).""" - access_key_id: Optional[SecretStr] = None - secret_access_key: Optional[SecretStr] = None - region_name: Optional[str] = None + chat_model: Optional[str] = None + emb_model: Optional[str] = None + region: Optional[str] = None class LocalSettings(BaseModel): @@ -79,6 +82,7 @@ class LocalSettings(BaseModel): class HuggingFaceSettings(BaseModel): """HuggingFace specific settings (if different from local embeddings).""" + chat_model: Optional[str] = None emb_model: Optional[str] = None api_token: Optional[SecretStr] = None