Source code for trinity.manager.config_registry.model_config_manager

import os

import streamlit as st

from trinity.common.constants import AlgorithmType, MonitorType
from trinity.manager.config_registry.config_registry import CONFIG_GENERATORS
from trinity.manager.config_registry.trainer_config_manager import use_critic
from trinity.trainer.verl.ray_trainer import AdvantageEstimator


[docs] def set_total_gpu_num(): st.session_state["total_gpu_num"] = ( st.session_state["gpu_per_node"] * st.session_state["node_num"] ) set_trainer_gpu_num()
[docs] def set_trainer_gpu_num(): if st.session_state["mode"] == "both": trainer_gpu_num = ( st.session_state["total_gpu_num"] - st.session_state["engine_num"] * st.session_state["tensor_parallel_size"] ) for idx in range(st.session_state["_auxiliary_models_num"]): engine_num = st.session_state[f"auxiliary_model_{idx}_engine_num"] tensor_parallel_size = st.session_state[f"auxiliary_model_{idx}_tensor_parallel_size"] trainer_gpu_num -= engine_num * tensor_parallel_size st.session_state["trainer_gpu_num"] = trainer_gpu_num else: # model == train st.session_state["trainer_gpu_num"] = st.session_state["total_gpu_num"]
@CONFIG_GENERATORS.register_config(default_value="Trinity-RFT") def set_project(**kwargs): st.text_input("Project", **kwargs) @CONFIG_GENERATORS.register_config(default_value="qwen2.5-1.5B") def set_exp_name(**kwargs): st.text_input("Experiment Name", **kwargs) @CONFIG_GENERATORS.register_config(default_value="") def set_checkpoint_root_dir(**kwargs): st.text_input("Checkpoint Root Dir", **kwargs) @CONFIG_GENERATORS.register_check() def check_checkpoint_root_dir(unfinished_fields: set, key: str): if not st.session_state[key].strip(): # TODO: may auto generate unfinished_fields.add(key) st.warning("Please input checkpoint root dir.") elif not os.path.isabs(st.session_state[key].strip()): unfinished_fields.add("checkpoint_root_dir") st.warning("Please input an absolute path.") @CONFIG_GENERATORS.register_config(default_value=MonitorType.TENSORBOARD.value) def set_monitor_type(**kwargs): st.selectbox( "Monitor Type", options=[monitor_type.value for monitor_type in MonitorType], **kwargs, ) # Algorithm Configs @CONFIG_GENERATORS.register_config( default_value=AlgorithmType.PPO.value, other_configs={"mode": "both", "adv_estimator": AdvantageEstimator.GAE.value}, ) def set_algorithm_type(**kwargs): def on_change(): if st.session_state["algorithm_type"] == AlgorithmType.PPO.value: st.session_state["mode"] = "both" st.session_state["adv_estimator"] = AdvantageEstimator.GAE.value elif st.session_state["algorithm_type"] == AlgorithmType.GRPO.value: st.session_state["mode"] = "both" st.session_state["adv_estimator"] = AdvantageEstimator.GRPO.value elif st.session_state["algorithm_type"] == AlgorithmType.DPO.value: st.session_state["mode"] = "train" st.session_state["adv_estimator"] = AdvantageEstimator.GRPO.value elif st.session_state["algorithm_type"] == AlgorithmType.OPMD.value: st.session_state["mode"] = "both" st.session_state["adv_estimator"] = AdvantageEstimator.GRPO.value else: # TODO: add more algorithms pass set_trainer_gpu_num() st.selectbox( "Algorithm Type", [ AlgorithmType.PPO.value, AlgorithmType.GRPO.value, AlgorithmType.DPO.value, AlgorithmType.OPMD.value, ], key="algorithm_type", on_change=on_change, ) @CONFIG_GENERATORS.register_config( default_value=1, visible=lambda: st.session_state["mode"] == "both", other_configs={ "_grouped_adv_repeat_times": 2, "_not_grouped_adv_repeat_times": 1, }, ) def set_repeat_times(**kwargs): # TODO key = kwargs.get("key") grouped_adv_algorithms = [ AlgorithmType.GRPO.value, AlgorithmType.OPMD.value, # TODO: may add rloo ] if st.session_state["algorithm_type"] in grouped_adv_algorithms: min_repeat_times = 2 st.session_state[key] = st.session_state["_grouped_adv_repeat_times"] else: min_repeat_times = 1 st.session_state[key] = st.session_state["_not_grouped_adv_repeat_times"] def on_change(): if st.session_state["algorithm_type"] in grouped_adv_algorithms: st.session_state["_grouped_adv_repeat_times"] = st.session_state[key] else: st.session_state["_not_grouped_adv_repeat_times"] = st.session_state[key] st.number_input( "Repeat Times", min_value=min_repeat_times, help="`repeat_times` is used to set how many experiences each task can generate, " "and it must be greater than `1` when `algorithm_type` is `opmd` or `grpo`.", on_change=on_change, **kwargs, ) @CONFIG_GENERATORS.register_config(default_value=1.0) def set_gamma(**kwargs): st.number_input(r"Gamma :blue-badge[$\gamma$]", **kwargs) @CONFIG_GENERATORS.register_config(default_value=1.0) def set_lam(**kwargs): st.number_input(r"Lambda :blue-badge[$\lambda$]", **kwargs) # Model Configs @CONFIG_GENERATORS.register_config(default_value="") def set_model_path(**kwargs): st.text_input("Model Path", **kwargs) @CONFIG_GENERATORS.register_check() def check_model_path(unfinished_fields: set, key: str): if not st.session_state[key].strip(): unfinished_fields.add(key) st.warning("Please input model path.") @CONFIG_GENERATORS.register_config( default_value="", visible=use_critic, ) def set_critic_model_path(**kwargs): st.text_input( "Critic Model Path (defaults to `model_path`)", key="critic_model_path", ) @CONFIG_GENERATORS.register_config(default_value=1024) def set_max_prompt_tokens(**kwargs): st.number_input("Max Prompt Tokens", min_value=1, **kwargs) @CONFIG_GENERATORS.register_config(default_value=1024) def set_max_response_tokens(**kwargs): st.number_input("Max Response Tokens", min_value=1, **kwargs) # Cluster Config @CONFIG_GENERATORS.register_config(default_value=1) def set_node_num(**kwargs): st.number_input("Node Num", min_value=1, on_change=set_total_gpu_num, **kwargs) @CONFIG_GENERATORS.register_config( default_value=8, other_configs={"total_gpu_num": 8, "trainer_gpu_num": 6} ) def set_gpu_per_node(**kwargs): st.number_input( "GPU Per Node", min_value=1, max_value=8, on_change=set_total_gpu_num, **kwargs, )