trinity.common.config module#
Configs for RFT.
- class trinity.common.config.FormatConfig(prompt_type: PromptType = PromptType.MESSAGES, prompt_key: str = 'prompt', response_key: str = 'response', system_prompt_key: str | None = None, system_prompt: str | None = None, messages_key: str = 'message', tools_key: str = 'tools', image_key: str | None = None, video_key: str | None = None, reply_prefix: str | None = None, workflow_key: str = '', reward_fn_key: str = '', chosen_key: str = 'chosen', rejected_key: str = 'rejected', enable_concatenated_multi_turn: bool = False, chat_template: str | None = None)[source]#
Bases:
object
Configuration for data formatting
- prompt_type: PromptType = 'messages'#
- prompt_key: str = 'prompt'#
- response_key: str = 'response'#
- system_prompt_key: str | None = None#
- system_prompt: str | None = None#
- messages_key: str = 'message'#
- tools_key: str = 'tools'#
- image_key: str | None = None#
- video_key: str | None = None#
- reply_prefix: str | None = None#
- workflow_key: str = ''#
- reward_fn_key: str = ''#
- chosen_key: str = 'chosen'#
- rejected_key: str = 'rejected'#
- enable_concatenated_multi_turn: bool = False#
- chat_template: str | None = None#
- __init__(prompt_type: PromptType = PromptType.MESSAGES, prompt_key: str = 'prompt', response_key: str = 'response', system_prompt_key: str | None = None, system_prompt: str | None = None, messages_key: str = 'message', tools_key: str = 'tools', image_key: str | None = None, video_key: str | None = None, reply_prefix: str | None = None, workflow_key: str = '', reward_fn_key: str = '', chosen_key: str = 'chosen', rejected_key: str = 'rejected', enable_concatenated_multi_turn: bool = False, chat_template: str | None = None) None #
- class trinity.common.config.GenerationConfig(temperature: float = 1.0, top_p: float = 1.0, top_k: int = -1, logprobs: int = 0, n: int = 1)[source]#
Bases:
object
- temperature: float = 1.0#
- top_p: float = 1.0#
- top_k: int = -1#
- logprobs: int = 0#
- n: int = 1#
- __init__(temperature: float = 1.0, top_p: float = 1.0, top_k: int = -1, logprobs: int = 0, n: int = 1) None #
- class trinity.common.config.StorageConfig(name: str = '', storage_type: ~trinity.common.constants.StorageType = StorageType.FILE, path: str | None = None, repeat_times: int | None = None, index: int = 0, mm_data_kwargs: dict = <factory>, split: str = 'train', subset_name: str | None = None, format: ~trinity.common.config.FormatConfig = <factory>, capacity: int = 10000, max_read_timeout: float = 1800, use_priority_queue: bool = False, reuse_cooldown_time: float | None = None, replay_buffer_kwargs: dict = <factory>, max_retry_times: int = 3, max_retry_interval: int = 1, default_workflow_type: str | None = None, default_eval_workflow_type: str | None = None, default_reward_fn_type: str | None = None, rollout_args: ~trinity.common.config.GenerationConfig = <factory>, workflow_args: dict = <factory>, reward_fn_args: dict = <factory>, enable_progress_bar: bool | None = False, ray_namespace: str | None = None, wrap_in_ray: bool = True, schema_type: str | None = None, total_epochs: int = 1, total_steps: int | None = None, is_eval: bool = False)[source]#
Bases:
object
Storage config.
- name: str = ''#
- storage_type: StorageType = 'file'#
- path: str | None = None#
- repeat_times: int | None = None#
- index: int = 0#
- mm_data_kwargs: dict#
- split: str = 'train'#
- subset_name: str | None = None#
- format: FormatConfig#
- capacity: int = 10000#
- max_read_timeout: float = 1800#
- use_priority_queue: bool = False#
- reuse_cooldown_time: float | None = None#
- replay_buffer_kwargs: dict#
- max_retry_times: int = 3#
- max_retry_interval: int = 1#
- default_workflow_type: str | None = None#
- default_eval_workflow_type: str | None = None#
- default_reward_fn_type: str | None = None#
- rollout_args: GenerationConfig#
- workflow_args: dict#
- reward_fn_args: dict#
- enable_progress_bar: bool | None = False#
- ray_namespace: str | None = None#
- wrap_in_ray: bool = True#
- schema_type: str | None = None#
- total_epochs: int = 1#
- total_steps: int | None = None#
- is_eval: bool = False#
- __init__(name: str = '', storage_type: ~trinity.common.constants.StorageType = StorageType.FILE, path: str | None = None, repeat_times: int | None = None, index: int = 0, mm_data_kwargs: dict = <factory>, split: str = 'train', subset_name: str | None = None, format: ~trinity.common.config.FormatConfig = <factory>, capacity: int = 10000, max_read_timeout: float = 1800, use_priority_queue: bool = False, reuse_cooldown_time: float | None = None, replay_buffer_kwargs: dict = <factory>, max_retry_times: int = 3, max_retry_interval: int = 1, default_workflow_type: str | None = None, default_eval_workflow_type: str | None = None, default_reward_fn_type: str | None = None, rollout_args: ~trinity.common.config.GenerationConfig = <factory>, workflow_args: dict = <factory>, reward_fn_args: dict = <factory>, enable_progress_bar: bool | None = False, ray_namespace: str | None = None, wrap_in_ray: bool = True, schema_type: str | None = None, total_epochs: int = 1, total_steps: int | None = None, is_eval: bool = False) None #
- class trinity.common.config.OperatorConfig(name: str = '', args: Dict[str, Any] = <factory>)[source]#
Bases:
object
- name: str = ''#
- args: Dict[str, Any]#
- __init__(name: str = '', args: ~typing.Dict[str, ~typing.Any] = <factory>) None #
- class trinity.common.config.ExperiencePipelineConfig(operators: ~typing.List[~trinity.common.config.OperatorConfig] = <factory>, save_input: bool = True, input_save_path: str | None = None, inputs: ~typing.Dict[str, ~trinity.common.config.StorageConfig] = <factory>, output: ~trinity.common.config.StorageConfig | None = None)[source]#
Bases:
object
Config for experience pipeline.
Experience Pipeline is used to pre-process rollout experiences for better training.
- operators: List[OperatorConfig]#
- save_input: bool = True#
- input_save_path: str | None = None#
- inputs: Dict[str, StorageConfig]#
- output: StorageConfig | None = None#
- __init__(operators: ~typing.List[~trinity.common.config.OperatorConfig] = <factory>, save_input: bool = True, input_save_path: str | None = None, inputs: ~typing.Dict[str, ~trinity.common.config.StorageConfig] = <factory>, output: ~trinity.common.config.StorageConfig | None = None) None #
- class trinity.common.config.TaskPipelineConfig(operators: ~typing.List[~trinity.common.config.OperatorConfig] = <factory>, num_process: int = 4, config_path: str | None = None, inputs: ~typing.List[str] = <factory>, output: ~trinity.common.config.StorageConfig | None = None, target_fields: ~typing.List[str] = <factory>, priority_weights: ~typing.Dict[str, float] = <factory>, top_k: int = -1)[source]#
Bases:
object
Config for task pipeline.
Task Pipeline is used to pre-process raw tasks for better exploring. Currently, we only support using Data-Juicer operators for task pipeline.
- operators: List[OperatorConfig]#
- num_process: int = 4#
- config_path: str | None = None#
- inputs: List[str]#
- output: StorageConfig | None = None#
- target_fields: List[str]#
- priority_weights: Dict[str, float]#
- top_k: int = -1#
- __init__(operators: ~typing.List[~trinity.common.config.OperatorConfig] = <factory>, num_process: int = 4, config_path: str | None = None, inputs: ~typing.List[str] = <factory>, output: ~trinity.common.config.StorageConfig | None = None, target_fields: ~typing.List[str] = <factory>, priority_weights: ~typing.Dict[str, float] = <factory>, top_k: int = -1) None #
- class trinity.common.config.DataProcessorConfig(task_pipeline: ~trinity.common.config.TaskPipelineConfig | None = None, experience_pipeline: ~trinity.common.config.ExperiencePipelineConfig | None = <factory>)[source]#
Bases:
object
Data Processor config
- task_pipeline: TaskPipelineConfig | None = None#
- experience_pipeline: ExperiencePipelineConfig | None#
- __init__(task_pipeline: ~trinity.common.config.TaskPipelineConfig | None = None, experience_pipeline: ~trinity.common.config.ExperiencePipelineConfig | None = <factory>) None #
- class trinity.common.config.ModelConfig(model_path: str = '', critic_model_path: str = '', max_model_len: int | None = None, max_prompt_tokens: int | None = None, max_response_tokens: int | None = None, min_response_tokens: int = 1, custom_chat_template: str | None = None)[source]#
Bases:
object
- model_path: str = ''#
- critic_model_path: str = ''#
- max_model_len: int | None = None#
- max_prompt_tokens: int | None = None#
- max_response_tokens: int | None = None#
- min_response_tokens: int = 1#
- custom_chat_template: str | None = None#
- __init__(model_path: str = '', critic_model_path: str = '', max_model_len: int | None = None, max_prompt_tokens: int | None = None, max_response_tokens: int | None = None, min_response_tokens: int = 1, custom_chat_template: str | None = None) None #
- class trinity.common.config.InferenceModelConfig(model_path: str = '', engine_type: str = 'vllm_async', engine_num: int = 1, tensor_parallel_size: int = 1, use_v1: bool = True, enforce_eager: bool = True, enable_prefix_caching: bool = False, enable_chunked_prefill: bool = False, gpu_memory_utilization: float = 0.9, dtype: str = 'bfloat16', seed: int = 42, max_model_len: int | None = None, max_prompt_tokens: int | None = None, max_response_tokens: int | None = None, min_response_tokens: int | None = None, ignore_eos: bool = False, chat_template: str | None = None, enable_thinking: bool = False, enable_history: bool = False, enable_openai_api: bool = False, enable_auto_tool_choice: bool = False, tool_call_parser: str | None = None, reasoning_parser: str | None = None, bundle_indices: str = '')[source]#
Bases:
object
- model_path: str = ''#
- engine_type: str = 'vllm_async'#
- engine_num: int = 1#
- tensor_parallel_size: int = 1#
- use_v1: bool = True#
- enforce_eager: bool = True#
- enable_prefix_caching: bool = False#
- enable_chunked_prefill: bool = False#
- gpu_memory_utilization: float = 0.9#
- dtype: str = 'bfloat16'#
- seed: int = 42#
- max_model_len: int | None = None#
- max_prompt_tokens: int | None = None#
- max_response_tokens: int | None = None#
- min_response_tokens: int | None = None#
- ignore_eos: bool = False#
- chat_template: str | None = None#
- enable_thinking: bool = False#
- enable_history: bool = False#
- enable_openai_api: bool = False#
- enable_auto_tool_choice: bool = False#
- tool_call_parser: str | None = None#
- reasoning_parser: str | None = None#
- bundle_indices: str = ''#
- __init__(model_path: str = '', engine_type: str = 'vllm_async', engine_num: int = 1, tensor_parallel_size: int = 1, use_v1: bool = True, enforce_eager: bool = True, enable_prefix_caching: bool = False, enable_chunked_prefill: bool = False, gpu_memory_utilization: float = 0.9, dtype: str = 'bfloat16', seed: int = 42, max_model_len: int | None = None, max_prompt_tokens: int | None = None, max_response_tokens: int | None = None, min_response_tokens: int | None = None, ignore_eos: bool = False, chat_template: str | None = None, enable_thinking: bool = False, enable_history: bool = False, enable_openai_api: bool = False, enable_auto_tool_choice: bool = False, tool_call_parser: str | None = None, reasoning_parser: str | None = None, bundle_indices: str = '') None #
- class trinity.common.config.AlgorithmConfig(algorithm_type: str = 'ppo', repeat_times: int = 1, sample_strategy: str | None = None, sample_strategy_args: dict | None = None, advantage_fn: str | None = None, advantage_fn_args: dict | None = None, kl_penalty_fn: str | None = None, kl_penalty_fn_args: dict | None = None, policy_loss_fn: str | None = None, policy_loss_fn_args: dict | None = None, kl_loss_fn: str | None = None, kl_loss_fn_args: dict | None = None, entropy_loss_fn: str | None = None, entropy_loss_fn_args: dict | None = None, use_token_level_loss: bool = True)[source]#
Bases:
object
Config for algorithm.
- algorithm_type: str = 'ppo'#
- repeat_times: int = 1#
- sample_strategy: str | None = None#
- sample_strategy_args: dict | None = None#
- advantage_fn: str | None = None#
- advantage_fn_args: dict | None = None#
- kl_penalty_fn: str | None = None#
- kl_penalty_fn_args: dict | None = None#
- policy_loss_fn: str | None = None#
- policy_loss_fn_args: dict | None = None#
- kl_loss_fn: str | None = None#
- kl_loss_fn_args: dict | None = None#
- entropy_loss_fn: str | None = None#
- entropy_loss_fn_args: dict | None = None#
- use_token_level_loss: bool = True#
- __init__(algorithm_type: str = 'ppo', repeat_times: int = 1, sample_strategy: str | None = None, sample_strategy_args: dict | None = None, advantage_fn: str | None = None, advantage_fn_args: dict | None = None, kl_penalty_fn: str | None = None, kl_penalty_fn_args: dict | None = None, policy_loss_fn: str | None = None, policy_loss_fn_args: dict | None = None, kl_loss_fn: str | None = None, kl_loss_fn_args: dict | None = None, entropy_loss_fn: str | None = None, entropy_loss_fn_args: dict | None = None, use_token_level_loss: bool = True) None #
- class trinity.common.config.ClusterConfig(node_num: int = 1, gpu_per_node: int = 8)[source]#
Bases:
object
Config for the cluster.
- node_num: int = 1#
- gpu_per_node: int = 8#
- __init__(node_num: int = 1, gpu_per_node: int = 8) None #
- class trinity.common.config.ExplorerInput(taskset: ~trinity.common.config.StorageConfig = <factory>, eval_tasksets: ~typing.List[~trinity.common.config.StorageConfig] = <factory>, default_workflow_type: str | None = None, default_eval_workflow_type: str | None = None, default_reward_fn_type: str | None = None, system_prompt: str | None = None, reply_prefix: str | None = None)[source]#
Bases:
object
Config for explorer input.
- taskset: StorageConfig#
- eval_tasksets: List[StorageConfig]#
- default_workflow_type: str | None = None#
- default_eval_workflow_type: str | None = None#
- default_reward_fn_type: str | None = None#
- system_prompt: str | None = None#
- reply_prefix: str | None = None#
- __init__(taskset: ~trinity.common.config.StorageConfig = <factory>, eval_tasksets: ~typing.List[~trinity.common.config.StorageConfig] = <factory>, default_workflow_type: str | None = None, default_eval_workflow_type: str | None = None, default_reward_fn_type: str | None = None, system_prompt: str | None = None, reply_prefix: str | None = None) None #
- class trinity.common.config.TrainerInput(experience_buffer: StorageConfig | None = None, sft_warmup_dataset: StorageConfig | None = None, sft_warmup_steps: int = 0)[source]#
Bases:
object
Config for trainer input.
- experience_buffer: StorageConfig | None = None#
- sft_warmup_dataset: StorageConfig | None = None#
- sft_warmup_steps: int = 0#
- __init__(experience_buffer: StorageConfig | None = None, sft_warmup_dataset: StorageConfig | None = None, sft_warmup_steps: int = 0) None #
- class trinity.common.config.BufferConfig(batch_size: int = 1, train_batch_size: int = 0, total_epochs: int = 1, total_steps: int | None = None, explorer_input: ~trinity.common.config.ExplorerInput = <factory>, trainer_input: ~trinity.common.config.TrainerInput = <factory>, explorer_output: ~trinity.common.config.StorageConfig | None = None, tokenizer_path: str | None = None, pad_token_id: int | None = None, cache_dir: str | None = None)[source]#
Bases:
object
Config for buffer.
- batch_size: int = 1#
- train_batch_size: int = 0#
- total_epochs: int = 1#
- total_steps: int | None = None#
- explorer_input: ExplorerInput#
- trainer_input: TrainerInput#
- explorer_output: StorageConfig | None = None#
- tokenizer_path: str | None = None#
- pad_token_id: int | None = None#
- cache_dir: str | None = None#
- __init__(batch_size: int = 1, train_batch_size: int = 0, total_epochs: int = 1, total_steps: int | None = None, explorer_input: ~trinity.common.config.ExplorerInput = <factory>, trainer_input: ~trinity.common.config.TrainerInput = <factory>, explorer_output: ~trinity.common.config.StorageConfig | None = None, tokenizer_path: str | None = None, pad_token_id: int | None = None, cache_dir: str | None = None) None #
- class trinity.common.config.ExplorerConfig(name: str = 'explorer', runner_per_model: int = 8, max_timeout: int = 1800, max_retry_times: int = 2, env_vars: dict = <factory>, max_repeat_times_per_runner: int | None = None, runner_num: int | None = None, rollout_model: ~trinity.common.config.InferenceModelConfig = <factory>, auxiliary_models: ~typing.List[~trinity.common.config.InferenceModelConfig] = <factory>, eval_interval: int = 100, eval_on_startup: bool = True, bench_on_latest_checkpoint: bool = False)[source]#
Bases:
object
Config for explorer.
- name: str = 'explorer'#
- runner_per_model: int = 8#
- max_timeout: int = 1800#
- max_retry_times: int = 2#
- env_vars: dict#
- max_repeat_times_per_runner: int | None = None#
- runner_num: int | None = None#
- rollout_model: InferenceModelConfig#
- auxiliary_models: List[InferenceModelConfig]#
- eval_interval: int = 100#
- eval_on_startup: bool = True#
- bench_on_latest_checkpoint: bool = False#
- __init__(name: str = 'explorer', runner_per_model: int = 8, max_timeout: int = 1800, max_retry_times: int = 2, env_vars: dict = <factory>, max_repeat_times_per_runner: int | None = None, runner_num: int | None = None, rollout_model: ~trinity.common.config.InferenceModelConfig = <factory>, auxiliary_models: ~typing.List[~trinity.common.config.InferenceModelConfig] = <factory>, eval_interval: int = 100, eval_on_startup: bool = True, bench_on_latest_checkpoint: bool = False) None #
- class trinity.common.config.TrainerConfig(name: str = 'trainer', trainer_type: str = 'verl', save_interval: int = 0, enable_preview: bool = True, actor_grad_clip: Optional[float] = None, trainer_config: Any = <factory>, trainer_config_path: str = '')[source]#
Bases:
object
- name: str = 'trainer'#
- trainer_type: str = 'verl'#
- save_interval: int = 0#
- enable_preview: bool = True#
- actor_grad_clip: float | None = None#
- trainer_config: Any#
- trainer_config_path: str = ''#
- __init__(name: str = 'trainer', trainer_type: str = 'verl', save_interval: int = 0, enable_preview: bool = True, actor_grad_clip: float | None = None, trainer_config: ~typing.Any = <factory>, trainer_config_path: str = '') None #
- class trinity.common.config.MonitorConfig(monitor_type: str = 'tensorboard', monitor_args: Dict | None = None, enable_ray_timeline: bool = False, cache_dir: str = '')[source]#
Bases:
object
- monitor_type: str = 'tensorboard'#
- monitor_args: Dict | None = None#
- enable_ray_timeline: bool = False#
- cache_dir: str = ''#
- __init__(monitor_type: str = 'tensorboard', monitor_args: Dict | None = None, enable_ray_timeline: bool = False, cache_dir: str = '') None #
- class trinity.common.config.SynchronizerConfig(sync_method: SyncMethod = SyncMethod.NCCL, sync_style: SyncStyle = SyncStyle.FIXED, sync_interval: int = 1, sync_offset: int = 0, sync_timeout: int = 3600, wait_for_checkpoint: bool = False, explorer_world_size: int | None = None, ray_namespace: str = '')[source]#
Bases:
object
Configs for model weight synchronization.
- sync_method: SyncMethod = 'nccl'#
- sync_interval: int = 1#
- sync_offset: int = 0#
- sync_timeout: int = 3600#
- wait_for_checkpoint: bool = False#
- explorer_world_size: int | None = None#
- ray_namespace: str = ''#
- __init__(sync_method: SyncMethod = SyncMethod.NCCL, sync_style: SyncStyle = SyncStyle.FIXED, sync_interval: int = 1, sync_offset: int = 0, sync_timeout: int = 3600, wait_for_checkpoint: bool = False, explorer_world_size: int | None = None, ray_namespace: str = '') None #
- class trinity.common.config.DataJuicerServiceConfig(server_url: str | None = None, auto_start: bool = False, port: int | None = None)[source]#
Bases:
object
Config for Data-Juicer.
Please update trinity.service.data_juicer.server.server.py correspondingly if you change the fields here.
- server_url: str | None = None#
- auto_start: bool = False#
- port: int | None = None#
- __init__(server_url: str | None = None, auto_start: bool = False, port: int | None = None) None #
- class trinity.common.config.ServiceConfig(data_juicer: DataJuicerServiceConfig | None = None)[source]#
Bases:
object
Configs for outside services.
- data_juicer: DataJuicerServiceConfig | None = None#
- __init__(data_juicer: DataJuicerServiceConfig | None = None) None #
- class trinity.common.config.LogConfig(level: str = 'INFO', group_by_node: bool = False, save_dir: str = '')[source]#
Bases:
object
Configs for logger.
- level: str = 'INFO'#
- group_by_node: bool = False#
- save_dir: str = ''#
- __init__(level: str = 'INFO', group_by_node: bool = False, save_dir: str = '') None #
- class trinity.common.config.Config(mode: str = 'both', project: str = 'Trinity-RFT', group: str = '', name: str = 'rft', checkpoint_root_dir: str = '', checkpoint_job_dir: str = '', ray_namespace: str = '', continue_from_checkpoint: bool = True, algorithm: ~trinity.common.config.AlgorithmConfig = <factory>, data_processor: ~trinity.common.config.DataProcessorConfig = <factory>, model: ~trinity.common.config.ModelConfig = <factory>, cluster: ~trinity.common.config.ClusterConfig = <factory>, buffer: ~trinity.common.config.BufferConfig = <factory>, explorer: ~trinity.common.config.ExplorerConfig = <factory>, trainer: ~trinity.common.config.TrainerConfig = <factory>, monitor: ~trinity.common.config.MonitorConfig = <factory>, synchronizer: ~trinity.common.config.SynchronizerConfig = <factory>, service: ~trinity.common.config.ServiceConfig = <factory>, log: ~trinity.common.config.LogConfig = <factory>)[source]#
Bases:
object
Global Configuration
- mode: str = 'both'#
- project: str = 'Trinity-RFT'#
- group: str = ''#
- name: str = 'rft'#
- checkpoint_root_dir: str = ''#
- checkpoint_job_dir: str = ''#
- ray_namespace: str = ''#
- continue_from_checkpoint: bool = True#
- algorithm: AlgorithmConfig#
- data_processor: DataProcessorConfig#
- model: ModelConfig#
- cluster: ClusterConfig#
- buffer: BufferConfig#
- explorer: ExplorerConfig#
- trainer: TrainerConfig#
- monitor: MonitorConfig#
- synchronizer: SynchronizerConfig#
- service: ServiceConfig#
- flatten() Dict[str, Any] [source]#
Flatten the config into a single-level dict with dot-separated keys for nested fields.
- __init__(mode: str = 'both', project: str = 'Trinity-RFT', group: str = '', name: str = 'rft', checkpoint_root_dir: str = '', checkpoint_job_dir: str = '', ray_namespace: str = '', continue_from_checkpoint: bool = True, algorithm: ~trinity.common.config.AlgorithmConfig = <factory>, data_processor: ~trinity.common.config.DataProcessorConfig = <factory>, model: ~trinity.common.config.ModelConfig = <factory>, cluster: ~trinity.common.config.ClusterConfig = <factory>, buffer: ~trinity.common.config.BufferConfig = <factory>, explorer: ~trinity.common.config.ExplorerConfig = <factory>, trainer: ~trinity.common.config.TrainerConfig = <factory>, monitor: ~trinity.common.config.MonitorConfig = <factory>, synchronizer: ~trinity.common.config.SynchronizerConfig = <factory>, service: ~trinity.common.config.ServiceConfig = <factory>, log: ~trinity.common.config.LogConfig = <factory>) None #