Source code for trinity.common.rewards.accuracy_reward

from typing import Any, Callable, Dict, List

from .base import RewardShapper


[docs] class AccuracyRewardShapper(RewardShapper): """Shapper for accuracy-based rewards"""
[docs] def __init__( self, answer_parser: Callable[[str], str], correct_reward: float = 1.0, incorrect_reward: float = 0.0, kwargs: Dict[str, Any] = {}, ): self.answer_parser = answer_parser self.correct_reward = correct_reward self.incorrect_reward = incorrect_reward self.response_key = kwargs.get("response", "response") self.truth_key = kwargs.get("ground_truth", "ground_truth")
[docs] def shape(self, sample: Dict[str, Any]) -> Dict[str, Any]: response = sample[self.response_key] truth = sample[self.truth_key] parsed_response = self.answer_parser(response) reward = self.correct_reward if parsed_response == truth else self.incorrect_reward sample["accuracy_reward"] = reward return sample
[docs] def batch_shape(self, samples: List[Dict[str, Any]]) -> List[Dict[str, Any]]: return [self.shape(sample) for sample in samples]