Source code for trinity.buffer.reader.file_reader

"""Filed based buffer reader."""

from typing import List, Optional

import datasets
import transformers
from datasets import Dataset, load_dataset
from ray.experimental.tqdm_ray import tqdm

from trinity.algorithm.algorithm import DPOAlgorithm, SFTAlgorithm
from trinity.buffer.buffer_reader import BufferReader
from trinity.common.config import BufferConfig, StorageConfig
from trinity.common.constants import PromptType, ReadStrategy, TaskType
from trinity.common.experience import Experience
from trinity.common.rewards import REWARD_FUNCTIONS
from trinity.common.workflows import WORKFLOWS, Task
from trinity.utils.registry import Registry

FILE_READERS = Registry("file_readers")


class _HFBatchReader:
    def __init__(
        self,
        dataset: Dataset,
        name: str,
        max_epoch: int = 1,
        offset: int = 0,
    ):
        self.dataset = dataset
        self.dataset_size = len(dataset)
        self.name = name
        self.current_batch_size = None
        self.max_epoch = max_epoch
        if offset >= self.dataset_size:
            self.current_epoch = offset // self.dataset_size
            self.current_offset = offset % self.dataset_size
        else:
            self.current_epoch = 0
            self.current_offset = offset
        self.iter = iter(self.dataset)

        for _ in range(self.current_offset):
            next(self.iter)

        # Initialize tqdm progress bar
        self.total_steps = self.dataset_size * self.max_epoch
        self.progress_bar = tqdm(
            total=self.total_steps,
            desc=f"Dataset [{self.name}] Progressing",
        )
        initial = self.current_epoch * self.dataset_size + self.current_offset
        self.progress_bar.update(initial)

    def read_batch(self, batch_size: int) -> List:
        if self.current_epoch >= self.max_epoch:
            self.progress_bar.close()
            raise StopIteration
        batch = []

        while len(batch) < batch_size:
            try:
                self.progress_bar.update(1)
                item = next(self.iter)
                batch.append(item)
                self.current_offset += 1

            except StopIteration:
                self.current_epoch += 1
                self.current_offset = 0

                if self.current_epoch >= self.max_epoch:
                    if len(batch) > 0:
                        return batch
                    else:
                        self.progress_bar.close()
                        raise StopIteration
                # Step to the next epoch
                self.iter = iter(self.dataset)
        return batch


[docs] @FILE_READERS.register_module(SFTAlgorithm.name()) class SFTDataReader(BufferReader): """Reader for SFT file data."""
[docs] def __init__(self, meta: StorageConfig, config: BufferConfig): self.split = meta.split subset_name = meta.subset_name self.prompt_type = meta.format.prompt_type self.messages_key = meta.format.messages_key self.prompt_key = meta.format.prompt_key self.response_key = meta.format.response_key self.read_batch_size = config.read_batch_size self.dataset = _HFBatchReader( load_dataset(meta.path, name=subset_name, split=self.split, trust_remote_code=True), name=meta.name, max_epoch=meta.total_epochs, ) # TODO: support resume self.tokenizer = transformers.AutoTokenizer.from_pretrained(config.tokenizer_path)
[docs] def read( self, batch_size: Optional[int] = None, strategy: Optional[ReadStrategy] = None ) -> List: samples = self.dataset.read_batch(batch_size or self.read_batch_size) exp_list = [] if self.prompt_type == PromptType.MESSAGES: for sample in samples: messages = sample[self.messages_key] tokens = self.tokenizer.apply_chat_template( messages, add_generation_prompt=False, return_tensors="pt" )[0] prompt_tokens = self.tokenizer.apply_chat_template( messages[:-1], add_generation_prompt=True, return_tensors="pt" )[0] experience = Experience( tokens=tokens, prompt_length=len(prompt_tokens), ) exp_list.append(experience) elif self.prompt_type == PromptType.CHATPAIR: for sample in samples: prompt_messages = sample[self.prompt_key] response_messages = sample[self.response_key] if not isinstance(prompt_messages, list): prompt_messages = [prompt_messages] if not isinstance(response_messages, list): response_messages = [response_messages] full_messages = prompt_messages + response_messages tokens = self.tokenizer.apply_chat_template( full_messages, add_generation_prompt=False, return_tensors="pt" )[0] prompt_tokens = self.tokenizer.apply_chat_template( prompt_messages, add_generation_prompt=True, return_tensors="pt" )[0] experience = Experience( tokens=tokens, prompt_length=len(prompt_tokens), ) exp_list.append(experience) elif self.prompt_type == PromptType.PLAINTEXT: # TODO: support HF format without chat template for sample in samples: prompt = sample[self.prompt_key] response = sample[self.response_key] tokens = self.tokenizer(prompt + response, return_tensors="pt")["input_ids"][0] prompt_tokens = self.tokenizer(prompt, return_tensors="pt")["input_ids"][0] experience = Experience( tokens=tokens, prompt_length=len(prompt_tokens), ) exp_list.append(experience) else: raise ValueError(f"Unknown data format: {self.prompt_type}") return exp_list
[docs] @FILE_READERS.register_module(DPOAlgorithm.name()) class DPODataReader(BufferReader):
[docs] def __init__(self, meta: StorageConfig, config: BufferConfig): self.split = meta.split subset_name = meta.subset_name self.prompt_type = meta.format.prompt_type self.prompt_key = meta.format.prompt_key self.chosen_key = meta.format.chosen_key self.rejected_key = meta.format.rejected_key self.read_batch_size = config.read_batch_size self.dataset = _HFBatchReader( load_dataset(meta.path, name=subset_name, split=self.split, trust_remote_code=True), name=meta.name, max_epoch=meta.total_epochs, ) # TODO: support resume self.tokenizer = transformers.AutoTokenizer.from_pretrained(config.tokenizer_path)
def _get_assistant_message(self, item) -> dict: if isinstance(item, List): item = item[0] if isinstance(item, str): return {"role": "assistant", "content": item} else: return item
[docs] def read( self, batch_size: Optional[int] = None, strategy: Optional[ReadStrategy] = None ) -> List: batch_data = self.dataset.read_batch(batch_size or self.read_batch_size) exp_list = [] for sample in batch_data: prompt = sample[self.prompt_key] chosen = sample[self.chosen_key] rejected = sample[self.rejected_key] if self.prompt_type == PromptType.MESSAGES: prompt_messages = prompt elif self.prompt_type == PromptType.PLAINTEXT: prompt_messages = [ { "role": "user", "content": prompt, } ] else: raise ValueError(f"Unknown prompt type: {self.prompt_type}") prompt_tokens = self.tokenizer.apply_chat_template( prompt_messages, add_generation_prompt=True, return_tensors="pt" )[0] prompt_length = len(prompt_tokens) messages_with_chosen = prompt_messages + [self._get_assistant_message(chosen)] chosen_tokens = self.tokenizer.apply_chat_template( messages_with_chosen, add_generation_prompt=False, return_tensors="pt", )[0][prompt_length:] messages_with_rejected = prompt_messages + [self._get_assistant_message(rejected)] rejected_tokens = self.tokenizer.apply_chat_template( messages_with_rejected, add_generation_prompt=False, return_tensors="pt", )[0][prompt_length:] experience = Experience( tokens=prompt_tokens, prompt_length=len(prompt_tokens), chosen=chosen_tokens, rejected=rejected_tokens, ) exp_list.append(experience) return exp_list
[docs] @FILE_READERS.register_module("rollout") class RolloutDataReader(BufferReader):
[docs] def __init__(self, meta: StorageConfig, config: BufferConfig): self.meta = meta self.name = meta.name self.split = meta.split subset_name = meta.subset_name # disable datasets caching to avoid reuse old-version dataset self.epoch = 0 datasets.disable_caching() self.dataset = _HFBatchReader( load_dataset(meta.path, name=subset_name, split=self.split, trust_remote_code=True), name=meta.name, max_epoch=self.meta.total_epochs if meta.task_type == TaskType.EXPLORE else 1, offset=self.meta.index, ) self.read_batch_size = config.batch_size self.prompt_key = meta.format.prompt_key self.response_key = meta.format.response_key self.workflow_key = meta.format.workflow_key self.reward_fn_key = meta.format.reward_fn_key self.task_type = meta.task_type self.default_workflow_cls = WORKFLOWS.get(meta.default_workflow_type) # type: ignore self.default_reward_fn_cls = REWARD_FUNCTIONS.get(meta.default_reward_fn_type) # type: ignore
[docs] def read( self, batch_size: Optional[int] = None, strategy: Optional[ReadStrategy] = None ) -> List: batch_size = batch_size or self.read_batch_size tasks = [] samples = self.dataset.read_batch(batch_size) for sample in samples: workflow_class = ( WORKFLOWS.get(sample[self.workflow_key]) if self.workflow_key in sample else self.default_workflow_cls ) reward_fn = ( REWARD_FUNCTIONS.get(sample[self.reward_fn_key]) if self.reward_fn_key in sample else self.default_reward_fn_cls ) assert ( workflow_class is not None ), "`default_workflow_type` or `workflow_key` is required" task = Task( workflow=workflow_class, format_args=self.meta.format, rollout_args=self.meta.rollout_args, workflow_args=self.meta.workflow_args, is_eval=self.meta.task_type == TaskType.EVAL, reward_fn=reward_fn, raw_task=sample, ) tasks.append(task) return tasks
[docs] def reset(self): self.dataset.reset()
[docs] @FILE_READERS.register_module("raw") class RawDataReader(BufferReader):
[docs] def __init__(self, meta: StorageConfig, config: Optional[BufferConfig]): self.returned = False self.dataset = load_dataset( meta.path, name=meta.subset_name, split=meta.split, trust_remote_code=True )
def __len__(self): return len(self.dataset)
[docs] def read( self, batch_size: Optional[int] = None, strategy: Optional[ReadStrategy] = None ) -> List: if self.returned: raise StopIteration self.returned = True return self.dataset.to_list()