# -*- coding: utf-8 -*-
"""Experience Class."""
from __future__ import annotations
import pickle
import uuid
from dataclasses import asdict, dataclass, field, fields
from typing import Dict, List, Literal, Optional
import torch
from datasets import Dataset
from torch import Tensor
[docs]
@dataclass
class EID:
"""Experience ID class to uniquely identify an experience.
To enable the full functionality of the experience grouping, user should manually set the `run` and `step` fields in custom workflows.
"""
# TODO: do we need to add project/name here to make it unique across different projects?
# Batch number, e.g., the explorer step num
# Automatically set by the workflow runner
batch: int = 0
# Task number, e.g., the task sequence in the batch, the first task in the batch has task=0
# Automatically set by the workflow runner
task: int = 0
# Run id, e.g., the first run in the task has run=0
# User should set this field in custom workflows when creating experiences
run: int = 0
# Step number when running the task, e.g., the first step in the task has step=0
# User should set this field in custom workflows when creating experiences
step: int = 0
suffix: str = field(
default_factory=lambda: uuid.uuid4().hex[:6]
) # Unique identifier suffix, e.g., a UUID
@property
def uid(self) -> str:
"""An unique identifier for the experience."""
return f"{self.batch}/{self.task}/{self.run}/{self.step}/{self.suffix}"
@property
def sid(self) -> str:
"""Step ID of the experience.
For example, experiences generated by all runs of a same task at the same step will have the same sid.
"""
return f"{self.batch}/{self.task}/{self.step}"
@property
def rid(self) -> str:
"""Run ID of the experience.
For example, experiences generated by one run of a task at all steps will have the same run_id.
"""
return f"{self.batch}/{self.task}/{self.run}"
@property
def tid(self) -> str:
"""Task ID for the experience.
For example, experiences generated by a all run of a same task in GRPO-like algorithms will have the same tid.
"""
return f"{self.batch}/{self.task}"
def __str__(self):
return self.uid
def __repr__(self):
return f"EID(batch={self.batch}, task={self.task}, run={self.run}, step={self.step}, uuid={self.suffix})"
[docs]
def to_dict(self) -> dict:
"""Convert the EID to a dictionary."""
return {
"batch": self.batch,
"task": self.task,
"run": self.run,
"step": self.step,
"suffix": self.suffix,
}
[docs]
@dataclass(frozen=True)
class CustomField:
"""Custom field for Experiences.
This is used to store additional information into the Experiences class.
"""
source_field: str # The source field name in the Experience.info
destination_field: str # The destination field name in the Experiences class
data_type: torch.dtype # The data type of the field, e.g., torch.float32, torch.int64, etc.
[docs]
@dataclass
class Experience:
eid: EID = field(default_factory=EID) # Unique identifier for the experience
tokens: Optional[Tensor] = None # [seq_length]
logprobs: Optional[Tensor] = None # [resp_length]
reward: Optional[float] = None
advantages: Optional[Tensor] = None # [resp_length]
returns: Optional[Tensor] = None # [resp_length]
info: dict = field(
default_factory=dict
) # Additional information about the experience, can also be used to store custom fields
metrics: dict[str, float] = field(
default_factory=dict
) # Metrics associated with the experience, directly used by the monitor
# for single-turn experiences
prompt_length: int = 1 # Length of the prompt in tokens, used for generating attention masks
response_text: Optional[str] = None # Text of the response
prompt_text: Optional[str] = None # Text of the prompt
# for multi-turn experiences
# Action mask which indicates which tokens are generated by the model
action_mask: Optional[Tensor] = None # [resp_length]
messages: Optional[List[dict]] = None # List of messages
tools: Optional[List[dict]] = None
# for dpo experiences
chosen: Optional[Tensor] = None # Token ids of the chosen response [resp_length]
rejected: Optional[Tensor] = None # Token ids of the rejected response [resp_length]
chosen_text: Optional[str] = None # Text of the chosen response
rejected_text: Optional[str] = None # Text of the rejected response
[docs]
def __init__( # noqa: C901
self,
*,
eid=None,
tokens,
logprobs=None,
reward=None,
advantages=None,
returns=None,
info=None,
metrics=None,
prompt_length=1,
response_text=None,
prompt_text=None,
action_mask=None,
messages=None,
tools=None,
chosen=None,
rejected=None,
chosen_text=None,
rejected_text=None,
):
if action_mask is not None:
experience_type = "multi_turn"
elif chosen is not None and rejected is not None:
experience_type = "dpo"
else:
experience_type = "single_turn"
if experience_type == "single_turn":
assert (
prompt_length > 0
), "Prompt length must be greater than 0 for single-turn experiences."
assert (
len(tokens) > prompt_length
), f"Token ids must be longer than the prompt length. Got len(tokens)={len(tokens)}, prompt_length={prompt_length}."
action_mask = torch.ones(len(tokens) - prompt_length, dtype=torch.bool)
elif experience_type == "dpo":
prompt_length = len(tokens)
if eid is None:
self.eid = EID()
elif isinstance(eid, dict):
self.eid = EID(**eid)
else:
self.eid = eid
if isinstance(tokens, list):
tokens = torch.tensor(tokens, dtype=torch.int32)
self.tokens = tokens
if isinstance(logprobs, list):
logprobs = torch.tensor(logprobs, dtype=torch.float32)
self.logprobs = logprobs
self.reward = reward
if isinstance(advantages, list):
advantages = torch.tensor(advantages, dtype=torch.float32)
self.advantages = advantages
if isinstance(returns, list):
returns = torch.tensor(returns, dtype=torch.float32)
self.returns = returns
self.experience_type = experience_type
self.info = info or {}
self.metrics = metrics or {}
self.prompt_length = prompt_length
self.response_text = response_text
self.prompt_text = prompt_text
if isinstance(action_mask, list):
action_mask = torch.tensor(action_mask, dtype=torch.bool)
self.action_mask = action_mask
self.messages = messages
self.tools = tools
if isinstance(chosen, list):
chosen = torch.tensor(chosen, dtype=torch.int32)
self.chosen = chosen
if isinstance(rejected, list):
rejected = torch.tensor(rejected, dtype=torch.int32)
self.rejected = rejected
self.chosen_text = chosen_text
self.rejected_text = rejected_text
if not isinstance(self.tokens, Tensor):
self.tokens = torch.tensor(self.tokens)
if self.logprobs is not None and not isinstance(self.logprobs, Tensor):
self.logprobs = torch.tensor(self.logprobs)
if self.action_mask is not None and not isinstance(self.action_mask, Tensor):
self.action_mask = torch.tensor(self.action_mask)
if self.chosen is not None and not isinstance(self.chosen, Tensor):
self.chosen = torch.tensor(self.chosen)
if self.rejected is not None and not isinstance(self.rejected, Tensor):
self.rejected = torch.tensor(self.rejected)
[docs]
def serialize(self) -> bytes:
"""Serialize the experience to bytes."""
return pickle.dumps(self)
[docs]
@classmethod
def deserialize(cls, data: bytes) -> Experience:
return pickle.loads(data)
[docs]
def to_dict(self) -> dict:
"""Convert the experience to a dictionary."""
res = {
"eid": self.eid,
"type": self.experience_type,
"prompt_length": self.prompt_length,
"response_length": len(self.tokens) - self.prompt_length, # type: ignore [arg-type]
"info": self.info,
"metrics": self.metrics,
}
if self.prompt_text is not None:
res["prompt_text"] = self.prompt_text
if self.response_text is not None:
res["response_text"] = self.response_text
if self.messages is not None:
res["messages"] = self.messages
if self.tools is not None:
res["tools"] = self.tools
if self.chosen_text is not None:
res["chosen_text"] = self.chosen_text
if self.rejected_text is not None:
res["rejected_text"] = self.rejected_text
if self.reward is not None:
res["reward"] = float(self.reward)
return res
[docs]
@classmethod
def gather(
cls,
experiences: List[Experience],
pad_token_id: int = 0,
custom_fields: Optional[List[CustomField]] = None,
) -> Experiences:
if len(experiences) == 0:
return empty_experiences(custom_fields)
exp_type = experiences[0].experience_type
if exp_type == "dpo":
experiences = split_dpo_experience_to_single_turn(experiences)
max_prompt_length = max([exp.prompt_length for exp in experiences]) # type: ignore [type-var]
max_response_length = max([len(exp.tokens) - exp.prompt_length for exp in experiences]) # type: ignore [arg-type]
eids = [exp.eid for exp in experiences]
# Gather tokens
tokens = gather_token_ids(experiences, max_prompt_length, max_response_length, pad_token_id)
# Gather rewards
if experiences[0].reward is not None:
rewards = torch.tensor([exp.reward for exp in experiences], dtype=torch.float)
else:
rewards = None
# gather action_masks
action_masks = gather_action_masks(experiences, max_response_length)
# gather attention_masks
attention_masks = gather_attention_masks(
experiences, max_prompt_length, max_response_length
)
# gather logprobs
if all(exp.logprobs is not None for exp in experiences):
logprobs = gather_logprobs(experiences, max_response_length)
else:
logprobs = None
# gather advantages
if all(exp.advantages is not None for exp in experiences):
advantages = gather_advantages(experiences, max_response_length)
else:
advantages = None
# gather returns
if all(exp.returns is not None for exp in experiences):
returns = gather_returns(experiences, max_response_length)
else:
returns = None
exps = Experiences(
eids=eids,
tokens=tokens,
rewards=rewards,
advantages=advantages,
returns=returns,
attention_masks=attention_masks,
action_masks=action_masks,
prompt_length=max_prompt_length,
logprobs=logprobs,
)
if custom_fields is not None:
for custom_field in custom_fields:
exps.custom_fields.append(custom_field.destination_field)
setattr(
exps,
custom_field.destination_field,
torch.tensor(
[exp.info[custom_field.source_field] for exp in experiences],
dtype=custom_field.data_type,
),
)
return exps
[docs]
def split_dpo_experience_to_single_turn(experiences: List[Experience]) -> List[Experience]:
single_turn_experiences = []
for exp in experiences:
single_turn_experiences.append(
Experience(
eid=EID(
batch=exp.eid.batch,
task=exp.eid.task,
step=exp.eid.step,
run=exp.eid.run,
),
tokens=torch.cat([exp.tokens, exp.chosen]),
reward=exp.reward,
info=exp.info,
metrics=exp.metrics,
prompt_length=len(exp.tokens), # type: ignore [arg-type]
prompt_text=exp.prompt_text,
response_text=exp.chosen_text,
)
)
single_turn_experiences.append(
Experience(
eid=EID(
batch=exp.eid.batch,
task=exp.eid.task,
step=exp.eid.step,
run=exp.eid.run,
),
tokens=torch.cat([exp.tokens, exp.rejected]),
reward=exp.reward,
info=exp.info,
metrics=exp.metrics,
prompt_length=len(exp.tokens), # type: ignore [arg-type]
prompt_text=exp.prompt_text,
response_text=exp.rejected_text,
)
)
return single_turn_experiences
[docs]
@dataclass
class Experiences:
"""A container for a batch of experiences, for high performance communication usage.
Example:
>>> |<- prompt_length ->| |
>>> tokens: ('P' represents prompt, 'O' represents output)
>>> exp1: |........PPPPPPPPPPP|OOOOOOOOOO.....|
>>> exp2: |......PPPPPPPPPPPPP|OOOOOOO........|
>>>
>>> attention_masks: ('.' represents False and '1' represents True)
>>> exp1: |........11111111111|1111111111.....|
>>> exp2: |......1111111111111|1111111........|
"""
eids: List[EID] # Experience IDs of each experience in the batch
tokens: Tensor # [batch_size, seq_length]
rewards: Tensor # [batch_size]
advantages: Optional[Tensor] # [batch_size, response_length]
returns: Optional[Tensor] # [batch_size, response_length]
attention_masks: Tensor # [batch_size, sequence_length]
action_masks: Optional[Tensor] # [batch_size, response_length]
prompt_length: int
logprobs: Optional[Tensor] # [batch_size, response_length]
custom_fields: List[str] = field(
default_factory=list
) # Custom fields to include in the gathered experiences
@property
def batch_size(self) -> int:
"""Get the batch size."""
return self.tokens.size(0)
[docs]
@classmethod
def gather_experiences(
cls,
experiences: list[Experience],
pad_token_id: int = 0,
custom_fields: Optional[List[CustomField]] = None,
) -> Experiences:
"""Gather a batch of experiences from a list of experiences.
This method will automatically pad the `tokens` and `logprobs` of input experiences to the same length.
Args:
experiences (list[Experience]): A list of experiences to gather.
pad_token_id (int): The token ID to use for padding. Default is 0.
custom_fields (Optional[List[CustomField]]): Custom fields to include in the gathered experiences.
"""
if len(experiences) == 0:
return empty_experiences(custom_fields)
return experiences[0].__class__.gather(
experiences, pad_token_id=pad_token_id, custom_fields=custom_fields
)
[docs]
def empty_experiences(custom_fields: Optional[List[CustomField]]) -> Experiences:
exps = Experiences(
tokens=torch.empty(0, dtype=torch.int32),
rewards=torch.empty(0, dtype=torch.float32),
advantages=torch.empty(0, dtype=torch.float32),
returns=torch.empty(0, dtype=torch.float32),
attention_masks=torch.empty(0, dtype=torch.bool),
action_masks=torch.empty(0, dtype=torch.bool),
logprobs=torch.empty(0, dtype=torch.float32),
prompt_length=torch.empty(0, dtype=torch.int32),
eids=[],
)
if custom_fields is not None:
for custom_field in custom_fields:
exps.custom_fields.append(custom_field.destination_field)
setattr(
exps, custom_field.destination_field, torch.empty(0, dtype=custom_field.data_type)
)
return exps
[docs]
def gather_token_ids(
experiences, max_prompt_length: int, max_response_length: int, pad_token_id: int
) -> Tensor:
token_ids_dtype = experiences[0].tokens.dtype
return torch.stack(
[
torch.cat(
[
torch.full(
(max_prompt_length - exp.prompt_length,),
pad_token_id,
dtype=token_ids_dtype,
),
exp.tokens,
torch.full(
(max_response_length + exp.prompt_length - len(exp.tokens),),
pad_token_id,
dtype=token_ids_dtype,
),
]
)
for exp in experiences
]
)
[docs]
def gather_action_masks(experiences, max_response_length: int) -> Tensor:
return torch.stack(
[
torch.cat(
[
exp.action_mask,
torch.full(
(max_response_length - len(exp.action_mask),),
0,
dtype=torch.bool,
),
]
)
for exp in experiences
]
)
[docs]
def gather_attention_masks(experiences, max_prompt_length: int, max_response_length: int) -> Tensor:
attention_masks = torch.zeros(
(len(experiences), max_prompt_length + max_response_length), dtype=torch.bool
)
for i, exp in enumerate(experiences):
start = max_prompt_length - exp.prompt_length
end = start + len(exp.tokens)
attention_masks[i, start:end] = 1
return attention_masks
[docs]
def gather_logprobs(experiences, max_response_length: int) -> Tensor:
logprob_dtype = experiences[0].logprobs.dtype # type: ignore [union-attr]
return torch.stack(
[
torch.cat(
[
exp.logprobs,
torch.full(
(max_response_length - len(exp.logprobs),),
0.0,
dtype=logprob_dtype,
),
]
)
for exp in experiences
]
)
[docs]
def gather_advantages(experiences, max_response_length: int) -> Optional[Tensor]:
if experiences[0].advantages is None:
return None
advantages_dtype = experiences[0].advantages.dtype
return torch.stack(
[
torch.cat(
[
exp.advantages,
torch.full(
(max_response_length - len(exp.advantages),),
0.0,
dtype=advantages_dtype,
),
]
)
for exp in experiences
]
)
[docs]
def gather_returns(experiences, max_response_length: int) -> Optional[Tensor]:
if experiences[0].returns is None:
return None
returns_dtype = experiences[0].returns.dtype
return torch.stack(
[
torch.cat(
[
exp.returns,
torch.full(
(max_response_length - len(exp.returns),),
0.0,
dtype=returns_dtype,
),
]
)
for exp in experiences
]
)
[docs]
def group_by(
experiences: List[Experience], id_type: Literal["task", "run", "step"]
) -> Dict[str, List[Experience]]:
"""Group experiences by ID."""
if id_type == "task":
id_type = "tid"
elif id_type == "run":
id_type = "rid"
elif id_type == "step":
id_type = "sid"
else:
raise ValueError(f"Unknown id_type: {id_type}")
grouped = {}
for exp in experiences:
group_id = getattr(exp.eid, id_type)
if group_id not in grouped:
grouped[group_id] = []
grouped[group_id].append(exp)
return grouped
[docs]
def to_hf_datasets(experiences: list[Experience]) -> Dataset:
"""
Convert a list of Experience objects to a HuggingFace Dataset,
preserving all fields.
"""
return Dataset.from_list([asdict(exp) for exp in experiences])
[docs]
def from_hf_datasets(dataset: Dataset) -> List[Experience]:
"""
Convert a HuggingFace Dataset back to a list of Experience objects.
"""
def dict_to_dataclass(cls, d):
valid_keys = {f.name for f in fields(cls)}
filtered = {k: v for k, v in d.items() if k in valid_keys}
return cls(**filtered)
experiences = [dict_to_dataclass(Experience, row) for row in dataset.to_list()]
return experiences