Source code for trinity.buffer.schema.formatter
import json
from abc import ABC, abstractmethod
from typing import Dict, List, Optional
import transformers
from trinity.common.config import FormatConfig, StorageConfig
from trinity.common.constants import PromptType
from trinity.common.experience import Experience
from trinity.common.models.utils import get_action_mask_method
from trinity.common.rewards import REWARD_FUNCTIONS
from trinity.common.workflows import WORKFLOWS, Task
from trinity.utils.log import get_logger
from trinity.utils.registry import Registry
FORMATTER = Registry("formatter")
[docs]
class ExperienceFormatter(ABC):
[docs]
@abstractmethod
def format(self, sample: Dict) -> Experience:
"""Format a raw sample dict into an experience."""
[docs]
@FORMATTER.register_module("task")
class TaskFormatter:
"""Formatter for task data.
Example Input:
.. code-block:: python
{
"input": "Hello",
"output": "Hi"
}
"""
[docs]
def __init__(self, config: StorageConfig):
self.config = config
self.is_eval = config.is_eval
self.default_workflow_cls = WORKFLOWS.get(config.default_workflow_type) # type: ignore
if self.is_eval and config.default_eval_workflow_type:
self.default_workflow_cls = WORKFLOWS.get(config.default_eval_workflow_type)
self.default_reward_fn_cls = REWARD_FUNCTIONS.get(config.default_reward_fn_type) # type: ignore
self.workflow_key = config.format.workflow_key
self.reward_fn_key = config.format.reward_fn_key
[docs]
def format(self, sample: Dict) -> Task:
"""Format a raw sample dict into a Task."""
workflow_name = sample.get(self.workflow_key, None) if self.workflow_key else None
reward_fn_name = sample.get(self.reward_fn_key, None) if self.reward_fn_key else None
workflow_cls = (
WORKFLOWS.get(workflow_name) if workflow_name else None
) or self.default_workflow_cls
reward_fn_cls = (
REWARD_FUNCTIONS.get(reward_fn_name) if reward_fn_name else None
) or self.default_reward_fn_cls
assert workflow_cls is not None, "`default_workflow_type` or `workflow_key` is required"
return Task(
workflow=workflow_cls,
reward_fn=reward_fn_cls,
format_args=self.config.format,
repeat_times=self.config.repeat_times,
rollout_args=self.config.rollout_args,
workflow_args=self.config.workflow_args,
reward_fn_args=self.config.reward_fn_args,
is_eval=self.is_eval,
raw_task=sample,
)
[docs]
@FORMATTER.register_module("sft")
class SFTFormatter(ExperienceFormatter):
"""Formatter for SFT data, supporting both message list and plaintext formats.
Uses format_config.prompt_type to distinguish between 'messages' and 'plaintext'.
Example input of MESSAGES:
.. code-block:: python
{
"messages": [
{"role": "user", "content": "Hello, how are you?"},
{"role": "assistant", "content": "I'm fine, thank you!"}
]
}
Example input of PLAINTEXT:
.. code-block:: python
{
"system_prompt_key": "system",
"prompt_key": "prompt",
"response_key": "response",
}
"""
[docs]
def __init__(self, tokenizer_path: str, format_config: FormatConfig):
self.logger = get_logger("sft_dataset_formatter", in_ray_actor=True)
self.prompt_type = format_config.prompt_type
self.enable_concatenated_multi_turn = format_config.enable_concatenated_multi_turn
self.tools_key = format_config.tools_key
self.image_key = format_config.image_key
self.video_key = format_config.video_key
if self.image_key is not None or self.video_key is not None:
assert (
self.enable_concatenated_multi_turn is False
), "Concatenated multi-turn not supported for multi-modal data yet."
self.processor = transformers.AutoProcessor.from_pretrained(tokenizer_path)
self.tokenizer = self.processor.tokenizer
else:
self.processor = None
self.tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer_path)
self.chat_template = format_config.chat_template or self.tokenizer.chat_template
# For messages type
if self.prompt_type == PromptType.MESSAGES:
self.messages_key = format_config.messages_key
if format_config.enable_concatenated_multi_turn:
self.action_mask_method = get_action_mask_method(self.chat_template)
# For plaintext type
elif self.prompt_type == PromptType.PLAINTEXT:
self.prompt_key = format_config.prompt_key
self.response_key = format_config.response_key
self.system_prompt_key = format_config.system_prompt_key
self.system_prompt = format_config.system_prompt
self.tools_key = format_config.tools_key
else:
raise ValueError(f"Unsupported prompt_type: {self.prompt_type}")
def _messages_to_experience(
self,
messages: List[Dict],
tools: Optional[List[Dict] | str] = None,
mm_data: Optional[Dict] = None,
) -> Experience:
"""Convert messages and tools into an Experience object.
Args:
messages (List[Dict]): The list of message dictionaries.
tools (Optional[List[Dict]|str], optional): The list of tool dictionaries or a JSON string. Defaults to None.
mm_data (Optional[Dict], optional): Multi-modal data such as images or videos. Defaults to None.
Returns:
Experience: The resulting Experience object.
"""
if tools is not None and isinstance(tools, list):
self.logger.warning(
"[SFT Data Warning] 'tools' is provided as a list of dictionaries. "
"When loading with Huggingface Datasets, schema auto-alignment may set unmatched fields to null, "
"potentially causing undesired behavior. "
"It is recommended to pre-process 'tools' objects with json.dumps before saving/loading, "
"and to restore them with json.loads in this function."
)
if isinstance(tools, str):
try:
tools = json.loads(tools)
except json.JSONDecodeError:
self.logger.error(
"[SFT Data Error] Failed to decode 'tools' JSON. Please check your data format."
)
raise ValueError("Invalid JSON format for tools")
if self.enable_concatenated_multi_turn:
token_ids, action_mask, prompt_length = self.action_mask_method(
tokenizer=self.tokenizer,
messages=messages,
tools=tools,
chat_template=self.chat_template,
)
return Experience(
tokens=token_ids,
action_mask=action_mask[prompt_length:],
prompt_length=prompt_length,
messages=messages,
)
if mm_data:
return self.convert_mm_data_to_experiences(messages=messages, mm_data=mm_data)
token_ids = self.tokenizer.apply_chat_template(
messages,
tools=tools,
add_generation_prompt=False,
return_tensors="pt",
chat_template=self.chat_template,
)[0]
prompt_tokens_ids = self.tokenizer.apply_chat_template(
messages[:-1],
tools=tools,
add_generation_prompt=True,
return_tensors="pt",
chat_template=self.chat_template,
)[0]
return Experience(
tokens=token_ids,
prompt_length=len(prompt_tokens_ids),
messages=messages,
)
[docs]
def load_mm_data(self, sample: Dict) -> Dict:
"""Load multi-modal data such as images or videos.
NOTE: You can override this method for custom data loading.
Args:
sample (Dict): The raw sample dictionary containing multi-modal data.
Returns:
Dict: A dictionary containing multi-modal data. Specifically, it may contain:
- images: A list of `PIL.Image.Image` if `self.image_key` is set
- videos: A list of `numpy.ndarray` if `self.video_key` is set
"""
from verl.utils.dataset.vision_utils import process_image, process_video
mm_data = {}
if self.image_key:
mm_data["images"] = [process_image(img) for img in sample[self.image_key]]
if self.video_key:
mm_data["videos"] = [process_video(vid).numpy() for vid in sample[self.video_key]]
return mm_data
[docs]
def convert_mm_data_to_experiences(
self,
messages: List[Dict],
mm_data: Dict,
) -> Experience:
from trinity.common.models.mm_utils import (
build_multi_modal_inputs,
convert_messages_to_mm_format,
)
messages = convert_messages_to_mm_format(messages)
sequence: str = self.processor.apply_chat_template(
messages,
add_generation_prompt=False,
chat_template=self.chat_template,
)
prompt: str = self.processor.apply_chat_template(
messages[:-1],
add_generation_prompt=True,
chat_template=self.chat_template,
)
sequence_data = build_multi_modal_inputs(
prompt=sequence,
images=mm_data.get("images", None),
videos=mm_data.get("videos", None),
processor=self.processor,
)
prompt_data = build_multi_modal_inputs(
prompt=prompt,
images=mm_data.get("images", None),
videos=mm_data.get("videos", None),
processor=self.processor,
)
return Experience(
tokens=sequence_data["prompt_token_ids"],
prompt_length=len(prompt_data["prompt_token_ids"]),
messages=messages,
multi_modal_inputs=sequence_data["multi_modal_inputs"],
)
[docs]
def format(self, sample: Dict) -> Experience:
if self.prompt_type == PromptType.MESSAGES:
messages = sample[self.messages_key]
# load messages from json string if needed
if isinstance(messages, str):
try:
messages = json.loads(messages)
except json.JSONDecodeError:
self.logger.error(
"[SFT Data Error] Failed to decode 'messages' JSON. please check your data format."
)
raise ValueError("Invalid JSON format for messages")
elif self.prompt_type == PromptType.PLAINTEXT:
messages = []
if self.system_prompt_key is not None:
system_message = {"role": "system", "content": sample[self.system_prompt_key]}
messages.append(system_message)
elif self.system_prompt is not None:
system_message = {"role": "system", "content": self.system_prompt}
messages.append(system_message)
messages.append({"role": "user", "content": sample[self.prompt_key]})
messages.append({"role": "assistant", "content": sample[self.response_key]})
else:
raise ValueError(f"Unsupported prompt_type: {self.prompt_type}")
tools = sample.get(self.tools_key, None)
mm_data = self.load_mm_data(sample) if self.image_key or self.video_key else None
return self._messages_to_experience(messages, tools, mm_data)
[docs]
@FORMATTER.register_module("dpo")
class DPOFormatter(ExperienceFormatter):
"""Formatter for DPO plaintext data.
Example Input for PLAINTEXT:
.. code-block:: python
{
"prompt": "What is your name?",
"chosen": "My name is Assistant.",
"rejected": "I don't have a name."
}
Example Input for MESSAGES:
.. code-block:: python
{
"messages": [
{"role": "user", "content": "What is your name?"},
],
"chosen": [
{"role": "assistant", "content": "My name is Assistant."},
],
"rejected": [
{"role": "assistant", "content": "I don't have a favorite color."}
]
}
"""
[docs]
def __init__(self, tokenizer_path: str, format_config: FormatConfig):
self.tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer_path)
self.prompt_type = format_config.prompt_type
self.chat_template = format_config.chat_template
if self.prompt_type == PromptType.PLAINTEXT:
self.prompt_key = format_config.prompt_key
self.chosen_key = format_config.chosen_key
self.rejected_key = format_config.rejected_key
self.system_prompt_key = format_config.system_prompt_key
self.system_prompt = format_config.system_prompt
elif self.prompt_type == PromptType.MESSAGES:
self.messages_key = format_config.messages_key
self.chosen_key = format_config.chosen_key
self.rejected_key = format_config.rejected_key
else:
raise ValueError(f"Unsupported prompt_type: {self.prompt_type}")
# currently DPO not support tools
def _messages_to_experience(
self, prompt_messages, chosen_messages, rejected_messages
) -> Experience:
prompt_tokens = self.tokenizer.apply_chat_template(
prompt_messages,
add_generation_prompt=True,
return_tensors="pt",
chat_template=self.chat_template,
)[0]
chosen_tokens = self.tokenizer.apply_chat_template(
prompt_messages + chosen_messages,
add_generation_prompt=False,
return_tensors="pt",
chat_template=self.chat_template,
)[0][len(prompt_tokens) :]
rejected_tokens = self.tokenizer.apply_chat_template(
prompt_messages + rejected_messages,
add_generation_prompt=False,
return_tensors="pt",
chat_template=self.chat_template,
)[0][len(prompt_tokens) :]
return Experience(
tokens=prompt_tokens,
prompt_length=len(prompt_tokens),
chosen=chosen_tokens,
rejected=rejected_tokens,
chosen_messages=prompt_messages + chosen_messages,
rejected_messages=prompt_messages + rejected_messages,
)
[docs]
def format(self, sample: Dict) -> Experience:
if self.prompt_type == PromptType.PLAINTEXT:
messages = []
if self.system_prompt_key is not None:
system_message = {"role": "system", "content": sample[self.system_prompt_key]}
messages.append(system_message)
elif self.system_prompt is not None:
system_message = {"role": "system", "content": self.system_prompt}
messages.append(system_message)
messages.append({"role": "user", "content": sample[self.prompt_key]})
chosen = [{"role": "assistant", "content": sample[self.chosen_key]}]
rejected = [{"role": "assistant", "content": sample[self.rejected_key]}]
elif self.prompt_type == PromptType.MESSAGES:
messages = sample[self.messages_key]
chosen = sample[self.chosen_key]
rejected = sample[self.rejected_key]
else:
raise ValueError(f"Unsupported prompt_type: {self.prompt_type}")
return self._messages_to_experience(
prompt_messages=messages,
chosen_messages=chosen,
rejected_messages=rejected,
)