Source code for trinity.common.rewards.reward_fn

# -*- coding: utf-8 -*-
"""Base Reward Function Class."""
from abc import ABC, abstractmethod
from typing import Any, Dict, List

from trinity.common.experience import Experience
from trinity.common.rewards.utils import to_rm_gallery_messages
from trinity.utils.log import get_logger
from trinity.utils.registry import Registry

logger = get_logger(__name__)


REWARD_FUNCTIONS = Registry("reward_functions")


[docs] class RewardFn(ABC): """Base Reward Function Class."""
[docs] @abstractmethod def __init__(self, **kwargs) -> None: pass
@abstractmethod def __call__(self, **kwargs) -> Dict[str, float]: pass
[docs] @REWARD_FUNCTIONS.register_module("rm_gallery_reward") class RMGalleryFn(RewardFn): """Reward Function from RMGallery. https://github.com/modelscope/RM-Gallery """
[docs] def __init__( self, reward_name, **kwargs, ): from rm_gallery.core.reward.registry import RewardRegistry self.reward_model = RewardRegistry.get(reward_name)(**kwargs)
def __call__(self, experience: Experience, messages: List[Dict[str, Any]], **kwargs) -> Dict[str, float]: # type: ignore """Call the reward function.""" sample = self._build_sample_from_experience(experience, messages, **kwargs) sample_with_reward = self.reward_model.evaluate(sample, **kwargs) return self._extract_reward(sample_with_reward) def _build_sample_from_experience( self, experience: Experience, messages: List[Dict[str, Any]], **kwargs ) -> Any: """Convert experience to sample. Ref: https://github.com/modelscope/RM-Gallery/blob/main/rm_gallery/core/data/schema.py """ from rm_gallery.core.data.schema import DataOutput, DataSample, Step output = [ DataOutput( answer=Step( role="assistant", content=str(experience.response_text), label={"reference": kwargs.get("ground_truth", "")}, ), ) ] sample = DataSample( unique_id=experience.eid.uid, input=to_rm_gallery_messages(messages), output=output, metadata=experience.info, ) return sample def _extract_reward(self, sample: Any) -> Dict[str, float]: """ Extract reward from DataSample in rm-gallery """ reward_dict = {} try: reward_obj = sample.output[0].answer.reward except Exception as e: raise ValueError(f"No reward is found in sample: {e}") from rm_gallery.core.reward.schema import ( RewardDimensionWithRank, RewardDimensionWithScore, ) if reward_obj.details: for detail in reward_obj.details: if isinstance(detail, RewardDimensionWithScore): reward_dict[detail.name] = detail.score elif isinstance(detail, RewardDimensionWithRank): # TODO: support multi-ranked dimension if detail: top_ranked_item = detail[0] reward_dict[top_ranked_item.name] = top_ranked_item.score else: reward_dict["reward"] = reward_obj.score return reward_dict