Source code for trinity.common.experience

# -*- coding: utf-8 -*-
"""Experience Class."""
from __future__ import annotations

import pickle
import uuid
from dataclasses import asdict, dataclass, field, fields
from typing import Dict, List, Literal, Optional

import torch
from datasets import Dataset
from torch import Tensor


[docs] @dataclass class EID: """Experience ID class to uniquely identify an experience. To enable the full functionality of the experience grouping, user should manually set the `run` and `step` fields in custom workflows. """ # TODO: do we need to add project/name here to make it unique across different projects? # Batch number, e.g., the explorer step num # Automatically set by the workflow runner batch: int = 0 # Task number, e.g., the task sequence in the batch, the first task in the batch has task=0 # Automatically set by the workflow runner task: int = 0 # Run id, e.g., the first run in the task has run=0 # User should set this field in custom workflows when creating experiences run: int = 0 # Step number when running the task, e.g., the first step in the task has step=0 # User should set this field in custom workflows when creating experiences step: int = 0 suffix: str = field( default_factory=lambda: uuid.uuid4().hex[:6] ) # Unique identifier suffix, e.g., a UUID @property def uid(self) -> str: """An unique identifier for the experience.""" return f"{self.batch}/{self.task}/{self.run}/{self.step}/{self.suffix}" @property def sid(self) -> str: """Step ID of the experience. For example, experiences generated by all runs of a same task at the same step will have the same sid. """ return f"{self.batch}/{self.task}/{self.step}" @property def rid(self) -> str: """Run ID of the experience. For example, experiences generated by one run of a task at all steps will have the same run_id. """ return f"{self.batch}/{self.task}/{self.run}" @property def tid(self) -> str: """Task ID for the experience. For example, experiences generated by a all run of a same task in GRPO-like algorithms will have the same tid. """ return f"{self.batch}/{self.task}" def __str__(self): return self.uid def __repr__(self): return f"EID(batch={self.batch}, task={self.task}, run={self.run}, step={self.step}, uuid={self.suffix})"
[docs] def to_dict(self) -> dict: """Convert the EID to a dictionary.""" return { "batch": self.batch, "task": self.task, "run": self.run, "step": self.step, "suffix": self.suffix, }
[docs] @dataclass(frozen=True) class CustomField: """Custom field for Experiences. This is used to store additional information into the Experiences class. """ source_field: str # The source field name in the Experience.info destination_field: str # The destination field name in the Experiences class data_type: torch.dtype # The data type of the field, e.g., torch.float32, torch.int64, etc.
[docs] @dataclass class Experience: eid: EID = field(default_factory=EID) # Unique identifier for the experience tokens: Optional[Tensor] = None # [seq_length] logprobs: Optional[Tensor] = None # [resp_length] reward: Optional[float] = None advantages: Optional[Tensor] = None # [resp_length] returns: Optional[Tensor] = None # [resp_length] info: dict = field( default_factory=dict ) # Additional information about the experience, can also be used to store custom fields metrics: dict[str, float] = field( default_factory=dict ) # Metrics associated with the experience, directly used by the monitor # for single-turn experiences prompt_length: int = 1 # Length of the prompt in tokens, used for generating attention masks response_text: Optional[str] = None # Text of the response prompt_text: Optional[str] = None # Text of the prompt # for multi-turn experiences # Action mask which indicates which tokens are generated by the model action_mask: Optional[Tensor] = None # [resp_length] messages: Optional[List[dict]] = None # List of messages tools: Optional[List[dict]] = None # for dpo experiences chosen: Optional[Tensor] = None # Token ids of the chosen response [resp_length] rejected: Optional[Tensor] = None # Token ids of the rejected response [resp_length] chosen_text: Optional[str] = None # Text of the chosen response rejected_text: Optional[str] = None # Text of the rejected response
[docs] def __init__( # noqa: C901 self, *, eid=None, tokens, logprobs=None, reward=None, advantages=None, returns=None, info=None, metrics=None, prompt_length=1, response_text=None, prompt_text=None, action_mask=None, messages=None, tools=None, chosen=None, rejected=None, chosen_text=None, rejected_text=None, ): if action_mask is not None: experience_type = "multi_turn" elif chosen is not None and rejected is not None: experience_type = "dpo" else: experience_type = "single_turn" if experience_type == "single_turn": assert ( prompt_length > 0 ), "Prompt length must be greater than 0 for single-turn experiences." assert ( len(tokens) > prompt_length ), f"Token ids must be longer than the prompt length. Got len(tokens)={len(tokens)}, prompt_length={prompt_length}." action_mask = torch.ones(len(tokens) - prompt_length, dtype=torch.bool) elif experience_type == "dpo": prompt_length = len(tokens) if eid is None: self.eid = EID() elif isinstance(eid, dict): self.eid = EID(**eid) else: self.eid = eid if isinstance(tokens, list): tokens = torch.tensor(tokens, dtype=torch.int32) self.tokens = tokens if isinstance(logprobs, list): logprobs = torch.tensor(logprobs, dtype=torch.float32) self.logprobs = logprobs self.reward = reward if isinstance(advantages, list): advantages = torch.tensor(advantages, dtype=torch.float32) self.advantages = advantages if isinstance(returns, list): returns = torch.tensor(returns, dtype=torch.float32) self.returns = returns self.experience_type = experience_type self.info = info or {} self.metrics = metrics or {} self.prompt_length = prompt_length self.response_text = response_text self.prompt_text = prompt_text if isinstance(action_mask, list): action_mask = torch.tensor(action_mask, dtype=torch.bool) self.action_mask = action_mask self.messages = messages self.tools = tools if isinstance(chosen, list): chosen = torch.tensor(chosen, dtype=torch.int32) self.chosen = chosen if isinstance(rejected, list): rejected = torch.tensor(rejected, dtype=torch.int32) self.rejected = rejected self.chosen_text = chosen_text self.rejected_text = rejected_text if not isinstance(self.tokens, Tensor): self.tokens = torch.tensor(self.tokens) if self.logprobs is not None and not isinstance(self.logprobs, Tensor): self.logprobs = torch.tensor(self.logprobs) if self.action_mask is not None and not isinstance(self.action_mask, Tensor): self.action_mask = torch.tensor(self.action_mask) if self.chosen is not None and not isinstance(self.chosen, Tensor): self.chosen = torch.tensor(self.chosen) if self.rejected is not None and not isinstance(self.rejected, Tensor): self.rejected = torch.tensor(self.rejected)
[docs] def serialize(self) -> bytes: """Serialize the experience to bytes.""" return pickle.dumps(self)
[docs] @classmethod def deserialize(cls, data: bytes) -> Experience: return pickle.loads(data)
[docs] def to_dict(self) -> dict: """Convert the experience to a dictionary.""" res = { "eid": self.eid, "type": self.experience_type, "prompt_length": self.prompt_length, "response_length": len(self.tokens) - self.prompt_length, # type: ignore [arg-type] "info": self.info, "metrics": self.metrics, } if self.prompt_text is not None: res["prompt_text"] = self.prompt_text if self.response_text is not None: res["response_text"] = self.response_text if self.messages is not None: res["messages"] = self.messages if self.tools is not None: res["tools"] = self.tools if self.chosen_text is not None: res["chosen_text"] = self.chosen_text if self.rejected_text is not None: res["rejected_text"] = self.rejected_text if self.reward is not None: res["reward"] = float(self.reward) return res
[docs] @classmethod def gather( cls, experiences: List[Experience], pad_token_id: int = 0, custom_fields: Optional[List[CustomField]] = None, ) -> Experiences: if len(experiences) == 0: return empty_experiences(custom_fields) exp_type = experiences[0].experience_type if exp_type == "dpo": experiences = split_dpo_experience_to_single_turn(experiences) max_prompt_length = max([exp.prompt_length for exp in experiences]) # type: ignore [type-var] max_response_length = max([len(exp.tokens) - exp.prompt_length for exp in experiences]) # type: ignore [arg-type] eids = [exp.eid for exp in experiences] # Gather tokens tokens = gather_token_ids(experiences, max_prompt_length, max_response_length, pad_token_id) # Gather rewards if experiences[0].reward is not None: rewards = torch.tensor([exp.reward for exp in experiences], dtype=torch.float) else: rewards = None # gather action_masks action_masks = gather_action_masks(experiences, max_response_length) # gather attention_masks attention_masks = gather_attention_masks( experiences, max_prompt_length, max_response_length ) # gather logprobs if all(exp.logprobs is not None for exp in experiences): logprobs = gather_logprobs(experiences, max_response_length) else: logprobs = None # gather advantages if all(exp.advantages is not None for exp in experiences): advantages = gather_advantages(experiences, max_response_length) else: advantages = None # gather returns if all(exp.returns is not None for exp in experiences): returns = gather_returns(experiences, max_response_length) else: returns = None exps = Experiences( eids=eids, tokens=tokens, rewards=rewards, advantages=advantages, returns=returns, attention_masks=attention_masks, action_masks=action_masks, prompt_length=max_prompt_length, logprobs=logprobs, ) if custom_fields is not None: for custom_field in custom_fields: exps.custom_fields.append(custom_field.destination_field) setattr( exps, custom_field.destination_field, torch.tensor( [exp.info[custom_field.source_field] for exp in experiences], dtype=custom_field.data_type, ), ) return exps
[docs] def split_dpo_experience_to_single_turn(experiences: List[Experience]) -> List[Experience]: single_turn_experiences = [] for exp in experiences: single_turn_experiences.append( Experience( eid=EID( batch=exp.eid.batch, task=exp.eid.task, step=exp.eid.step, run=exp.eid.run, ), tokens=torch.cat([exp.tokens, exp.chosen]), reward=exp.reward, info=exp.info, metrics=exp.metrics, prompt_length=len(exp.tokens), # type: ignore [arg-type] prompt_text=exp.prompt_text, response_text=exp.chosen_text, ) ) single_turn_experiences.append( Experience( eid=EID( batch=exp.eid.batch, task=exp.eid.task, step=exp.eid.step, run=exp.eid.run, ), tokens=torch.cat([exp.tokens, exp.rejected]), reward=exp.reward, info=exp.info, metrics=exp.metrics, prompt_length=len(exp.tokens), # type: ignore [arg-type] prompt_text=exp.prompt_text, response_text=exp.rejected_text, ) ) return single_turn_experiences
[docs] @dataclass class Experiences: """A container for a batch of experiences, for high performance communication usage. Example: >>> |<- prompt_length ->| | >>> tokens: ('P' represents prompt, 'O' represents output) >>> exp1: |........PPPPPPPPPPP|OOOOOOOOOO.....| >>> exp2: |......PPPPPPPPPPPPP|OOOOOOO........| >>> >>> attention_masks: ('.' represents False and '1' represents True) >>> exp1: |........11111111111|1111111111.....| >>> exp2: |......1111111111111|1111111........| """ eids: List[EID] # Experience IDs of each experience in the batch tokens: Tensor # [batch_size, seq_length] rewards: Tensor # [batch_size] advantages: Optional[Tensor] # [batch_size, response_length] returns: Optional[Tensor] # [batch_size, response_length] attention_masks: Tensor # [batch_size, sequence_length] action_masks: Optional[Tensor] # [batch_size, response_length] prompt_length: int logprobs: Optional[Tensor] # [batch_size, response_length] custom_fields: List[str] = field( default_factory=list ) # Custom fields to include in the gathered experiences @property def batch_size(self) -> int: """Get the batch size.""" return self.tokens.size(0)
[docs] @classmethod def gather_experiences( cls, experiences: list[Experience], pad_token_id: int = 0, custom_fields: Optional[List[CustomField]] = None, ) -> Experiences: """Gather a batch of experiences from a list of experiences. This method will automatically pad the `tokens` and `logprobs` of input experiences to the same length. Args: experiences (list[Experience]): A list of experiences to gather. pad_token_id (int): The token ID to use for padding. Default is 0. custom_fields (Optional[List[CustomField]]): Custom fields to include in the gathered experiences. """ if len(experiences) == 0: return empty_experiences(custom_fields) return experiences[0].__class__.gather( experiences, pad_token_id=pad_token_id, custom_fields=custom_fields )
[docs] def empty_experiences(custom_fields: Optional[List[CustomField]]) -> Experiences: exps = Experiences( tokens=torch.empty(0, dtype=torch.int32), rewards=torch.empty(0, dtype=torch.float32), advantages=torch.empty(0, dtype=torch.float32), returns=torch.empty(0, dtype=torch.float32), attention_masks=torch.empty(0, dtype=torch.bool), action_masks=torch.empty(0, dtype=torch.bool), logprobs=torch.empty(0, dtype=torch.float32), prompt_length=torch.empty(0, dtype=torch.int32), eids=[], ) if custom_fields is not None: for custom_field in custom_fields: exps.custom_fields.append(custom_field.destination_field) setattr( exps, custom_field.destination_field, torch.empty(0, dtype=custom_field.data_type) ) return exps
[docs] def gather_token_ids( experiences, max_prompt_length: int, max_response_length: int, pad_token_id: int ) -> Tensor: token_ids_dtype = experiences[0].tokens.dtype return torch.stack( [ torch.cat( [ torch.full( (max_prompt_length - exp.prompt_length,), pad_token_id, dtype=token_ids_dtype, ), exp.tokens, torch.full( (max_response_length + exp.prompt_length - len(exp.tokens),), pad_token_id, dtype=token_ids_dtype, ), ] ) for exp in experiences ] )
[docs] def gather_action_masks(experiences, max_response_length: int) -> Tensor: return torch.stack( [ torch.cat( [ exp.action_mask, torch.full( (max_response_length - len(exp.action_mask),), 0, dtype=torch.bool, ), ] ) for exp in experiences ] )
[docs] def gather_attention_masks(experiences, max_prompt_length: int, max_response_length: int) -> Tensor: attention_masks = torch.zeros( (len(experiences), max_prompt_length + max_response_length), dtype=torch.bool ) for i, exp in enumerate(experiences): start = max_prompt_length - exp.prompt_length end = start + len(exp.tokens) attention_masks[i, start:end] = 1 return attention_masks
[docs] def gather_logprobs(experiences, max_response_length: int) -> Tensor: logprob_dtype = experiences[0].logprobs.dtype # type: ignore [union-attr] return torch.stack( [ torch.cat( [ exp.logprobs, torch.full( (max_response_length - len(exp.logprobs),), 0.0, dtype=logprob_dtype, ), ] ) for exp in experiences ] )
[docs] def gather_advantages(experiences, max_response_length: int) -> Optional[Tensor]: if experiences[0].advantages is None: return None advantages_dtype = experiences[0].advantages.dtype return torch.stack( [ torch.cat( [ exp.advantages, torch.full( (max_response_length - len(exp.advantages),), 0.0, dtype=advantages_dtype, ), ] ) for exp in experiences ] )
[docs] def gather_returns(experiences, max_response_length: int) -> Optional[Tensor]: if experiences[0].returns is None: return None returns_dtype = experiences[0].returns.dtype return torch.stack( [ torch.cat( [ exp.returns, torch.full( (max_response_length - len(exp.returns),), 0.0, dtype=returns_dtype, ), ] ) for exp in experiences ] )
[docs] def group_by( experiences: List[Experience], id_type: Literal["task", "run", "step"] ) -> Dict[str, List[Experience]]: """Group experiences by ID.""" if id_type == "task": id_type = "tid" elif id_type == "run": id_type = "rid" elif id_type == "step": id_type = "sid" else: raise ValueError(f"Unknown id_type: {id_type}") grouped = {} for exp in experiences: group_id = getattr(exp.eid, id_type) if group_id not in grouped: grouped[group_id] = [] grouped[group_id].append(exp) return grouped
[docs] def to_hf_datasets(experiences: list[Experience]) -> Dataset: """ Convert a list of Experience objects to a HuggingFace Dataset, preserving all fields. """ return Dataset.from_list([asdict(exp) for exp in experiences])
[docs] def from_hf_datasets(dataset: Dataset) -> List[Experience]: """ Convert a HuggingFace Dataset back to a list of Experience objects. """ def dict_to_dataclass(cls, d): valid_keys = {f.name for f in fields(cls)} filtered = {k: v for k, v in d.items() if k in valid_keys} return cls(**filtered) experiences = [dict_to_dataclass(Experience, row) for row in dataset.to_list()] return experiences