# -*- coding: utf-8 -*-
"""Constants."""
from enum import Enum, EnumMeta
from trinity.utils.log import get_logger
logger = get_logger(__name__)
# names
EXPLORER_NAME = "explorer"
TRAINER_NAME = "trainer"
ROLLOUT_WEIGHT_SYNC_GROUP_NAME = "rollout_weight_sync"
# enumerate types
[docs]
class CaseInsensitiveEnum(Enum, metaclass=CaseInsensitiveEnumMeta):
pass
[docs]
class PromptType(CaseInsensitiveEnum):
"""Prompt Type."""
MESSAGES = "messages" # prompt+response: message list
CHATPAIR = "chatpair" # prompt: message list, response: message list
PLAINTEXT = "plaintext" # prompt: plaintext, response: plaintext
[docs]
class TaskType(Enum):
"""Task Type."""
EXPLORE = 0
EVAL = 1
[docs]
class ReadStrategy(CaseInsensitiveEnum):
"""Pop Strategy."""
DEFAULT = None
FIFO = "fifo"
RANDOM = "random"
LRU = "lru"
LFU = "lfu"
PRIORITY = "priority"
[docs]
class StorageType(CaseInsensitiveEnum):
"""Storage Type."""
SQL = "sql"
QUEUE = "queue"
FILE = "file"
[docs]
class MonitorType(CaseInsensitiveEnum):
"""Monitor Type."""
WANDB = "wandb"
TENSORBOARD = "tensorboard"
[docs]
class SyncMethod(CaseInsensitiveEnum, metaclass=SyncMethodEnumMeta):
"""Sync Method."""
NCCL = "nccl"
CHECKPOINT = "checkpoint"
MEMORY = "memory"
[docs]
class RunningStatus(Enum):
"""Running status of explorer and trainer."""
RUNNING = "running"
REQUIRE_SYNC = "require_sync"
WAITING_SYNC = "waiting_sync"
STOPPED = "stopped"
[docs]
class DataProcessorPipelineType(Enum):
"""Data processor pipeline type."""
EXPERIENCE = "experience_pipeline"
TASK = "task_pipeline"
[docs]
class OpType(Enum):
"""Operator type for reward shaping."""
ADD = "add"
SUB = "sub"
MUL = "mul"
DIV = "div"
[docs]
class SyncStyle(CaseInsensitiveEnum):
FIXED = "fixed"
DYNAMIC_BY_TRAINER = "dynamic_by_trainer"
DYNAMIC_BY_EXPLORER = "dynamic_by_explorer"