diff --git a/generic_rag/backend/models.py b/generic_rag/backend/models.py index ebcf45c..e09e5f5 100644 --- a/generic_rag/backend/models.py +++ b/generic_rag/backend/models.py @@ -19,6 +19,10 @@ class BackendType(Enum): aws = "aws" local = "local" + # make the enum pretty printable for argparse + def __str__(self): + return self.value + def get_chat_model(backend_type: BackendType) -> BaseChatModel: if backend_type == BackendType.azure: