trinity.common.models.utils module

trinity.common.models.utils.tokenize_and_mask_messages_hf(tokenizer: Any, messages: List[dict], chat_template: str | None = None) Tuple[Tensor, Tensor, int][source]

Calculate the assistant token mask with chat_template.

Parameters:
  • tokenizer (Any) – The tokenizer.

  • chat_template (str) – The chat template with {% generation %} symbol.

  • messages (List[dict]) – Messages with role and content fields.

Returns:

The token_ids (sequence_length) torch.Tensor: Assistant_masks (sequence_length). int: Prompt length.

Return type:

torch.Tensor

trinity.common.models.utils.tokenize_and_mask_messages_default(tokenizer: Any, messages: List[dict], chat_template: str | None = None) Tuple[Tensor, Tensor, int][source]

Calculate the assistant token mask.

Parameters:
  • tokenizer (Any) – The tokenizer.

  • chat_template (str) – The chat template with {% generation %} symbol.

  • messages (List[dict]) – Messages with role and content fields.

Returns:

The token_ids (sequence_length) torch.Tensor: Assistant_masks (sequence_length). int: Prompt length.

Return type:

torch.Tensor

Note

This method is based on the assumption that as the number of chat rounds increases, the tokens of the previous round are exactly the prefix tokens of the next round. If the assumption is not met, the function may produce incorrect results. Please check the chat template before using this method.

trinity.common.models.utils.get_checkpoint_dir_with_step_num(checkpoint_root_path: str, trainer_type: str = 'verl', step_num: int | None = None) Tuple[str, int][source]

Get the checkpoint directory from a root checkpoint directory.

Parameters:
  • checkpoint_root_path (str) – The root checkpoint directory.

  • trainer_type (str) – The trainer type. Only support “verl” for now.

  • step_num (Optional[int], optional) – The step number. If specified, load the checkpoint with the specified step number. If None, load the latest checkpoint. Defaults to None.

Returns:

The checkpoint directory and the step number of the checkpoint.

Return type:

Tuple[str, int]

trinity.common.models.utils.load_state_dict(checkpoint_dir: str, trainer_type: str = 'verl') dict[source]

Load state dict from a checkpoint dir.

Parameters:
  • checkpoint_dir (str) – The checkpoint directory.

  • trainer_type (str) – The trainer type. Only support “verl” for now.

trinity.common.models.utils.merge_by_placement(tensors: List[Tensor], placement: Placement)[source]
trinity.common.models.utils.get_verl_checkpoint_info(checkpoint_path: str, step_num: int | None = None) Tuple[str, int][source]

Get the checkpoint directory from a Verl root checkpoint directory.

Parameters:
  • checkpoint_path (str) – The root checkpoint directory.

  • step_num (Optional[int], optional) – The step number. If specified, load the checkpoint with the specified step number. If None, load the latest checkpoint. Defaults to None.

Returns:

The checkpoint directory and the step number of the checkpoint.

Return type:

Tuple[str, int]

trinity.common.models.utils.load_state_dict_from_verl_checkpoint(checkpoint_path: str) dict[source]

Load state dict from a Verl checkpoint.