Algorithm Development

Note

This guide is an advanced version of the Algorithms section in the Developer Guide.

This guide introduces how to integrate a new algorithm to Trinity-RFT. As an example, we incorporate some “expert” data generated by a more advanced LLM and propose an algorithm named MIX , which optimizes the following policy objective:

\[ \mathcal{J}_{\text{Mix}}(\theta) = (1-\mu) \mathcal{J}_{\text{GRPO}}(\theta) + \mu \cdot \underbrace{\frac{1}{B'} \sum_{b=1}^{B'} \left[ \frac{1}{T'_b} \sum_{t=1}^{T'_b} \log \pi_\theta(o'_{b,t} \mid q'_b, o'_{b,<t}) \right]}_{\text{Auxiliary Loss on Expert Data}}. \]

The first term corresponds to the standard GRPO objective, which aims to maximize the expected reward. The last term is an auxiliary loss defined on expert data, encouraging the policy to imitate expert behavior. \(\mu\) is a weighting factor that controls the relative importance of the two terms.

Step 0: Prepare the Expert Data

We prompt a powerful LLM to generate responses with the CoT process for some pre-defined questions. The collected dta are viewed as some experiences from an expert. We store them in a jsonl file expert_data.jsonl with the following format:

{
    "messages": [
    { "role": "system", "content": <system_prompt> },
    { "role": "user", "content": "What is the sum of 4 and 12?" },
    { "role": "assistant", "content": "<think>thinking process...</think>\n<answer>16</answer>" } ]
},
...

The path to expert data is passed to buffer.trainer_input.sft_warmup_dataset for later use.

Step 1: Define the Algorithm

In trinity/algorithm/algorithm.py, we introduce a new algorithm type MIX.

@ALGORITHM_TYPE.register_module("mix")
class MIXAlgorithm(AlgorithmType):
    """MIX algorithm."""

    use_critic: bool = False
    use_reference: bool = True
    use_advantage: bool = True
    can_balance_batch: bool = True
    schema: type = ExperienceModel

    @classmethod
    def default_config(cls) -> Dict:
        return {
            "repeat_times": 8,
            "policy_loss_fn": "mix",
            "advantage_fn": "grpo",
            "sample_strategy": "mix",
        }

Step 2: Define the Sampling Strategy

We need to read two kinds of experiences: usual experiences and expert experiences in each step. For this purpose, we define a new experience sampling strategy named MixSampleStrategy.

class MixSampleStrategy(SampleStrategy):
    """The default sample strategy."""

    def __init__(self, buffer_config: BufferConfig, trainer_type: str, **kwargs):
        super().__init__(buffer_config, trainer_type)
        self.expert_data_ratio = kwargs.get("expert_data_ratio", 0.5)
        tot_batch_size = buffer_config.read_batch_size
        expert_batch_size = ceil(self.expert_data_ratio * tot_batch_size)

        # experience buffer
        usual_buffer_config = copy.deepcopy(buffer_config)
        usual_buffer_config.read_batch_size = tot_batch_size - expert_batch_size
        self.usual_exp_buffer = get_buffer_reader(
            buffer_config.trainer_input.experience_buffer, usual_buffer_config  # type: ignore
        )

        if buffer_config.trainer_input.sft_warmup_dataset is None:
            raise ValueError(
                "`buffer_config.trainer_input.sft_warmup_dataset` is required in MIX algorithm"
            )

        # expert experience buffer
        expert_buffer_config = copy.deepcopy(buffer_config)
        expert_buffer_config.read_batch_size = expert_batch_size
        self.expert_exp_buffer = get_buffer_reader(
            buffer_config.trainer_input.sft_warmup_dataset, expert_buffer_config
        )

    def sample(self, step: int) -> Tuple[Any, Dict, List]:
        metrics = {}
        with Timer(metrics, "read_time"):
            usual_exp_list = self.usual_exp_buffer.read()
            for exp in usual_exp_list:
                if exp.info is None:
                    exp.info = {}
                exp.info["is_expert"] = False

            expert_exp_list = self.expert_exp_buffer.read()
            for exp in expert_exp_list:
                exp.reward = 0.0
                exp.logprobs = torch.zeros_like(exp.tokens, dtype=torch.float32)
                if exp.info is None:
                    exp.info = {}
                exp.info["is_expert"] = True

            exp_list = usual_exp_list + expert_exp_list
            repr_samples = representative_sample(exp_list)

        is_expert_mask = torch.tensor([exp.info["is_expert"] for exp in exp_list], dtype=torch.bool)

        with Timer(metrics, "gather_time"):
            exps = Experiences.gather_experiences(exp_list, self.pad_token_id)  # type: ignore

        if self.trainer_type == "verl":
            with Timer(metrics, "convert_time"):
                data = to_data_proto_mix(exps, is_expert_mask)
            return data, metrics, repr_samples
        else:
            raise NotImplementedError(f"backend {self.trainer_type} is not supported")

We also need to add an is_expert_mask field when transforming to DataProto to indicate the data type.

+ def to_data_proto_mix(experiences: Experiences, is_expert_mask: torch.tensor) -> DataProto:
    attention_mask = experiences.attention_masks
    cumsum = torch.cumsum(attention_mask, dim=-1)
    position_ids = torch.clip(cumsum - 1, 0, None).long()
    batch_dict = {
        "uid": np.array(experiences.run_ids),
        "position_ids": position_ids,
        "input_ids": experiences.tokens.long(),
        "responses": experiences.tokens[:, experiences.prompt_length :].long(),
        "attention_mask": attention_mask.long(),
        "response_mask": (
            experiences.action_masks[:, experiences.prompt_length :].long()
            if hasattr(experiences, "action_masks") and experiences.action_masks is not None
            else attention_mask[:, experiences.prompt_length :].long()
        ),
+       "is_expert_mask": is_expert_mask,
    }
    if experiences.rewards is not None:
        token_level_rewards = torch.zeros(attention_mask.shape, dtype=experiences.rewards.dtype)
        eos_mask_idx = cumsum.argmax(dim=-1)
        token_level_rewards[
            torch.arange(experiences.batch_size), eos_mask_idx
        ] = experiences.rewards
        token_level_rewards = token_level_rewards[:, experiences.prompt_length :]
        batch_dict.update(
            {
                "token_level_scores": token_level_rewards,
                "old_log_probs": experiences.logprobs[:, experiences.prompt_length :],  # type: ignore
            }
        )
    return DataProto.from_single_dict(batch_dict)

Step 3: Define the Policy Loss Function

We define a MixPolicyLoss class in trinity/algorithm/policy_loss_fn/mix_policy_loss.py, which computes the sum of two loss terms regarding usual and expert experiences, respectively.

@POLICY_LOSS_FN.register_module("mix")
class MIXPolicyLossFn(PolicyLossFn):
    def __init__(
        self,
        backend: str = "verl",
        mu: float = 0.1,
        clip_range: Optional[float] = None,
        clip_range_low: Optional[float] = None,
        clip_range_high: Optional[float] = None,
        use_dynamic_bsz: Optional[bool] = None,
        repeat_times: Optional[int] = None,
        ppo_mini_batch_size: Optional[int] = None,
        ppo_micro_batch_size_per_gpu: Optional[int] = None,
        ngpus_trainer: Optional[int] = None,
        read_batch_size_usual: Optional[int] = None,
        read_batch_size_expert: Optional[int] = None,
        use_token_level_loss_in_sft: bool = True,
    ) -> None:
        super().__init__(backend=backend)
        self.mu = mu
        self.use_dynamic_bsz = use_dynamic_bsz
        self.experience_per_gpu = ppo_mini_batch_size * repeat_times // ngpus_trainer  # type: ignore
        self.gradient_accumulation = (
            ppo_mini_batch_size * repeat_times // ppo_micro_batch_size_per_gpu  # type: ignore
        )
        self.read_batch_size_usual = read_batch_size_usual
        self.read_batch_size_expert = read_batch_size_expert
        self.grpo_loss_fn = PPOPolicyLossFn(
            clip_range=clip_range,
            clip_range_low=clip_range_low,
            clip_range_high=clip_range_high,
        )
        self.sft_loss_fn = SFTLossFn(use_token_level_loss=use_token_level_loss_in_sft)

    def __call__(  # type: ignore
        self,
        logprob: torch.Tensor,
        old_logprob: torch.Tensor,
        action_mask: torch.Tensor,
        advantages: torch.Tensor,
        is_expert_mask: torch.Tensor,
        **kwargs,
    ) -> Tuple[torch.Tensor, Dict]:
        assert (
            len(is_expert_mask) == logprob.shape[0]
        ), f"Error: {len(is_expert_mask)=} != {logprob.shape[0]=}"

        n_usual_exp = torch.sum(~is_expert_mask).item()
        n_expert_exp = torch.sum(is_expert_mask).item()

        if self.use_dynamic_bsz:
            per_micro_batch_weight_usual = self.experience_per_gpu / (
                logprob.shape[0] * self.read_batch_size_usual
            )
            per_micro_batch_weight_expert = self.experience_per_gpu / (
                logprob.shape[0] * self.read_batch_size_expert
            )
        else:
            per_micro_batch_weight_usual = self.gradient_accumulation / self.read_batch_size_usual  # type: ignore
            per_micro_batch_weight_expert = self.gradient_accumulation / self.read_batch_size_expert  # type: ignore

        if n_usual_exp > 0:
            grpo_loss, grpo_metrics = self.grpo_loss_fn(
                logprob[~is_expert_mask],
                old_logprob[~is_expert_mask],
                action_mask[~is_expert_mask],
                advantages[~is_expert_mask],
                **kwargs,
            )
            grpo_loss = grpo_loss * n_usual_exp * per_micro_batch_weight_usual
            grpo_metrics = {
                k: v * n_usual_exp * per_micro_batch_weight_usual for k, v in grpo_metrics.items()
            }
        else:
            grpo_loss = torch.tensor(0.0, device=logprob.device)
            grpo_metrics = {}

        # SFT Loss (expert)
        if n_expert_exp > 0:
            sft_loss, sft_metrics = self.sft_loss_fn(
                logprob[is_expert_mask],
                action_mask[is_expert_mask],
            )
            sft_loss = sft_loss * n_expert_exp * per_micro_batch_weight_expert
            sft_metrics = {
                k: v * n_expert_exp * per_micro_batch_weight_expert for k, v in sft_metrics.items()
            }
        else:
            sft_loss = torch.tensor(0.0, device=logprob.device)
            sft_metrics = {}

        loss = (1 - self.mu) * grpo_loss + self.mu * sft_loss

        metrics = {f"usual/{k}": v for k, v in grpo_metrics.items()}
        metrics.update({f"expert/{k}": v for k, v in sft_metrics.items()})
        metrics.update({"loss": loss.item()})

        return loss, metrics

    @classmethod
    def default_args(cls) -> Dict:
        return {
            "mu": 0.1,
            "clip_range": 0.2,
        }

Step 4: Run the Experiment

With the above newly-defined classes and functions, we can run the experiments without modifying other process. An example showing some important configurations is shown below, including the weighting factor \(\mu\) as algorithm.policy_loss_fn_args['mu'] and the batch size of expert experiences \(B'\), calculated as the product of buffer.batch_size, algorithm.sample_strategy_args['expert_data_ratio'] and algorithm.repeat_times. For the full configuration, please refer to mix_math.yaml and train_mix_math.yaml.

algorithm:
  algorithm_type: mix
  repeat_times: 8
  sample_strategy_args:
    expert_data_ratio: 0.25
  policy_loss_fn_args:
    mu: 0.1
    clip_range: 0.2
    use_token_level_loss_in_sft: False
    use_dynamic_bsz: False
    repeat_times: 8
    ppo_mini_batch_size: 32
    ppo_micro_batch_size_per_gpu: 4
    ngpus_trainer: 4
    read_batch_size_expert: 64
    read_batch_size_usual: 192

With the above configurations, the experiment can be run with the following command:

trinity run --config examples/mix_math/mix_math.yaml