trinity.common.config module

Contents

trinity.common.config module#

Configs for RFT.

class trinity.common.config.FormatConfig(prompt_type: PromptType = PromptType.MESSAGES, prompt_key: str = 'prompt', response_key: str = 'response', system_prompt_key: str | None = None, system_prompt: str | None = None, messages_key: str = 'message', tools_key: str = 'tools', image_key: str | None = None, video_key: str | None = None, reply_prefix: str | None = None, workflow_key: str = '', reward_fn_key: str = '', chosen_key: str = 'chosen', rejected_key: str = 'rejected', enable_concatenated_multi_turn: bool = False, chat_template: str | None = None)[source]#

Bases: object

Configuration for data formatting

prompt_type: PromptType = 'messages'#
prompt_key: str = 'prompt'#
response_key: str = 'response'#
system_prompt_key: str | None = None#
system_prompt: str | None = None#
messages_key: str = 'message'#
tools_key: str = 'tools'#
image_key: str | None = None#
video_key: str | None = None#
reply_prefix: str | None = None#
workflow_key: str = ''#
reward_fn_key: str = ''#
chosen_key: str = 'chosen'#
rejected_key: str = 'rejected'#
enable_concatenated_multi_turn: bool = False#
chat_template: str | None = None#
__init__(prompt_type: PromptType = PromptType.MESSAGES, prompt_key: str = 'prompt', response_key: str = 'response', system_prompt_key: str | None = None, system_prompt: str | None = None, messages_key: str = 'message', tools_key: str = 'tools', image_key: str | None = None, video_key: str | None = None, reply_prefix: str | None = None, workflow_key: str = '', reward_fn_key: str = '', chosen_key: str = 'chosen', rejected_key: str = 'rejected', enable_concatenated_multi_turn: bool = False, chat_template: str | None = None) None#
class trinity.common.config.GenerationConfig(temperature: float = 1.0, top_p: float = 1.0, top_k: int = -1, logprobs: int = 0, n: int = 1)[source]#

Bases: object

temperature: float = 1.0#
top_p: float = 1.0#
top_k: int = -1#
logprobs: int = 0#
n: int = 1#
__init__(temperature: float = 1.0, top_p: float = 1.0, top_k: int = -1, logprobs: int = 0, n: int = 1) None#
class trinity.common.config.StorageConfig(name: str = '', storage_type: ~trinity.common.constants.StorageType = StorageType.FILE, path: str | None = None, repeat_times: int | None = None, index: int = 0, mm_data_kwargs: dict = <factory>, split: str = 'train', subset_name: str | None = None, format: ~trinity.common.config.FormatConfig = <factory>, capacity: int = 10000, max_read_timeout: float = 1800, use_priority_queue: bool = False, reuse_cooldown_time: float | None = None, replay_buffer_kwargs: dict = <factory>, max_retry_times: int = 3, max_retry_interval: int = 1, default_workflow_type: str | None = None, default_eval_workflow_type: str | None = None, default_reward_fn_type: str | None = None, rollout_args: ~trinity.common.config.GenerationConfig = <factory>, workflow_args: dict = <factory>, reward_fn_args: dict = <factory>, enable_progress_bar: bool | None = False, ray_namespace: str | None = None, wrap_in_ray: bool = True, schema_type: str | None = None, total_epochs: int = 1, total_steps: int | None = None, is_eval: bool = False)[source]#

Bases: object

Storage config.

name: str = ''#
storage_type: StorageType = 'file'#
path: str | None = None#
repeat_times: int | None = None#
index: int = 0#
mm_data_kwargs: dict#
split: str = 'train'#
subset_name: str | None = None#
format: FormatConfig#
capacity: int = 10000#
max_read_timeout: float = 1800#
use_priority_queue: bool = False#
reuse_cooldown_time: float | None = None#
replay_buffer_kwargs: dict#
max_retry_times: int = 3#
max_retry_interval: int = 1#
default_workflow_type: str | None = None#
default_eval_workflow_type: str | None = None#
default_reward_fn_type: str | None = None#
rollout_args: GenerationConfig#
workflow_args: dict#
reward_fn_args: dict#
enable_progress_bar: bool | None = False#
ray_namespace: str | None = None#
wrap_in_ray: bool = True#
schema_type: str | None = None#
total_epochs: int = 1#
total_steps: int | None = None#
is_eval: bool = False#
__init__(name: str = '', storage_type: ~trinity.common.constants.StorageType = StorageType.FILE, path: str | None = None, repeat_times: int | None = None, index: int = 0, mm_data_kwargs: dict = <factory>, split: str = 'train', subset_name: str | None = None, format: ~trinity.common.config.FormatConfig = <factory>, capacity: int = 10000, max_read_timeout: float = 1800, use_priority_queue: bool = False, reuse_cooldown_time: float | None = None, replay_buffer_kwargs: dict = <factory>, max_retry_times: int = 3, max_retry_interval: int = 1, default_workflow_type: str | None = None, default_eval_workflow_type: str | None = None, default_reward_fn_type: str | None = None, rollout_args: ~trinity.common.config.GenerationConfig = <factory>, workflow_args: dict = <factory>, reward_fn_args: dict = <factory>, enable_progress_bar: bool | None = False, ray_namespace: str | None = None, wrap_in_ray: bool = True, schema_type: str | None = None, total_epochs: int = 1, total_steps: int | None = None, is_eval: bool = False) None#
class trinity.common.config.OperatorConfig(name: str = '', args: Dict[str, Any] = <factory>)[source]#

Bases: object

name: str = ''#
args: Dict[str, Any]#
__init__(name: str = '', args: ~typing.Dict[str, ~typing.Any] = <factory>) None#
class trinity.common.config.ExperiencePipelineConfig(operators: ~typing.List[~trinity.common.config.OperatorConfig] = <factory>, save_input: bool = True, input_save_path: str | None = None, inputs: ~typing.Dict[str, ~trinity.common.config.StorageConfig] = <factory>, output: ~trinity.common.config.StorageConfig | None = None)[source]#

Bases: object

Config for experience pipeline.

Experience Pipeline is used to pre-process rollout experiences for better training.

operators: List[OperatorConfig]#
save_input: bool = True#
input_save_path: str | None = None#
inputs: Dict[str, StorageConfig]#
output: StorageConfig | None = None#
__init__(operators: ~typing.List[~trinity.common.config.OperatorConfig] = <factory>, save_input: bool = True, input_save_path: str | None = None, inputs: ~typing.Dict[str, ~trinity.common.config.StorageConfig] = <factory>, output: ~trinity.common.config.StorageConfig | None = None) None#
class trinity.common.config.TaskPipelineConfig(operators: ~typing.List[~trinity.common.config.OperatorConfig] = <factory>, num_process: int = 4, config_path: str | None = None, inputs: ~typing.List[str] = <factory>, output: ~trinity.common.config.StorageConfig | None = None, target_fields: ~typing.List[str] = <factory>, priority_weights: ~typing.Dict[str, float] = <factory>, top_k: int = -1)[source]#

Bases: object

Config for task pipeline.

Task Pipeline is used to pre-process raw tasks for better exploring. Currently, we only support using Data-Juicer operators for task pipeline.

operators: List[OperatorConfig]#
num_process: int = 4#
config_path: str | None = None#
inputs: List[str]#
output: StorageConfig | None = None#
target_fields: List[str]#
priority_weights: Dict[str, float]#
top_k: int = -1#
__init__(operators: ~typing.List[~trinity.common.config.OperatorConfig] = <factory>, num_process: int = 4, config_path: str | None = None, inputs: ~typing.List[str] = <factory>, output: ~trinity.common.config.StorageConfig | None = None, target_fields: ~typing.List[str] = <factory>, priority_weights: ~typing.Dict[str, float] = <factory>, top_k: int = -1) None#
class trinity.common.config.DataProcessorConfig(task_pipeline: ~trinity.common.config.TaskPipelineConfig | None = None, experience_pipeline: ~trinity.common.config.ExperiencePipelineConfig | None = <factory>)[source]#

Bases: object

Data Processor config

task_pipeline: TaskPipelineConfig | None = None#
experience_pipeline: ExperiencePipelineConfig | None#
__init__(task_pipeline: ~trinity.common.config.TaskPipelineConfig | None = None, experience_pipeline: ~trinity.common.config.ExperiencePipelineConfig | None = <factory>) None#
class trinity.common.config.ModelConfig(model_path: str = '', critic_model_path: str = '', max_model_len: int | None = None, max_prompt_tokens: int | None = None, max_response_tokens: int | None = None, min_response_tokens: int = 1, custom_chat_template: str | None = None)[source]#

Bases: object

model_path: str = ''#
critic_model_path: str = ''#
max_model_len: int | None = None#
max_prompt_tokens: int | None = None#
max_response_tokens: int | None = None#
min_response_tokens: int = 1#
custom_chat_template: str | None = None#
__init__(model_path: str = '', critic_model_path: str = '', max_model_len: int | None = None, max_prompt_tokens: int | None = None, max_response_tokens: int | None = None, min_response_tokens: int = 1, custom_chat_template: str | None = None) None#
class trinity.common.config.InferenceModelConfig(model_path: str = '', engine_type: str = 'vllm_async', engine_num: int = 1, tensor_parallel_size: int = 1, use_v1: bool = True, enforce_eager: bool = True, enable_prefix_caching: bool = False, enable_chunked_prefill: bool = False, gpu_memory_utilization: float = 0.9, dtype: str = 'bfloat16', seed: int = 42, max_model_len: int | None = None, max_prompt_tokens: int | None = None, max_response_tokens: int | None = None, min_response_tokens: int | None = None, ignore_eos: bool = False, chat_template: str | None = None, enable_thinking: bool = False, enable_history: bool = False, enable_openai_api: bool = False, enable_auto_tool_choice: bool = False, tool_call_parser: str | None = None, reasoning_parser: str | None = None, bundle_indices: str = '')[source]#

Bases: object

model_path: str = ''#
engine_type: str = 'vllm_async'#
engine_num: int = 1#
tensor_parallel_size: int = 1#
use_v1: bool = True#
enforce_eager: bool = True#
enable_prefix_caching: bool = False#
enable_chunked_prefill: bool = False#
gpu_memory_utilization: float = 0.9#
dtype: str = 'bfloat16'#
seed: int = 42#
max_model_len: int | None = None#
max_prompt_tokens: int | None = None#
max_response_tokens: int | None = None#
min_response_tokens: int | None = None#
ignore_eos: bool = False#
chat_template: str | None = None#
enable_thinking: bool = False#
enable_history: bool = False#
enable_openai_api: bool = False#
enable_auto_tool_choice: bool = False#
tool_call_parser: str | None = None#
reasoning_parser: str | None = None#
bundle_indices: str = ''#
__init__(model_path: str = '', engine_type: str = 'vllm_async', engine_num: int = 1, tensor_parallel_size: int = 1, use_v1: bool = True, enforce_eager: bool = True, enable_prefix_caching: bool = False, enable_chunked_prefill: bool = False, gpu_memory_utilization: float = 0.9, dtype: str = 'bfloat16', seed: int = 42, max_model_len: int | None = None, max_prompt_tokens: int | None = None, max_response_tokens: int | None = None, min_response_tokens: int | None = None, ignore_eos: bool = False, chat_template: str | None = None, enable_thinking: bool = False, enable_history: bool = False, enable_openai_api: bool = False, enable_auto_tool_choice: bool = False, tool_call_parser: str | None = None, reasoning_parser: str | None = None, bundle_indices: str = '') None#
class trinity.common.config.AlgorithmConfig(algorithm_type: str = 'ppo', repeat_times: int = 1, sample_strategy: str | None = None, sample_strategy_args: dict | None = None, advantage_fn: str | None = None, advantage_fn_args: dict | None = None, kl_penalty_fn: str | None = None, kl_penalty_fn_args: dict | None = None, policy_loss_fn: str | None = None, policy_loss_fn_args: dict | None = None, kl_loss_fn: str | None = None, kl_loss_fn_args: dict | None = None, entropy_loss_fn: str | None = None, entropy_loss_fn_args: dict | None = None, use_token_level_loss: bool = True)[source]#

Bases: object

Config for algorithm.

algorithm_type: str = 'ppo'#
repeat_times: int = 1#
sample_strategy: str | None = None#
sample_strategy_args: dict | None = None#
advantage_fn: str | None = None#
advantage_fn_args: dict | None = None#
kl_penalty_fn: str | None = None#
kl_penalty_fn_args: dict | None = None#
policy_loss_fn: str | None = None#
policy_loss_fn_args: dict | None = None#
kl_loss_fn: str | None = None#
kl_loss_fn_args: dict | None = None#
entropy_loss_fn: str | None = None#
entropy_loss_fn_args: dict | None = None#
use_token_level_loss: bool = True#
__init__(algorithm_type: str = 'ppo', repeat_times: int = 1, sample_strategy: str | None = None, sample_strategy_args: dict | None = None, advantage_fn: str | None = None, advantage_fn_args: dict | None = None, kl_penalty_fn: str | None = None, kl_penalty_fn_args: dict | None = None, policy_loss_fn: str | None = None, policy_loss_fn_args: dict | None = None, kl_loss_fn: str | None = None, kl_loss_fn_args: dict | None = None, entropy_loss_fn: str | None = None, entropy_loss_fn_args: dict | None = None, use_token_level_loss: bool = True) None#
class trinity.common.config.ClusterConfig(node_num: int = 1, gpu_per_node: int = 8)[source]#

Bases: object

Config for the cluster.

node_num: int = 1#
gpu_per_node: int = 8#
__init__(node_num: int = 1, gpu_per_node: int = 8) None#
class trinity.common.config.ExplorerInput(taskset: ~trinity.common.config.StorageConfig = <factory>, eval_tasksets: ~typing.List[~trinity.common.config.StorageConfig] = <factory>, default_workflow_type: str | None = None, default_eval_workflow_type: str | None = None, default_reward_fn_type: str | None = None, system_prompt: str | None = None, reply_prefix: str | None = None)[source]#

Bases: object

Config for explorer input.

taskset: StorageConfig#
eval_tasksets: List[StorageConfig]#
default_workflow_type: str | None = None#
default_eval_workflow_type: str | None = None#
default_reward_fn_type: str | None = None#
system_prompt: str | None = None#
reply_prefix: str | None = None#
__init__(taskset: ~trinity.common.config.StorageConfig = <factory>, eval_tasksets: ~typing.List[~trinity.common.config.StorageConfig] = <factory>, default_workflow_type: str | None = None, default_eval_workflow_type: str | None = None, default_reward_fn_type: str | None = None, system_prompt: str | None = None, reply_prefix: str | None = None) None#
class trinity.common.config.TrainerInput(experience_buffer: StorageConfig | None = None, sft_warmup_dataset: StorageConfig | None = None, sft_warmup_steps: int = 0)[source]#

Bases: object

Config for trainer input.

experience_buffer: StorageConfig | None = None#
sft_warmup_dataset: StorageConfig | None = None#
sft_warmup_steps: int = 0#
__init__(experience_buffer: StorageConfig | None = None, sft_warmup_dataset: StorageConfig | None = None, sft_warmup_steps: int = 0) None#
class trinity.common.config.BufferConfig(batch_size: int = 1, train_batch_size: int = 0, total_epochs: int = 1, total_steps: int | None = None, explorer_input: ~trinity.common.config.ExplorerInput = <factory>, trainer_input: ~trinity.common.config.TrainerInput = <factory>, explorer_output: ~trinity.common.config.StorageConfig | None = None, tokenizer_path: str | None = None, pad_token_id: int | None = None, cache_dir: str | None = None)[source]#

Bases: object

Config for buffer.

batch_size: int = 1#
train_batch_size: int = 0#
total_epochs: int = 1#
total_steps: int | None = None#
explorer_input: ExplorerInput#
trainer_input: TrainerInput#
explorer_output: StorageConfig | None = None#
tokenizer_path: str | None = None#
pad_token_id: int | None = None#
cache_dir: str | None = None#
__init__(batch_size: int = 1, train_batch_size: int = 0, total_epochs: int = 1, total_steps: int | None = None, explorer_input: ~trinity.common.config.ExplorerInput = <factory>, trainer_input: ~trinity.common.config.TrainerInput = <factory>, explorer_output: ~trinity.common.config.StorageConfig | None = None, tokenizer_path: str | None = None, pad_token_id: int | None = None, cache_dir: str | None = None) None#
class trinity.common.config.ExplorerConfig(name: str = 'explorer', runner_per_model: int = 8, max_timeout: int = 1800, max_retry_times: int = 2, env_vars: dict = <factory>, max_repeat_times_per_runner: int | None = None, runner_num: int | None = None, rollout_model: ~trinity.common.config.InferenceModelConfig = <factory>, auxiliary_models: ~typing.List[~trinity.common.config.InferenceModelConfig] = <factory>, eval_interval: int = 100, eval_on_startup: bool = True, bench_on_latest_checkpoint: bool = False)[source]#

Bases: object

Config for explorer.

name: str = 'explorer'#
runner_per_model: int = 8#
max_timeout: int = 1800#
max_retry_times: int = 2#
env_vars: dict#
max_repeat_times_per_runner: int | None = None#
runner_num: int | None = None#
rollout_model: InferenceModelConfig#
auxiliary_models: List[InferenceModelConfig]#
eval_interval: int = 100#
eval_on_startup: bool = True#
bench_on_latest_checkpoint: bool = False#
__init__(name: str = 'explorer', runner_per_model: int = 8, max_timeout: int = 1800, max_retry_times: int = 2, env_vars: dict = <factory>, max_repeat_times_per_runner: int | None = None, runner_num: int | None = None, rollout_model: ~trinity.common.config.InferenceModelConfig = <factory>, auxiliary_models: ~typing.List[~trinity.common.config.InferenceModelConfig] = <factory>, eval_interval: int = 100, eval_on_startup: bool = True, bench_on_latest_checkpoint: bool = False) None#
class trinity.common.config.TrainerConfig(name: str = 'trainer', trainer_type: str = 'verl', save_interval: int = 0, enable_preview: bool = True, actor_grad_clip: Optional[float] = None, trainer_config: Any = <factory>, trainer_config_path: str = '')[source]#

Bases: object

name: str = 'trainer'#
trainer_type: str = 'verl'#
save_interval: int = 0#
enable_preview: bool = True#
actor_grad_clip: float | None = None#
trainer_config: Any#
trainer_config_path: str = ''#
__init__(name: str = 'trainer', trainer_type: str = 'verl', save_interval: int = 0, enable_preview: bool = True, actor_grad_clip: float | None = None, trainer_config: ~typing.Any = <factory>, trainer_config_path: str = '') None#
class trinity.common.config.MonitorConfig(monitor_type: str = 'tensorboard', monitor_args: Dict | None = None, enable_ray_timeline: bool = False, cache_dir: str = '')[source]#

Bases: object

monitor_type: str = 'tensorboard'#
monitor_args: Dict | None = None#
enable_ray_timeline: bool = False#
cache_dir: str = ''#
__init__(monitor_type: str = 'tensorboard', monitor_args: Dict | None = None, enable_ray_timeline: bool = False, cache_dir: str = '') None#
class trinity.common.config.SynchronizerConfig(sync_method: SyncMethod = SyncMethod.NCCL, sync_style: SyncStyle = SyncStyle.FIXED, sync_interval: int = 1, sync_offset: int = 0, sync_timeout: int = 3600, wait_for_checkpoint: bool = False, explorer_world_size: int | None = None, ray_namespace: str = '')[source]#

Bases: object

Configs for model weight synchronization.

sync_method: SyncMethod = 'nccl'#
sync_style: SyncStyle = 'fixed'#
sync_interval: int = 1#
sync_offset: int = 0#
sync_timeout: int = 3600#
wait_for_checkpoint: bool = False#
explorer_world_size: int | None = None#
ray_namespace: str = ''#
__init__(sync_method: SyncMethod = SyncMethod.NCCL, sync_style: SyncStyle = SyncStyle.FIXED, sync_interval: int = 1, sync_offset: int = 0, sync_timeout: int = 3600, wait_for_checkpoint: bool = False, explorer_world_size: int | None = None, ray_namespace: str = '') None#
class trinity.common.config.DataJuicerServiceConfig(server_url: str | None = None, auto_start: bool = False, port: int | None = None)[source]#

Bases: object

Config for Data-Juicer.

Please update trinity.service.data_juicer.server.server.py correspondingly if you change the fields here.

server_url: str | None = None#
auto_start: bool = False#
port: int | None = None#
__init__(server_url: str | None = None, auto_start: bool = False, port: int | None = None) None#
class trinity.common.config.ServiceConfig(data_juicer: DataJuicerServiceConfig | None = None)[source]#

Bases: object

Configs for outside services.

data_juicer: DataJuicerServiceConfig | None = None#
__init__(data_juicer: DataJuicerServiceConfig | None = None) None#
class trinity.common.config.LogConfig(level: str = 'INFO', group_by_node: bool = False, save_dir: str = '')[source]#

Bases: object

Configs for logger.

level: str = 'INFO'#
group_by_node: bool = False#
save_dir: str = ''#
__init__(level: str = 'INFO', group_by_node: bool = False, save_dir: str = '') None#
class trinity.common.config.Config(mode: str = 'both', project: str = 'Trinity-RFT', group: str = '', name: str = 'rft', checkpoint_root_dir: str = '', checkpoint_job_dir: str = '', ray_namespace: str = '', continue_from_checkpoint: bool = True, algorithm: ~trinity.common.config.AlgorithmConfig = <factory>, data_processor: ~trinity.common.config.DataProcessorConfig = <factory>, model: ~trinity.common.config.ModelConfig = <factory>, cluster: ~trinity.common.config.ClusterConfig = <factory>, buffer: ~trinity.common.config.BufferConfig = <factory>, explorer: ~trinity.common.config.ExplorerConfig = <factory>, trainer: ~trinity.common.config.TrainerConfig = <factory>, monitor: ~trinity.common.config.MonitorConfig = <factory>, synchronizer: ~trinity.common.config.SynchronizerConfig = <factory>, service: ~trinity.common.config.ServiceConfig = <factory>, log: ~trinity.common.config.LogConfig = <factory>)[source]#

Bases: object

Global Configuration

mode: str = 'both'#
project: str = 'Trinity-RFT'#
group: str = ''#
name: str = 'rft'#
checkpoint_root_dir: str = ''#
checkpoint_job_dir: str = ''#
ray_namespace: str = ''#
continue_from_checkpoint: bool = True#
algorithm: AlgorithmConfig#
data_processor: DataProcessorConfig#
model: ModelConfig#
cluster: ClusterConfig#
buffer: BufferConfig#
explorer: ExplorerConfig#
trainer: TrainerConfig#
monitor: MonitorConfig#
synchronizer: SynchronizerConfig#
service: ServiceConfig#
log: LogConfig#
save(config_path: str) None[source]#

Save config to file.

check_and_update() None[source]#

Check and update the config.

flatten() Dict[str, Any][source]#

Flatten the config into a single-level dict with dot-separated keys for nested fields.

__init__(mode: str = 'both', project: str = 'Trinity-RFT', group: str = '', name: str = 'rft', checkpoint_root_dir: str = '', checkpoint_job_dir: str = '', ray_namespace: str = '', continue_from_checkpoint: bool = True, algorithm: ~trinity.common.config.AlgorithmConfig = <factory>, data_processor: ~trinity.common.config.DataProcessorConfig = <factory>, model: ~trinity.common.config.ModelConfig = <factory>, cluster: ~trinity.common.config.ClusterConfig = <factory>, buffer: ~trinity.common.config.BufferConfig = <factory>, explorer: ~trinity.common.config.ExplorerConfig = <factory>, trainer: ~trinity.common.config.TrainerConfig = <factory>, monitor: ~trinity.common.config.MonitorConfig = <factory>, synchronizer: ~trinity.common.config.SynchronizerConfig = <factory>, service: ~trinity.common.config.ServiceConfig = <factory>, log: ~trinity.common.config.LogConfig = <factory>) None#
trinity.common.config.load_config(config_path: str) Config[source]#

Load the configuration from the given path.