forked from AI_team/Philosophy-RAG-demo
🎨 Cleanup model init ✨ Add huggingface chat model
This commit is contained in:
parent
572a278a7d
commit
770f341c1f
@ -28,6 +28,8 @@ google_vertex:
|
|||||||
emb_model: "textembedding-gecko@001"
|
emb_model: "textembedding-gecko@001"
|
||||||
|
|
||||||
aws:
|
aws:
|
||||||
|
chat_model: "amazon.titan-llm-v1"
|
||||||
|
emb_model: "amazon.titan-embed-text-v1"
|
||||||
region: "us-east-1"
|
region: "us-east-1"
|
||||||
credentials: "PATH_TO_YOUR_CREDENTIALS_FILE.json"
|
credentials: "PATH_TO_YOUR_CREDENTIALS_FILE.json"
|
||||||
|
|
||||||
@ -36,6 +38,7 @@ local: # Settings for local models (e.g., Ollama)
|
|||||||
emb_model: "llama3.1:8b"
|
emb_model: "llama3.1:8b"
|
||||||
|
|
||||||
huggingface: # Settings specific to HuggingFace embedding backend
|
huggingface: # Settings specific to HuggingFace embedding backend
|
||||||
|
chat_model: "meta-llama/Llama-2-7b-chat-hf"
|
||||||
emb_model: "sentence-transformers/paraphrase-MiniLM-L12-v2"
|
emb_model: "sentence-transformers/paraphrase-MiniLM-L12-v2"
|
||||||
|
|
||||||
# --- Data Processing Settings ---
|
# --- Data Processing Settings ---
|
||||||
|
|||||||
@ -2,14 +2,13 @@ import logging
|
|||||||
|
|
||||||
from generic_rag.parsers.config import AppSettings, ChatBackend, EmbeddingBackend
|
from generic_rag.parsers.config import AppSettings, ChatBackend, EmbeddingBackend
|
||||||
|
|
||||||
# Langchain imports
|
|
||||||
from langchain_core.embeddings import Embeddings
|
from langchain_core.embeddings import Embeddings
|
||||||
from langchain_core.language_models.chat_models import BaseChatModel
|
from langchain_core.language_models.chat_models import BaseChatModel
|
||||||
from langchain_aws import BedrockEmbeddings, ChatBedrock # Import ChatBedrock
|
from langchain_aws import BedrockEmbeddings, ChatBedrock
|
||||||
from langchain_google_vertexai import VertexAIEmbeddings, ChatVertexAI # Import ChatVertexAI
|
from langchain_google_vertexai import VertexAIEmbeddings, ChatVertexAI
|
||||||
from langchain_huggingface import HuggingFaceEmbeddings
|
from langchain_huggingface import HuggingFaceEmbeddings, ChatHuggingFace, HuggingFacePipeline
|
||||||
from langchain_ollama import ChatOllama, OllamaEmbeddings
|
from langchain_ollama import ChatOllama, OllamaEmbeddings
|
||||||
from langchain_openai import AzureChatOpenAI, AzureOpenAIEmbeddings, ChatOpenAI, OpenAIEmbeddings # Import ChatOpenAI
|
from langchain_openai import AzureChatOpenAI, AzureOpenAIEmbeddings, ChatOpenAI, OpenAIEmbeddings
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -52,7 +51,6 @@ def get_chat_model(settings: AppSettings) -> BaseChatModel:
|
|||||||
raise ValueError("OpenAI chat backend selected, but 'openai' configuration section is missing.")
|
raise ValueError("OpenAI chat backend selected, but 'openai' configuration section is missing.")
|
||||||
if not settings.openai.api_key or not settings.openai.chat_model:
|
if not settings.openai.api_key or not settings.openai.chat_model:
|
||||||
raise ValueError("OpenAI configuration requires 'api_key' and 'chat_model'.")
|
raise ValueError("OpenAI configuration requires 'api_key' and 'chat_model'.")
|
||||||
logger.info(f"Using OpenAI model: {model_name}")
|
|
||||||
return ChatOpenAI(model=settings.openai.chat_model, openai_api_key=settings.openai.api_key.get_secret_value())
|
return ChatOpenAI(model=settings.openai.chat_model, openai_api_key=settings.openai.api_key.get_secret_value())
|
||||||
|
|
||||||
if settings.chat_backend == ChatBackend.google_vertex:
|
if settings.chat_backend == ChatBackend.google_vertex:
|
||||||
@ -60,9 +58,12 @@ def get_chat_model(settings: AppSettings) -> BaseChatModel:
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Google Vertex chat backend selected, but 'google_vertex' configuration section is missing."
|
"Google Vertex chat backend selected, but 'google_vertex' configuration section is missing."
|
||||||
)
|
)
|
||||||
if settings.google_vertex.chat_model:
|
if (
|
||||||
model_name = settings.google.chat_model
|
not settings.google_vertex.chat_model
|
||||||
logger.info(f"Using Google Vertex model: {model_name}")
|
or not settings.google_vertex.project_id
|
||||||
|
or not settings.google_vertex.location
|
||||||
|
):
|
||||||
|
raise ValueError("Google Vertex configuration requires 'chat_model' and 'project_id'.")
|
||||||
return ChatVertexAI(
|
return ChatVertexAI(
|
||||||
model_name=settings.google_vertex.chat_model,
|
model_name=settings.google_vertex.chat_model,
|
||||||
project=settings.google_vertex.project_id,
|
project=settings.google_vertex.project_id,
|
||||||
@ -72,22 +73,35 @@ def get_chat_model(settings: AppSettings) -> BaseChatModel:
|
|||||||
if settings.chat_backend == ChatBackend.aws:
|
if settings.chat_backend == ChatBackend.aws:
|
||||||
if not settings.aws:
|
if not settings.aws:
|
||||||
raise ValueError("AWS Bedrock chat backend selected, but 'aws' configuration section is missing.")
|
raise ValueError("AWS Bedrock chat backend selected, but 'aws' configuration section is missing.")
|
||||||
model_name = "anthropic.claude-v2" # Example default
|
if not settings.aws.chat_model or not settings.aws.region_name:
|
||||||
if hasattr(settings.aws, "chat_model") and settings.aws.chat_model:
|
raise ValueError("AWS Bedrock configuration requires 'chat_model' and 'region_name'")
|
||||||
model_name = settings.aws.chat_model
|
|
||||||
logger.info(f"Using AWS Bedrock model: {model_name}")
|
|
||||||
return ChatBedrock(
|
return ChatBedrock(
|
||||||
model_id=model_name,
|
model_id=settings.aws.chat_model,
|
||||||
region_name=settings.aws.region_name,
|
region_name=settings.aws.region_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
if settings.chat_backend == ChatBackend.local:
|
if settings.chat_backend == ChatBackend.local:
|
||||||
if not settings.local or not settings.local.chat_model:
|
if not settings.local:
|
||||||
raise ValueError("Local chat backend selected, but 'local.chat_model' is missing in config.")
|
raise ValueError("Local chat backend selected, but 'local' configuration section is missing.")
|
||||||
logger.info(f"Using Local Ollama model: {settings.local.chat_model}")
|
if not settings.local.chat_model:
|
||||||
# Base URL can also be configured, e.g., base_url=config.local.ollama_base_url
|
raise ValueError("Local configuration requires 'chat_model'")
|
||||||
return ChatOllama(model=settings.local.chat_model)
|
return ChatOllama(model=settings.local.chat_model)
|
||||||
|
|
||||||
|
if settings.chat_backend == ChatBackend.huggingface:
|
||||||
|
if not settings.huggingface:
|
||||||
|
raise ValueError("Huggingface chat backend selected, but 'huggingface' configuration section is missing.")
|
||||||
|
if not settings.huggingface.chat_model:
|
||||||
|
raise ValueError("Huggingface configuration requires 'chat_model'")
|
||||||
|
llm = HuggingFacePipeline.from_model_id(
|
||||||
|
model_id=settings.huggingface.chat_model,
|
||||||
|
task="text-generation",
|
||||||
|
pipeline_kwargs=dict(
|
||||||
|
max_new_tokens=512,
|
||||||
|
do_sample=False,
|
||||||
|
repetition_penalty=1.03,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return ChatHuggingFace(llm=llm)
|
||||||
# This should not be reached if all Enum members are handled
|
# This should not be reached if all Enum members are handled
|
||||||
raise ValueError(f"Unknown or unhandled chat backend type: {settings.chat_backend}")
|
raise ValueError(f"Unknown or unhandled chat backend type: {settings.chat_backend}")
|
||||||
|
|
||||||
@ -130,25 +144,21 @@ def get_embedding_model(settings: AppSettings) -> Embeddings:
|
|||||||
raise ValueError("OpenAI embedding backend selected, but 'openai' configuration section is missing.")
|
raise ValueError("OpenAI embedding backend selected, but 'openai' configuration section is missing.")
|
||||||
if not settings.openai.api_key:
|
if not settings.openai.api_key:
|
||||||
raise ValueError("OpenAI configuration requires 'api_key'.")
|
raise ValueError("OpenAI configuration requires 'api_key'.")
|
||||||
model_name = "text-embedding-ada-002" # Example default
|
return OpenAIEmbeddings(
|
||||||
if hasattr(settings.openai, "emb_model") and settings.openai.emb_model:
|
model=settings.openai.emb_model, openai_api_key=settings.openai.api_key.get_secret_value()
|
||||||
model_name = settings.openai.emb_model
|
)
|
||||||
logger.info(f"Using OpenAI embedding model: {model_name}")
|
|
||||||
return OpenAIEmbeddings(model=model_name, openai_api_key=settings.openai.api_key.get_secret_value())
|
|
||||||
|
|
||||||
if settings.emb_backend == EmbeddingBackend.google_vertex:
|
if settings.emb_backend == EmbeddingBackend.google_vertex:
|
||||||
if not settings.google_vertex:
|
if not settings.google_vertex:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Google Vertex embedding backend selected, but 'google_vertex' configuration section is missing."
|
"Google Vertex embedding backend selected, but 'google_vertex' configuration section is missing."
|
||||||
)
|
)
|
||||||
model_name = "textembedding-gecko@001" # Example default
|
|
||||||
if (
|
if (
|
||||||
not settings.google_vertex.emb_model
|
not settings.google_vertex.emb_model
|
||||||
or not settings.google_vertex.project_id
|
or not settings.google_vertex.project_id
|
||||||
or not settings.google_vertex.location
|
or not settings.google_vertex.location
|
||||||
):
|
):
|
||||||
raise ValueError("Google Vertex configuration requires 'emb_model', 'project_id', and 'location'.")
|
raise ValueError("Google Vertex configuration requires 'emb_model', 'project_id', and 'location'.")
|
||||||
logger.info(f"Using Google Vertex embedding model: {model_name}")
|
|
||||||
return VertexAIEmbeddings(
|
return VertexAIEmbeddings(
|
||||||
model_name=settings.google_vertex.emb_model,
|
model_name=settings.google_vertex.emb_model,
|
||||||
project=settings.google_vertex.project_id,
|
project=settings.google_vertex.project_id,
|
||||||
@ -158,33 +168,24 @@ def get_embedding_model(settings: AppSettings) -> Embeddings:
|
|||||||
if settings.emb_backend == EmbeddingBackend.aws:
|
if settings.emb_backend == EmbeddingBackend.aws:
|
||||||
if not settings.aws:
|
if not settings.aws:
|
||||||
raise ValueError("AWS Bedrock embedding backend selected, but 'aws' configuration section is missing.")
|
raise ValueError("AWS Bedrock embedding backend selected, but 'aws' configuration section is missing.")
|
||||||
model_name = "amazon.titan-embed-text-v1" # Example default
|
if not settings.aws.emb_model or not settings.aws.region_name:
|
||||||
if hasattr(settings.aws, "emb_model") and settings.aws.emb_model:
|
raise ValueError("AWS Bedrock configuration requires 'emb_model' and 'region_name'")
|
||||||
model_name = settings.aws.emb_model
|
return BedrockEmbeddings(model_id=settings.aws.emb_model, region_name=settings.aws.region_name)
|
||||||
logger.info(f"Using AWS Bedrock embedding model: {model_name}")
|
|
||||||
return BedrockEmbeddings(model_id=model_name, region_name=settings.aws.region_name)
|
|
||||||
|
|
||||||
if settings.emb_backend == EmbeddingBackend.local:
|
if settings.emb_backend == EmbeddingBackend.local:
|
||||||
if not settings.local or not settings.local.emb_model:
|
if not settings.local:
|
||||||
raise ValueError("Local embedding backend selected, but 'local.emb_model' is missing in config.")
|
raise ValueError("Local embedding backend selected, but 'local' configuration section is missing.")
|
||||||
logger.info(f"Using Local Ollama embedding model: {settings.local.emb_model}")
|
if not settings.local.emb_model:
|
||||||
|
raise ValueError("Local configuration requires 'emb_model'")
|
||||||
return OllamaEmbeddings(model=settings.local.emb_model)
|
return OllamaEmbeddings(model=settings.local.emb_model)
|
||||||
|
|
||||||
if settings.emb_backend == EmbeddingBackend.huggingface:
|
if settings.emb_backend == EmbeddingBackend.huggingface:
|
||||||
if not settings.huggingface or not settings.huggingface.emb_model:
|
if not settings.huggingface:
|
||||||
if settings.local and settings.local.emb_model:
|
|
||||||
logger.warning(
|
|
||||||
"HuggingFace backend selected, but 'huggingface.emb_model' missing. Using 'local.emb_model'."
|
|
||||||
)
|
|
||||||
model_name = settings.local.emb_model
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"HuggingFace embedding backend selected, but 'huggingface.emb_model' (or 'local.emb_model') is missing in config."
|
"HuggingFace embedding backend selected, but 'huggingface' configuration section is missing."
|
||||||
)
|
)
|
||||||
else:
|
if not settings.huggingface.emb_model:
|
||||||
model_name = settings.huggingface.emb_model
|
raise ValueError("HuggingFace configuration requires 'emb_model'.")
|
||||||
|
return HuggingFaceEmbeddings(model_name=settings.huggingface.emb_model)
|
||||||
logger.info(f"Using HuggingFace embedding model: {model_name}")
|
|
||||||
return HuggingFaceEmbeddings(model_name=model_name)
|
|
||||||
|
|
||||||
raise ValueError(f"Unknown or unhandled embedding backend type: {settings.emb_backend}")
|
raise ValueError(f"Unknown or unhandled embedding backend type: {settings.emb_backend}")
|
||||||
|
|||||||
@ -17,6 +17,7 @@ class ChatBackend(str, Enum):
|
|||||||
google_vertex = "google_vertex"
|
google_vertex = "google_vertex"
|
||||||
aws = "aws"
|
aws = "aws"
|
||||||
local = "local"
|
local = "local"
|
||||||
|
huggingface = "huggingface"
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return self.value
|
return self.value
|
||||||
@ -50,6 +51,8 @@ class OpenAISettings(BaseModel):
|
|||||||
"""OpenAI specific settings."""
|
"""OpenAI specific settings."""
|
||||||
|
|
||||||
api_key: Optional[SecretStr] = None
|
api_key: Optional[SecretStr] = None
|
||||||
|
chat_model: Optional[str] = None
|
||||||
|
emb_model: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
class GoogleVertexSettings(BaseModel):
|
class GoogleVertexSettings(BaseModel):
|
||||||
@ -64,9 +67,9 @@ class GoogleVertexSettings(BaseModel):
|
|||||||
class AwsSettings(BaseModel):
|
class AwsSettings(BaseModel):
|
||||||
"""AWS specific settings (e.g., for Bedrock)."""
|
"""AWS specific settings (e.g., for Bedrock)."""
|
||||||
|
|
||||||
access_key_id: Optional[SecretStr] = None
|
chat_model: Optional[str] = None
|
||||||
secret_access_key: Optional[SecretStr] = None
|
emb_model: Optional[str] = None
|
||||||
region_name: Optional[str] = None
|
region: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
class LocalSettings(BaseModel):
|
class LocalSettings(BaseModel):
|
||||||
@ -79,6 +82,7 @@ class LocalSettings(BaseModel):
|
|||||||
class HuggingFaceSettings(BaseModel):
|
class HuggingFaceSettings(BaseModel):
|
||||||
"""HuggingFace specific settings (if different from local embeddings)."""
|
"""HuggingFace specific settings (if different from local embeddings)."""
|
||||||
|
|
||||||
|
chat_model: Optional[str] = None
|
||||||
emb_model: Optional[str] = None
|
emb_model: Optional[str] = None
|
||||||
api_token: Optional[SecretStr] = None
|
api_token: Optional[SecretStr] = None
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user