Source code for trinity.manager.config_registry.explorer_config_manager

import streamlit as st

from trinity.common.constants import AlgorithmType, SyncMethod
from trinity.manager.config_registry.config_registry import CONFIG_GENERATORS
from trinity.manager.config_registry.model_config_manager import set_trainer_gpu_num


[docs] def explorer_visible() -> bool: return st.session_state["mode"] == "both"
@CONFIG_GENERATORS.register_config(default_value=32, visible=explorer_visible) def set_runner_num(**kwargs): st.number_input("Runner Num", min_value=1, **kwargs) @CONFIG_GENERATORS.register_config(default_value=900, visible=explorer_visible) def set_max_timeout(**kwargs): st.number_input("Max Timeout", min_value=0, **kwargs) @CONFIG_GENERATORS.register_config(default_value=2, visible=explorer_visible) def set_explorer_max_retry_times(**kwargs): st.number_input("Explorer Max Retry Times", min_value=0, **kwargs) @CONFIG_GENERATORS.register_config(default_value=1000, visible=explorer_visible) def set_eval_interval(**kwargs): st.number_input("Eval Interval", min_value=1, **kwargs) @CONFIG_GENERATORS.register_config(default_value=True, visible=explorer_visible) def set_eval_on_latest_checkpoint(**kwargs): st.checkbox("Eval on Latest Checkpoint", **kwargs) @CONFIG_GENERATORS.register_config(default_value="vllm_async", visible=explorer_visible) def set_engine_type(**kwargs): st.selectbox("Engine Type", ["vllm_async", "vllm"], **kwargs) @CONFIG_GENERATORS.register_config(default_value=2, visible=explorer_visible) def set_engine_num(**kwargs): key = kwargs.get("key") total_gpu_num = st.session_state["total_gpu_num"] max_engine_num = (total_gpu_num - 1) // st.session_state["tensor_parallel_size"] if st.session_state[key] > max_engine_num: st.session_state[key] = max_engine_num set_trainer_gpu_num() st.number_input( "Engine Num", min_value=1, max_value=max_engine_num, on_change=set_trainer_gpu_num, **kwargs, ) @CONFIG_GENERATORS.register_config(default_value=1, visible=explorer_visible) def set_tensor_parallel_size(**kwargs): key = kwargs.get("key") total_gpu_num = st.session_state["total_gpu_num"] max_tensor_parallel_size = (total_gpu_num - 1) // st.session_state["engine_num"] if st.session_state[key] > max_tensor_parallel_size: st.session_state[key] = max_tensor_parallel_size set_trainer_gpu_num() st.number_input( "Tensor Parallel Size", min_value=1, max_value=max_tensor_parallel_size, on_change=set_trainer_gpu_num, **kwargs, ) @CONFIG_GENERATORS.register_check() def check_tensor_parallel_size(unfinished_fields: set, key: str): if st.session_state["trainer_gpu_num"] <= 0: unfinished_fields.add("engine_num") unfinished_fields.add("tensor_parallel_size") st.warning( "Please check the settings of each `engine_num` and `tensor_marallel_size` to ensure that at least one GPU is reserved for the `trainer`." ) elif ( st.session_state["node_num"] > 1 and st.session_state["trainer_gpu_num"] % st.session_state["gpu_per_node"] != 0 ): unfinished_fields.add("engine_num") unfinished_fields.add("tensor_parallel_size") st.warning( "When `node_num > 1`, please check the settings of each `engine_num` and `tensor_marallel_size` to ensure that the number of GPUs reserved for the `trainer` is divisible by `gpu_per_node`" ) @CONFIG_GENERATORS.register_config(default_value=True, visible=explorer_visible) def set_use_v1(**kwargs): st.checkbox("Use V1 Engine", **kwargs) @CONFIG_GENERATORS.register_config(default_value=True, visible=explorer_visible) def set_enforce_eager(**kwargs): st.checkbox("Enforce Eager", **kwargs) @CONFIG_GENERATORS.register_config(default_value=False, visible=explorer_visible) def set_enable_prefix_caching(**kwargs): st.checkbox("Prefix Caching", **kwargs) @CONFIG_GENERATORS.register_config(default_value=False, visible=explorer_visible) def set_enable_chunked_prefill(**kwargs): st.checkbox("Chunked Prefill", **kwargs) @CONFIG_GENERATORS.register_config(default_value=0.9, visible=explorer_visible) def set_gpu_memory_utilization(**kwargs): st.number_input("GPU Memory Utilization", min_value=0.0, max_value=1.0, **kwargs) @CONFIG_GENERATORS.register_config(default_value="bfloat16", visible=explorer_visible) def set_dtype(**kwargs): st.selectbox("Dtype", ["bfloat16", "float16", "float32"], **kwargs) @CONFIG_GENERATORS.register_config(default_value=42, visible=explorer_visible) def set_seed(**kwargs): st.number_input("Seed", step=1, **kwargs) # TODO: max_prompt_tokens # TODO: max_response_tokens # TODO: chat_template @CONFIG_GENERATORS.register_config(default_value=False, visible=explorer_visible) def set_enable_thinking(**kwargs): st.checkbox("Enable Thinking For Qwen3", **kwargs) @CONFIG_GENERATORS.register_config(default_value=False, visible=explorer_visible) def set_enable_openai_api(**kwargs): st.checkbox("Enable OpenAI API", **kwargs) def _set_auxiliary_model_idx(idx): col1, col2 = st.columns([9, 1]) col1.text_input( "Model Path", key=f"auxiliary_model_{idx}_model_path", ) if col2.button("✖️", key=f"auxiliary_model_{idx}_del_flag", type="primary"): st.rerun() engine_type_col, engine_num_col, tensor_parallel_size_col = st.columns(3) total_gpu_num = st.session_state["total_gpu_num"] engine_type_col.selectbox( "Engine Type", ["vllm_async"], key=f"auxiliary_model_{idx}_engine_type" ) engine_num_col.number_input( "Engine Num", min_value=1, max_value=total_gpu_num - 1, on_change=set_trainer_gpu_num, key=f"auxiliary_model_{idx}_engine_num", ) tensor_parallel_size_col.number_input( "Tensor Parallel Size", min_value=1, max_value=8, on_change=set_trainer_gpu_num, key=f"auxiliary_model_{idx}_tensor_parallel_size", ) gpu_memory_utilization_col, dtype_col, seed_col = st.columns(3) gpu_memory_utilization_col.number_input( "GPU Memory Utilization", min_value=0.0, max_value=1.0, key=f"auxiliary_model_{idx}_gpu_memory_utilization", ) dtype_col.selectbox( "Dtype", ["bfloat16", "float16", "float32"], key=f"auxiliary_model_{idx}_dtype" ) seed_col.number_input("Seed", step=1, key=f"auxiliary_model_{idx}_seed") ( use_v1_col, enforce_eager_col, enable_prefix_caching_col, enable_chunked_prefill_col, ) = st.columns(4) use_v1_col.checkbox("Use V1 Engine", key=f"auxiliary_model_{idx}_use_v1") enforce_eager_col.checkbox("Enforce Eager", key=f"auxiliary_model_{idx}_enforce_eager") enable_prefix_caching_col.checkbox( "Prefix Caching", key=f"auxiliary_model_{idx}_enable_prefix_caching" ) enable_chunked_prefill_col.checkbox( "Chunked Prefill", key=f"auxiliary_model_{idx}_enable_chunked_prefill" ) enable_thinking_col, enable_openai_api = st.columns(2) enable_thinking_col.checkbox( "Enable Thinking For Qwen3", key=f"auxiliary_model_{idx}_enable_thinking" ) enable_openai_api.checkbox("Enable OpenAI API", key=f"auxiliary_model_{idx}_enable_openai_api") @CONFIG_GENERATORS.register_config(other_configs={"_auxiliary_models_num": 0}) def set_auxiliary_models(**kwargs): if st.button("Add Auxiliary Models"): idx = st.session_state["_auxiliary_models_num"] st.session_state[f"auxiliary_model_{idx}_engine_num"] = 1 st.session_state[f"auxiliary_model_{idx}_tensor_parallel_size"] = 1 st.session_state[f"auxiliary_model_{idx}_gpu_memory_utilization"] = 0.9 st.session_state[f"auxiliary_model_{idx}_seed"] = 42 st.session_state[f"auxiliary_model_{idx}_use_v1"] = True st.session_state[f"auxiliary_model_{idx}_enforce_eager"] = True st.session_state["_auxiliary_models_num"] += 1 set_trainer_gpu_num() if st.session_state["_auxiliary_models_num"] > 0: tabs = st.tabs( [f"Auxiliary Model {i + 1}" for i in range(st.session_state["_auxiliary_models_num"])] ) for idx, tab in enumerate(tabs): with tab: _set_auxiliary_model_idx(idx) @CONFIG_GENERATORS.register_check() def check_auxiliary_models(unfinished_fields: set, key: str): if st.session_state["trainer_gpu_num"] <= 0: unfinished_fields.add("engine_num") unfinished_fields.add("tensor_parallel_size") st.warning( "Please check the settings of each `engine_num` and `tensor_marallel_size` to ensure that at least one GPU is reserved for the `trainer`." ) elif ( st.session_state["node_num"] > 1 and st.session_state["trainer_gpu_num"] % st.session_state["gpu_per_node"] != 0 ): unfinished_fields.add("engine_num") unfinished_fields.add("tensor_parallel_size") st.warning( "When `node_num > 1`, please check the settings of each `engine_num` and `tensor_marallel_size` to ensure that the number of GPUs reserved for the `trainer` is divisible by `gpu_per_node`" ) # Synchronizer Configs @CONFIG_GENERATORS.register_config( default_value=SyncMethod.NCCL.value, visible=explorer_visible, other_configs={"_not_dpo_sync_method": SyncMethod.NCCL.value}, ) def set_sync_method(**kwargs): key = kwargs.get("key") if st.session_state["algorithm_type"] == AlgorithmType.DPO.value: st.session_state[key] = SyncMethod.CHECKPOINT.value disabled = True else: st.session_state[key] = st.session_state["_not_dpo_sync_method"] disabled = False def on_change(): if st.session_state["algorithm_type"] != AlgorithmType.DPO.value: st.session_state["_not_dpo_sync_method"] = st.session_state[key] st.selectbox( "Sync Method", [sync_method.value for sync_method in SyncMethod], help="""`nccl`: the explorer and trainer sync model weights once every `sync_interval` steps. `checkpoint`: the trainer saves the model checkpoint, and the explorer loads it at `sync_interval`.""", disabled=disabled, on_change=on_change, **kwargs, ) @CONFIG_GENERATORS.register_config(default_value=10, visible=explorer_visible) def set_sync_interval(**kwargs): st.number_input( "Sync Interval", min_value=1, help="""The step interval at which the `explorer` and `trainer` synchronize model weight.""", **kwargs, ) @CONFIG_GENERATORS.register_config(default_value=1200, visible=explorer_visible) def set_sync_timeout(**kwargs): st.number_input( "Sync Timeout", min_value=1, help="The timeout value for the synchronization operation.", **kwargs, )