diff --git a/generic_rag/backend/models.py b/generic_rag/backend/models.py index 4f7081e..15be5cd 100644 --- a/generic_rag/backend/models.py +++ b/generic_rag/backend/models.py @@ -15,7 +15,7 @@ from langchain_ollama import OllamaEmbeddings class BackendType(Enum): azure = "azure" openai = "openai" - google = "google" + google_vertex = "google_vertex" aws = "aws" local = "local" @@ -35,7 +35,7 @@ def get_chat_model(backend_type: BackendType) -> BaseChatModel: if backend_type == BackendType.openai: return init_chat_model(os.environ["OPENAI_CHAT_MODEL"], model_provider="openai") - if backend_type == BackendType.google: + if backend_type == BackendType.google_vertex: return init_chat_model(os.environ["GOOGLE_CHAT_MODEL"], model_provider="google_vertexai") if backend_type == BackendType.aws: @@ -58,7 +58,7 @@ def get_embedding_model(backend_type: BackendType) -> Embeddings: if backend_type == BackendType.openai: return OpenAIEmbeddings(model=os.environ["OPENAI_EMB_MODEL"]) - if backend_type == BackendType.google: + if backend_type == BackendType.google_vertex: return VertexAIEmbeddings(model=os.environ["GOOGLE_EMB_MODEL"]) if backend_type == BackendType.aws: