trinity.common.experience module
Experience Class.
- class trinity.common.experience.EID(batch: int = 0, task: int = 0, run: int = 0, step: int = 0, suffix: str = <factory>)[source]
Bases:
object
Experience ID class to uniquely identify an experience.
To enable the full functionality of the experience grouping, user should manually set the run and step fields in custom workflows.
- batch: int = 0
- task: int = 0
- run: int = 0
- step: int = 0
- suffix: str
- property uid: str
An unique identifier for the experience.
- property sid: str
Step ID of the experience.
For example, experiences generated by all runs of a same task at the same step will have the same sid.
- property rid: str
Run ID of the experience.
For example, experiences generated by one run of a task at all steps will have the same run_id.
- property tid: str
Task ID for the experience.
For example, experiences generated by a all run of a same task in GRPO-like algorithms will have the same tid.
- __init__(batch: int = 0, task: int = 0, run: int = 0, step: int = 0, suffix: str = <factory>) None
- class trinity.common.experience.CustomField(source_field: str, destination_field: str, data_type: dtype)[source]
Bases:
object
Custom field for Experiences.
This is used to store additional information into the Experiences class.
- source_field: str
- destination_field: str
- data_type: dtype
- __init__(source_field: str, destination_field: str, data_type: dtype) None
- class trinity.common.experience.Experience(*, eid=None, tokens, logprobs=None, reward=None, advantages=None, returns=None, info=None, metrics=None, prompt_length=1, response_text=None, prompt_text=None, action_mask=None, messages=None, tools=None, chosen=None, rejected=None, chosen_text=None, rejected_text=None)[source]
Bases:
object
- __init__(*, eid=None, tokens, logprobs=None, reward=None, advantages=None, returns=None, info=None, metrics=None, prompt_length=1, response_text=None, prompt_text=None, action_mask=None, messages=None, tools=None, chosen=None, rejected=None, chosen_text=None, rejected_text=None)[source]
- reward: float | None = None
- advantages: Tensor | None = None
- returns: Tensor | None = None
- info: dict
- metrics: dict[str, float]
- prompt_length: int = 1
- response_text: str | None = None
- prompt_text: str | None = None
- messages: List[dict] | None = None
- tools: List[dict] | None = None
- chosen_text: str | None = None
- rejected_text: str | None = None
- tokens: Tensor | None = None
- logprobs: Tensor | None = None
- action_mask: Tensor | None = None
- chosen: Tensor | None = None
- rejected: Tensor | None = None
- classmethod deserialize(data: bytes) Experience [source]
- classmethod gather(experiences: List[Experience], pad_token_id: int = 0, custom_fields: List[CustomField] | None = None) Experiences [source]
- trinity.common.experience.split_dpo_experience_to_single_turn(experiences: List[Experience]) List[Experience] [source]
- class trinity.common.experience.Experiences(eids: ~typing.List[~trinity.common.experience.EID], tokens: ~torch.Tensor, rewards: ~torch.Tensor, advantages: ~torch.Tensor | None, returns: ~torch.Tensor | None, attention_masks: ~torch.Tensor, action_masks: ~torch.Tensor | None, prompt_length: int, logprobs: ~torch.Tensor | None, custom_fields: ~typing.List[str] = <factory>)[source]
Bases:
object
A container for a batch of experiences, for high performance communication usage.
Example
>>> |<- prompt_length ->| | >>> tokens: ('P' represents prompt, 'O' represents output) >>> exp1: |........PPPPPPPPPPP|OOOOOOOOOO.....| >>> exp2: |......PPPPPPPPPPPPP|OOOOOOO........| >>> >>> attention_masks: ('.' represents False and '1' represents True) >>> exp1: |........11111111111|1111111111.....| >>> exp2: |......1111111111111|1111111........|
- __init__(eids: ~typing.List[~trinity.common.experience.EID], tokens: ~torch.Tensor, rewards: ~torch.Tensor, advantages: ~torch.Tensor | None, returns: ~torch.Tensor | None, attention_masks: ~torch.Tensor, action_masks: ~torch.Tensor | None, prompt_length: int, logprobs: ~torch.Tensor | None, custom_fields: ~typing.List[str] = <factory>) None
- tokens: Tensor
- rewards: Tensor
- advantages: Tensor | None
- returns: Tensor | None
- attention_masks: Tensor
- action_masks: Tensor | None
- prompt_length: int
- logprobs: Tensor | None
- custom_fields: List[str]
- property batch_size: int
Get the batch size.
- classmethod gather_experiences(experiences: list[Experience], pad_token_id: int = 0, custom_fields: List[CustomField] | None = None) Experiences [source]
Gather a batch of experiences from a list of experiences.
This method will automatically pad the tokens and logprobs of input experiences to the same length.
- Parameters:
experiences (list[Experience]) – A list of experiences to gather.
pad_token_id (int) – The token ID to use for padding. Default is 0.
custom_fields (Optional[List[CustomField]]) – Custom fields to include in the gathered experiences.
- trinity.common.experience.empty_experiences(custom_fields: List[CustomField] | None) Experiences [source]
- trinity.common.experience.gather_token_ids(experiences, max_prompt_length: int, max_response_length: int, pad_token_id: int) Tensor [source]
- trinity.common.experience.gather_action_masks(experiences, max_response_length: int) Tensor [source]
- trinity.common.experience.gather_attention_masks(experiences, max_prompt_length: int, max_response_length: int) Tensor [source]
- trinity.common.experience.gather_advantages(experiences, max_response_length: int) Tensor | None [source]
- trinity.common.experience.gather_returns(experiences, max_response_length: int) Tensor | None [source]
- trinity.common.experience.group_by(experiences: List[Experience], id_type: Literal['task', 'run', 'step']) Dict[str, List[Experience]] [source]
Group experiences by ID.
- trinity.common.experience.to_hf_datasets(experiences: list[Experience]) Dataset [source]
Convert a list of Experience objects to a HuggingFace Dataset, preserving all fields.
- trinity.common.experience.from_hf_datasets(dataset: Dataset) List[Experience] [source]
Convert a HuggingFace Dataset back to a list of Experience objects.