Source code for trinity.common.constants
# -*- coding: utf-8 -*-
"""Constants."""
from enum import Enum, EnumMeta
# names
EXPLORER_NAME = "explorer"
TRAINER_NAME = "trainer"
ROLLOUT_WEIGHT_SYNC_GROUP_NAME = "rollout_weight_sync"
# trinity env var names
CHECKPOINT_ROOT_DIR_ENV_VAR = "TRINITY_CHECKPOINT_ROOT_DIR"
MODEL_PATH_ENV_VAR = "TRINITY_MODEL_PATH"
TASKSET_PATH_ENV_VAR = "TRINITY_TASKSET_PATH"
BUFFER_PATH_ENV_VAR = "TRINITY_BUFFER_PATH"
PLUGIN_DIRS_ENV_VAR = "TRINITY_PLUGIN_DIRS"
LOG_DIR_ENV_VAR = "TRINITY_LOG_DIR" # log dir
LOG_LEVEL_ENV_VAR = "TRINITY_LOG_LEVEL" # global log level
LOG_NODE_IP_ENV_VAR = "TRINITY_LOG_NODE_IP" # whether to organize logs by node IP
# constants
MAX_MODEL_LEN = 4096
# enumerate types
[docs]
class CaseInsensitiveEnum(Enum, metaclass=CaseInsensitiveEnumMeta):
pass
[docs]
class PromptType(CaseInsensitiveEnum):
"""Prompt Type."""
MESSAGES = "messages" # a list of message dict
PLAINTEXT = "plaintext" # user prompt text and assistant response text
[docs]
class StorageType(CaseInsensitiveEnum):
"""Storage Type."""
SQL = "sql"
QUEUE = "queue"
FILE = "file"
[docs]
class SyncMethod(CaseInsensitiveEnum, metaclass=SyncMethodEnumMeta):
"""Sync Method."""
NCCL = "nccl"
CHECKPOINT = "checkpoint"
MEMORY = "memory"
[docs]
class RunningStatus(Enum):
"""Running status of explorer and trainer."""
RUNNING = "running"
REQUIRE_SYNC = "require_sync"
WAITING_SYNC = "waiting_sync"
STOPPED = "stopped"
[docs]
class OpType(Enum):
"""Operator type for reward shaping."""
ADD = "add"
SUB = "sub"
MUL = "mul"
DIV = "div"
[docs]
class SyncStyle(CaseInsensitiveEnum):
FIXED = "fixed"
DYNAMIC_BY_TRAINER = "dynamic_by_trainer"
DYNAMIC_BY_EXPLORER = "dynamic_by_explorer"