Source code for trinity.manager.config_registry.trainer_config_manager

import streamlit as st

from trinity.common.constants import AlgorithmType, SyncMethod
from trinity.manager.config_registry.config_registry import CONFIG_GENERATORS
from trinity.trainer.verl.ray_trainer import AdvantageEstimator


[docs] def use_critic(): return st.session_state["adv_estimator"] == AdvantageEstimator.GAE.value
@CONFIG_GENERATORS.register_config(default_value="verl") def set_trainer_type(**kwargs): st.selectbox("Trainer Type", ["verl"], **kwargs) @CONFIG_GENERATORS.register_config(default_value=100, other_configs={"_nccl_save_interval": 100}) def set_save_interval(**kwargs): key = kwargs.get("key") if ( st.session_state["algorithm_type"] == AlgorithmType.DPO.value or st.session_state["sync_method"] == SyncMethod.NCCL.value ): st.session_state[key] = st.session_state["_nccl_save_interval"] freeze_save_interval = False else: st.session_state[key] = st.session_state["sync_interval"] freeze_save_interval = True def on_change(): if ( st.session_state["algorithm_type"] == AlgorithmType.DPO.value or st.session_state["sync_method"] == SyncMethod.NCCL.value ): st.session_state["_nccl_save_interval"] = st.session_state[key] st.number_input( "Save Interval", min_value=1, help="Set to `sync_interval` when `algorithm_type != DPO && sync_method == checkpoint`", disabled=freeze_save_interval, on_change=on_change, **kwargs, ) @CONFIG_GENERATORS.register_config(default_value=True) def set_enable_preview(**kwargs): st.checkbox("Enable Preview", **kwargs) def _actor_use_kl_loss_visible(): if st.session_state["algorithm_type"] == AlgorithmType.DPO.value: st.session_state["actor_use_kl_loss"] = True return False return True @CONFIG_GENERATORS.register_config( default_value=True, visible=_actor_use_kl_loss_visible, other_configs={"_not_dpo_actor_use_kl_loss": True}, ) def set_actor_use_kl_loss(**kwargs): key = kwargs.get("key") st.session_state[key] = st.session_state["_not_dpo_actor_use_kl_loss"] def on_change(): st.session_state["_not_dpo_actor_use_kl_loss"] = st.session_state[key] st.checkbox("Use KL Loss", on_change=on_change, **kwargs) @CONFIG_GENERATORS.register_config( default_value=0.001, visible=lambda: st.session_state["actor_use_kl_loss"] ) def set_actor_kl_loss_coef(**kwargs): st.number_input( r"KL Loss Coef :blue-badge[$\beta$]", min_value=0.0, max_value=1.0, format="%.1e", **kwargs, ) @CONFIG_GENERATORS.register_config( default_value=0.001, visible=lambda: st.session_state["actor_use_kl_loss"] ) def set_actor_entropy_coef(**kwargs): st.number_input( "Entropy Coeff", min_value=0.0, max_value=1.0, format="%.1e", **kwargs, ) @CONFIG_GENERATORS.register_config(default_value=1.0) def set_actor_grad_clip(**kwargs): st.number_input( "Grad Clip :blue-badge[(Actor)]", min_value=0.0, max_value=1.0, help="Clipping by Norm", **kwargs, ) @CONFIG_GENERATORS.register_config(default_value=0.2) def set_actor_clip_ratio(**kwargs): st.number_input( r"Clip Ratio :blue-badge[$\epsilon$]", min_value=0.0, max_value=1.0, **kwargs, ) # veRL Trainer Configs @CONFIG_GENERATORS.register_config( default_value=[ "balance_batch", "gradient_checkpointing", "remove_padding", "dynamic_bsz", ] ) def set_training_args(**kwargs): st.multiselect( "Training Args", [ "balance_batch", "gradient_checkpointing", "remove_padding", "dynamic_bsz", ], **kwargs, ) @CONFIG_GENERATORS.register_config(default_value=1) def set_ppo_epochs(**kwargs): st.number_input("PPO Epochs", min_value=1, **kwargs) @CONFIG_GENERATORS.register_config(default_value="fsdp") def set_training_strategy(**kwargs): st.selectbox( "Training Strategy", ["fsdp", "megatron"], help="megatron is not tested", **kwargs, )
[docs] def use_fsdp(): return st.session_state["training_strategy"] == "fsdp"
@CONFIG_GENERATORS.register_config(default_value=False, visible=use_fsdp) def set_param_offload(**kwargs): st.checkbox("FSDP Param Offload", **kwargs) @CONFIG_GENERATORS.register_config(default_value=False, visible=use_fsdp) def set_optimizer_offload(**kwargs): st.checkbox("FSDP Optimizer Offload", **kwargs) @CONFIG_GENERATORS.register_config(default_value="auto") def set_resume_mode(**kwargs): st.selectbox("Resume Mode", ["disable", "auto", "resume_path"], **kwargs) @CONFIG_GENERATORS.register_config( default_value="", visible=lambda: st.session_state["resume_mode"] == "resume_path" ) def set_resume_from_path(**kwargs): st.text_input("Resume Path", **kwargs) @CONFIG_GENERATORS.register_check() def check_resume_from_path(unfinished_fields: set, key: str): if st.session_state["resume_mode"] == "resume_path" and ( not st.session_state[key].strip() or "global_step_" not in st.session_state[key] ): unfinished_fields.add(key) st.warning("Please input a valid resume path when `resume_mode == resume_path`") @CONFIG_GENERATORS.register_config(default_value=0) def set_critic_warmup(**kwargs): st.number_input("Critic Warmup Steps", min_value=0, **kwargs) @CONFIG_GENERATORS.register_config(default_value=None) def set_total_training_steps(**kwargs): st.number_input("Total Training Steps", min_value=1, **kwargs) @CONFIG_GENERATORS.register_config(default_value=None) def set_default_hdfs_dir(**kwargs): st.text_input("Default HDFS Dir", **kwargs) @CONFIG_GENERATORS.register_config(default_value=False) def set_remove_previous_ckpt_in_save(**kwargs): st.checkbox("Remove Previous Checkpoint in Save", **kwargs) @CONFIG_GENERATORS.register_config(default_value=False) def set_del_local_ckpt_after_load(**kwargs): st.checkbox("Delete Local Checkpoint After Load", **kwargs) @CONFIG_GENERATORS.register_config(default_value=None) def set_max_actor_ckpt_to_keep(**kwargs): st.number_input("Max Actor Checkpoint to Keep", min_value=1, **kwargs) @CONFIG_GENERATORS.register_config(default_value=None) def set_max_critic_ckpt_to_keep(**kwargs): st.number_input("Max Critic Checkpoint to Keep", min_value=1, **kwargs) @CONFIG_GENERATORS.register_config(default_value=True) def set_norm_adv_by_std_in_grpo(**kwargs): st.checkbox("Norm Adv by Std in GRPO", **kwargs) @CONFIG_GENERATORS.register_config(default_value=False) def set_use_kl_in_reward(**kwargs): st.checkbox("Use KL in Reward", **kwargs) @CONFIG_GENERATORS.register_config(default_value="low_var_kl") def set_kl_penalty(**kwargs): st.selectbox("KL Penalty", ["kl", "abs", "mse", "low_var_kl"], **kwargs) @CONFIG_GENERATORS.register_config(default_value="fixed") def set_kl_ctrl_type(**kwargs): st.selectbox("KL Ctrl Type", ["fixed", "adaptive"], **kwargs) @CONFIG_GENERATORS.register_config(default_value=0.001) def set_kl_ctrl_coef(**kwargs): st.number_input("KL Ctrl Coef", format="%.1e", **kwargs) @CONFIG_GENERATORS.register_config(default_value=10000) def set_horizon(**kwargs): st.number_input("Horizon", min_value=1.0, **kwargs) @CONFIG_GENERATORS.register_config(default_value=0.1) def set_target_kl(**kwargs): st.number_input("Target KL", format="%.1e", **kwargs) @CONFIG_GENERATORS.register_config(default_value=4) def set_actor_ppo_micro_batch_size_per_gpu(**kwargs): key = kwargs.get("key") max_value = st.session_state["_train_batch_size_per_gpu"] st.session_state[key] = min(st.session_state[key], max_value) st.number_input( "Micro Batch Size Per GPU :blue-badge[(Actor)]", min_value=1, max_value=max_value, **kwargs ) @CONFIG_GENERATORS.register_config(default_value=8) def set_ref_log_prob_micro_batch_size_per_gpu(**kwargs): key = kwargs.get("key") max_value = st.session_state["_train_batch_size_per_gpu"] st.session_state[key] = min(st.session_state[key], max_value) st.number_input( "Micro Batch Size Per GPU :blue-badge[(Ref)]", min_value=1, max_value=max_value, **kwargs ) @CONFIG_GENERATORS.register_config(default_value=1) def set_actor_ulysses_sequence_parallel_size(**kwargs): st.number_input( "Ulysses Sequence Parallel Size", min_value=1, max_value=8, **kwargs, ) @CONFIG_GENERATORS.register_config(default_value=1e-6) def set_actor_lr(**kwargs): st.number_input( "Learning Rate :blue-badge[(Actor)]", min_value=1e-7, max_value=1e-3, format="%.1e", **kwargs, ) @CONFIG_GENERATORS.register_config(default_value="constant") def set_actor_warmup_style(**kwargs): st.selectbox( "LR Warmup Style :blue-badge[(Actor)]", ["constant", "cosine"], **kwargs, ) @CONFIG_GENERATORS.register_config(default_value=0.0) def set_actor_lr_warmup_steps_ratio(**kwargs): st.number_input( "LR Warmup Steps Ratio :blue-badge[(Actor)]", min_value=0.0, max_value=1.0, **kwargs, ) @CONFIG_GENERATORS.register_config( default_value=0.0, visible=lambda: st.session_state["algorithm_type"] == "opmd" ) def set_actor_tau(**kwargs): st.number_input("Tau for OPMD", min_value=0.0, format="%.1e", **kwargs) @CONFIG_GENERATORS.register_config( default_value="mean", visible=lambda: st.session_state["algorithm_type"] == "opmd" ) def set_actor_opmd_baseline(**kwargs): st.selectbox( "OPMD Baseline", ["mean", "logavgexp"], **kwargs, ) @CONFIG_GENERATORS.register_config( default_value=False, visible=lambda: st.session_state["algorithm_type"] == "opmd" ) def set_actor_use_uid(**kwargs): st.checkbox("Use UID for OPMD", **kwargs) @CONFIG_GENERATORS.register_config(default_value="low_var_kl") def set_actor_kl_loss_type(**kwargs): st.selectbox( "KL Loss Type", ["kl", "abs", "mse", "low_var_kl"], **kwargs, ) @CONFIG_GENERATORS.register_config(default_value=["model", "hf_model", "optimizer", "extra"]) def set_actor_checkpoint(**kwargs): st.multiselect( "Checkpoint", ["model", "hf_model", "optimizer", "extra"], **kwargs, ) @CONFIG_GENERATORS.register_config(default_value=1e-6, visible=use_critic) def set_critic_lr(**kwargs): st.number_input( "Learning Rate :blue-badge[(Critic)]", min_value=1e-7, max_value=1e-3, format="%.1e", **kwargs, ) @CONFIG_GENERATORS.register_config(default_value="constant", visible=use_critic) def set_critic_warmup_style(**kwargs): st.selectbox( "LR Warmup Style :blue-badge[(Critic)]", ["constant", "cosine"], **kwargs, ) @CONFIG_GENERATORS.register_config(default_value=0.0, visible=use_critic) def set_critic_lr_warmup_steps_ratio(**kwargs): st.number_input( "LR Warmup Steps Ratio :blue-badge[(Critic)]", min_value=0.0, max_value=1.0, **kwargs, ) @CONFIG_GENERATORS.register_config(default_value=1.0, visible=use_critic) def set_critic_grad_clip(**kwargs): st.number_input( "Grad Clip :blue-badge[(Critic)]", min_value=0.0, max_value=1.0, help="Clipping by Norm", **kwargs, ) @CONFIG_GENERATORS.register_config(default_value=0.5, visible=use_critic) def set_critic_cliprange_value(**kwargs): st.number_input( "Cliprange Value", min_value=0.0, max_value=1.0, **kwargs, ) @CONFIG_GENERATORS.register_config(default_value=8, visible=use_critic) def set_critic_ppo_micro_batch_size_per_gpu(**kwargs): key = kwargs.get("key") max_value = st.session_state["_train_batch_size_per_gpu"] st.session_state[key] = min(st.session_state[key], max_value) st.number_input( "Micro Batch Size Per GPU :blue-badge[(Critic)]", min_value=1, max_value=max_value, **kwargs, ) @CONFIG_GENERATORS.register_config(default_value=1, visible=use_critic) def set_critic_ulysses_sequence_parallel_size(**kwargs): st.number_input( "Ulysses Sequence Parallel Size", min_value=1, max_value=8, **kwargs, ) @CONFIG_GENERATORS.register_config( default_value=["model", "optimizer", "extra"], visible=use_critic ) def set_critic_checkpoint(**kwargs): st.multiselect( "Checkpoint", ["model", "hf_model", "optimizer", "extra"], **kwargs, )