trinity.common.config module

Configs for RFT.

class trinity.common.config.FormatConfig(prompt_type: PromptType = PromptType.MESSAGES, prompt_key: str = 'prompt', response_key: str = 'response', messages_key: str = 'message', tools_key: str = 'tools', chat_template: str = '', system_prompt: str | None = None, reply_prefix: str | None = None, reward_fn_key: str = '', workflow_key: str = '', solution_key: str = 'solution', reward_key: str = 'reward', chosen_key: str = 'chosen', rejected_key: str = 'rejected', label_key: str = '')[source]

Bases: object

Configuration for data formatting

prompt_type: PromptType = 'messages'
prompt_key: str = 'prompt'
response_key: str = 'response'
messages_key: str = 'message'
tools_key: str = 'tools'
chat_template: str = ''
system_prompt: str | None = None
reply_prefix: str | None = None
reward_fn_key: str = ''
workflow_key: str = ''
solution_key: str = 'solution'
reward_key: str = 'reward'
chosen_key: str = 'chosen'
rejected_key: str = 'rejected'
label_key: str = ''
__init__(prompt_type: PromptType = PromptType.MESSAGES, prompt_key: str = 'prompt', response_key: str = 'response', messages_key: str = 'message', tools_key: str = 'tools', chat_template: str = '', system_prompt: str | None = None, reply_prefix: str | None = None, reward_fn_key: str = '', workflow_key: str = '', solution_key: str = 'solution', reward_key: str = 'reward', chosen_key: str = 'chosen', rejected_key: str = 'rejected', label_key: str = '') None
class trinity.common.config.GenerationConfig(temperature: float = 1.0, top_p: float = 1.0, top_k: int = -1, logprobs: int = 0, n: int = 1)[source]

Bases: object

temperature: float = 1.0
top_p: float = 1.0
top_k: int = -1
logprobs: int = 0
n: int = 1
__init__(temperature: float = 1.0, top_p: float = 1.0, top_k: int = -1, logprobs: int = 0, n: int = 1) None
class trinity.common.config.StorageConfig(name: str = '', storage_type: ~trinity.common.constants.StorageType = StorageType.FILE, path: str | None = None, repeat_times: int | None = None, raw: bool = False, split: str = 'train', subset_name: str | None = None, format: ~trinity.common.config.FormatConfig = <factory>, index: int = 0, wrap_in_ray: bool = True, capacity: int = 10000, max_read_timeout: float = 1800, use_priority_queue: bool = False, reuse_cooldown_time: float | None = None, replay_buffer_kwargs: dict = <factory>, default_workflow_type: str | None = None, default_eval_workflow_type: str | None = None, default_reward_fn_type: str | None = None, rollout_args: ~trinity.common.config.GenerationConfig = <factory>, workflow_args: dict = <factory>, reward_fn_args: dict = <factory>, enable_progress_bar: bool | None = False, ray_namespace: str | None = None, algorithm_type: str | None = None, total_epochs: int = 1, total_steps: int | None = None, task_type: ~trinity.common.constants.TaskType = TaskType.EXPLORE)[source]

Bases: object

Storage config.

name: str = ''
storage_type: StorageType = 'file'
path: str | None = None
repeat_times: int | None = None
raw: bool = False
split: str = 'train'
subset_name: str | None = None
format: FormatConfig
index: int = 0
wrap_in_ray: bool = True
capacity: int = 10000
max_read_timeout: float = 1800
use_priority_queue: bool = False
reuse_cooldown_time: float | None = None
replay_buffer_kwargs: dict
default_workflow_type: str | None = None
default_eval_workflow_type: str | None = None
default_reward_fn_type: str | None = None
rollout_args: GenerationConfig
workflow_args: dict
reward_fn_args: dict
enable_progress_bar: bool | None = False
ray_namespace: str | None = None
algorithm_type: str | None = None
total_epochs: int = 1
total_steps: int | None = None
task_type: TaskType = 0
__init__(name: str = '', storage_type: ~trinity.common.constants.StorageType = StorageType.FILE, path: str | None = None, repeat_times: int | None = None, raw: bool = False, split: str = 'train', subset_name: str | None = None, format: ~trinity.common.config.FormatConfig = <factory>, index: int = 0, wrap_in_ray: bool = True, capacity: int = 10000, max_read_timeout: float = 1800, use_priority_queue: bool = False, reuse_cooldown_time: float | None = None, replay_buffer_kwargs: dict = <factory>, default_workflow_type: str | None = None, default_eval_workflow_type: str | None = None, default_reward_fn_type: str | None = None, rollout_args: ~trinity.common.config.GenerationConfig = <factory>, workflow_args: dict = <factory>, reward_fn_args: dict = <factory>, enable_progress_bar: bool | None = False, ray_namespace: str | None = None, algorithm_type: str | None = None, total_epochs: int = 1, total_steps: int | None = None, task_type: ~trinity.common.constants.TaskType = TaskType.EXPLORE) None
class trinity.common.config.OperatorConfig(name: str = '', args: Dict[str, Any] = <factory>)[source]

Bases: object

name: str = ''
args: Dict[str, Any]
__init__(name: str = '', args: ~typing.Dict[str, ~typing.Any] = <factory>) None
class trinity.common.config.ExperiencePipelineConfig(operators: ~typing.List[~trinity.common.config.OperatorConfig] = <factory>, save_input: bool = True, input_save_path: str | None = None, inputs: ~typing.Dict[str, ~trinity.common.config.StorageConfig] = <factory>, output: ~trinity.common.config.StorageConfig | None = None)[source]

Bases: object

Config for experience pipeline.

Experience Pipeline is used to pre-process rollout experiences for better training.

operators: List[OperatorConfig]
save_input: bool = True
input_save_path: str | None = None
inputs: Dict[str, StorageConfig]
output: StorageConfig | None = None
__init__(operators: ~typing.List[~trinity.common.config.OperatorConfig] = <factory>, save_input: bool = True, input_save_path: str | None = None, inputs: ~typing.Dict[str, ~trinity.common.config.StorageConfig] = <factory>, output: ~trinity.common.config.StorageConfig | None = None) None
class trinity.common.config.TaskPipelineConfig(operators: ~typing.List[~trinity.common.config.OperatorConfig] = <factory>, num_process: int = 4, config_path: str | None = None, inputs: ~typing.List[str] = <factory>, output: ~trinity.common.config.StorageConfig | None = None, target_fields: ~typing.List[str] = <factory>, priority_weights: ~typing.Dict[str, float] = <factory>, top_k: int = -1)[source]

Bases: object

Config for task pipeline.

Task Pipeline is used to pre-process raw tasks for better exploring. Currently, we only support using Data-Juicer operators for task pipeline.

operators: List[OperatorConfig]
num_process: int = 4
config_path: str | None = None
inputs: List[str]
output: StorageConfig | None = None
target_fields: List[str]
priority_weights: Dict[str, float]
top_k: int = -1
__init__(operators: ~typing.List[~trinity.common.config.OperatorConfig] = <factory>, num_process: int = 4, config_path: str | None = None, inputs: ~typing.List[str] = <factory>, output: ~trinity.common.config.StorageConfig | None = None, target_fields: ~typing.List[str] = <factory>, priority_weights: ~typing.Dict[str, float] = <factory>, top_k: int = -1) None
class trinity.common.config.DataProcessorConfig(task_pipeline: ~trinity.common.config.TaskPipelineConfig | None = None, experience_pipeline: ~trinity.common.config.ExperiencePipelineConfig | None = <factory>, setup_data_processor: bool = False, data_processor_url: str | None = None)[source]

Bases: object

Data Processor config

task_pipeline: TaskPipelineConfig | None = None
experience_pipeline: ExperiencePipelineConfig | None
setup_data_processor: bool = False
data_processor_url: str | None = None
__init__(task_pipeline: ~trinity.common.config.TaskPipelineConfig | None = None, experience_pipeline: ~trinity.common.config.ExperiencePipelineConfig | None = <factory>, setup_data_processor: bool = False, data_processor_url: str | None = None) None
class trinity.common.config.ModelConfig(model_path: str = '', critic_model_path: str = '', max_model_len: int | None = None, max_prompt_tokens: int | None = None, max_response_tokens: int | None = None, min_response_tokens: int = 1, custom_chat_template: str | None = None)[source]

Bases: object

model_path: str = ''
critic_model_path: str = ''
max_model_len: int | None = None
max_prompt_tokens: int | None = None
max_response_tokens: int | None = None
min_response_tokens: int = 1
custom_chat_template: str | None = None
__init__(model_path: str = '', critic_model_path: str = '', max_model_len: int | None = None, max_prompt_tokens: int | None = None, max_response_tokens: int | None = None, min_response_tokens: int = 1, custom_chat_template: str | None = None) None
class trinity.common.config.InferenceModelConfig(model_path: str = '', engine_type: str = 'vllm_async', engine_num: int = 1, tensor_parallel_size: int = 1, use_v1: bool = True, enforce_eager: bool = True, enable_prefix_caching: bool = False, enable_chunked_prefill: bool = False, gpu_memory_utilization: float = 0.9, dtype: str = 'bfloat16', seed: int = 42, max_model_len: int | None = None, max_prompt_tokens: int | None = None, max_response_tokens: int | None = None, min_response_tokens: int | None = None, ignore_eos: bool = False, chat_template: str | None = None, enable_thinking: bool = False, enable_history: bool = False, enable_openai_api: bool = False, enable_auto_tool_choice: bool = False, tool_call_parser: str | None = None, reasoning_parser: str | None = None, bundle_indices: str = '')[source]

Bases: object

model_path: str = ''
engine_type: str = 'vllm_async'
engine_num: int = 1
tensor_parallel_size: int = 1
use_v1: bool = True
enforce_eager: bool = True
enable_prefix_caching: bool = False
enable_chunked_prefill: bool = False
gpu_memory_utilization: float = 0.9
dtype: str = 'bfloat16'
seed: int = 42
max_model_len: int | None = None
max_prompt_tokens: int | None = None
max_response_tokens: int | None = None
min_response_tokens: int | None = None
ignore_eos: bool = False
chat_template: str | None = None
enable_thinking: bool = False
enable_history: bool = False
enable_openai_api: bool = False
enable_auto_tool_choice: bool = False
tool_call_parser: str | None = None
reasoning_parser: str | None = None
bundle_indices: str = ''
__init__(model_path: str = '', engine_type: str = 'vllm_async', engine_num: int = 1, tensor_parallel_size: int = 1, use_v1: bool = True, enforce_eager: bool = True, enable_prefix_caching: bool = False, enable_chunked_prefill: bool = False, gpu_memory_utilization: float = 0.9, dtype: str = 'bfloat16', seed: int = 42, max_model_len: int | None = None, max_prompt_tokens: int | None = None, max_response_tokens: int | None = None, min_response_tokens: int | None = None, ignore_eos: bool = False, chat_template: str | None = None, enable_thinking: bool = False, enable_history: bool = False, enable_openai_api: bool = False, enable_auto_tool_choice: bool = False, tool_call_parser: str | None = None, reasoning_parser: str | None = None, bundle_indices: str = '') None
class trinity.common.config.AlgorithmConfig(algorithm_type: str = 'ppo', repeat_times: int = 1, sample_strategy: str | None = None, sample_strategy_args: dict | None = None, advantage_fn: str | None = None, advantage_fn_args: dict | None = None, kl_penalty_fn: str | None = None, kl_penalty_fn_args: dict | None = None, policy_loss_fn: str | None = None, policy_loss_fn_args: dict | None = None, kl_loss_fn: str | None = None, kl_loss_fn_args: dict | None = None, entropy_loss_fn: str | None = None, entropy_loss_fn_args: dict | None = None, use_token_level_loss: bool = True)[source]

Bases: object

Config for algorithm.

algorithm_type: str = 'ppo'
repeat_times: int = 1
sample_strategy: str | None = None
sample_strategy_args: dict | None = None
advantage_fn: str | None = None
advantage_fn_args: dict | None = None
kl_penalty_fn: str | None = None
kl_penalty_fn_args: dict | None = None
policy_loss_fn: str | None = None
policy_loss_fn_args: dict | None = None
kl_loss_fn: str | None = None
kl_loss_fn_args: dict | None = None
entropy_loss_fn: str | None = None
entropy_loss_fn_args: dict | None = None
use_token_level_loss: bool = True
__init__(algorithm_type: str = 'ppo', repeat_times: int = 1, sample_strategy: str | None = None, sample_strategy_args: dict | None = None, advantage_fn: str | None = None, advantage_fn_args: dict | None = None, kl_penalty_fn: str | None = None, kl_penalty_fn_args: dict | None = None, policy_loss_fn: str | None = None, policy_loss_fn_args: dict | None = None, kl_loss_fn: str | None = None, kl_loss_fn_args: dict | None = None, entropy_loss_fn: str | None = None, entropy_loss_fn_args: dict | None = None, use_token_level_loss: bool = True) None
class trinity.common.config.ClusterConfig(node_num: int = 1, gpu_per_node: int = 8)[source]

Bases: object

Config for the cluster.

node_num: int = 1
gpu_per_node: int = 8
__init__(node_num: int = 1, gpu_per_node: int = 8) None
class trinity.common.config.ExplorerInput(taskset: ~trinity.common.config.StorageConfig = <factory>, eval_tasksets: ~typing.List[~trinity.common.config.StorageConfig] = <factory>, default_workflow_type: str | None = None, default_eval_workflow_type: str | None = None, default_reward_fn_type: str | None = None, system_prompt: str | None = None, reply_prefix: str | None = None)[source]

Bases: object

Config for explorer input.

taskset: StorageConfig
eval_tasksets: List[StorageConfig]
default_workflow_type: str | None = None
default_eval_workflow_type: str | None = None
default_reward_fn_type: str | None = None
system_prompt: str | None = None
reply_prefix: str | None = None
__init__(taskset: ~trinity.common.config.StorageConfig = <factory>, eval_tasksets: ~typing.List[~trinity.common.config.StorageConfig] = <factory>, default_workflow_type: str | None = None, default_eval_workflow_type: str | None = None, default_reward_fn_type: str | None = None, system_prompt: str | None = None, reply_prefix: str | None = None) None
class trinity.common.config.TrainerInput(experience_buffer: StorageConfig | None = None, sft_warmup_dataset: StorageConfig | None = None, sft_warmup_steps: int = 0)[source]

Bases: object

Config for trainer input.

experience_buffer: StorageConfig | None = None
sft_warmup_dataset: StorageConfig | None = None
sft_warmup_steps: int = 0
__init__(experience_buffer: StorageConfig | None = None, sft_warmup_dataset: StorageConfig | None = None, sft_warmup_steps: int = 0) None
class trinity.common.config.BufferConfig(batch_size: int = 1, train_batch_size: int = 0, total_epochs: int = 1, total_steps: int | None = None, explorer_input: ~trinity.common.config.ExplorerInput = <factory>, trainer_input: ~trinity.common.config.TrainerInput = <factory>, max_retry_times: int = 3, max_retry_interval: int = 1, explorer_output: ~trinity.common.config.StorageConfig | None = None, tokenizer_path: str | None = None, pad_token_id: int | None = None, cache_dir: str | None = None)[source]

Bases: object

Config for buffer.

batch_size: int = 1
train_batch_size: int = 0
total_epochs: int = 1
total_steps: int | None = None
explorer_input: ExplorerInput
trainer_input: TrainerInput
max_retry_times: int = 3
max_retry_interval: int = 1
explorer_output: StorageConfig | None = None
tokenizer_path: str | None = None
pad_token_id: int | None = None
cache_dir: str | None = None
__init__(batch_size: int = 1, train_batch_size: int = 0, total_epochs: int = 1, total_steps: int | None = None, explorer_input: ~trinity.common.config.ExplorerInput = <factory>, trainer_input: ~trinity.common.config.TrainerInput = <factory>, max_retry_times: int = 3, max_retry_interval: int = 1, explorer_output: ~trinity.common.config.StorageConfig | None = None, tokenizer_path: str | None = None, pad_token_id: int | None = None, cache_dir: str | None = None) None
class trinity.common.config.ExplorerConfig(name: str = 'explorer', runner_per_model: int = 8, max_timeout: int = 1800, max_retry_times: int = 2, env_vars: dict = <factory>, max_repeat_times_per_runner: int | None = None, runner_num: int | None = None, rollout_model: ~trinity.common.config.InferenceModelConfig = <factory>, auxiliary_models: ~typing.List[~trinity.common.config.InferenceModelConfig] = <factory>, eval_interval: int = 100, eval_on_startup: bool = True, bench_on_latest_checkpoint: bool = False)[source]

Bases: object

Config for explorer.

name: str = 'explorer'
runner_per_model: int = 8
max_timeout: int = 1800
max_retry_times: int = 2
env_vars: dict
max_repeat_times_per_runner: int | None = None
runner_num: int | None = None
rollout_model: InferenceModelConfig
auxiliary_models: List[InferenceModelConfig]
eval_interval: int = 100
eval_on_startup: bool = True
bench_on_latest_checkpoint: bool = False
__init__(name: str = 'explorer', runner_per_model: int = 8, max_timeout: int = 1800, max_retry_times: int = 2, env_vars: dict = <factory>, max_repeat_times_per_runner: int | None = None, runner_num: int | None = None, rollout_model: ~trinity.common.config.InferenceModelConfig = <factory>, auxiliary_models: ~typing.List[~trinity.common.config.InferenceModelConfig] = <factory>, eval_interval: int = 100, eval_on_startup: bool = True, bench_on_latest_checkpoint: bool = False) None
class trinity.common.config.TrainerConfig(name: str = 'trainer', trainer_type: str = 'verl', save_interval: int = 0, enable_preview: bool = True, actor_grad_clip: Optional[float] = None, trainer_config: Any = <factory>, trainer_config_path: str = '')[source]

Bases: object

name: str = 'trainer'
trainer_type: str = 'verl'
save_interval: int = 0
enable_preview: bool = True
actor_grad_clip: float | None = None
trainer_config: Any
trainer_config_path: str = ''
__init__(name: str = 'trainer', trainer_type: str = 'verl', save_interval: int = 0, enable_preview: bool = True, actor_grad_clip: float | None = None, trainer_config: ~typing.Any = <factory>, trainer_config_path: str = '') None
class trinity.common.config.MonitorConfig(monitor_type: str = 'tensorboard', monitor_args: Dict | None = None, enable_ray_timeline: bool = False, cache_dir: str = '')[source]

Bases: object

monitor_type: str = 'tensorboard'
monitor_args: Dict | None = None
enable_ray_timeline: bool = False
cache_dir: str = ''
__init__(monitor_type: str = 'tensorboard', monitor_args: Dict | None = None, enable_ray_timeline: bool = False, cache_dir: str = '') None
class trinity.common.config.SynchronizerConfig(sync_method: SyncMethod = SyncMethod.NCCL, sync_style: SyncStyle = SyncStyle.FIXED, sync_interval: int = 1, sync_offset: int = 0, sync_timeout: int = 3600, wait_for_checkpoint: bool = False, explorer_world_size: int | None = None, ray_namespace: str = '')[source]

Bases: object

Configs for model weight synchronization.

sync_method: SyncMethod = 'nccl'
sync_style: SyncStyle = 'fixed'
sync_interval: int = 1
sync_offset: int = 0
sync_timeout: int = 3600
wait_for_checkpoint: bool = False
explorer_world_size: int | None = None
ray_namespace: str = ''
__init__(sync_method: SyncMethod = SyncMethod.NCCL, sync_style: SyncStyle = SyncStyle.FIXED, sync_interval: int = 1, sync_offset: int = 0, sync_timeout: int = 3600, wait_for_checkpoint: bool = False, explorer_world_size: int | None = None, ray_namespace: str = '') None
class trinity.common.config.DataJuicerServiceConfig(server_url: str | None = None, auto_start: bool = False, port: int | None = None)[source]

Bases: object

Config for Data-Juicer.

Please update trinity.service.data_juicer.server.server.py correspondingly if you change the fields here.

server_url: str | None = None
auto_start: bool = False
port: int | None = None
__init__(server_url: str | None = None, auto_start: bool = False, port: int | None = None) None
class trinity.common.config.ServiceConfig(data_juicer: DataJuicerServiceConfig | None = None)[source]

Bases: object

Configs for outside services.

data_juicer: DataJuicerServiceConfig | None = None
__init__(data_juicer: DataJuicerServiceConfig | None = None) None
class trinity.common.config.Config(mode: str = 'both', project: str = 'Trinity-RFT', group: str = '', name: str = 'rft', checkpoint_root_dir: str = '', checkpoint_job_dir: str = '', ray_namespace: str = '', continue_from_checkpoint: bool = True, algorithm: ~trinity.common.config.AlgorithmConfig = <factory>, data_processor: ~trinity.common.config.DataProcessorConfig = <factory>, model: ~trinity.common.config.ModelConfig = <factory>, cluster: ~trinity.common.config.ClusterConfig = <factory>, buffer: ~trinity.common.config.BufferConfig = <factory>, explorer: ~trinity.common.config.ExplorerConfig = <factory>, trainer: ~trinity.common.config.TrainerConfig = <factory>, monitor: ~trinity.common.config.MonitorConfig = <factory>, synchronizer: ~trinity.common.config.SynchronizerConfig = <factory>, service: ~trinity.common.config.ServiceConfig = <factory>)[source]

Bases: object

Global Configuration

mode: str = 'both'
project: str = 'Trinity-RFT'
group: str = ''
name: str = 'rft'
checkpoint_root_dir: str = ''
checkpoint_job_dir: str = ''
ray_namespace: str = ''
continue_from_checkpoint: bool = True
algorithm: AlgorithmConfig
data_processor: DataProcessorConfig
model: ModelConfig
cluster: ClusterConfig
buffer: BufferConfig
explorer: ExplorerConfig
trainer: TrainerConfig
monitor: MonitorConfig
synchronizer: SynchronizerConfig
service: ServiceConfig
save(config_path: str) None[source]

Save config to file.

check_and_update() None[source]

Check and update the config.

flatten() Dict[str, Any][source]

Flatten the config into a single-level dict with dot-separated keys for nested fields.

__init__(mode: str = 'both', project: str = 'Trinity-RFT', group: str = '', name: str = 'rft', checkpoint_root_dir: str = '', checkpoint_job_dir: str = '', ray_namespace: str = '', continue_from_checkpoint: bool = True, algorithm: ~trinity.common.config.AlgorithmConfig = <factory>, data_processor: ~trinity.common.config.DataProcessorConfig = <factory>, model: ~trinity.common.config.ModelConfig = <factory>, cluster: ~trinity.common.config.ClusterConfig = <factory>, buffer: ~trinity.common.config.BufferConfig = <factory>, explorer: ~trinity.common.config.ExplorerConfig = <factory>, trainer: ~trinity.common.config.TrainerConfig = <factory>, monitor: ~trinity.common.config.MonitorConfig = <factory>, synchronizer: ~trinity.common.config.SynchronizerConfig = <factory>, service: ~trinity.common.config.ServiceConfig = <factory>) None
trinity.common.config.load_config(config_path: str) Config[source]

Load the configuration from the given path.