Source code for trinity.manager.config_registry.buffer_config_manager

import streamlit as st

from trinity.buffer.storage.queue import PRIORITY_FUNC
from trinity.common.constants import PromptType, StorageType
from trinity.common.rewards.reward_fn import REWARD_FUNCTIONS
from trinity.common.workflows.workflow import WORKFLOWS
from trinity.manager.config_registry.config_registry import CONFIG_GENERATORS


@CONFIG_GENERATORS.register_config(default_value=20)
def set_total_epochs(**kwargs):
    st.number_input("Total Epochs", min_value=1, **kwargs)


@CONFIG_GENERATORS.register_config(default_value=96)
def set_explore_batch_size(**kwargs):
    st.number_input(
        "Task Batch Size",
        min_value=1,
        help="Number of tasks to explore in one explore step",
        **kwargs,
    )


[docs] def get_train_batch_size() -> int: return ( st.session_state["train_batch_size"] or st.session_state["explore_batch_size"] * st.session_state["repeat_times"] )
[docs] def get_train_batch_size_per_gpu() -> int: return st.session_state["_train_batch_size_per_gpu"] or max( st.session_state["explore_batch_size"] * st.session_state["repeat_times"] // st.session_state["trainer_gpu_num"], 1, )
def _str_for_train_batch_size(): trainer_gpu_num_str = ( "`gpu_per_node * node_num - engine_num * tensor_parallel_size`" if st.session_state["mode"] == "both" else "`gpu_per_node * node_num`" ) return ( f"`train_batch_size` defaults to `task_batch_size` * `repeat_times`.\n\n" f"Please ensure that `train_batch_size` ({get_train_batch_size()}) can be divided by " f"{trainer_gpu_num_str} ({st.session_state['trainer_gpu_num']})." ) @CONFIG_GENERATORS.register_config( default_value=None, visible=lambda: st.session_state["trainer_gpu_num"] > 0, other_configs={"_train_batch_size_per_gpu": None}, ) def set_train_batch_size(**kwargs): key = kwargs.get("key") trainer_gpu_num = st.session_state["trainer_gpu_num"] st.session_state[key] = ( st.session_state["_train_batch_size_per_gpu"] * st.session_state["trainer_gpu_num"] if st.session_state["_train_batch_size_per_gpu"] is not None else None ) placeholder = st.session_state["explore_batch_size"] * st.session_state["repeat_times"] def on_change(): st.session_state["_train_batch_size_per_gpu"] = max( st.session_state[key] // st.session_state["trainer_gpu_num"], 1 ) st.number_input( "Train Batch Size", min_value=trainer_gpu_num, step=trainer_gpu_num, help=_str_for_train_batch_size(), on_change=on_change, placeholder=placeholder, **kwargs, ) @CONFIG_GENERATORS.register_check() def check_train_batch_size(unfinished_fields: set, key: str): if get_train_batch_size() % st.session_state["trainer_gpu_num"] != 0: unfinished_fields.add(key) st.warning(_str_for_train_batch_size()) @CONFIG_GENERATORS.register_config(default_value=3) def set_buffer_max_retry_times(**kwargs): st.number_input("Max Retry Times", min_value=1, **kwargs) @CONFIG_GENERATORS.register_config(default_value=1) def set_max_retry_interval(**kwargs): st.number_input("Max Retry Interval", min_value=1, **kwargs) @CONFIG_GENERATORS.register_config(default_value="") def set_taskset_path(**kwargs): st.text_input("Taskset Path", **kwargs) @CONFIG_GENERATORS.register_check() def check_taskset_path(unfinished_fields: set, key: str): if not st.session_state[key].strip(): unfinished_fields.add(key) st.warning("Please input taskset path.") @CONFIG_GENERATORS.register_config( visible=lambda: st.session_state["taskset_path"] and "://" not in st.session_state["taskset_path"], other_configs={ "taskset_subset_name": None, "taskset_split": "train", "taskset_prompt_key": "question", "taskset_response_key": "answer", "temperature": 1.0, "top_p": 1.0, # TODO: to be used "top_k": -1, # TODO: to be used "logprobs": 0, }, ) def set_taskset_args(**kwargs): subset_name_col, split_col = st.columns(2) subset_name_col.text_input( "Subset Name :orange-badge[(Needs review)]", key="taskset_subset_name", help="The subset name used for `datasets.load_datasets`, see " "[here](https://huggingface.co/docs/datasets/v3.5.0/en/package_reference/loading_methods#datasets.load_dataset.name) for details.", ) split_col.text_input("Train Split :orange-badge[(Needs review)]", key="taskset_split") prompt_key_col, response_key_col = st.columns(2) prompt_key_col.text_input("Prompt Key :orange-badge[(Needs review)]", key="taskset_prompt_key") response_key_col.text_input( "Response Key :orange-badge[(Needs review)]", key="taskset_response_key" ) temperature_col, logprobs_col = st.columns(2) temperature_col.number_input("Temperature", key="temperature", min_value=0.0, max_value=2.0) logprobs_col.number_input("Logprobs", key="logprobs", min_value=0, max_value=20) def _set_eval_taskset_idx(idx): col1, col2 = st.columns([9, 1]) col1.text_input( "Taskset Name", key=f"eval_taskset_{idx}_name", ) if col2.button("✖️", key=f"eval_taskset_{idx}_del_flag", type="primary"): st.rerun() st.text_input( "Eval Taskset Path", key=f"eval_taskset_{idx}_path", ) if not st.session_state[f"eval_taskset_{idx}_path"].strip(): st.warning("Please input the taskset path, or it will be ignored.") subset_name_col, split_col = st.columns(2) subset_name_col.text_input( "Subset Name :orange-badge[(Needs review)]", key=f"eval_taskset_{idx}_subset_name", help="The subset name used for `datasets.load_datasets`, see " "[here](https://huggingface.co/docs/datasets/v3.5.0/en/package_reference/loading_methods#datasets.load_dataset.name) for details.", ) split_col.text_input( "Eval Split :orange-badge[(Needs review)]", key=f"eval_taskset_{idx}_split", ) prompt_key_col, response_key_col = st.columns(2) prompt_key_col.text_input( "Prompt Key :orange-badge[(Needs review)]", key=f"eval_taskset_{idx}_prompt_key", ) response_key_col.text_input( "Response Key :orange-badge[(Needs review)]", key=f"eval_taskset_{idx}_response_key", ) temperature_col, logprobs_col, n_col = st.columns(3) temperature_col.number_input( "Temperature", key=f"eval_taskset_{idx}_temperature", min_value=0.0, max_value=1.0, ) logprobs_col.number_input( "Logprobs", key=f"eval_taskset_{idx}_logprobs", min_value=0, max_value=20, ) n_col.number_input( "Eval repeat times", key=f"eval_taskset_{idx}_n", min_value=1, max_value=20, ) @CONFIG_GENERATORS.register_config(other_configs={"_eval_tasksets_num": 0}) def set_eval_tasksets(**kwargs): if st.button("Add Eval Taskset"): idx = st.session_state["_eval_tasksets_num"] st.session_state[f"eval_taskset_{idx}_split"] = "test" st.session_state[f"eval_taskset_{idx}_prompt_key"] = "prompt" st.session_state[f"eval_taskset_{idx}_response_key"] = "response" st.session_state[f"eval_taskset_{idx}_temperature"] = 0.1 st.session_state["_eval_tasksets_num"] += 1 if st.session_state["_eval_tasksets_num"] > 0: tabs = st.tabs( [f"Eval Taskset {i + 1}" for i in range(st.session_state["_eval_tasksets_num"])] ) for idx, tab in enumerate(tabs): with tab: _set_eval_taskset_idx(idx) @CONFIG_GENERATORS.register_config(default_value="math_workflow") def set_default_workflow_type(**kwargs): st.selectbox( "Default Workflow Type :orange-badge[(Needs review)]", WORKFLOWS.modules.keys(), help=r"""`simple_workflow`: call 'model.chat()' to get responses. `math_workflow`: call 'model.chat()' with a pre-defined system prompt to get responses. Other workflows: conduct multi-turn task for the given dataset. """, **kwargs, ) @CONFIG_GENERATORS.register_config(default_value="math_workflow") def set_default_eval_workflow_type(**kwargs): st.selectbox( "Default Eval Workflow Type :orange-badge[(Needs review)]", WORKFLOWS.modules.keys(), help=r"""`simple_workflow`: call 'model.chat()' to get responses. `math_workflow`: call 'model.chat()' with a pre-defined system prompt to get responses. Other workflows: conduct multi-turn task for the given dataset. """, **kwargs, ) @CONFIG_GENERATORS.register_config(default_value="math_reward") def set_default_reward_fn_type(**kwargs): st.selectbox( "Default Reward Fn Type :orange-badge[(Needs review)]", REWARD_FUNCTIONS.modules.keys(), help=r"""`accuracy_reward`: check the accuracy for math problems. `format_reward`: check if the response matches the format (default: `<think>*</think>* <answer>*</answer>`). `math_reward`: `accuracy_reward` (1 or 0) + `format_reward` (+0.1 or -0.1). """, **kwargs, ) @CONFIG_GENERATORS.register_config(default_value=None) def set_system_prompt(**kwargs): st.text_area( "System Prompt", placeholder="""You are a helpful assistant that solves MATH problems....""", **kwargs, ) @CONFIG_GENERATORS.register_config(default_value=None) def set_reply_prefix(**kwargs): st.text_area( "Assistant Reply Prefix", placeholder="""Assistant reply prefix is used to specify the initial content of model reply, """ """and a common setting is: \nLet me solve this step by step. """, **kwargs, ) @CONFIG_GENERATORS.register_config( default_value=StorageType.QUEUE.value, other_configs={ "_dpo_storage_type": StorageType.FILE.value, "_not_dpo_storage_type": StorageType.QUEUE.value, }, ) def set_storage_type(**kwargs): key = kwargs.get("key") if st.session_state["algorithm_type"] == "dpo": st.session_state[key] = st.session_state["_dpo_storage_type"] storage_candidates = [StorageType.FILE.value, StorageType.SQL.value] else: st.session_state[key] = st.session_state["_not_dpo_storage_type"] storage_candidates = [StorageType.QUEUE.value] def on_change(): if st.session_state["algorithm_type"] == "dpo": st.session_state["_dpo_storage_type"] = st.session_state[key] else: st.session_state["_not_dpo_storage_type"] = st.session_state[key] st.selectbox( "Storage Type", storage_candidates, on_change=on_change, **kwargs, ) @CONFIG_GENERATORS.register_config(default_value=False) def set_use_priority_queue(**kwargs): st.checkbox("Use Priority Queue", **kwargs) @CONFIG_GENERATORS.register_config( default_value=None, visible=lambda: st.session_state["use_priority_queue"] ) def set_reuse_cooldown_time(**kwargs): st.number_input( "Reuse Cooldown Time", min_value=0.0, max_value=1e5, help="Leave blank to indicate no reuse", placeholder=None, **kwargs, ) @CONFIG_GENERATORS.register_config( default_value="linear_decay", visible=lambda: st.session_state["use_priority_queue"] ) def set_priority_fn(**kwargs): candidates = list(PRIORITY_FUNC.modules.keys()) st.selectbox( "Priority Function", candidates, **kwargs, ) @CONFIG_GENERATORS.register_config( default_value=0.1, visible=lambda: st.session_state["use_priority_queue"] ) def set_priority_decay(**kwargs): st.number_input( "Priority Decay", **kwargs, ) @CONFIG_GENERATORS.register_config( default_value="", other_configs={ "_dpo_experience_buffer_path": "", "_not_dpo_experience_buffer_path": "", }, ) def set_experience_buffer_path(**kwargs): # TODO key = kwargs.get("key") if st.session_state["algorithm_type"] == "dpo": if st.session_state["taskset_path"] and not st.session_state["_dpo_experience_buffer_path"]: st.session_state["_dpo_experience_buffer_path"] = st.session_state["taskset_path"] st.session_state[key] = st.session_state["_dpo_experience_buffer_path"] title = "DPO Dataset Path" help_msg = r"""This path to DPO dataset, if `storage_type == StorageType.FILE`, this should be a path to a file, if `storage_type == StorageType.SQL`, this should be a path to database.""" else: st.session_state[key] = st.session_state["_not_dpo_experience_buffer_path"] title = "Experience Buffer Path" help_msg = r"""This path is used for experiences persistent storage, default to `None`.""" def on_change(): if st.session_state["algorithm_type"] == "dpo": st.session_state["_dpo_experience_buffer_path"] = st.session_state[key] else: st.session_state["_not_dpo_experience_buffer_path"] = st.session_state[key] st.text_input(title, help=help_msg, on_change=on_change, **kwargs) @CONFIG_GENERATORS.register_check() def check_experience_buffer_path(unfinished_fields: set, key: str): if st.session_state["algorithm_type"] == "dpo": if not st.session_state[key].strip(): unfinished_fields.add(key) st.warning("Please input DPO dataset path.") @CONFIG_GENERATORS.register_config( other_configs={ "dpo_dataset_train_split": "train", "dpo_dataset_prompt_type": PromptType.MESSAGES.value, "dpo_dataset_prompt_key": "prompt", "dpo_dataset_chosen_key": "chosen", "dpo_dataset_rejected_key": "rejected", } ) def set_dpo_dataset_kwargs(**kwargs): dpo_dataset_train_split_col, dpo_dataset_prompt_type_col = st.columns(2) dpo_dataset_train_split_col.text_input( "DPO Dataset Train Split :orange-badge[(Needs review)]", key="dpo_dataset_train_split" ) dpo_dataset_prompt_type_col.selectbox( "DPO Dataset Prompt Type :orange-badge[(Needs review)]", [prompt_type.value for prompt_type in PromptType], key="dpo_dataset_prompt_type", ) ( dpo_dataset_prompt_key_col, dpo_dataset_chosen_key_col, dpo_dataset_rejected_key_col, ) = st.columns(3) dpo_dataset_prompt_key_col.text_input( "DPO Dataset Prompt Key :orange-badge[(Needs review)]", key="dpo_dataset_prompt_key" ) dpo_dataset_chosen_key_col.text_input( "DPO Dataset Chosen Key :orange-badge[(Needs review)]", key="dpo_dataset_chosen_key" ) dpo_dataset_rejected_key_col.text_input( "DPO Dataset Rejected Key :orange-badge[(Needs review)]", key="dpo_dataset_rejected_key", ) @CONFIG_GENERATORS.register_config(default_value="") def set_sft_warmup_dataset_path(**kwargs): st.text_input("SFT Warmup Dataset Path", **kwargs) @CONFIG_GENERATORS.register_check() def check_sft_warmup_dataset_path(unfinished_fields: set, key: str): if st.session_state["sft_warmup_steps"]: if not st.session_state[key].strip(): unfinished_fields.add(key) st.warning("Please input SFT warmup dataset path when `sft_warmup_steps` is not 0") @CONFIG_GENERATORS.register_config( visible=lambda: st.session_state["sft_warmup_dataset_path"] and "://" not in st.session_state["sft_warmup_dataset_path"], other_configs={ "sft_warmup_train_split": "train", "sft_warmup_prompt_type": PromptType.MESSAGES.value, "sft_warmup_messages_key": "messages", "sft_warmup_prompt_key": "prompt", "sft_warmup_response_key": "response", }, ) def set_sft_warmup_dataset_args(**kwargs): ( sft_warmup_train_split_col, sft_warmup_prompt_type_col, ) = st.columns(2) sft_warmup_train_split_col.text_input( "SFT Dataset Train Split :orange-badge[(Needs review)]", key="sft_warmup_train_split", ) sft_warmup_prompt_type_col.selectbox( "SFT Dataset Prompt Type :orange-badge[(Needs review)]", [prompt_type.value for prompt_type in PromptType], key="sft_warmup_prompt_type", ) ( sft_warmup_messages_key_col, sft_warmup_prompt_key_col, sft_warmup_response_key_col, ) = st.columns( 3 ) # TODO: select by prompt type sft_warmup_messages_key_col.text_input( "SFT Dataset Messages Key :orange-badge[(Needs review)]", key="sft_warmup_messages_key", ) sft_warmup_prompt_key_col.text_input( "SFT Dataset Prompt Key :orange-badge[(Needs review)]", key="sft_warmup_prompt_key" ) sft_warmup_response_key_col.text_input( "SFT Dataset Response Key :orange-badge[(Needs review)]", key="sft_warmup_response_key", ) @CONFIG_GENERATORS.register_config(default_value=0) def set_sft_warmup_steps(**kwargs): st.number_input("SFT Warmup Steps", min_value=0, **kwargs)