trinity.common.workflows.on_policy_distill_workflow module#
On-Policy Distillation Workflow.
Reference: Tinker library’s on-policy distillation implementation.
Algorithm: 1. Student samples trajectories (with logprobs) 2. Teacher computes logprobs on same trajectories 3. Store teacher_logprobs in experience.info[“teacher_logprobs”] 4. Trainer’s advantage_fn computes: advantages = teacher_logprobs - student_logprobs 5. Train with importance_sampling loss
- class trinity.common.workflows.on_policy_distill_workflow.OnPolicyDistillWorkflow(*, task: Task, model: ModelWrapper, auxiliary_models: List[ModelWrapper] | None = None)[source]#
Bases:
WorkflowOn-policy distillation workflow.
Computes and stores teacher_logprobs in experience.info. The advantage_fn in trainer will compute:
advantages = teacher_logprobs - student_logprobs
Note: This workflow does NOT use reward_fn because: - Advantage is computed from teacher-student logprobs difference - No external reward signal is needed
- is_async: bool = True#
- can_reset: bool = True#
- can_repeat: bool = True#
- __init__(*, task: Task, model: ModelWrapper, auxiliary_models: List[ModelWrapper] | None = None)[source]#
- reset(task: Task)[source]#
Reset the workflow with a new task.
Unlike BaseSimpleWorkflow, this does NOT require reward_fn.
- set_repeat_times(repeat_times, run_id_base)[source]#
Set the number of times to repeat the workflow. :param repeat_times: number of times to repeat the workflow (if repeatable). :type repeat_times: int :param run_id_base: base run_id for setting run_id in experiences. :type run_id_base: int
- property rollout_args#
- format_messages()[source]#
Format messages for the instruct model.
Default format: system_prompt (optional) + task_desc + reply_prefix (optional)
- compute_reward(response: Experience) float[source]#
Compute reward for a response.
In base class, returns 0.0 as advantage is computed from teacher-student logprobs. Subclasses can override this to compute actual rewards.
- async run_async() List[Experience][source]#
Run workflow in async and return a list of experiences.
- class trinity.common.workflows.on_policy_distill_workflow.OnPolicyDistillMathWorkflow(*, task: Task, model: ModelWrapper, auxiliary_models: List[ModelWrapper] | None = None)[source]#
Bases:
OnPolicyDistillWorkflowOn-policy distillation workflow with Qwen2.5-Math style format.
This workflow: - Uses Qwen2.5-Math style prompt format (same as math_eval_workflow) - Computes accuracy using verify_math_answer as reward - Suitable for math reasoning tasks like GSM8K, MATH, etc.
- format_messages()[source]#
Format messages using Qwen2.5-Math style.
System prompt: “You are a helpful assistant.” User prompt: “{question}
Please reason step by step, and put your final answer within boxed{}.”
- compute_reward(response: Experience) float[source]#
Compute accuracy as reward using Qwen2.5-Math evaluation.
Returns 1.0 if answer is correct, 0.0 otherwise.