diff --git a/generic_rag/backend/models.py b/generic_rag/backend/models.py index e09e5f5..bebb093 100644 --- a/generic_rag/backend/models.py +++ b/generic_rag/backend/models.py @@ -7,7 +7,7 @@ from langchain_core.embeddings import Embeddings from langchain_core.language_models.chat_models import BaseChatModel from langchain_google_vertexai import VertexAIEmbeddings from langchain_huggingface import HuggingFaceEmbeddings -from langchain_ollama import OllamaLLM +from langchain_ollama import ChatOllama from langchain_openai import AzureChatOpenAI, AzureOpenAIEmbeddings from langchain_openai import OpenAIEmbeddings @@ -41,7 +41,7 @@ def get_chat_model(backend_type: BackendType) -> BaseChatModel: return init_chat_model(model=os.environ["AWS_CHAT_MODEL"], model_provider="bedrock_converse") if backend_type == BackendType.local: - return OllamaLLM(model=os.environ["LOCAL_CHAT_MODEL"]) + return ChatOllama(model=os.environ["LOCAL_CHAT_MODEL"]) raise ValueError(f"Unknown backend type: {backend_type}")