Source code for trinity.algorithm.add_strategy.add_strategy

import asyncio
from abc import ABC, abstractmethod
from typing import Dict, List, Literal, Tuple

import numpy as np
import torch

from trinity.buffer import BufferWriter
from trinity.common.experience import Experience
from trinity.utils.monitor import gather_metrics
from trinity.utils.registry import Registry
from trinity.utils.timer import Timer

ADD_STRATEGY = Registry("add_strategy")


[docs] class AddStrategy(ABC):
[docs] def __init__(self, writer: BufferWriter, **kwargs) -> None: self.writer = writer
[docs] @abstractmethod async def add(self, experiences: List[Experience], step: int) -> Tuple[int, Dict]: """Add experiences to the buffer. Args: experiences (`Experience`): The experiences to be added. step (`int`): The current step number. Returns: `int`: The number of experiences added to the buffer. `Dict`: Metrics for logging. """
[docs] @classmethod @abstractmethod def default_args(cls) -> dict: """Get the default arguments of the add strategy. Returns: `dict`: The default arguments. """
[docs] class GroupAdvantageStrategy(AddStrategy): """An example AddStrategy that calculates group advantages."""
[docs] @abstractmethod def group_experiences(self, exps: List[Experience]) -> Dict[str, List[Experience]]: """Group experiences by a certain criterion. Args: exps (List[Experience]): List of experiences to be grouped. Returns: Dict[str, List[Experience]]: A dictionary where keys are group identifiers and values are lists of experiences. """
[docs] @abstractmethod def calculate_group_advantage( self, group_id: str, exps: List[Experience] ) -> Tuple[List[Experience], Dict]: """Calculate advantages for a group of experiences. Args: group_id (str): The identifier for the group of experiences. exps (List[Experience]): List of experiences in the group. Returns: Tuple[List[Experience], Dict]: A tuple containing the modified list of experiences and a dictionary of metrics. """
[docs] async def add(self, exps: List[Experience], step: int) -> Tuple[int, Dict]: if len(exps) == 0: return 0, {} exp_groups = self.group_experiences(exps) cnt = 0 metric_list = [] tasks = [] for group_id, group_exps in exp_groups.items(): group_exps, group_metrics = self.calculate_group_advantage(group_id, group_exps) metric_list.append(group_metrics) cnt += len(group_exps) if len(group_exps) > 0: tasks.append(self.writer.write_async(group_exps)) if tasks: await asyncio.gather(*tasks) try: metrics = gather_metrics(metric_list, "group_advantages") except ValueError: metrics = {} # empty metric list causes ValueError, ignore it return cnt, metrics
[docs] @ADD_STRATEGY.register_module("grpo") class GRPOAddStrategy(GroupAdvantageStrategy): """An example AddStrategy that calculates GRPO advantages."""
[docs] def __init__(self, writer: BufferWriter, epsilon: float = 1e-6, **kwargs) -> None: super().__init__(writer) self.epsilon = epsilon
[docs] def group_experiences(self, exps): return group_by(exps, id_type="task")
[docs] def calculate_group_advantage( self, group_id: str, exps: List[Experience] ) -> Tuple[List[Experience], Dict]: with torch.no_grad(): if len(exps) == 1: group_reward_mean = torch.tensor(0.0) group_reward_std = torch.tensor(1.0) else: rewards = torch.tensor([exp.reward for exp in exps], dtype=torch.float32) group_reward_mean = torch.mean(rewards) group_reward_std = torch.std(rewards) for exp in exps: score = (exp.reward - group_reward_mean) / (group_reward_std + self.epsilon) exp.advantages = score * exp.action_mask exp.returns = exp.advantages.clone() metrics = { "reward_mean": group_reward_mean.item(), "reward_std": group_reward_std.item(), } return exps, metrics
[docs] @classmethod def default_args(cls) -> dict: return {"epsilon": 1e-6}
[docs] @ADD_STRATEGY.register_module("opmd") class OPMDAddStrategy(GroupAdvantageStrategy): """An example AddStrategy that calculates OPMD advantages."""
[docs] def __init__( self, writer: BufferWriter, opmd_baseline: str = "mean", tau: float = 1.0, **kwargs ) -> None: super().__init__(writer) assert opmd_baseline in [ "mean", "logavgexp", ], f"opmd_baseline must be 'mean' or 'logavgexp', got {opmd_baseline}" self.opmd_baseline = opmd_baseline self.tau = tau
[docs] def group_experiences(self, exps): return group_by(exps, id_type="task")
[docs] def calculate_group_advantage( self, group_id: str, exps: List[Experience] ) -> Tuple[List[Experience], Dict]: with torch.no_grad(): if len(exps) == 1: group_baseline = torch.tensor(0.0) else: group_rewards = torch.tensor([exp.reward for exp in exps], dtype=torch.float32) if self.opmd_baseline == "mean": group_baseline = torch.mean(group_rewards) else: group_baseline = self.tau * ( torch.logsumexp(group_rewards / self.tau, dim=-1) - torch.log(torch.tensor(len(exps))) ) for exp in exps: score = exp.reward - group_baseline exp.advantages = score * exp.action_mask exp.returns = exp.advantages.clone() metrics = { "group_baseline": group_baseline, } return exps, metrics
[docs] @classmethod def default_args(cls) -> dict: return {"opmd_baseline": "mean", "tau": 1.0}
[docs] @ADD_STRATEGY.register_module("reward_variance") class RewardVarianceAddStrategy(AddStrategy): """An example AddStrategy that filters experiences based on a reward variance threshold."""
[docs] def __init__(self, writer: BufferWriter, variance_threshold: float = 0.0, **kwargs) -> None: super().__init__(writer) self.variance_threshold = variance_threshold
[docs] async def add(self, experiences: List[Experience], step: int) -> Tuple[int, Dict]: cnt = 0 metrics = {} tasks = [] with Timer(metrics, "add_strategy_time"): grouped_experiences = group_by(experiences, id_type="task") for _, group_exps in grouped_experiences.items(): if len(group_exps) < 2: continue rewards = [exp.reward for exp in group_exps] variance = np.var(rewards) if variance <= self.variance_threshold: continue cnt += len(group_exps) tasks.append(self.writer.write_async(group_exps)) if tasks: await asyncio.gather(*tasks) return cnt, metrics
[docs] @classmethod def default_args(cls) -> dict: return {"variance_threshold": 0.0}
[docs] def group_by( experiences: List[Experience], id_type: Literal["task", "run", "step"] ) -> Dict[str, List[Experience]]: """Group experiences by ID.""" if id_type == "task": id_type = "tid" elif id_type == "run": id_type = "rid" elif id_type == "step": id_type = "sid" else: raise ValueError(f"Unknown id_type: {id_type}") grouped = {} for exp in experiences: group_id = getattr(exp.eid, id_type) if group_id not in grouped: grouped[group_id] = [] grouped[group_id].append(exp) return grouped