🎨 Cleanup model init Add huggingface chat model

This commit is contained in:
Ruben Lucas 2025-04-16 16:06:58 +02:00
parent 572a278a7d
commit 770f341c1f
3 changed files with 59 additions and 51 deletions

View File

@ -28,6 +28,8 @@ google_vertex:
emb_model: "textembedding-gecko@001" emb_model: "textembedding-gecko@001"
aws: aws:
chat_model: "amazon.titan-llm-v1"
emb_model: "amazon.titan-embed-text-v1"
region: "us-east-1" region: "us-east-1"
credentials: "PATH_TO_YOUR_CREDENTIALS_FILE.json" credentials: "PATH_TO_YOUR_CREDENTIALS_FILE.json"
@ -36,6 +38,7 @@ local: # Settings for local models (e.g., Ollama)
emb_model: "llama3.1:8b" emb_model: "llama3.1:8b"
huggingface: # Settings specific to HuggingFace embedding backend huggingface: # Settings specific to HuggingFace embedding backend
chat_model: "meta-llama/Llama-2-7b-chat-hf"
emb_model: "sentence-transformers/paraphrase-MiniLM-L12-v2" emb_model: "sentence-transformers/paraphrase-MiniLM-L12-v2"
# --- Data Processing Settings --- # --- Data Processing Settings ---

View File

@ -2,14 +2,13 @@ import logging
from generic_rag.parsers.config import AppSettings, ChatBackend, EmbeddingBackend from generic_rag.parsers.config import AppSettings, ChatBackend, EmbeddingBackend
# Langchain imports
from langchain_core.embeddings import Embeddings from langchain_core.embeddings import Embeddings
from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.language_models.chat_models import BaseChatModel
from langchain_aws import BedrockEmbeddings, ChatBedrock # Import ChatBedrock from langchain_aws import BedrockEmbeddings, ChatBedrock
from langchain_google_vertexai import VertexAIEmbeddings, ChatVertexAI # Import ChatVertexAI from langchain_google_vertexai import VertexAIEmbeddings, ChatVertexAI
from langchain_huggingface import HuggingFaceEmbeddings from langchain_huggingface import HuggingFaceEmbeddings, ChatHuggingFace, HuggingFacePipeline
from langchain_ollama import ChatOllama, OllamaEmbeddings from langchain_ollama import ChatOllama, OllamaEmbeddings
from langchain_openai import AzureChatOpenAI, AzureOpenAIEmbeddings, ChatOpenAI, OpenAIEmbeddings # Import ChatOpenAI from langchain_openai import AzureChatOpenAI, AzureOpenAIEmbeddings, ChatOpenAI, OpenAIEmbeddings
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -52,7 +51,6 @@ def get_chat_model(settings: AppSettings) -> BaseChatModel:
raise ValueError("OpenAI chat backend selected, but 'openai' configuration section is missing.") raise ValueError("OpenAI chat backend selected, but 'openai' configuration section is missing.")
if not settings.openai.api_key or not settings.openai.chat_model: if not settings.openai.api_key or not settings.openai.chat_model:
raise ValueError("OpenAI configuration requires 'api_key' and 'chat_model'.") raise ValueError("OpenAI configuration requires 'api_key' and 'chat_model'.")
logger.info(f"Using OpenAI model: {model_name}")
return ChatOpenAI(model=settings.openai.chat_model, openai_api_key=settings.openai.api_key.get_secret_value()) return ChatOpenAI(model=settings.openai.chat_model, openai_api_key=settings.openai.api_key.get_secret_value())
if settings.chat_backend == ChatBackend.google_vertex: if settings.chat_backend == ChatBackend.google_vertex:
@ -60,9 +58,12 @@ def get_chat_model(settings: AppSettings) -> BaseChatModel:
raise ValueError( raise ValueError(
"Google Vertex chat backend selected, but 'google_vertex' configuration section is missing." "Google Vertex chat backend selected, but 'google_vertex' configuration section is missing."
) )
if settings.google_vertex.chat_model: if (
model_name = settings.google.chat_model not settings.google_vertex.chat_model
logger.info(f"Using Google Vertex model: {model_name}") or not settings.google_vertex.project_id
or not settings.google_vertex.location
):
raise ValueError("Google Vertex configuration requires 'chat_model' and 'project_id'.")
return ChatVertexAI( return ChatVertexAI(
model_name=settings.google_vertex.chat_model, model_name=settings.google_vertex.chat_model,
project=settings.google_vertex.project_id, project=settings.google_vertex.project_id,
@ -72,22 +73,35 @@ def get_chat_model(settings: AppSettings) -> BaseChatModel:
if settings.chat_backend == ChatBackend.aws: if settings.chat_backend == ChatBackend.aws:
if not settings.aws: if not settings.aws:
raise ValueError("AWS Bedrock chat backend selected, but 'aws' configuration section is missing.") raise ValueError("AWS Bedrock chat backend selected, but 'aws' configuration section is missing.")
model_name = "anthropic.claude-v2" # Example default if not settings.aws.chat_model or not settings.aws.region_name:
if hasattr(settings.aws, "chat_model") and settings.aws.chat_model: raise ValueError("AWS Bedrock configuration requires 'chat_model' and 'region_name'")
model_name = settings.aws.chat_model
logger.info(f"Using AWS Bedrock model: {model_name}")
return ChatBedrock( return ChatBedrock(
model_id=model_name, model_id=settings.aws.chat_model,
region_name=settings.aws.region_name, region_name=settings.aws.region_name,
) )
if settings.chat_backend == ChatBackend.local: if settings.chat_backend == ChatBackend.local:
if not settings.local or not settings.local.chat_model: if not settings.local:
raise ValueError("Local chat backend selected, but 'local.chat_model' is missing in config.") raise ValueError("Local chat backend selected, but 'local' configuration section is missing.")
logger.info(f"Using Local Ollama model: {settings.local.chat_model}") if not settings.local.chat_model:
# Base URL can also be configured, e.g., base_url=config.local.ollama_base_url raise ValueError("Local configuration requires 'chat_model'")
return ChatOllama(model=settings.local.chat_model) return ChatOllama(model=settings.local.chat_model)
if settings.chat_backend == ChatBackend.huggingface:
if not settings.huggingface:
raise ValueError("Huggingface chat backend selected, but 'huggingface' configuration section is missing.")
if not settings.huggingface.chat_model:
raise ValueError("Huggingface configuration requires 'chat_model'")
llm = HuggingFacePipeline.from_model_id(
model_id=settings.huggingface.chat_model,
task="text-generation",
pipeline_kwargs=dict(
max_new_tokens=512,
do_sample=False,
repetition_penalty=1.03,
),
)
return ChatHuggingFace(llm=llm)
# This should not be reached if all Enum members are handled # This should not be reached if all Enum members are handled
raise ValueError(f"Unknown or unhandled chat backend type: {settings.chat_backend}") raise ValueError(f"Unknown or unhandled chat backend type: {settings.chat_backend}")
@ -130,25 +144,21 @@ def get_embedding_model(settings: AppSettings) -> Embeddings:
raise ValueError("OpenAI embedding backend selected, but 'openai' configuration section is missing.") raise ValueError("OpenAI embedding backend selected, but 'openai' configuration section is missing.")
if not settings.openai.api_key: if not settings.openai.api_key:
raise ValueError("OpenAI configuration requires 'api_key'.") raise ValueError("OpenAI configuration requires 'api_key'.")
model_name = "text-embedding-ada-002" # Example default return OpenAIEmbeddings(
if hasattr(settings.openai, "emb_model") and settings.openai.emb_model: model=settings.openai.emb_model, openai_api_key=settings.openai.api_key.get_secret_value()
model_name = settings.openai.emb_model )
logger.info(f"Using OpenAI embedding model: {model_name}")
return OpenAIEmbeddings(model=model_name, openai_api_key=settings.openai.api_key.get_secret_value())
if settings.emb_backend == EmbeddingBackend.google_vertex: if settings.emb_backend == EmbeddingBackend.google_vertex:
if not settings.google_vertex: if not settings.google_vertex:
raise ValueError( raise ValueError(
"Google Vertex embedding backend selected, but 'google_vertex' configuration section is missing." "Google Vertex embedding backend selected, but 'google_vertex' configuration section is missing."
) )
model_name = "textembedding-gecko@001" # Example default
if ( if (
not settings.google_vertex.emb_model not settings.google_vertex.emb_model
or not settings.google_vertex.project_id or not settings.google_vertex.project_id
or not settings.google_vertex.location or not settings.google_vertex.location
): ):
raise ValueError("Google Vertex configuration requires 'emb_model', 'project_id', and 'location'.") raise ValueError("Google Vertex configuration requires 'emb_model', 'project_id', and 'location'.")
logger.info(f"Using Google Vertex embedding model: {model_name}")
return VertexAIEmbeddings( return VertexAIEmbeddings(
model_name=settings.google_vertex.emb_model, model_name=settings.google_vertex.emb_model,
project=settings.google_vertex.project_id, project=settings.google_vertex.project_id,
@ -158,33 +168,24 @@ def get_embedding_model(settings: AppSettings) -> Embeddings:
if settings.emb_backend == EmbeddingBackend.aws: if settings.emb_backend == EmbeddingBackend.aws:
if not settings.aws: if not settings.aws:
raise ValueError("AWS Bedrock embedding backend selected, but 'aws' configuration section is missing.") raise ValueError("AWS Bedrock embedding backend selected, but 'aws' configuration section is missing.")
model_name = "amazon.titan-embed-text-v1" # Example default if not settings.aws.emb_model or not settings.aws.region_name:
if hasattr(settings.aws, "emb_model") and settings.aws.emb_model: raise ValueError("AWS Bedrock configuration requires 'emb_model' and 'region_name'")
model_name = settings.aws.emb_model return BedrockEmbeddings(model_id=settings.aws.emb_model, region_name=settings.aws.region_name)
logger.info(f"Using AWS Bedrock embedding model: {model_name}")
return BedrockEmbeddings(model_id=model_name, region_name=settings.aws.region_name)
if settings.emb_backend == EmbeddingBackend.local: if settings.emb_backend == EmbeddingBackend.local:
if not settings.local or not settings.local.emb_model: if not settings.local:
raise ValueError("Local embedding backend selected, but 'local.emb_model' is missing in config.") raise ValueError("Local embedding backend selected, but 'local' configuration section is missing.")
logger.info(f"Using Local Ollama embedding model: {settings.local.emb_model}") if not settings.local.emb_model:
raise ValueError("Local configuration requires 'emb_model'")
return OllamaEmbeddings(model=settings.local.emb_model) return OllamaEmbeddings(model=settings.local.emb_model)
if settings.emb_backend == EmbeddingBackend.huggingface: if settings.emb_backend == EmbeddingBackend.huggingface:
if not settings.huggingface or not settings.huggingface.emb_model: if not settings.huggingface:
if settings.local and settings.local.emb_model:
logger.warning(
"HuggingFace backend selected, but 'huggingface.emb_model' missing. Using 'local.emb_model'."
)
model_name = settings.local.emb_model
else:
raise ValueError( raise ValueError(
"HuggingFace embedding backend selected, but 'huggingface.emb_model' (or 'local.emb_model') is missing in config." "HuggingFace embedding backend selected, but 'huggingface' configuration section is missing."
) )
else: if not settings.huggingface.emb_model:
model_name = settings.huggingface.emb_model raise ValueError("HuggingFace configuration requires 'emb_model'.")
return HuggingFaceEmbeddings(model_name=settings.huggingface.emb_model)
logger.info(f"Using HuggingFace embedding model: {model_name}")
return HuggingFaceEmbeddings(model_name=model_name)
raise ValueError(f"Unknown or unhandled embedding backend type: {settings.emb_backend}") raise ValueError(f"Unknown or unhandled embedding backend type: {settings.emb_backend}")

View File

@ -17,6 +17,7 @@ class ChatBackend(str, Enum):
google_vertex = "google_vertex" google_vertex = "google_vertex"
aws = "aws" aws = "aws"
local = "local" local = "local"
huggingface = "huggingface"
def __str__(self): def __str__(self):
return self.value return self.value
@ -50,6 +51,8 @@ class OpenAISettings(BaseModel):
"""OpenAI specific settings.""" """OpenAI specific settings."""
api_key: Optional[SecretStr] = None api_key: Optional[SecretStr] = None
chat_model: Optional[str] = None
emb_model: Optional[str] = None
class GoogleVertexSettings(BaseModel): class GoogleVertexSettings(BaseModel):
@ -64,9 +67,9 @@ class GoogleVertexSettings(BaseModel):
class AwsSettings(BaseModel): class AwsSettings(BaseModel):
"""AWS specific settings (e.g., for Bedrock).""" """AWS specific settings (e.g., for Bedrock)."""
access_key_id: Optional[SecretStr] = None chat_model: Optional[str] = None
secret_access_key: Optional[SecretStr] = None emb_model: Optional[str] = None
region_name: Optional[str] = None region: Optional[str] = None
class LocalSettings(BaseModel): class LocalSettings(BaseModel):
@ -79,6 +82,7 @@ class LocalSettings(BaseModel):
class HuggingFaceSettings(BaseModel): class HuggingFaceSettings(BaseModel):
"""HuggingFace specific settings (if different from local embeddings).""" """HuggingFace specific settings (if different from local embeddings)."""
chat_model: Optional[str] = None
emb_model: Optional[str] = None emb_model: Optional[str] = None
api_token: Optional[SecretStr] = None api_token: Optional[SecretStr] = None