trinity.common.workflows.on_policy_distill_workflow module#

On-Policy Distillation Workflow.

Reference: Tinker library’s on-policy distillation implementation.

Algorithm: 1. Student samples trajectories (with logprobs) 2. Teacher computes logprobs on same trajectories 3. Store teacher_logprobs in experience.info[“teacher_logprobs”] 4. Trainer’s advantage_fn computes: advantages = teacher_logprobs - student_logprobs 5. Train with importance_sampling loss

class trinity.common.workflows.on_policy_distill_workflow.OnPolicyDistillWorkflow(*, task: Task, model: ModelWrapper, auxiliary_models: List[ModelWrapper] | None = None)[source]#

Bases: Workflow

On-policy distillation workflow.

Computes and stores teacher_logprobs in experience.info. The advantage_fn in trainer will compute:

advantages = teacher_logprobs - student_logprobs

Note: This workflow does NOT use reward_fn because: - Advantage is computed from teacher-student logprobs difference - No external reward signal is needed

is_async: bool = True#
can_reset: bool = True#
can_repeat: bool = True#
__init__(*, task: Task, model: ModelWrapper, auxiliary_models: List[ModelWrapper] | None = None)[source]#
reset(task: Task)[source]#

Reset the workflow with a new task.

Unlike BaseSimpleWorkflow, this does NOT require reward_fn.

set_repeat_times(repeat_times, run_id_base)[source]#

Set the number of times to repeat the workflow. :param repeat_times: number of times to repeat the workflow (if repeatable). :type repeat_times: int :param run_id_base: base run_id for setting run_id in experiences. :type run_id_base: int

property rollout_args#
format_messages()[source]#

Format messages for the instruct model.

Default format: system_prompt (optional) + task_desc + reply_prefix (optional)

compute_reward(response: Experience) float[source]#

Compute reward for a response.

In base class, returns 0.0 as advantage is computed from teacher-student logprobs. Subclasses can override this to compute actual rewards.

async run_async() List[Experience][source]#

Run workflow in async and return a list of experiences.

class trinity.common.workflows.on_policy_distill_workflow.OnPolicyDistillMathWorkflow(*, task: Task, model: ModelWrapper, auxiliary_models: List[ModelWrapper] | None = None)[source]#

Bases: OnPolicyDistillWorkflow

On-policy distillation workflow with Qwen2.5-Math style format.

This workflow: - Uses Qwen2.5-Math style prompt format (same as math_eval_workflow) - Computes accuracy using verify_math_answer as reward - Suitable for math reasoning tasks like GSM8K, MATH, etc.

format_messages()[source]#

Format messages using Qwen2.5-Math style.

System prompt: “You are a helpful assistant.” User prompt: “{question}

Please reason step by step, and put your final answer within boxed{}.”

compute_reward(response: Experience) float[source]#

Compute accuracy as reward using Qwen2.5-Math evaluation.

Returns 1.0 if answer is correct, 0.0 otherwise.