Source code for trinity.buffer.schema.formatter

import json
from abc import ABC, abstractmethod
from typing import Dict, List, Optional

from trinity.common.config import FormatConfig, StorageConfig
from trinity.common.constants import PromptType
from trinity.common.experience import Experience
from trinity.common.models.utils import get_action_mask_method
from trinity.common.rewards import REWARD_FUNCTIONS
from trinity.common.workflows import WORKFLOWS, Task
from trinity.utils.log import get_logger
from trinity.utils.registry import Registry

FORMATTER = Registry("formatter")


[docs] class ExperienceFormatter(ABC):
[docs] @abstractmethod def format(self, sample: Dict) -> Experience: """Format a raw sample dict into an experience."""
[docs] @FORMATTER.register_module("task") class TaskFormatter: """Formatter for task data. Example Input: .. code-block:: python { "input": "Hello", "output": "Hi" } """
[docs] def __init__(self, config: StorageConfig): self.config = config self.is_eval = config.is_eval self.default_workflow_cls = WORKFLOWS.get(config.default_workflow_type) # type: ignore if self.is_eval and config.default_eval_workflow_type: self.default_workflow_cls = WORKFLOWS.get(config.default_eval_workflow_type) self.default_reward_fn_cls = REWARD_FUNCTIONS.get(config.default_reward_fn_type) # type: ignore self.workflow_key = config.format.workflow_key self.reward_fn_key = config.format.reward_fn_key
[docs] def format(self, sample: Dict) -> Task: """Format a raw sample dict into a Task.""" workflow_name = sample.get(self.workflow_key, None) if self.workflow_key else None reward_fn_name = sample.get(self.reward_fn_key, None) if self.reward_fn_key else None workflow_cls = ( WORKFLOWS.get(workflow_name) if workflow_name else None ) or self.default_workflow_cls reward_fn_cls = ( REWARD_FUNCTIONS.get(reward_fn_name) if reward_fn_name else None ) or self.default_reward_fn_cls assert workflow_cls is not None, "`default_workflow_type` or `workflow_key` is required" return Task( workflow=workflow_cls, reward_fn=reward_fn_cls, format_args=self.config.format, repeat_times=self.config.repeat_times, rollout_args=self.config.rollout_args, workflow_args=self.config.workflow_args, reward_fn_args=self.config.reward_fn_args, is_eval=self.is_eval, raw_task=sample, )
[docs] @FORMATTER.register_module("sft") class SFTFormatter(ExperienceFormatter): """Formatter for SFT data, supporting both message list and plaintext formats. Uses format_config.prompt_type to distinguish between 'messages' and 'plaintext'. Example input of MESSAGES: .. code-block:: python { "messages": [ {"role": "user", "content": "Hello, how are you?"}, {"role": "assistant", "content": "I'm fine, thank you!"} ] } Example input of PLAINTEXT: .. code-block:: python { "system_prompt_key": "system", "prompt_key": "prompt", "response_key": "response", } """
[docs] def __init__(self, tokenizer, format_config: FormatConfig): self.logger = get_logger("sft_dataset_formatter", in_ray_actor=True) self.tokenizer = tokenizer self.prompt_type = format_config.prompt_type self.enable_concatenated_multi_turn = format_config.enable_concatenated_multi_turn self.chat_template = format_config.chat_template or tokenizer.chat_template # For messages type if self.prompt_type == PromptType.MESSAGES: self.messages_key = format_config.messages_key self.tools_key = format_config.tools_key if format_config.enable_concatenated_multi_turn: self.action_mask_method = get_action_mask_method(self.chat_template) # For plaintext type elif self.prompt_type == PromptType.PLAINTEXT: self.prompt_key = format_config.prompt_key self.response_key = format_config.response_key self.system_prompt_key = format_config.system_prompt_key self.system_prompt = format_config.system_prompt self.tools_key = format_config.tools_key else: raise ValueError(f"Unsupported prompt_type: {self.prompt_type}")
def _messages_to_experience( self, messages: List[Dict] | str, # or could be str from json dumps tools: Optional[List[Dict] | str] = None, # or could also be str from json dumps ) -> Experience: """Convert messages and tools into an Experience object. Args: messages (List[Dict]|str): The list of message dictionaries or a JSON string. tools (Optional[List[Dict]|str], optional): The list of tool dictionaries or a JSON string. Defaults to None. Returns: Experience: The resulting Experience object. """ if isinstance(messages, str): try: messages = json.loads(messages) except json.JSONDecodeError: self.logger.error( "[SFT Data Error] Failed to decode 'messages' JSON. please check your data format." ) raise ValueError("Invalid JSON format for messages") # Warning if tools is accidentally provided as list of dicts (with Huggingface datasets this may cause schema issues) if tools is not None and isinstance(tools, list): self.logger.warning( "[SFT Data Warning] 'tools' is provided as a list of dictionaries. " "When loading with Huggingface Datasets, schema auto-alignment may set unmatched fields to null, " "potentially causing undesired behavior. " "It is recommended to pre-process 'tools' objects with json.dumps before saving/loading, " "and to restore them with json.loads in this function." ) if isinstance(tools, str): try: tools = json.loads(tools) except json.JSONDecodeError: self.logger.error( "[SFT Data Error] Failed to decode 'tools' JSON. Please check your data format." ) raise ValueError("Invalid JSON format for tools") tokens = self.tokenizer.apply_chat_template( messages, tools=tools, add_generation_prompt=False, return_tensors="pt", chat_template=self.chat_template, )[0] if self.enable_concatenated_multi_turn: token_ids, action_mask, prompt_length = self.action_mask_method( tokenizer=self.tokenizer, messages=messages, tools=tools, chat_template=self.chat_template, ) return Experience( tokens=token_ids, action_mask=action_mask[prompt_length:], prompt_length=prompt_length, messages=messages, ) else: prompt_tokens_ids = self.tokenizer.apply_chat_template( messages[:-1], tools=tools, add_generation_prompt=True, return_tensors="pt", chat_template=self.chat_template, )[0] return Experience( tokens=tokens, prompt_length=len(prompt_tokens_ids), messages=messages, )
[docs] def format(self, sample: Dict) -> Experience: if self.prompt_type == PromptType.MESSAGES: messages = sample[self.messages_key] elif self.prompt_type == PromptType.PLAINTEXT: messages = [] if self.system_prompt_key is not None: system_message = {"role": "system", "content": sample[self.system_prompt_key]} messages.append(system_message) elif self.system_prompt is not None: system_message = {"role": "system", "content": self.system_prompt} messages.append(system_message) messages.append({"role": "user", "content": sample[self.prompt_key]}) messages.append({"role": "assistant", "content": sample[self.response_key]}) else: raise ValueError(f"Unsupported prompt_type: {self.prompt_type}") tools = sample.get(self.tools_key, None) return self._messages_to_experience(messages, tools)
[docs] @FORMATTER.register_module("dpo") class DPOFormatter(ExperienceFormatter): """Formatter for DPO plaintext data. Example Input for PLAINTEXT: .. code-block:: python { "prompt": "What is your name?", "chosen": "My name is Assistant.", "rejected": "I don't have a name." } Example Input for MESSAGES: .. code-block:: python { "messages": [ {"role": "user", "content": "What is your name?"}, ], "chosen": [ {"role": "assistant", "content": "My name is Assistant."}, ], "rejected": [ {"role": "assistant", "content": "I don't have a favorite color."} ] } """
[docs] def __init__(self, tokenizer, format_config: FormatConfig): self.tokenizer = tokenizer self.prompt_type = format_config.prompt_type self.chat_template = format_config.chat_template if self.prompt_type == PromptType.PLAINTEXT: self.prompt_key = format_config.prompt_key self.chosen_key = format_config.chosen_key self.rejected_key = format_config.rejected_key self.system_prompt_key = format_config.system_prompt_key self.system_prompt = format_config.system_prompt elif self.prompt_type == PromptType.MESSAGES: self.messages_key = format_config.messages_key self.chosen_key = format_config.chosen_key self.rejected_key = format_config.rejected_key else: raise ValueError(f"Unsupported prompt_type: {self.prompt_type}")
# currently DPO not support tools def _messages_to_experience( self, prompt_messages, chosen_messages, rejected_messages ) -> Experience: prompt_tokens = self.tokenizer.apply_chat_template( prompt_messages, add_generation_prompt=True, return_tensors="pt", chat_template=self.chat_template, )[0] chosen_tokens = self.tokenizer.apply_chat_template( prompt_messages + chosen_messages, add_generation_prompt=False, return_tensors="pt", chat_template=self.chat_template, )[0][len(prompt_tokens) :] rejected_tokens = self.tokenizer.apply_chat_template( prompt_messages + rejected_messages, add_generation_prompt=False, return_tensors="pt", chat_template=self.chat_template, )[0][len(prompt_tokens) :] return Experience( tokens=prompt_tokens, prompt_length=len(prompt_tokens), chosen=chosen_tokens, rejected=rejected_tokens, chosen_messages=prompt_messages + chosen_messages, rejected_messages=prompt_messages + rejected_messages, )
[docs] def format(self, sample: Dict) -> Experience: if self.prompt_type == PromptType.PLAINTEXT: messages = [] if self.system_prompt_key is not None: system_message = {"role": "system", "content": sample[self.system_prompt_key]} messages.append(system_message) elif self.system_prompt is not None: system_message = {"role": "system", "content": self.system_prompt} messages.append(system_message) messages.append({"role": "user", "content": sample[self.prompt_key]}) chosen = [{"role": "assistant", "content": sample[self.chosen_key]}] rejected = [{"role": "assistant", "content": sample[self.rejected_key]}] elif self.prompt_type == PromptType.MESSAGES: messages = sample[self.messages_key] chosen = sample[self.chosen_key] rejected = sample[self.rejected_key] else: raise ValueError(f"Unsupported prompt_type: {self.prompt_type}") return self._messages_to_experience( prompt_messages=messages, chosen_messages=chosen, rejected_messages=rejected, )