"""Filed based buffer reader."""
import copy
from typing import List, Optional
import datasets
import transformers
from datasets import Dataset, load_dataset
from trinity.algorithm.algorithm import DPOAlgorithm, SFTAlgorithm
from trinity.buffer.buffer_reader import BufferReader
from trinity.common.config import BufferConfig, StorageConfig
from trinity.common.constants import PromptType, ReadStrategy, TaskType
from trinity.common.experience import Experience
from trinity.common.rewards import REWARD_FUNCTIONS
from trinity.common.workflows import WORKFLOWS, Task
from trinity.utils.registry import Registry
FILE_READERS = Registry("file_readers")
[docs]
class DummyProgressBar:
[docs]
def __init__(self):
pass
[docs]
def update(self, num: int):
pass
class _HFBatchReader:
def __init__(
self,
dataset: Dataset,
name: str,
default_batch_size: int,
total_epochs: int = 1,
offset: int = 0,
drop_last: bool = True,
total_steps: Optional[int] = None,
enable_progress_bar: Optional[bool] = True,
):
self.dataset = dataset
self.dataset_size = len(dataset)
self.name = name
self.current_batch_size = None
self.drop_last = drop_last
self.current_offset = offset
self.iter = iter(self.dataset)
for _ in range(self.current_offset % self.dataset_size):
next(self.iter)
# convert epochs/steps to sample number
if total_steps:
self.total_samples = default_batch_size * total_steps
else:
self.total_samples = self.dataset_size * total_epochs
if enable_progress_bar:
from ray.experimental.tqdm_ray import tqdm
self.progress_bar = tqdm(
total=self.total_samples,
desc=f"Dataset [{self.name}] Progressing",
)
else:
self.progress_bar = DummyProgressBar()
self.progress_bar.update(self.current_offset)
def read_batch(self, batch_size: int) -> List:
if self.current_offset >= self.total_samples:
self.progress_bar.close()
raise StopIteration
batch = []
while len(batch) < batch_size:
try:
item = next(self.iter)
batch.append(item)
self.current_offset += 1
except StopIteration:
if self.current_offset >= self.total_samples:
# No more data to read
if not self.drop_last and len(batch) > 0:
# return last batch
self.progress_bar.update(len(batch))
return batch
else:
self.progress_bar.close()
raise StopIteration
# Step to the next epoch
self.iter = iter(self.dataset)
self.progress_bar.update(batch_size)
return batch
[docs]
class BaseFileReader(BufferReader):
[docs]
async def read_async(
self, batch_size: Optional[int] = None, strategy: Optional[ReadStrategy] = None
):
try:
return self.read(batch_size, strategy)
except StopIteration as e:
raise StopAsyncIteration from e
[docs]
@FILE_READERS.register_module(SFTAlgorithm.name())
class SFTDataReader(BaseFileReader):
"""Reader for SFT file data."""
[docs]
def __init__(self, meta: StorageConfig, config: BufferConfig):
self.split = meta.split
subset_name = meta.subset_name
self.prompt_type = meta.format.prompt_type
self.messages_key = meta.format.messages_key
self.prompt_key = meta.format.prompt_key
self.response_key = meta.format.response_key
self.read_batch_size = config.train_batch_size
self.dataset = _HFBatchReader(
load_dataset(meta.path, name=subset_name, split=self.split),
name=meta.name,
default_batch_size=self.read_batch_size,
total_epochs=meta.total_epochs,
drop_last=True,
total_steps=meta.total_steps,
enable_progress_bar=meta.enable_progress_bar,
)
self.tokenizer = transformers.AutoTokenizer.from_pretrained(config.tokenizer_path)
[docs]
def read(
self, batch_size: Optional[int] = None, strategy: Optional[ReadStrategy] = None
) -> List:
samples = self.dataset.read_batch(batch_size or self.read_batch_size)
exp_list = []
if self.prompt_type == PromptType.MESSAGES:
for sample in samples:
messages = sample[self.messages_key]
tokens = self.tokenizer.apply_chat_template(
messages, add_generation_prompt=False, return_tensors="pt"
)[0]
prompt_tokens_ids = self.tokenizer.apply_chat_template(
messages[:-1], add_generation_prompt=True, return_tensors="pt"
)[0]
experience = Experience(
tokens=tokens,
prompt_length=len(prompt_tokens_ids),
)
exp_list.append(experience)
elif self.prompt_type == PromptType.CHATPAIR:
for sample in samples:
prompt_messages = sample[self.prompt_key]
response_messages = sample[self.response_key]
if not isinstance(prompt_messages, list):
prompt_messages = [prompt_messages]
if not isinstance(response_messages, list):
response_messages = [response_messages]
full_messages = prompt_messages + response_messages
tokens = self.tokenizer.apply_chat_template(
full_messages, add_generation_prompt=False, return_tensors="pt"
)[0]
prompt_tokens_ids = self.tokenizer.apply_chat_template(
prompt_messages, add_generation_prompt=True, return_tensors="pt"
)[0]
experience = Experience(
tokens=tokens,
prompt_length=len(prompt_tokens_ids),
)
exp_list.append(experience)
elif self.prompt_type == PromptType.PLAINTEXT:
# TODO: support HF format without chat template
for sample in samples:
prompt = sample[self.prompt_key]
response = sample[self.response_key]
tokens = self.tokenizer(prompt + response, return_tensors="pt")["input_ids"][0]
prompt_tokens_ids = self.tokenizer(prompt, return_tensors="pt")["input_ids"][0]
experience = Experience(
tokens=tokens,
prompt_length=len(prompt_tokens_ids),
)
exp_list.append(experience)
else:
raise ValueError(f"Unknown data format: {self.prompt_type}")
return exp_list
[docs]
@FILE_READERS.register_module(DPOAlgorithm.name())
class DPODataReader(BaseFileReader):
[docs]
def __init__(self, meta: StorageConfig, config: BufferConfig):
self.split = meta.split
subset_name = meta.subset_name
self.prompt_type = meta.format.prompt_type
self.prompt_key = meta.format.prompt_key
self.chosen_key = meta.format.chosen_key
self.rejected_key = meta.format.rejected_key
self.read_batch_size = config.train_batch_size
self.dataset = _HFBatchReader(
load_dataset(meta.path, name=subset_name, split=self.split),
name=meta.name,
default_batch_size=self.read_batch_size,
total_epochs=meta.total_epochs,
drop_last=True,
total_steps=meta.total_steps,
enable_progress_bar=meta.enable_progress_bar,
) # TODO: support resume
self.tokenizer = transformers.AutoTokenizer.from_pretrained(config.tokenizer_path)
def _get_assistant_message(self, item) -> dict:
if isinstance(item, List):
item = item[0]
if isinstance(item, str):
return {"role": "assistant", "content": item}
else:
return item
[docs]
def read(
self, batch_size: Optional[int] = None, strategy: Optional[ReadStrategy] = None
) -> List:
batch_data = self.dataset.read_batch(batch_size or self.read_batch_size)
print(f"Read {len(batch_data)} item from dpo dataset.")
exp_list = []
for sample in batch_data:
prompt = sample[self.prompt_key]
chosen = sample[self.chosen_key]
rejected = sample[self.rejected_key]
if self.prompt_type == PromptType.MESSAGES:
prompt_messages = prompt
elif self.prompt_type == PromptType.PLAINTEXT:
prompt_messages = [
{
"role": "user",
"content": prompt,
}
]
else:
raise ValueError(f"Unknown prompt type: {self.prompt_type}")
prompt_tokens = self.tokenizer.apply_chat_template(
prompt_messages, add_generation_prompt=True, return_tensors="pt"
)[0]
prompt_length = len(prompt_tokens)
messages_with_chosen = prompt_messages + [self._get_assistant_message(chosen)]
chosen_tokens = self.tokenizer.apply_chat_template(
messages_with_chosen,
add_generation_prompt=False,
return_tensors="pt",
)[0][prompt_length:]
messages_with_rejected = prompt_messages + [self._get_assistant_message(rejected)]
rejected_tokens = self.tokenizer.apply_chat_template(
messages_with_rejected,
add_generation_prompt=False,
return_tensors="pt",
)[0][prompt_length:]
experience = Experience(
tokens=prompt_tokens,
chosen=chosen_tokens,
rejected=rejected_tokens,
)
exp_list.append(experience)
return exp_list
[docs]
@FILE_READERS.register_module("rollout")
class RolloutDataReader(BaseFileReader):
[docs]
def __init__(self, meta: StorageConfig, config: BufferConfig):
self.meta = meta
self.name = meta.name
self.split = meta.split
subset_name = meta.subset_name
# disable datasets caching to avoid reuse old-version dataset
self.epoch = 0
datasets.disable_caching()
self.read_batch_size = config.batch_size
self.dataset = _HFBatchReader(
load_dataset(meta.path, name=subset_name, split=self.split),
name=meta.name,
default_batch_size=self.read_batch_size,
total_epochs=self.meta.total_epochs if meta.task_type == TaskType.EXPLORE else 1,
offset=self.meta.index,
drop_last=self.meta.task_type == TaskType.EXPLORE,
total_steps=meta.total_steps,
enable_progress_bar=meta.enable_progress_bar,
)
self.prompt_key = meta.format.prompt_key
self.response_key = meta.format.response_key
self.workflow_key = meta.format.workflow_key
self.reward_fn_key = meta.format.reward_fn_key
self.task_type = meta.task_type
self.default_workflow_cls = WORKFLOWS.get(meta.default_workflow_type) # type: ignore
self.default_eval_workflow_cls = None
if getattr(meta, "default_eval_workflow_type", None):
self.default_eval_workflow_cls = WORKFLOWS.get(meta.default_eval_workflow_type)
self.default_reward_fn_cls = REWARD_FUNCTIONS.get(meta.default_reward_fn_type) # type: ignore
[docs]
def read(
self, batch_size: Optional[int] = None, strategy: Optional[ReadStrategy] = None
) -> List:
batch_size = batch_size or self.read_batch_size
tasks = []
samples = self.dataset.read_batch(batch_size)
for sample in samples:
if self.task_type == TaskType.EVAL and self.default_eval_workflow_cls:
workflow_class = self.default_eval_workflow_cls
else:
workflow_class = (
WORKFLOWS.get(sample[self.workflow_key])
if self.workflow_key in sample
else self.default_workflow_cls
)
reward_fn = (
REWARD_FUNCTIONS.get(sample[self.reward_fn_key])
if self.reward_fn_key in sample
else self.default_reward_fn_cls
)
assert (
workflow_class is not None
), "`default_workflow_type` or `workflow_key` is required"
task = Task(
workflow=workflow_class,
repeat_times=self.meta.repeat_times,
format_args=copy.deepcopy(self.meta.format),
rollout_args=copy.deepcopy(self.meta.rollout_args),
workflow_args=copy.deepcopy(self.meta.workflow_args),
reward_fn_args=copy.deepcopy(self.meta.reward_fn_args),
is_eval=self.meta.task_type == TaskType.EVAL,
reward_fn=reward_fn,
raw_task=sample,
)
tasks.append(task)
return tasks
[docs]
@FILE_READERS.register_module("raw")
class RawDataReader(BaseFileReader):
[docs]
def __init__(self, meta: StorageConfig, config: Optional[BufferConfig]):
self.returned = False
self.dataset = load_dataset(meta.path, name=meta.subset_name, split=meta.split)
def __len__(self):
return len(self.dataset)
[docs]
def read(
self, batch_size: Optional[int] = None, strategy: Optional[ReadStrategy] = None
) -> List:
if self.returned:
raise StopIteration
self.returned = True
return self.dataset.to_list()