import streamlit as st
from trinity.common.constants import AlgorithmType, SyncMethod
from trinity.manager.config_registry.config_registry import CONFIG_GENERATORS
from trinity.trainer.verl.ray_trainer import AdvantageEstimator
[docs]
def use_critic():
return st.session_state["adv_estimator"] == AdvantageEstimator.GAE.value
@CONFIG_GENERATORS.register_config(default_value="verl")
def set_trainer_type(**kwargs):
st.selectbox("Trainer Type", ["verl"], **kwargs)
@CONFIG_GENERATORS.register_config(default_value=100, other_configs={"_nccl_save_interval": 100})
def set_save_interval(**kwargs):
key = kwargs.get("key")
if (
st.session_state["algorithm_type"] == AlgorithmType.DPO.value
or st.session_state["sync_method"] == SyncMethod.NCCL.value
):
st.session_state[key] = st.session_state["_nccl_save_interval"]
freeze_save_interval = False
else:
st.session_state[key] = st.session_state["sync_interval"]
freeze_save_interval = True
def on_change():
if (
st.session_state["algorithm_type"] == AlgorithmType.DPO.value
or st.session_state["sync_method"] == SyncMethod.NCCL.value
):
st.session_state["_nccl_save_interval"] = st.session_state[key]
st.number_input(
"Save Interval",
min_value=1,
help="Set to `sync_interval` when `algorithm_type != DPO && sync_method == checkpoint`",
disabled=freeze_save_interval,
on_change=on_change,
**kwargs,
)
@CONFIG_GENERATORS.register_config(default_value=True)
def set_enable_preview(**kwargs):
st.checkbox("Enable Preview", **kwargs)
def _actor_use_kl_loss_visible():
if st.session_state["algorithm_type"] == AlgorithmType.DPO.value:
st.session_state["actor_use_kl_loss"] = True
return False
return True
@CONFIG_GENERATORS.register_config(
default_value=True,
visible=_actor_use_kl_loss_visible,
other_configs={"_not_dpo_actor_use_kl_loss": True},
)
def set_actor_use_kl_loss(**kwargs):
key = kwargs.get("key")
st.session_state[key] = st.session_state["_not_dpo_actor_use_kl_loss"]
def on_change():
st.session_state["_not_dpo_actor_use_kl_loss"] = st.session_state[key]
st.checkbox("Use KL Loss", on_change=on_change, **kwargs)
@CONFIG_GENERATORS.register_config(
default_value=0.001, visible=lambda: st.session_state["actor_use_kl_loss"]
)
def set_actor_kl_loss_coef(**kwargs):
st.number_input(
r"KL Loss Coef :blue-badge[$\beta$]",
min_value=0.0,
max_value=1.0,
format="%.1e",
**kwargs,
)
@CONFIG_GENERATORS.register_config(
default_value=0.001, visible=lambda: st.session_state["actor_use_kl_loss"]
)
def set_actor_entropy_coef(**kwargs):
st.number_input(
"Entropy Coeff",
min_value=0.0,
max_value=1.0,
format="%.1e",
**kwargs,
)
@CONFIG_GENERATORS.register_config(default_value=1.0)
def set_actor_grad_clip(**kwargs):
st.number_input(
"Grad Clip :blue-badge[(Actor)]",
min_value=0.0,
max_value=1.0,
help="Clipping by Norm",
**kwargs,
)
@CONFIG_GENERATORS.register_config(default_value=0.2)
def set_actor_clip_ratio(**kwargs):
st.number_input(
r"Clip Ratio :blue-badge[$\epsilon$]",
min_value=0.0,
max_value=1.0,
**kwargs,
)
# veRL Trainer Configs
@CONFIG_GENERATORS.register_config(
default_value=[
"balance_batch",
"gradient_checkpointing",
"remove_padding",
"dynamic_bsz",
]
)
def set_training_args(**kwargs):
st.multiselect(
"Training Args",
[
"balance_batch",
"gradient_checkpointing",
"remove_padding",
"dynamic_bsz",
],
**kwargs,
)
@CONFIG_GENERATORS.register_config(default_value=1)
def set_ppo_epochs(**kwargs):
st.number_input("PPO Epochs", min_value=1, **kwargs)
@CONFIG_GENERATORS.register_config(default_value="fsdp")
def set_training_strategy(**kwargs):
st.selectbox(
"Training Strategy",
["fsdp", "megatron"],
help="megatron is not tested",
**kwargs,
)
[docs]
def use_fsdp():
return st.session_state["training_strategy"] == "fsdp"
@CONFIG_GENERATORS.register_config(default_value=False, visible=use_fsdp)
def set_param_offload(**kwargs):
st.checkbox("FSDP Param Offload", **kwargs)
@CONFIG_GENERATORS.register_config(default_value=False, visible=use_fsdp)
def set_optimizer_offload(**kwargs):
st.checkbox("FSDP Optimizer Offload", **kwargs)
@CONFIG_GENERATORS.register_config(default_value="auto")
def set_resume_mode(**kwargs):
st.selectbox("Resume Mode", ["disable", "auto", "resume_path"], **kwargs)
@CONFIG_GENERATORS.register_config(
default_value="", visible=lambda: st.session_state["resume_mode"] == "resume_path"
)
def set_resume_from_path(**kwargs):
st.text_input("Resume Path", **kwargs)
@CONFIG_GENERATORS.register_check()
def check_resume_from_path(unfinished_fields: set, key: str):
if st.session_state["resume_mode"] == "resume_path" and (
not st.session_state[key].strip() or "global_step_" not in st.session_state[key]
):
unfinished_fields.add(key)
st.warning("Please input a valid resume path when `resume_mode == resume_path`")
@CONFIG_GENERATORS.register_config(default_value=0)
def set_critic_warmup(**kwargs):
st.number_input("Critic Warmup Steps", min_value=0, **kwargs)
@CONFIG_GENERATORS.register_config(default_value=None)
def set_total_training_steps(**kwargs):
st.number_input("Total Training Steps", min_value=1, **kwargs)
@CONFIG_GENERATORS.register_config(default_value=None)
def set_default_hdfs_dir(**kwargs):
st.text_input("Default HDFS Dir", **kwargs)
@CONFIG_GENERATORS.register_config(default_value=False)
def set_remove_previous_ckpt_in_save(**kwargs):
st.checkbox("Remove Previous Checkpoint in Save", **kwargs)
@CONFIG_GENERATORS.register_config(default_value=False)
def set_del_local_ckpt_after_load(**kwargs):
st.checkbox("Delete Local Checkpoint After Load", **kwargs)
@CONFIG_GENERATORS.register_config(default_value=None)
def set_max_actor_ckpt_to_keep(**kwargs):
st.number_input("Max Actor Checkpoint to Keep", min_value=1, **kwargs)
@CONFIG_GENERATORS.register_config(default_value=None)
def set_max_critic_ckpt_to_keep(**kwargs):
st.number_input("Max Critic Checkpoint to Keep", min_value=1, **kwargs)
@CONFIG_GENERATORS.register_config(default_value=True)
def set_norm_adv_by_std_in_grpo(**kwargs):
st.checkbox("Norm Adv by Std in GRPO", **kwargs)
@CONFIG_GENERATORS.register_config(default_value=False)
def set_use_kl_in_reward(**kwargs):
st.checkbox("Use KL in Reward", **kwargs)
@CONFIG_GENERATORS.register_config(default_value="low_var_kl")
def set_kl_penalty(**kwargs):
st.selectbox("KL Penalty", ["kl", "abs", "mse", "low_var_kl"], **kwargs)
@CONFIG_GENERATORS.register_config(default_value="fixed")
def set_kl_ctrl_type(**kwargs):
st.selectbox("KL Ctrl Type", ["fixed", "adaptive"], **kwargs)
@CONFIG_GENERATORS.register_config(default_value=0.001)
def set_kl_ctrl_coef(**kwargs):
st.number_input("KL Ctrl Coef", format="%.1e", **kwargs)
@CONFIG_GENERATORS.register_config(default_value=10000)
def set_horizon(**kwargs):
st.number_input("Horizon", min_value=1.0, **kwargs)
@CONFIG_GENERATORS.register_config(default_value=0.1)
def set_target_kl(**kwargs):
st.number_input("Target KL", format="%.1e", **kwargs)
@CONFIG_GENERATORS.register_config(default_value=4)
def set_actor_ppo_micro_batch_size_per_gpu(**kwargs):
key = kwargs.get("key")
max_value = st.session_state["_train_batch_size_per_gpu"]
st.session_state[key] = min(st.session_state[key], max_value)
st.number_input(
"Micro Batch Size Per GPU :blue-badge[(Actor)]", min_value=1, max_value=max_value, **kwargs
)
@CONFIG_GENERATORS.register_config(default_value=8)
def set_ref_log_prob_micro_batch_size_per_gpu(**kwargs):
key = kwargs.get("key")
max_value = st.session_state["_train_batch_size_per_gpu"]
st.session_state[key] = min(st.session_state[key], max_value)
st.number_input(
"Micro Batch Size Per GPU :blue-badge[(Ref)]", min_value=1, max_value=max_value, **kwargs
)
@CONFIG_GENERATORS.register_config(default_value=1)
def set_actor_ulysses_sequence_parallel_size(**kwargs):
st.number_input(
"Ulysses Sequence Parallel Size",
min_value=1,
max_value=8,
**kwargs,
)
@CONFIG_GENERATORS.register_config(default_value=1e-6)
def set_actor_lr(**kwargs):
st.number_input(
"Learning Rate :blue-badge[(Actor)]",
min_value=1e-7,
max_value=1e-3,
format="%.1e",
**kwargs,
)
@CONFIG_GENERATORS.register_config(default_value="constant")
def set_actor_warmup_style(**kwargs):
st.selectbox(
"LR Warmup Style :blue-badge[(Actor)]",
["constant", "cosine"],
**kwargs,
)
@CONFIG_GENERATORS.register_config(default_value=0.0)
def set_actor_lr_warmup_steps_ratio(**kwargs):
st.number_input(
"LR Warmup Steps Ratio :blue-badge[(Actor)]",
min_value=0.0,
max_value=1.0,
**kwargs,
)
@CONFIG_GENERATORS.register_config(
default_value=0.0, visible=lambda: st.session_state["algorithm_type"] == "opmd"
)
def set_actor_tau(**kwargs):
st.number_input("Tau for OPMD", min_value=0.0, format="%.1e", **kwargs)
@CONFIG_GENERATORS.register_config(
default_value="mean", visible=lambda: st.session_state["algorithm_type"] == "opmd"
)
def set_actor_opmd_baseline(**kwargs):
st.selectbox(
"OPMD Baseline",
["mean", "logavgexp"],
**kwargs,
)
@CONFIG_GENERATORS.register_config(
default_value=False, visible=lambda: st.session_state["algorithm_type"] == "opmd"
)
def set_actor_use_uid(**kwargs):
st.checkbox("Use UID for OPMD", **kwargs)
@CONFIG_GENERATORS.register_config(default_value="low_var_kl")
def set_actor_kl_loss_type(**kwargs):
st.selectbox(
"KL Loss Type",
["kl", "abs", "mse", "low_var_kl"],
**kwargs,
)
@CONFIG_GENERATORS.register_config(default_value=["model", "hf_model", "optimizer", "extra"])
def set_actor_checkpoint(**kwargs):
st.multiselect(
"Checkpoint",
["model", "hf_model", "optimizer", "extra"],
**kwargs,
)
@CONFIG_GENERATORS.register_config(default_value=1e-6, visible=use_critic)
def set_critic_lr(**kwargs):
st.number_input(
"Learning Rate :blue-badge[(Critic)]",
min_value=1e-7,
max_value=1e-3,
format="%.1e",
**kwargs,
)
@CONFIG_GENERATORS.register_config(default_value="constant", visible=use_critic)
def set_critic_warmup_style(**kwargs):
st.selectbox(
"LR Warmup Style :blue-badge[(Critic)]",
["constant", "cosine"],
**kwargs,
)
@CONFIG_GENERATORS.register_config(default_value=0.0, visible=use_critic)
def set_critic_lr_warmup_steps_ratio(**kwargs):
st.number_input(
"LR Warmup Steps Ratio :blue-badge[(Critic)]",
min_value=0.0,
max_value=1.0,
**kwargs,
)
@CONFIG_GENERATORS.register_config(default_value=1.0, visible=use_critic)
def set_critic_grad_clip(**kwargs):
st.number_input(
"Grad Clip :blue-badge[(Critic)]",
min_value=0.0,
max_value=1.0,
help="Clipping by Norm",
**kwargs,
)
@CONFIG_GENERATORS.register_config(default_value=0.5, visible=use_critic)
def set_critic_cliprange_value(**kwargs):
st.number_input(
"Cliprange Value",
min_value=0.0,
max_value=1.0,
**kwargs,
)
@CONFIG_GENERATORS.register_config(default_value=8, visible=use_critic)
def set_critic_ppo_micro_batch_size_per_gpu(**kwargs):
key = kwargs.get("key")
max_value = st.session_state["_train_batch_size_per_gpu"]
st.session_state[key] = min(st.session_state[key], max_value)
st.number_input(
"Micro Batch Size Per GPU :blue-badge[(Critic)]",
min_value=1,
max_value=max_value,
**kwargs,
)
@CONFIG_GENERATORS.register_config(default_value=1, visible=use_critic)
def set_critic_ulysses_sequence_parallel_size(**kwargs):
st.number_input(
"Ulysses Sequence Parallel Size",
min_value=1,
max_value=8,
**kwargs,
)
@CONFIG_GENERATORS.register_config(
default_value=["model", "optimizer", "extra"], visible=use_critic
)
def set_critic_checkpoint(**kwargs):
st.multiselect(
"Checkpoint",
["model", "hf_model", "optimizer", "extra"],
**kwargs,
)