# -*- coding: utf-8 -*-
"""Experience Class."""
from __future__ import annotations
import pickle
import uuid
from dataclasses import dataclass, field
from enum import Enum
from typing import List, Optional
import torch
from torch import Tensor
[docs]
@dataclass
class EID:
"""Experience ID class to uniquely identify an experience.
To enable the full functionality of the experience grouping, user should manually set the `run` and `step` fields in custom workflows.
"""
# TODO: do we need to add project/name here to make it unique across different projects?
# Batch number, e.g., the explorer step num
# Automatically set by the workflow runner
batch: int = 0
# Task number, e.g., the task sequence in the batch, the first task in the batch has task=0
# Automatically set by the workflow runner
task: int = 0
# Run id, e.g., the first run in the task has run=0
# User should set this field in custom workflows when creating experiences
run: int = 0
# Step number when running the task, e.g., the first step in the task has step=0
# User should set this field in custom workflows when creating experiences
step: int = 0
suffix: str = field(
default_factory=lambda: uuid.uuid4().hex[:6]
) # Unique identifier suffix, e.g., a UUID
@property
def uid(self) -> str:
"""An unique identifier for the experience."""
return f"{self.batch}/{self.task}/{self.run}/{self.step}/{self.suffix}"
@property
def sid(self) -> str:
"""Step ID of the experience.
For example, experiences generated by all runs of a same task at the same step will have the same sid.
"""
return f"{self.batch}/{self.task}/{self.step}"
@property
def rid(self) -> str:
"""Run ID of the experience.
For example, experiences generated by one run of a task at all steps will have the same run_id.
"""
return f"{self.batch}/{self.task}/{self.run}"
@property
def tid(self) -> str:
"""Task ID for the experience.
For example, experiences generated by a all run of a same task in GRPO-like algorithms will have the same tid.
"""
return f"{self.batch}/{self.task}"
def __str__(self):
return self.uid
def __repr__(self):
return f"EID(batch={self.batch}, task={self.task}, run={self.run}, step={self.step}, uuid={self.suffix})"
[docs]
def to_dict(self) -> dict:
"""Convert the EID to a dictionary."""
return {
"batch": self.batch,
"task": self.task,
"run": self.run,
"step": self.step,
"suffix": self.suffix,
}
[docs]
class ExperienceType(Enum):
"""Enum for experience types."""
SINGLE_TURN = "single_turn" # Single-turn experience, e.g., a prompt-response pair
MULTI_TURN = "multi_turn" # Multi-turn experience, e.g., a conversation history
DPO = "dpo" # DPO experience, e.g., a chosen and rejected response pair
[docs]
@dataclass(frozen=True)
class CustomField:
"""Custom field for Experiences.
This is used to store additional information into the Experiences class.
"""
source_field: str # The source field name in the Experience.info
destination_field: str # The destination field name in the Experiences class
data_type: torch.dtype # The data type of the field, e.g., torch.float32, torch.int64, etc.
[docs]
@dataclass
class Experience:
eid: EID = field(default_factory=EID) # Unique identifier for the experience
tokens: Optional[Tensor] = None # [seq_length]
logprobs: Optional[Tensor] = None # [resp_length]
reward: Optional[float] = None
advantages: Optional[Tensor] = None # [resp_length]
returns: Optional[Tensor] = None # [resp_length]
# Type of the experience, automatically set based on the presence of action_mask or chosen/rejected
experience_type: ExperienceType = ExperienceType.SINGLE_TURN
info: dict = field(
default_factory=dict
) # Additional information about the experience, can also be used to store custom fields
metrics: dict[str, float] = field(
default_factory=dict
) # Metrics associated with the experience, directly used by the monitor
# for single-turn experiences
prompt_length: int = 1 # Length of the prompt in tokens, used for generating attention masks
response_text: Optional[str] = None # Text of the response
prompt_text: Optional[str] = None # Text of the prompt
# for multi-turn experiences
# Action mask which indicates which tokens are generated by the model
action_mask: Optional[Tensor] = None # [resp_length]
messages: Optional[List[dict]] = None # List of messages
# for dpo experiences
chosen: Optional[Tensor] = None # Token ids of the chosen response [resp_length]
rejected: Optional[Tensor] = None # Token ids of the rejected response [resp_length]
chosen_text: Optional[str] = None # Text of the chosen response
rejected_text: Optional[str] = None # Text of the rejected response
[docs]
def __init__(
self,
*,
eid=None,
tokens,
logprobs=None,
reward=None,
advantages=None,
returns=None,
info=None,
metrics=None,
prompt_length=1,
response_text=None,
prompt_text=None,
action_mask=None,
messages=None,
chosen=None,
rejected=None,
chosen_text=None,
rejected_text=None,
):
if action_mask is not None:
experience_type = ExperienceType.MULTI_TURN
elif chosen is not None and rejected is not None:
experience_type = ExperienceType.DPO
else:
experience_type = ExperienceType.SINGLE_TURN
if experience_type == ExperienceType.SINGLE_TURN:
assert (
prompt_length > 0
), "Prompt length must be greater than 0 for single-turn experiences."
assert (
len(tokens) > prompt_length
), f"Token ids must be longer than the prompt length. Got len(tokens)={len(tokens)}, prompt_length={prompt_length}."
action_mask = torch.ones(len(tokens) - prompt_length, dtype=torch.bool)
elif experience_type == ExperienceType.DPO:
prompt_length = len(tokens)
self.eid = eid or EID()
self.tokens = tokens
self.logprobs = logprobs
self.reward = reward
self.advantages = advantages
self.returns = returns
self.experience_type = experience_type
self.info = info or {}
self.metrics = metrics or {}
self.prompt_length = prompt_length
self.response_text = response_text
self.prompt_text = prompt_text
self.action_mask = action_mask
self.messages = messages
self.chosen = chosen
self.rejected = rejected
self.chosen_text = chosen_text
self.rejected_text = rejected_text
if not isinstance(self.tokens, Tensor):
self.tokens = torch.tensor(self.tokens)
if self.logprobs is not None and not isinstance(self.logprobs, Tensor):
self.logprobs = torch.tensor(self.logprobs)
if self.action_mask is not None and not isinstance(self.action_mask, Tensor):
self.action_mask = torch.tensor(self.action_mask)
if self.chosen is not None and not isinstance(self.chosen, Tensor):
self.chosen = torch.tensor(self.chosen)
if self.rejected is not None and not isinstance(self.rejected, Tensor):
self.rejected = torch.tensor(self.rejected)
[docs]
def serialize(self) -> bytes:
"""Serialize the experience to bytes."""
return pickle.dumps(self)
[docs]
@classmethod
def deserialize(cls, data: bytes) -> Experience:
return pickle.loads(data)
[docs]
def to_dict(self) -> dict:
"""Convert the experience to a dictionary."""
res = {
"eid": self.eid,
"type": self.experience_type.value,
"info": self.info,
"metrics": self.metrics,
}
if self.prompt_text is not None:
res["prompt_text"] = self.prompt_text
if self.response_text is not None:
res["response_text"] = self.response_text
if self.messages is not None:
res["messages"] = self.messages
if self.chosen_text is not None:
res["chosen_text"] = self.chosen_text
if self.rejected_text is not None:
res["rejected_text"] = self.rejected_text
if self.reward is not None:
res["reward"] = float(self.reward)
return res
[docs]
@classmethod
def gather(
cls,
experiences: List[Experience],
pad_token_id: int = 0,
custom_fields: Optional[List[CustomField]] = None,
) -> Experiences:
if len(experiences) == 0:
return empty_experiences(custom_fields)
exp_type = experiences[0].experience_type
if exp_type == ExperienceType.DPO:
experiences = split_dpo_experience_to_single_turn(experiences)
max_prompt_length = max([exp.prompt_length for exp in experiences]) # type: ignore [type-var]
max_response_length = max([len(exp.tokens) - exp.prompt_length for exp in experiences]) # type: ignore [arg-type]
eids = [exp.eid for exp in experiences]
# Gather tokens
tokens = gather_token_ids(experiences, max_prompt_length, max_response_length, pad_token_id)
# Gather rewards
if experiences[0].reward is not None:
rewards = torch.tensor([exp.reward for exp in experiences], dtype=torch.float)
else:
rewards = None
# gather action_masks
action_masks = gather_action_masks(experiences, max_response_length)
# gather attention_masks
attention_masks = gather_attention_masks(
experiences, max_prompt_length, max_response_length
)
# gather logprobs
if all(exp.logprobs is not None for exp in experiences):
logprobs = gather_logprobs(experiences, max_response_length)
else:
logprobs = None
# gather advantages
if all(exp.advantages is not None for exp in experiences):
advantages = gather_advantages(experiences, max_response_length)
else:
advantages = None
# gather returns
if all(exp.returns is not None for exp in experiences):
returns = gather_returns(experiences, max_response_length)
else:
returns = None
exps = Experiences(
eids=eids,
tokens=tokens,
rewards=rewards,
advantages=advantages,
returns=returns,
attention_masks=attention_masks,
action_masks=action_masks,
prompt_length=max_prompt_length,
logprobs=logprobs,
)
if custom_fields is not None:
for custom_field in custom_fields:
exps.custom_fields.append(custom_field.destination_field)
setattr(
exps,
custom_field.destination_field,
torch.tensor(
[exp.info[custom_field.source_field] for exp in experiences],
dtype=custom_field.data_type,
),
)
return exps
[docs]
def split_dpo_experience_to_single_turn(experiences: List[Experience]) -> List[Experience]:
single_turn_experiences = []
for exp in experiences:
single_turn_experiences.append(
Experience(
eid=EID(
batch=exp.eid.batch,
task=exp.eid.task,
step=exp.eid.step,
run=exp.eid.run,
),
tokens=torch.cat([exp.tokens, exp.chosen]),
reward=exp.reward,
info=exp.info,
metrics=exp.metrics,
prompt_length=len(exp.tokens), # type: ignore [arg-type]
prompt_text=exp.prompt_text,
response_text=exp.chosen_text,
)
)
single_turn_experiences.append(
Experience(
eid=EID(
batch=exp.eid.batch,
task=exp.eid.task,
step=exp.eid.step,
run=exp.eid.run,
),
tokens=torch.cat([exp.tokens, exp.rejected]),
reward=exp.reward,
info=exp.info,
metrics=exp.metrics,
prompt_length=len(exp.tokens), # type: ignore [arg-type]
prompt_text=exp.prompt_text,
response_text=exp.rejected_text,
)
)
return single_turn_experiences
[docs]
@dataclass
class Experiences:
"""A container for a batch of experiences, for high performance communication usage.
Example:
>>> |<- prompt_length ->| |
>>> tokens: ('P' represents prompt, 'O' represents output)
>>> exp1: |........PPPPPPPPPPP|OOOOOOOOOO.....|
>>> exp2: |......PPPPPPPPPPPPP|OOOOOOO........|
>>>
>>> attention_masks: ('.' represents False and '1' represents True)
>>> exp1: |........11111111111|1111111111.....|
>>> exp2: |......1111111111111|1111111........|
"""
eids: List[EID] # Experience IDs of each experience in the batch
tokens: Tensor # [batch_size, seq_length]
rewards: Tensor # [batch_size]
advantages: Optional[Tensor] # [batch_size, response_length]
returns: Optional[Tensor] # [batch_size, response_length]
attention_masks: Tensor # [batch_size, sequence_length]
action_masks: Optional[Tensor] # [batch_size, response_length]
prompt_length: int
logprobs: Optional[Tensor] # [batch_size, response_length]
custom_fields: List[str] = field(
default_factory=list
) # Custom fields to include in the gathered experiences
@property
def batch_size(self) -> int:
"""Get the batch size."""
return self.tokens.size(0)
[docs]
@classmethod
def gather_experiences(
cls,
experiences: list[Experience],
pad_token_id: int = 0,
custom_fields: Optional[List[CustomField]] = None,
) -> Experiences:
"""Gather a batch of experiences from a list of experiences.
This method will automatically pad the `tokens` and `logprobs` of input experiences to the same length.
Args:
experiences (list[Experience]): A list of experiences to gather.
pad_token_id (int): The token ID to use for padding. Default is 0.
custom_fields (Optional[List[CustomField]]): Custom fields to include in the gathered experiences.
"""
if len(experiences) == 0:
return empty_experiences(custom_fields)
return experiences[0].__class__.gather(
experiences, pad_token_id=pad_token_id, custom_fields=custom_fields
)
[docs]
def empty_experiences(custom_fields: Optional[List[CustomField]]) -> Experiences:
exps = Experiences(
tokens=torch.empty(0, dtype=torch.int32),
rewards=torch.empty(0, dtype=torch.float32),
advantages=torch.empty(0, dtype=torch.float32),
returns=torch.empty(0, dtype=torch.float32),
attention_masks=torch.empty(0, dtype=torch.bool),
action_masks=torch.empty(0, dtype=torch.bool),
logprobs=torch.empty(0, dtype=torch.float32),
prompt_length=torch.empty(0, dtype=torch.int32),
eids=[],
)
if custom_fields is not None:
for custom_field in custom_fields:
exps.custom_fields.append(custom_field.destination_field)
setattr(
exps, custom_field.destination_field, torch.empty(0, dtype=custom_field.data_type)
)
return exps
[docs]
def gather_token_ids(
experiences, max_prompt_length: int, max_response_length: int, pad_token_id: int
) -> Tensor:
token_ids_dtype = experiences[0].tokens.dtype
return torch.stack(
[
torch.cat(
[
torch.full(
(max_prompt_length - exp.prompt_length,),
pad_token_id,
dtype=token_ids_dtype,
),
exp.tokens,
torch.full(
(max_response_length + exp.prompt_length - len(exp.tokens),),
pad_token_id,
dtype=token_ids_dtype,
),
]
)
for exp in experiences
]
)
[docs]
def gather_action_masks(experiences, max_response_length: int) -> Tensor:
return torch.stack(
[
torch.cat(
[
exp.action_mask,
torch.full(
(max_response_length - len(exp.action_mask),),
0,
dtype=torch.bool,
),
]
)
for exp in experiences
]
)
[docs]
def gather_attention_masks(experiences, max_prompt_length: int, max_response_length: int) -> Tensor:
attention_masks = torch.zeros(
(len(experiences), max_prompt_length + max_response_length), dtype=torch.bool
)
for i, exp in enumerate(experiences):
start = max_prompt_length - exp.prompt_length
end = start + len(exp.tokens)
attention_masks[i, start:end] = 1
return attention_masks
[docs]
def gather_logprobs(experiences, max_response_length: int) -> Tensor:
logprob_dtype = experiences[0].logprobs.dtype # type: ignore [union-attr]
return torch.stack(
[
torch.cat(
[
exp.logprobs,
torch.full(
(max_response_length - len(exp.logprobs),),
0.0,
dtype=logprob_dtype,
),
]
)
for exp in experiences
]
)
[docs]
def gather_advantages(experiences, max_response_length: int) -> Optional[Tensor]:
if experiences[0].advantages is None:
return None
advantages_dtype = experiences[0].advantages.dtype
return torch.stack(
[
torch.cat(
[
exp.advantages,
torch.full(
(max_response_length - len(exp.advantages),),
0.0,
dtype=advantages_dtype,
),
]
)
for exp in experiences
]
)
[docs]
def gather_returns(experiences, max_response_length: int) -> Optional[Tensor]:
if experiences[0].returns is None:
return None
returns_dtype = experiences[0].returns.dtype
return torch.stack(
[
torch.cat(
[
exp.returns,
torch.full(
(max_response_length - len(exp.returns),),
0.0,
dtype=returns_dtype,
),
]
)
for exp in experiences
]
)