Source code for trinity.common.experience

# -*- coding: utf-8 -*-
"""Experience Class."""
from __future__ import annotations

import pickle
import uuid
from dataclasses import dataclass, field
from enum import Enum
from typing import List, Optional

import torch
from torch import Tensor


[docs] @dataclass class EID: """Experience ID class to uniquely identify an experience. To enable the full functionality of the experience grouping, user should manually set the `run` and `step` fields in custom workflows. """ # TODO: do we need to add project/name here to make it unique across different projects? # Batch number, e.g., the explorer step num # Automatically set by the workflow runner batch: int = 0 # Task number, e.g., the task sequence in the batch, the first task in the batch has task=0 # Automatically set by the workflow runner task: int = 0 # Run id, e.g., the first run in the task has run=0 # User should set this field in custom workflows when creating experiences run: int = 0 # Step number when running the task, e.g., the first step in the task has step=0 # User should set this field in custom workflows when creating experiences step: int = 0 suffix: str = field( default_factory=lambda: uuid.uuid4().hex[:6] ) # Unique identifier suffix, e.g., a UUID @property def uid(self) -> str: """An unique identifier for the experience.""" return f"{self.batch}/{self.task}/{self.run}/{self.step}/{self.suffix}" @property def sid(self) -> str: """Step ID of the experience. For example, experiences generated by all runs of a same task at the same step will have the same sid. """ return f"{self.batch}/{self.task}/{self.step}" @property def rid(self) -> str: """Run ID of the experience. For example, experiences generated by one run of a task at all steps will have the same run_id. """ return f"{self.batch}/{self.task}/{self.run}" @property def tid(self) -> str: """Task ID for the experience. For example, experiences generated by a all run of a same task in GRPO-like algorithms will have the same tid. """ return f"{self.batch}/{self.task}" def __str__(self): return self.uid def __repr__(self): return f"EID(batch={self.batch}, task={self.task}, run={self.run}, step={self.step}, uuid={self.suffix})"
[docs] def to_dict(self) -> dict: """Convert the EID to a dictionary.""" return { "batch": self.batch, "task": self.task, "run": self.run, "step": self.step, "suffix": self.suffix, }
[docs] class ExperienceType(Enum): """Enum for experience types.""" SINGLE_TURN = "single_turn" # Single-turn experience, e.g., a prompt-response pair MULTI_TURN = "multi_turn" # Multi-turn experience, e.g., a conversation history DPO = "dpo" # DPO experience, e.g., a chosen and rejected response pair
[docs] @dataclass(frozen=True) class CustomField: """Custom field for Experiences. This is used to store additional information into the Experiences class. """ source_field: str # The source field name in the Experience.info destination_field: str # The destination field name in the Experiences class data_type: torch.dtype # The data type of the field, e.g., torch.float32, torch.int64, etc.
[docs] @dataclass class Experience: eid: EID = field(default_factory=EID) # Unique identifier for the experience tokens: Optional[Tensor] = None # [seq_length] logprobs: Optional[Tensor] = None # [resp_length] reward: Optional[float] = None advantages: Optional[Tensor] = None # [resp_length] returns: Optional[Tensor] = None # [resp_length] # Type of the experience, automatically set based on the presence of action_mask or chosen/rejected experience_type: ExperienceType = ExperienceType.SINGLE_TURN info: dict = field( default_factory=dict ) # Additional information about the experience, can also be used to store custom fields metrics: dict[str, float] = field( default_factory=dict ) # Metrics associated with the experience, directly used by the monitor # for single-turn experiences prompt_length: int = 1 # Length of the prompt in tokens, used for generating attention masks response_text: Optional[str] = None # Text of the response prompt_text: Optional[str] = None # Text of the prompt # for multi-turn experiences # Action mask which indicates which tokens are generated by the model action_mask: Optional[Tensor] = None # [resp_length] messages: Optional[List[dict]] = None # List of messages # for dpo experiences chosen: Optional[Tensor] = None # Token ids of the chosen response [resp_length] rejected: Optional[Tensor] = None # Token ids of the rejected response [resp_length] chosen_text: Optional[str] = None # Text of the chosen response rejected_text: Optional[str] = None # Text of the rejected response
[docs] def __init__( self, *, eid=None, tokens, logprobs=None, reward=None, advantages=None, returns=None, info=None, metrics=None, prompt_length=1, response_text=None, prompt_text=None, action_mask=None, messages=None, chosen=None, rejected=None, chosen_text=None, rejected_text=None, ): if action_mask is not None: experience_type = ExperienceType.MULTI_TURN elif chosen is not None and rejected is not None: experience_type = ExperienceType.DPO else: experience_type = ExperienceType.SINGLE_TURN if experience_type == ExperienceType.SINGLE_TURN: assert ( prompt_length > 0 ), "Prompt length must be greater than 0 for single-turn experiences." assert ( len(tokens) > prompt_length ), f"Token ids must be longer than the prompt length. Got len(tokens)={len(tokens)}, prompt_length={prompt_length}." action_mask = torch.ones(len(tokens) - prompt_length, dtype=torch.bool) elif experience_type == ExperienceType.DPO: prompt_length = len(tokens) self.eid = eid or EID() self.tokens = tokens self.logprobs = logprobs self.reward = reward self.advantages = advantages self.returns = returns self.experience_type = experience_type self.info = info or {} self.metrics = metrics or {} self.prompt_length = prompt_length self.response_text = response_text self.prompt_text = prompt_text self.action_mask = action_mask self.messages = messages self.chosen = chosen self.rejected = rejected self.chosen_text = chosen_text self.rejected_text = rejected_text if not isinstance(self.tokens, Tensor): self.tokens = torch.tensor(self.tokens) if self.logprobs is not None and not isinstance(self.logprobs, Tensor): self.logprobs = torch.tensor(self.logprobs) if self.action_mask is not None and not isinstance(self.action_mask, Tensor): self.action_mask = torch.tensor(self.action_mask) if self.chosen is not None and not isinstance(self.chosen, Tensor): self.chosen = torch.tensor(self.chosen) if self.rejected is not None and not isinstance(self.rejected, Tensor): self.rejected = torch.tensor(self.rejected)
[docs] def serialize(self) -> bytes: """Serialize the experience to bytes.""" return pickle.dumps(self)
[docs] @classmethod def deserialize(cls, data: bytes) -> Experience: return pickle.loads(data)
[docs] def to_dict(self) -> dict: """Convert the experience to a dictionary.""" res = { "eid": self.eid, "type": self.experience_type.value, "info": self.info, "metrics": self.metrics, } if self.prompt_text is not None: res["prompt_text"] = self.prompt_text if self.response_text is not None: res["response_text"] = self.response_text if self.messages is not None: res["messages"] = self.messages if self.chosen_text is not None: res["chosen_text"] = self.chosen_text if self.rejected_text is not None: res["rejected_text"] = self.rejected_text if self.reward is not None: res["reward"] = float(self.reward) return res
[docs] @classmethod def gather( cls, experiences: List[Experience], pad_token_id: int = 0, custom_fields: Optional[List[CustomField]] = None, ) -> Experiences: if len(experiences) == 0: return empty_experiences(custom_fields) exp_type = experiences[0].experience_type if exp_type == ExperienceType.DPO: experiences = split_dpo_experience_to_single_turn(experiences) max_prompt_length = max([exp.prompt_length for exp in experiences]) # type: ignore [type-var] max_response_length = max([len(exp.tokens) - exp.prompt_length for exp in experiences]) # type: ignore [arg-type] eids = [exp.eid for exp in experiences] # Gather tokens tokens = gather_token_ids(experiences, max_prompt_length, max_response_length, pad_token_id) # Gather rewards if experiences[0].reward is not None: rewards = torch.tensor([exp.reward for exp in experiences], dtype=torch.float) else: rewards = None # gather action_masks action_masks = gather_action_masks(experiences, max_response_length) # gather attention_masks attention_masks = gather_attention_masks( experiences, max_prompt_length, max_response_length ) # gather logprobs if all(exp.logprobs is not None for exp in experiences): logprobs = gather_logprobs(experiences, max_response_length) else: logprobs = None # gather advantages if all(exp.advantages is not None for exp in experiences): advantages = gather_advantages(experiences, max_response_length) else: advantages = None # gather returns if all(exp.returns is not None for exp in experiences): returns = gather_returns(experiences, max_response_length) else: returns = None exps = Experiences( eids=eids, tokens=tokens, rewards=rewards, advantages=advantages, returns=returns, attention_masks=attention_masks, action_masks=action_masks, prompt_length=max_prompt_length, logprobs=logprobs, ) if custom_fields is not None: for custom_field in custom_fields: exps.custom_fields.append(custom_field.destination_field) setattr( exps, custom_field.destination_field, torch.tensor( [exp.info[custom_field.source_field] for exp in experiences], dtype=custom_field.data_type, ), ) return exps
[docs] def split_dpo_experience_to_single_turn(experiences: List[Experience]) -> List[Experience]: single_turn_experiences = [] for exp in experiences: single_turn_experiences.append( Experience( eid=EID( batch=exp.eid.batch, task=exp.eid.task, step=exp.eid.step, run=exp.eid.run, ), tokens=torch.cat([exp.tokens, exp.chosen]), reward=exp.reward, info=exp.info, metrics=exp.metrics, prompt_length=len(exp.tokens), # type: ignore [arg-type] prompt_text=exp.prompt_text, response_text=exp.chosen_text, ) ) single_turn_experiences.append( Experience( eid=EID( batch=exp.eid.batch, task=exp.eid.task, step=exp.eid.step, run=exp.eid.run, ), tokens=torch.cat([exp.tokens, exp.rejected]), reward=exp.reward, info=exp.info, metrics=exp.metrics, prompt_length=len(exp.tokens), # type: ignore [arg-type] prompt_text=exp.prompt_text, response_text=exp.rejected_text, ) ) return single_turn_experiences
[docs] @dataclass class Experiences: """A container for a batch of experiences, for high performance communication usage. Example: >>> |<- prompt_length ->| | >>> tokens: ('P' represents prompt, 'O' represents output) >>> exp1: |........PPPPPPPPPPP|OOOOOOOOOO.....| >>> exp2: |......PPPPPPPPPPPPP|OOOOOOO........| >>> >>> attention_masks: ('.' represents False and '1' represents True) >>> exp1: |........11111111111|1111111111.....| >>> exp2: |......1111111111111|1111111........| """ eids: List[EID] # Experience IDs of each experience in the batch tokens: Tensor # [batch_size, seq_length] rewards: Tensor # [batch_size] advantages: Optional[Tensor] # [batch_size, response_length] returns: Optional[Tensor] # [batch_size, response_length] attention_masks: Tensor # [batch_size, sequence_length] action_masks: Optional[Tensor] # [batch_size, response_length] prompt_length: int logprobs: Optional[Tensor] # [batch_size, response_length] custom_fields: List[str] = field( default_factory=list ) # Custom fields to include in the gathered experiences @property def batch_size(self) -> int: """Get the batch size.""" return self.tokens.size(0)
[docs] @classmethod def gather_experiences( cls, experiences: list[Experience], pad_token_id: int = 0, custom_fields: Optional[List[CustomField]] = None, ) -> Experiences: """Gather a batch of experiences from a list of experiences. This method will automatically pad the `tokens` and `logprobs` of input experiences to the same length. Args: experiences (list[Experience]): A list of experiences to gather. pad_token_id (int): The token ID to use for padding. Default is 0. custom_fields (Optional[List[CustomField]]): Custom fields to include in the gathered experiences. """ if len(experiences) == 0: return empty_experiences(custom_fields) return experiences[0].__class__.gather( experiences, pad_token_id=pad_token_id, custom_fields=custom_fields )
[docs] def empty_experiences(custom_fields: Optional[List[CustomField]]) -> Experiences: exps = Experiences( tokens=torch.empty(0, dtype=torch.int32), rewards=torch.empty(0, dtype=torch.float32), advantages=torch.empty(0, dtype=torch.float32), returns=torch.empty(0, dtype=torch.float32), attention_masks=torch.empty(0, dtype=torch.bool), action_masks=torch.empty(0, dtype=torch.bool), logprobs=torch.empty(0, dtype=torch.float32), prompt_length=torch.empty(0, dtype=torch.int32), eids=[], ) if custom_fields is not None: for custom_field in custom_fields: exps.custom_fields.append(custom_field.destination_field) setattr( exps, custom_field.destination_field, torch.empty(0, dtype=custom_field.data_type) ) return exps
[docs] def gather_token_ids( experiences, max_prompt_length: int, max_response_length: int, pad_token_id: int ) -> Tensor: token_ids_dtype = experiences[0].tokens.dtype return torch.stack( [ torch.cat( [ torch.full( (max_prompt_length - exp.prompt_length,), pad_token_id, dtype=token_ids_dtype, ), exp.tokens, torch.full( (max_response_length + exp.prompt_length - len(exp.tokens),), pad_token_id, dtype=token_ids_dtype, ), ] ) for exp in experiences ] )
[docs] def gather_action_masks(experiences, max_response_length: int) -> Tensor: return torch.stack( [ torch.cat( [ exp.action_mask, torch.full( (max_response_length - len(exp.action_mask),), 0, dtype=torch.bool, ), ] ) for exp in experiences ] )
[docs] def gather_attention_masks(experiences, max_prompt_length: int, max_response_length: int) -> Tensor: attention_masks = torch.zeros( (len(experiences), max_prompt_length + max_response_length), dtype=torch.bool ) for i, exp in enumerate(experiences): start = max_prompt_length - exp.prompt_length end = start + len(exp.tokens) attention_masks[i, start:end] = 1 return attention_masks
[docs] def gather_logprobs(experiences, max_response_length: int) -> Tensor: logprob_dtype = experiences[0].logprobs.dtype # type: ignore [union-attr] return torch.stack( [ torch.cat( [ exp.logprobs, torch.full( (max_response_length - len(exp.logprobs),), 0.0, dtype=logprob_dtype, ), ] ) for exp in experiences ] )
[docs] def gather_advantages(experiences, max_response_length: int) -> Optional[Tensor]: if experiences[0].advantages is None: return None advantages_dtype = experiences[0].advantages.dtype return torch.stack( [ torch.cat( [ exp.advantages, torch.full( (max_response_length - len(exp.advantages),), 0.0, dtype=advantages_dtype, ), ] ) for exp in experiences ] )
[docs] def gather_returns(experiences, max_response_length: int) -> Optional[Tensor]: if experiences[0].returns is None: return None returns_dtype = experiences[0].returns.dtype return torch.stack( [ torch.cat( [ exp.returns, torch.full( (max_response_length - len(exp.returns),), 0.0, dtype=returns_dtype, ), ] ) for exp in experiences ] )