trinity.common.rewards

Submodules

trinity.common.rewards.accuracy_reward module

class trinity.common.rewards.accuracy_reward.AccuracyRewardShapper(answer_parser: Callable[[str], str], correct_reward: float = 1.0, incorrect_reward: float = 0.0, kwargs: Dict[str, Any] = {})[source]

Bases: RewardShapper

Shapper for accuracy-based rewards

__init__(answer_parser: Callable[[str], str], correct_reward: float = 1.0, incorrect_reward: float = 0.0, kwargs: Dict[str, Any] = {})[source]
shape(sample: Dict[str, Any]) Dict[str, Any][source]

Shape a sample with rewards

batch_shape(samples: List[Dict[str, Any]]) List[Dict[str, Any]][source]

Shape a batch of samples

trinity.common.rewards.agents_reward module

trinity.common.rewards.base module

class trinity.common.rewards.base.RewardShapper[source]

Bases: ABC

Abstract base class for reward shapper

Supports: 1. Rule-based shaping 2. Model-based shaping 3. Tool-based shaping 4. Agent-based shaping 5. Human-in-the-loop shaping

abstract shape(sample: Dict[str, Any]) Dict[str, Any][source]

Shape a sample with rewards

abstract batch_shape(samples: List[Dict[str, Any]]) List[Dict[str, Any]][source]

Shape a batch of samples

trinity.common.rewards.composite_reward module

class trinity.common.rewards.composite_reward.CompositeRewardShapper(shappers: List[Tuple[RewardShapper, float]])[source]

Bases: RewardShapper

Combines multiple shappers with weights

__init__(shappers: List[Tuple[RewardShapper, float]])[source]
shape(sample: Dict[str, Any]) Dict[str, Any][source]

Shape a sample with rewards

trinity.common.rewards.format_reward module

class trinity.common.rewards.format_reward.FormatRewardShapper(pattern: str, correct_format_reward: float = 1.0, incorrect_format_reward: float = 0.0)[source]

Bases: RewardShapper

Shapper for format-based rewards

__init__(pattern: str, correct_format_reward: float = 1.0, incorrect_format_reward: float = 0.0)[source]
shape(sample: Dict[str, Any]) Dict[str, Any][source]

Shape a sample with rewards

batch_shape(samples: List[Dict[str, Any]]) List[Dict[str, Any]][source]

Shape a batch of samples

trinity.common.rewards.human_reward module

trinity.common.rewards.reward_fn module

Base Reward Function Class.

class trinity.common.rewards.reward_fn.RewardFn[source]

Bases: ABC

Base Reward Function Class.

class trinity.common.rewards.reward_fn.AccuracyReward(answer_parser: Callable[[str], str] | None = None)[source]

Bases: RewardFn

A reward function that rewards correct answers. Ref: https://github.com/huggingface/open-r1/blob/main/src/open_r1/rewards.py

__init__(answer_parser: Callable[[str], str] | None = None)[source]
class trinity.common.rewards.reward_fn.FormatReward(pattern: str | None = None)[source]

Bases: RewardFn

A reward function that checks if the reasoning process is enclosed within <think> and </think> tags, while the final answer is enclosed within <answer> and </answer> tags. Ref: https://github.com/huggingface/open-r1/blob/main/src/open_r1/rewards.py

__init__(pattern: str | None = None)[source]
class trinity.common.rewards.reward_fn.MathRewardFn(answer_parser=<function simple_answer_parser>, pattern='.*?<think>.*?</think>\\s*<answer>.*?</answer>\\s*$')[source]

Bases: RewardFn

A reward function that rewards for math task.

DEFAULT_FORMAT_PATTERN = '.*?<think>.*?</think>\\s*<answer>.*?</answer>\\s*$'
DEFAULT_ANSWER_PARSER() str
__init__(answer_parser=<function simple_answer_parser>, pattern='.*?<think>.*?</think>\\s*<answer>.*?</answer>\\s*$') None[source]
class trinity.common.rewards.reward_fn.CountDownRewardFn[source]

Bases: RewardFn

A reward function that rewards for countdown task.

__init__()[source]

trinity.common.rewards.tool_reward module

Module contents

Reward functions for RFT

class trinity.common.rewards.RewardFn[source]

Bases: ABC

Base Reward Function Class.

class trinity.common.rewards.AccuracyReward(answer_parser: Callable[[str], str] | None = None)[source]

Bases: RewardFn

A reward function that rewards correct answers. Ref: https://github.com/huggingface/open-r1/blob/main/src/open_r1/rewards.py

__init__(answer_parser: Callable[[str], str] | None = None)[source]
class trinity.common.rewards.FormatReward(pattern: str | None = None)[source]

Bases: RewardFn

A reward function that checks if the reasoning process is enclosed within <think> and </think> tags, while the final answer is enclosed within <answer> and </answer> tags. Ref: https://github.com/huggingface/open-r1/blob/main/src/open_r1/rewards.py

__init__(pattern: str | None = None)[source]