# -*- coding: utf-8 -*-
"""Constants."""
from enum import Enum, EnumMeta
from trinity.utils.log import get_logger
logger = get_logger(__name__)
# names
ROLLOUT_WEIGHT_SYNC_GROUP_NAME = "rollout_weight_sync"
# enumerate types
[docs]
class CaseInsensitiveEnum(Enum, metaclass=CaseInsensitiveEnumMeta):
pass
[docs]
class PromptType(CaseInsensitiveEnum):
"""Prompt Type."""
MESSAGES = "messages" # prompt+response: message list
CHATPAIR = "chatpair" # prompt: message list, response: message list
PLAINTEXT = "plaintext" # prompt: plaintext, response: plaintext
[docs]
class TaskType(Enum):
"""Task Type."""
EXPLORE = 0
EVAL = 1
[docs]
class ReadStrategy(CaseInsensitiveEnum):
"""Pop Strategy."""
DEFAULT = None
FIFO = "fifo"
RANDOM = "random"
LRU = "lru"
LFU = "lfu"
PRIORITY = "priority"
[docs]
class StorageType(CaseInsensitiveEnum):
"""Storage Type."""
SQL = "sql"
QUEUE = "queue"
FILE = "file"
[docs]
class AlgorithmType(CaseInsensitiveEnum):
"""Algorithm Type."""
SFT = "sft"
PPO = "ppo"
GRPO = "grpo"
OPMD = "opmd"
PAIRWISE_OPMD = "pairwise_opmd"
DPO = "dpo"
[docs]
def is_rft(self) -> bool:
"""Check if the algorithm is RFT."""
return self in [
AlgorithmType.PPO,
AlgorithmType.GRPO,
AlgorithmType.OPMD,
AlgorithmType.PAIRWISE_OPMD,
]
[docs]
def is_sft(self) -> bool:
"""Check if the algorithm is SFT."""
return self == AlgorithmType.SFT
[docs]
def is_dpo(self) -> bool:
"""Check if the algorithm is DPO."""
return self == AlgorithmType.DPO
[docs]
class MonitorType(CaseInsensitiveEnum):
"""Monitor Type."""
WANDB = "wandb"
TENSORBOARD = "tensorboard"
[docs]
class SyncMethod(CaseInsensitiveEnum, metaclass=SyncMethodEnumMeta):
"""Sync Method."""
NCCL = "nccl"
CHECKPOINT = "checkpoint"