Source code for trinity.buffer.reader.file_reader

"""Filed based buffer reader."""

from typing import List, Optional

import datasets
import transformers
from datasets import load_dataset

from trinity.buffer.buffer_reader import BufferReader
from trinity.common.config import BufferConfig, StorageConfig
from trinity.common.constants import AlgorithmType, PromptType, ReadStrategy, TaskType
from trinity.common.experience import Experience
from trinity.common.rewards import REWARD_FUNCTIONS
from trinity.common.workflows import WORKFLOWS, Task
from trinity.utils.registry import Registry

FILE_READERS = Registry("file_readers")


[docs] @FILE_READERS.register_module(AlgorithmType.SFT.value) class SFTDataReader(BufferReader): """Reader for SFT file data."""
[docs] def __init__(self, meta: StorageConfig, config: BufferConfig): self.split = meta.split subset_name = meta.subset_name self.prompt_type = meta.format.prompt_type self.messages_key = meta.format.messages_key self.prompt_key = meta.format.prompt_key self.response_key = meta.format.response_key self.read_batch_size = config.read_batch_size self.dataset = load_dataset( meta.path, name=subset_name, split=self.split ) # TODO: support resume self.data_iter = self.dataset.iter(self.read_batch_size, drop_last_batch=True) self.tokenizer = transformers.AutoTokenizer.from_pretrained(config.tokenizer_path)
[docs] def read(self, strategy: Optional[ReadStrategy] = None) -> List: try: batch_data = next(self.data_iter) except StopIteration: self.dataset = self.dataset.shuffle() self.data_iter = self.dataset.iter(self.read_batch_size, drop_last_batch=True) batch_data = next(self.data_iter) exp_list = [] if self.prompt_type == PromptType.MESSAGES: for messages in batch_data[self.messages_key]: tokens = self.tokenizer.apply_chat_template( messages, add_generation_prompt=False, return_tensors="pt" )[0] prompt_tokens = self.tokenizer.apply_chat_template( messages[:-1], add_generation_prompt=True, return_tensors="pt" )[0] experience = Experience( tokens=tokens, prompt_length=len(prompt_tokens), ) exp_list.append(experience) elif self.prompt_type == PromptType.CHATPAIR: for prompt_messages, response_messages in zip( batch_data[self.prompt_key], batch_data[self.response_key] ): if not isinstance(prompt_messages, list): prompt_messages = [prompt_messages] if not isinstance(response_messages, list): response_messages = [response_messages] full_messages = prompt_messages + response_messages tokens = self.tokenizer.apply_chat_template( full_messages, add_generation_prompt=False, return_tensors="pt" )[0] prompt_tokens = self.tokenizer.apply_chat_template( prompt_messages, add_generation_prompt=True, return_tensors="pt" )[0] experience = Experience( tokens=tokens, prompt_length=len(prompt_tokens), ) exp_list.append(experience) elif self.prompt_type == PromptType.PLAINTEXT: # TODO: support HF format without chat template for prompt, response in zip(batch_data[self.prompt_key], batch_data[self.response_key]): tokens = self.tokenizer(prompt + response, return_tensors="pt")["input_ids"][0] prompt_tokens = self.tokenizer(prompt, return_tensors="pt")["input_ids"][0] experience = Experience( tokens=tokens, prompt_length=len(prompt_tokens), ) exp_list.append(experience) else: raise ValueError(f"Unknown data format: {self.prompt_type}") return exp_list
[docs] @FILE_READERS.register_module(AlgorithmType.DPO.value) class DPODataReader(BufferReader):
[docs] def __init__(self, meta: StorageConfig, config: BufferConfig): self.split = meta.split subset_name = meta.subset_name self.prompt_type = meta.format.prompt_type self.prompt_key = meta.format.prompt_key self.chosen_key = meta.format.chosen_key self.rejected_key = meta.format.rejected_key self.read_batch_size = config.read_batch_size self.dataset = load_dataset( meta.path, name=subset_name, split=self.split ) # TODO: support resume self.data_iter = self.dataset.iter(self.read_batch_size, drop_last_batch=True) self.tokenizer = transformers.AutoTokenizer.from_pretrained(config.tokenizer_path)
def _get_assistant_message(self, item) -> dict: if isinstance(item, List): item = item[0] if isinstance(item, str): return {"role": "assistant", "content": item} else: return item
[docs] def read(self, strategy: Optional[ReadStrategy] = None) -> List: try: batch_data = next(self.data_iter) except StopIteration: self.dataset = self.dataset.shuffle() self.data_iter = self.dataset.iter(self.read_batch_size, drop_last_batch=True) batch_data = next(self.data_iter) exp_list = [] for prompt, chosen, rejected in zip( batch_data[self.prompt_key], batch_data[self.chosen_key], batch_data[self.rejected_key] ): if self.prompt_type == PromptType.MESSAGES: prompt_messages = prompt elif self.prompt_type == PromptType.PLAINTEXT: prompt_messages = [ { "role": "user", "content": prompt, } ] else: raise ValueError(f"Unknown prompt type: {self.prompt_type}") prompt_tokens = self.tokenizer.apply_chat_template( prompt_messages, add_generation_prompt=True, return_tensors="pt" )[0] prompt_length = len(prompt_tokens) messages_with_chosen = prompt_messages + [self._get_assistant_message(chosen)] chosen_tokens = self.tokenizer.apply_chat_template( messages_with_chosen, add_generation_prompt=False, return_tensors="pt", )[0][prompt_length:] messages_with_rejected = prompt_messages + [self._get_assistant_message(rejected)] rejected_tokens = self.tokenizer.apply_chat_template( messages_with_rejected, add_generation_prompt=False, return_tensors="pt", )[0][prompt_length:] experience = Experience( tokens=prompt_tokens, prompt_length=len(prompt_tokens), chosen=chosen_tokens, rejected=rejected_tokens, ) exp_list.append(experience) return exp_list
[docs] @FILE_READERS.register_module("rollout") class RolloutDataReader(BufferReader):
[docs] def __init__(self, meta: StorageConfig, config: BufferConfig): self.meta = meta self.name = meta.name self.split = meta.split subset_name = meta.subset_name # disable datasets caching to avoid reuse old-version dataset datasets.disable_caching() self.dataset = load_dataset( meta.path, name=subset_name, split=self.split ) # TODO: may from db_url # if task_type != TaskType.EVAL and config.db_url != "": # logger.info(f"Loading dataset from database with url: {config.db_url}") # db_type = config.db_url.split(":")[0] # db_name = config.db_url.split("/")[-1] # dataset = Dataset.from_sql(RftDatasetModel.__tablename__, f"{db_type}:///{db_name}") datasets.enable_caching() self.index = meta.index # TODO: apply shuffle self.prompt_key = meta.format.prompt_key self.response_key = meta.format.response_key self.workflow_key = meta.format.workflow_key self.reward_fn_key = meta.format.reward_fn_key self.task_type = meta.task_type self.default_workflow_cls = WORKFLOWS.get(meta.default_workflow_type) self.default_reward_fn_cls = REWARD_FUNCTIONS.get(meta.default_reward_fn_type) self.total_epochs = meta.total_epochs if self.task_type == TaskType.EXPLORE else 1
def __len__(self): return len(self.dataset)
[docs] def read(self, strategy: Optional[ReadStrategy] = None): if self.index >= len(self.dataset) * self.total_epochs: raise StopIteration sample = self.dataset[self.index % len(self.dataset)] workflow_class = ( WORKFLOWS.get(sample[self.workflow_key]) if self.workflow_key in sample else self.default_workflow_cls ) reward_fn = ( REWARD_FUNCTIONS.get(sample[self.reward_fn_key]) if self.reward_fn_key in sample else self.default_reward_fn_cls ) assert workflow_class is not None, "`default_reward_fn_type` or `workflow_key` is required" task = Task( workflow=workflow_class, format_args=self.meta.format, rollout_args=self.meta.rollout_args, is_eval=self.meta.task_type == TaskType.EVAL, reward_fn=reward_fn, raw_task=sample, ) self.index += 1 if self.task_type == TaskType.EVAL and self.index == len(self.dataset): self.index = 0 return task