forked from AI_team/Philosophy-RAG-demo
🎨 Cleanup model init ✨ Add huggingface chat model
This commit is contained in:
parent
572a278a7d
commit
770f341c1f
@ -28,6 +28,8 @@ google_vertex:
|
||||
emb_model: "textembedding-gecko@001"
|
||||
|
||||
aws:
|
||||
chat_model: "amazon.titan-llm-v1"
|
||||
emb_model: "amazon.titan-embed-text-v1"
|
||||
region: "us-east-1"
|
||||
credentials: "PATH_TO_YOUR_CREDENTIALS_FILE.json"
|
||||
|
||||
@ -36,6 +38,7 @@ local: # Settings for local models (e.g., Ollama)
|
||||
emb_model: "llama3.1:8b"
|
||||
|
||||
huggingface: # Settings specific to HuggingFace embedding backend
|
||||
chat_model: "meta-llama/Llama-2-7b-chat-hf"
|
||||
emb_model: "sentence-transformers/paraphrase-MiniLM-L12-v2"
|
||||
|
||||
# --- Data Processing Settings ---
|
||||
|
||||
@ -2,14 +2,13 @@ import logging
|
||||
|
||||
from generic_rag.parsers.config import AppSettings, ChatBackend, EmbeddingBackend
|
||||
|
||||
# Langchain imports
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_aws import BedrockEmbeddings, ChatBedrock # Import ChatBedrock
|
||||
from langchain_google_vertexai import VertexAIEmbeddings, ChatVertexAI # Import ChatVertexAI
|
||||
from langchain_huggingface import HuggingFaceEmbeddings
|
||||
from langchain_aws import BedrockEmbeddings, ChatBedrock
|
||||
from langchain_google_vertexai import VertexAIEmbeddings, ChatVertexAI
|
||||
from langchain_huggingface import HuggingFaceEmbeddings, ChatHuggingFace, HuggingFacePipeline
|
||||
from langchain_ollama import ChatOllama, OllamaEmbeddings
|
||||
from langchain_openai import AzureChatOpenAI, AzureOpenAIEmbeddings, ChatOpenAI, OpenAIEmbeddings # Import ChatOpenAI
|
||||
from langchain_openai import AzureChatOpenAI, AzureOpenAIEmbeddings, ChatOpenAI, OpenAIEmbeddings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -52,7 +51,6 @@ def get_chat_model(settings: AppSettings) -> BaseChatModel:
|
||||
raise ValueError("OpenAI chat backend selected, but 'openai' configuration section is missing.")
|
||||
if not settings.openai.api_key or not settings.openai.chat_model:
|
||||
raise ValueError("OpenAI configuration requires 'api_key' and 'chat_model'.")
|
||||
logger.info(f"Using OpenAI model: {model_name}")
|
||||
return ChatOpenAI(model=settings.openai.chat_model, openai_api_key=settings.openai.api_key.get_secret_value())
|
||||
|
||||
if settings.chat_backend == ChatBackend.google_vertex:
|
||||
@ -60,9 +58,12 @@ def get_chat_model(settings: AppSettings) -> BaseChatModel:
|
||||
raise ValueError(
|
||||
"Google Vertex chat backend selected, but 'google_vertex' configuration section is missing."
|
||||
)
|
||||
if settings.google_vertex.chat_model:
|
||||
model_name = settings.google.chat_model
|
||||
logger.info(f"Using Google Vertex model: {model_name}")
|
||||
if (
|
||||
not settings.google_vertex.chat_model
|
||||
or not settings.google_vertex.project_id
|
||||
or not settings.google_vertex.location
|
||||
):
|
||||
raise ValueError("Google Vertex configuration requires 'chat_model' and 'project_id'.")
|
||||
return ChatVertexAI(
|
||||
model_name=settings.google_vertex.chat_model,
|
||||
project=settings.google_vertex.project_id,
|
||||
@ -72,22 +73,35 @@ def get_chat_model(settings: AppSettings) -> BaseChatModel:
|
||||
if settings.chat_backend == ChatBackend.aws:
|
||||
if not settings.aws:
|
||||
raise ValueError("AWS Bedrock chat backend selected, but 'aws' configuration section is missing.")
|
||||
model_name = "anthropic.claude-v2" # Example default
|
||||
if hasattr(settings.aws, "chat_model") and settings.aws.chat_model:
|
||||
model_name = settings.aws.chat_model
|
||||
logger.info(f"Using AWS Bedrock model: {model_name}")
|
||||
if not settings.aws.chat_model or not settings.aws.region_name:
|
||||
raise ValueError("AWS Bedrock configuration requires 'chat_model' and 'region_name'")
|
||||
return ChatBedrock(
|
||||
model_id=model_name,
|
||||
model_id=settings.aws.chat_model,
|
||||
region_name=settings.aws.region_name,
|
||||
)
|
||||
|
||||
if settings.chat_backend == ChatBackend.local:
|
||||
if not settings.local or not settings.local.chat_model:
|
||||
raise ValueError("Local chat backend selected, but 'local.chat_model' is missing in config.")
|
||||
logger.info(f"Using Local Ollama model: {settings.local.chat_model}")
|
||||
# Base URL can also be configured, e.g., base_url=config.local.ollama_base_url
|
||||
if not settings.local:
|
||||
raise ValueError("Local chat backend selected, but 'local' configuration section is missing.")
|
||||
if not settings.local.chat_model:
|
||||
raise ValueError("Local configuration requires 'chat_model'")
|
||||
return ChatOllama(model=settings.local.chat_model)
|
||||
|
||||
if settings.chat_backend == ChatBackend.huggingface:
|
||||
if not settings.huggingface:
|
||||
raise ValueError("Huggingface chat backend selected, but 'huggingface' configuration section is missing.")
|
||||
if not settings.huggingface.chat_model:
|
||||
raise ValueError("Huggingface configuration requires 'chat_model'")
|
||||
llm = HuggingFacePipeline.from_model_id(
|
||||
model_id=settings.huggingface.chat_model,
|
||||
task="text-generation",
|
||||
pipeline_kwargs=dict(
|
||||
max_new_tokens=512,
|
||||
do_sample=False,
|
||||
repetition_penalty=1.03,
|
||||
),
|
||||
)
|
||||
return ChatHuggingFace(llm=llm)
|
||||
# This should not be reached if all Enum members are handled
|
||||
raise ValueError(f"Unknown or unhandled chat backend type: {settings.chat_backend}")
|
||||
|
||||
@ -130,25 +144,21 @@ def get_embedding_model(settings: AppSettings) -> Embeddings:
|
||||
raise ValueError("OpenAI embedding backend selected, but 'openai' configuration section is missing.")
|
||||
if not settings.openai.api_key:
|
||||
raise ValueError("OpenAI configuration requires 'api_key'.")
|
||||
model_name = "text-embedding-ada-002" # Example default
|
||||
if hasattr(settings.openai, "emb_model") and settings.openai.emb_model:
|
||||
model_name = settings.openai.emb_model
|
||||
logger.info(f"Using OpenAI embedding model: {model_name}")
|
||||
return OpenAIEmbeddings(model=model_name, openai_api_key=settings.openai.api_key.get_secret_value())
|
||||
return OpenAIEmbeddings(
|
||||
model=settings.openai.emb_model, openai_api_key=settings.openai.api_key.get_secret_value()
|
||||
)
|
||||
|
||||
if settings.emb_backend == EmbeddingBackend.google_vertex:
|
||||
if not settings.google_vertex:
|
||||
raise ValueError(
|
||||
"Google Vertex embedding backend selected, but 'google_vertex' configuration section is missing."
|
||||
)
|
||||
model_name = "textembedding-gecko@001" # Example default
|
||||
if (
|
||||
not settings.google_vertex.emb_model
|
||||
or not settings.google_vertex.project_id
|
||||
or not settings.google_vertex.location
|
||||
):
|
||||
raise ValueError("Google Vertex configuration requires 'emb_model', 'project_id', and 'location'.")
|
||||
logger.info(f"Using Google Vertex embedding model: {model_name}")
|
||||
return VertexAIEmbeddings(
|
||||
model_name=settings.google_vertex.emb_model,
|
||||
project=settings.google_vertex.project_id,
|
||||
@ -158,33 +168,24 @@ def get_embedding_model(settings: AppSettings) -> Embeddings:
|
||||
if settings.emb_backend == EmbeddingBackend.aws:
|
||||
if not settings.aws:
|
||||
raise ValueError("AWS Bedrock embedding backend selected, but 'aws' configuration section is missing.")
|
||||
model_name = "amazon.titan-embed-text-v1" # Example default
|
||||
if hasattr(settings.aws, "emb_model") and settings.aws.emb_model:
|
||||
model_name = settings.aws.emb_model
|
||||
logger.info(f"Using AWS Bedrock embedding model: {model_name}")
|
||||
return BedrockEmbeddings(model_id=model_name, region_name=settings.aws.region_name)
|
||||
if not settings.aws.emb_model or not settings.aws.region_name:
|
||||
raise ValueError("AWS Bedrock configuration requires 'emb_model' and 'region_name'")
|
||||
return BedrockEmbeddings(model_id=settings.aws.emb_model, region_name=settings.aws.region_name)
|
||||
|
||||
if settings.emb_backend == EmbeddingBackend.local:
|
||||
if not settings.local or not settings.local.emb_model:
|
||||
raise ValueError("Local embedding backend selected, but 'local.emb_model' is missing in config.")
|
||||
logger.info(f"Using Local Ollama embedding model: {settings.local.emb_model}")
|
||||
if not settings.local:
|
||||
raise ValueError("Local embedding backend selected, but 'local' configuration section is missing.")
|
||||
if not settings.local.emb_model:
|
||||
raise ValueError("Local configuration requires 'emb_model'")
|
||||
return OllamaEmbeddings(model=settings.local.emb_model)
|
||||
|
||||
if settings.emb_backend == EmbeddingBackend.huggingface:
|
||||
if not settings.huggingface or not settings.huggingface.emb_model:
|
||||
if settings.local and settings.local.emb_model:
|
||||
logger.warning(
|
||||
"HuggingFace backend selected, but 'huggingface.emb_model' missing. Using 'local.emb_model'."
|
||||
)
|
||||
model_name = settings.local.emb_model
|
||||
else:
|
||||
raise ValueError(
|
||||
"HuggingFace embedding backend selected, but 'huggingface.emb_model' (or 'local.emb_model') is missing in config."
|
||||
)
|
||||
else:
|
||||
model_name = settings.huggingface.emb_model
|
||||
|
||||
logger.info(f"Using HuggingFace embedding model: {model_name}")
|
||||
return HuggingFaceEmbeddings(model_name=model_name)
|
||||
if not settings.huggingface:
|
||||
raise ValueError(
|
||||
"HuggingFace embedding backend selected, but 'huggingface' configuration section is missing."
|
||||
)
|
||||
if not settings.huggingface.emb_model:
|
||||
raise ValueError("HuggingFace configuration requires 'emb_model'.")
|
||||
return HuggingFaceEmbeddings(model_name=settings.huggingface.emb_model)
|
||||
|
||||
raise ValueError(f"Unknown or unhandled embedding backend type: {settings.emb_backend}")
|
||||
|
||||
@ -17,6 +17,7 @@ class ChatBackend(str, Enum):
|
||||
google_vertex = "google_vertex"
|
||||
aws = "aws"
|
||||
local = "local"
|
||||
huggingface = "huggingface"
|
||||
|
||||
def __str__(self):
|
||||
return self.value
|
||||
@ -50,6 +51,8 @@ class OpenAISettings(BaseModel):
|
||||
"""OpenAI specific settings."""
|
||||
|
||||
api_key: Optional[SecretStr] = None
|
||||
chat_model: Optional[str] = None
|
||||
emb_model: Optional[str] = None
|
||||
|
||||
|
||||
class GoogleVertexSettings(BaseModel):
|
||||
@ -64,9 +67,9 @@ class GoogleVertexSettings(BaseModel):
|
||||
class AwsSettings(BaseModel):
|
||||
"""AWS specific settings (e.g., for Bedrock)."""
|
||||
|
||||
access_key_id: Optional[SecretStr] = None
|
||||
secret_access_key: Optional[SecretStr] = None
|
||||
region_name: Optional[str] = None
|
||||
chat_model: Optional[str] = None
|
||||
emb_model: Optional[str] = None
|
||||
region: Optional[str] = None
|
||||
|
||||
|
||||
class LocalSettings(BaseModel):
|
||||
@ -79,6 +82,7 @@ class LocalSettings(BaseModel):
|
||||
class HuggingFaceSettings(BaseModel):
|
||||
"""HuggingFace specific settings (if different from local embeddings)."""
|
||||
|
||||
chat_model: Optional[str] = None
|
||||
emb_model: Optional[str] = None
|
||||
api_token: Optional[SecretStr] = None
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user