38 lines
1.0 KiB
Python
38 lines
1.0 KiB
Python
from langchain_together import ChatTogether
|
|
from enum import StrEnum
|
|
from dotenv import load_dotenv
|
|
|
|
_ = load_dotenv()
|
|
|
|
class ModelName(StrEnum):
|
|
"""String enum representing different available models."""
|
|
|
|
# Together AI models
|
|
LLAMA_3_8B = "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo"
|
|
LLAMA_4_17B = "meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8"
|
|
MIXTRAL_8X7B = "mistralai/Mixtral-8x7B-Instruct-v0.1"
|
|
MIXTRAL_8X22B = "mistralai/Mixtral-8x22B-Instruct-v0.1"
|
|
LLAMA_GUARD = "meta-llama/Meta-Llama-Guard-3-8B"
|
|
|
|
# # Open source models
|
|
GEMMA_7B = "google/gemma-7b-it"
|
|
GEMMA_2B = "google/gemma-2b-it"
|
|
|
|
# Default model
|
|
DEFAULT = LLAMA_3_8B
|
|
|
|
def __str__(self) -> str:
|
|
return self.value
|
|
|
|
|
|
def get_model(model_string):
|
|
|
|
if model_string == ModelName.LLAMA_3_8B:
|
|
return ChatTogether(model=ModelName.LLAMA_3_8B)
|
|
if model_string == ModelName.LLAMA_4_17B:
|
|
return ChatTogether(model=ModelName.LLAMA_4_17B)
|
|
|
|
raise ValueError(f"{model_string} not known")
|
|
|
|
|
|
|