Source code for trinity.common.experience

# -*- coding: utf-8 -*-
"""Experience Class."""
from __future__ import annotations

import pickle
from dataclasses import dataclass
from itertools import chain, repeat
from typing import List, Optional

import torch
from torch import Tensor


[docs] @dataclass class Experience: """A single experience.""" tokens: Tensor # [seq] prompt_length: int logprobs: Optional[Tensor] = None # [seq] reward: Optional[float] = None prompt_text: Optional[str] = None response_text: Optional[str] = None action_mask: Optional[Tensor] = None chosen: Optional[Tensor] = None # for dpo rejected: Optional[Tensor] = None # for dpo info: Optional[dict] = None metrics: Optional[dict[str, float]] = None run_id: str = "" def __post_init__(self): if self.action_mask is not None: assert ( self.action_mask.shape == self.tokens.shape ), "The provided action_mask must have the same shape as tokens."
[docs] def serialize(self) -> bytes: """Serialize the experience to bytes.""" return pickle.dumps(self)
[docs] @staticmethod def deserialize(data: bytes) -> Experience: """Deserialize the experience from bytes.""" return pickle.loads(data)
[docs] @dataclass(frozen=True) class Experiences: """A container for a batch of experiences, for high performance communication usage. Example: >>> |<- prompt_length ->| | >>> tokens: ('P' represents prompt, 'O' represents output) >>> exp1: |........PPPPPPPPPPP|OOOOOOOOOO.....| >>> exp2: |......PPPPPPPPPPPPP|OOOOOOO........| >>> >>> attention_masks: ('.' represents False and '1' represents True) >>> exp1: |........11111111111|1111111111.....| >>> exp2: |......1111111111111|1111111........| """ tokens: Tensor rewards: Tensor attention_masks: Tensor action_masks: Optional[Tensor] prompt_length: int logprobs: Optional[Tensor] run_ids: List[str] @property def batch_size(self) -> int: """Get the batch size.""" return self.tokens.size(0)
[docs] @classmethod def gather_experiences( cls, experiences: list[Experience], pad_token_id: int = 0 ) -> Experiences: """Gather a batch of experiences from a list of experiences. This method will automatically pad the `tokens` and `logprobs` of input experiences to the same length. """ if len(experiences) == 0: return Experiences( tokens=torch.empty(0, dtype=torch.int32), rewards=torch.empty(0, dtype=torch.float32), attention_masks=torch.empty(0, dtype=torch.bool), action_masks=torch.empty(0, dtype=torch.bool), logprobs=torch.empty(0, dtype=torch.float32), prompt_length=torch.empty(0, dtype=torch.int32), run_ids=[], ) max_prompt_length = max([exp.prompt_length for exp in experiences]) max_response_length = max([len(exp.tokens) - exp.prompt_length for exp in experiences]) run_ids = [exp.run_id for exp in experiences] tokens_dtype = experiences[0].tokens.dtype tokens = torch.stack( [ torch.cat( [ torch.full( (max_prompt_length - exp.prompt_length,), pad_token_id, dtype=tokens_dtype, ), exp.tokens, torch.full( (max_response_length + exp.prompt_length - len(exp.tokens),), pad_token_id, dtype=tokens_dtype, ), ] ) for exp in experiences ] ) if experiences[0].reward is not None: rewards = torch.tensor([exp.reward for exp in experiences], dtype=torch.float) else: rewards = None # Calculate the action_masks according to the provided experience.action_mask if experiences[0].action_mask is not None: action_mask_dtype = experiences[0].action_mask.dtype action_masks = torch.stack( [ torch.cat( [ torch.full( (max_prompt_length - exp.prompt_length,), 0, dtype=action_mask_dtype, ), exp.action_mask, torch.full( (max_response_length + exp.prompt_length - len(exp.tokens),), 0, dtype=action_mask_dtype, ), ] ) for exp in experiences ] ) else: action_masks = None attention_masks = torch.zeros( (len(experiences), max_prompt_length + max_response_length), dtype=torch.bool ) for i, exp in enumerate(experiences): start = max_prompt_length - exp.prompt_length end = start + len(exp.tokens) attention_masks[i, start:end] = 1 if all(exp.logprobs is not None for exp in experiences): logprob_dtype = experiences[0].logprobs.dtype # type: ignore [union-attr] logprobs = torch.stack( [ torch.cat( [ torch.full( (max_prompt_length - exp.prompt_length,), 0.0, dtype=logprob_dtype, ), exp.logprobs, torch.full( (max_response_length + exp.prompt_length - len(exp.tokens),), 0.0, dtype=logprob_dtype, ), ] ) for exp in experiences ] ) else: logprobs = None return cls( run_ids=run_ids, tokens=tokens, rewards=rewards, attention_masks=attention_masks, action_masks=action_masks, prompt_length=max_prompt_length, logprobs=logprobs, )
[docs] @classmethod def gather_dpo_experiences( cls, experiences: list[Experience], pad_token_id: int = 0 ) -> Experiences: """Gather a batch of dpo experiences from a list of experiences. Reference: https://github.com/huggingface/trl/blob/main/trl/trainer/dpo_trainer.py#L849 Note: We arrange inputs in the order of (chosen, rejected, chosen, rejected, ...) to ensure that each pair of (chosen, rejected) is not split by subsequent operations Args: Experiences: `(list[Experience])` - `"prompt"`: token ids of the prompt - `"chosen"`: token ids of the chosen response - `"rejected"`: token ids of the rejected response pad_token_id: `(int)` The pad token id. Returns: Experiences: - `"tokens"`: Concatenated chosen and rejected completion input IDs of shape `(2 * batch_size, max_completion_length)`. - `"attention_masks"`: Concatenated chosen and rejected attention masks of shape `(2 * batch_size, max_completion_length)`. """ if len(experiences) == 0: return Experiences( tokens=torch.empty(0, dtype=torch.int32), rewards=torch.empty(0, dtype=torch.float32), attention_masks=torch.empty(0, dtype=torch.bool), action_masks=torch.empty(0, dtype=torch.bool), logprobs=torch.empty(0, dtype=torch.float32), prompt_length=torch.empty(0, dtype=torch.int32), run_ids=[], ) # TODO: exp.tokens in DPO are prompt tokens prompt_tokens = list(chain.from_iterable([repeat(exp.tokens, 2) for exp in experiences])) max_prompt_length = max([exp.prompt_length for exp in experiences]) chosen_tokens = [exp.chosen for exp in experiences] rejected_tokens = [exp.rejected for exp in experiences] response_tokens = list(chain.from_iterable(zip(chosen_tokens, rejected_tokens))) max_response_length = max([len(response) for response in response_tokens]) # type: ignore run_ids = list(chain.from_iterable([repeat(exp.run_id, 2) for exp in experiences])) tokens_dtype = experiences[0].tokens.dtype tokens = torch.stack( [ torch.cat( [ torch.full( (max_prompt_length - len(prompt),), pad_token_id, dtype=tokens_dtype, ), prompt, response, torch.full( (max_response_length - len(response),), # type: ignore pad_token_id, dtype=tokens_dtype, ), ] ) for prompt, response in zip(prompt_tokens, response_tokens) ] ) attention_masks = torch.zeros( (len(tokens), max_prompt_length + max_response_length), dtype=torch.bool ) for (i, prompt), response in zip(enumerate(prompt_tokens), response_tokens): start = max_prompt_length - len(prompt) end = max_prompt_length + len(response) # type: ignore attention_masks[i, start:end] = 1 assert len(tokens) == 2 * len(experiences) return cls( run_ids=run_ids, tokens=tokens, attention_masks=attention_masks, prompt_length=max_prompt_length, rewards=None, action_masks=None, logprobs=None, )