trinity.common.experience module

Experience Class.

class trinity.common.experience.EID(batch: int = 0, task: int = 0, run: int = 0, step: int = 0, suffix: str = <factory>)[source]

Bases: object

Experience ID class to uniquely identify an experience.

To enable the full functionality of the experience grouping, user should manually set the run and step fields in custom workflows.

batch: int = 0
task: int = 0
run: int = 0
step: int = 0
suffix: str
property uid: str

An unique identifier for the experience.

property sid: str

Step ID of the experience.

For example, experiences generated by all runs of a same task at the same step will have the same sid.

property rid: str

Run ID of the experience.

For example, experiences generated by one run of a task at all steps will have the same run_id.

property tid: str

Task ID for the experience.

For example, experiences generated by a all run of a same task in GRPO-like algorithms will have the same tid.

to_dict() dict[source]

Convert the EID to a dictionary.

__init__(batch: int = 0, task: int = 0, run: int = 0, step: int = 0, suffix: str = <factory>) None
class trinity.common.experience.CustomField(source_field: str, destination_field: str, data_type: dtype)[source]

Bases: object

Custom field for Experiences.

This is used to store additional information into the Experiences class.

source_field: str
destination_field: str
data_type: dtype
__init__(source_field: str, destination_field: str, data_type: dtype) None
class trinity.common.experience.Experience(*, eid=None, tokens, logprobs=None, reward=None, advantages=None, returns=None, info=None, metrics=None, prompt_length=1, response_text=None, prompt_text=None, action_mask=None, messages=None, tools=None, chosen=None, rejected=None, chosen_text=None, rejected_text=None)[source]

Bases: object

__init__(*, eid=None, tokens, logprobs=None, reward=None, advantages=None, returns=None, info=None, metrics=None, prompt_length=1, response_text=None, prompt_text=None, action_mask=None, messages=None, tools=None, chosen=None, rejected=None, chosen_text=None, rejected_text=None)[source]
eid: EID
reward: float | None = None
advantages: Tensor | None = None
returns: Tensor | None = None
info: dict
metrics: dict[str, float]
prompt_length: int = 1
response_text: str | None = None
prompt_text: str | None = None
messages: List[dict] | None = None
tools: List[dict] | None = None
chosen_text: str | None = None
rejected_text: str | None = None
tokens: Tensor | None = None
logprobs: Tensor | None = None
action_mask: Tensor | None = None
chosen: Tensor | None = None
rejected: Tensor | None = None
serialize() bytes[source]

Serialize the experience to bytes.

classmethod deserialize(data: bytes) Experience[source]
to_dict() dict[source]

Convert the experience to a dictionary.

classmethod gather(experiences: List[Experience], pad_token_id: int = 0, custom_fields: List[CustomField] | None = None) Experiences[source]
trinity.common.experience.split_dpo_experience_to_single_turn(experiences: List[Experience]) List[Experience][source]
class trinity.common.experience.Experiences(eids: ~typing.List[~trinity.common.experience.EID], tokens: ~torch.Tensor, rewards: ~torch.Tensor, advantages: ~torch.Tensor | None, returns: ~torch.Tensor | None, attention_masks: ~torch.Tensor, action_masks: ~torch.Tensor | None, prompt_length: int, logprobs: ~torch.Tensor | None, custom_fields: ~typing.List[str] = <factory>)[source]

Bases: object

A container for a batch of experiences, for high performance communication usage.

Example

>>>             |<- prompt_length ->|               |
>>> tokens: ('P' represents prompt, 'O' represents output)
>>> exp1:       |........PPPPPPPPPPP|OOOOOOOOOO.....|
>>> exp2:       |......PPPPPPPPPPPPP|OOOOOOO........|
>>>
>>> attention_masks: ('.' represents False and '1' represents True)
>>> exp1:       |........11111111111|1111111111.....|
>>> exp2:       |......1111111111111|1111111........|
__init__(eids: ~typing.List[~trinity.common.experience.EID], tokens: ~torch.Tensor, rewards: ~torch.Tensor, advantages: ~torch.Tensor | None, returns: ~torch.Tensor | None, attention_masks: ~torch.Tensor, action_masks: ~torch.Tensor | None, prompt_length: int, logprobs: ~torch.Tensor | None, custom_fields: ~typing.List[str] = <factory>) None
eids: List[EID]
tokens: Tensor
rewards: Tensor
advantages: Tensor | None
returns: Tensor | None
attention_masks: Tensor
action_masks: Tensor | None
prompt_length: int
logprobs: Tensor | None
custom_fields: List[str]
property batch_size: int

Get the batch size.

classmethod gather_experiences(experiences: list[Experience], pad_token_id: int = 0, custom_fields: List[CustomField] | None = None) Experiences[source]

Gather a batch of experiences from a list of experiences.

This method will automatically pad the tokens and logprobs of input experiences to the same length.

Parameters:
  • experiences (list[Experience]) – A list of experiences to gather.

  • pad_token_id (int) – The token ID to use for padding. Default is 0.

  • custom_fields (Optional[List[CustomField]]) – Custom fields to include in the gathered experiences.

trinity.common.experience.empty_experiences(custom_fields: List[CustomField] | None) Experiences[source]
trinity.common.experience.gather_token_ids(experiences, max_prompt_length: int, max_response_length: int, pad_token_id: int) Tensor[source]
trinity.common.experience.gather_action_masks(experiences, max_response_length: int) Tensor[source]
trinity.common.experience.gather_attention_masks(experiences, max_prompt_length: int, max_response_length: int) Tensor[source]
trinity.common.experience.gather_logprobs(experiences, max_response_length: int) Tensor[source]
trinity.common.experience.gather_advantages(experiences, max_response_length: int) Tensor | None[source]
trinity.common.experience.gather_returns(experiences, max_response_length: int) Tensor | None[source]
trinity.common.experience.group_by(experiences: List[Experience], id_type: Literal['task', 'run', 'step']) Dict[str, List[Experience]][source]

Group experiences by ID.

trinity.common.experience.to_hf_datasets(experiences: list[Experience]) Dataset[source]

Convert a list of Experience objects to a HuggingFace Dataset, preserving all fields.

trinity.common.experience.from_hf_datasets(dataset: Dataset) List[Experience][source]

Convert a HuggingFace Dataset back to a list of Experience objects.