Source code for trinity.common.rewards.format_reward
import re
from typing import Any, Dict, List
from .base import RewardShapper
[docs]
class FormatRewardShapper(RewardShapper):
"""Shapper for format-based rewards"""
[docs]
def __init__(
self, pattern: str, correct_format_reward: float = 1.0, incorrect_format_reward: float = 0.0
):
self.pattern = re.compile(pattern, re.DOTALL | re.MULTILINE)
self.correct_format_reward = correct_format_reward
self.incorrect_format_reward = incorrect_format_reward
[docs]
def shape(self, sample: Dict[str, Any]) -> Dict[str, Any]:
response = sample["response"]
reward = (
self.correct_format_reward
if self.pattern.match(response)
else self.incorrect_format_reward
)
sample["format_reward"] = reward
return sample
[docs]
def batch_shape(self, samples: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
return [self.shape(sample) for sample in samples]