# -*- coding: utf-8 -*-
"""The explorer module"""
from __future__ import annotations
import asyncio
import os
import time
from collections import defaultdict
from typing import List, Optional
import torch
from trinity.algorithm.algorithm_manager import AlgorithmManager
from trinity.buffer import get_buffer_writer
from trinity.buffer.buffer import get_buffer_reader
from trinity.common.config import Config
from trinity.common.constants import (
ROLLOUT_WEIGHT_SYNC_GROUP_NAME,
RunningStatus,
SyncMethod,
)
from trinity.common.models import create_inference_models
from trinity.common.models.utils import (
get_checkpoint_dir_with_step_num,
load_state_dict,
)
from trinity.explorer.runner_pool import RunnerPool
from trinity.manager.manager import CacheManager
from trinity.utils.log import get_logger
from trinity.utils.monitor import MONITOR
[docs]
class Explorer:
"""Responsible for exploring the taskset."""
[docs]
def __init__(self, config: Config):
self.logger = get_logger(__name__)
self.cache = CacheManager(config)
explorer_meta = self.cache.load_explorer()
self.explore_step_num = explorer_meta.get("latest_iteration", 0)
self.config = config
self.algorithm_manager = AlgorithmManager(config)
self.models, self.auxiliary_models = create_inference_models(config)
if self.config.mode != "bench":
self.experience_buffer = get_buffer_writer(
self.config.buffer.explorer_output, # type: ignore
self.config.buffer,
)
self.experience_buffer.acquire()
self.config.buffer.explorer_input.taskset.index = explorer_meta.get("latest_task_index", 0)
self.taskset = get_buffer_reader(
self.config.buffer.explorer_input.taskset, self.config.buffer
)
self.runner_pool = self._init_runner_pool()
self.monitor = MONITOR.get(self.config.monitor.monitor_type)(
project=self.config.project,
name=self.config.name,
role=self.config.explorer.name,
config=config,
)
self.batch_size = config.buffer.batch_size
self.update_interval = (
self.config.synchronizer.sync_interval * self.config.buffer.batch_size
)
self.use_checkpoint_weights_update = (
self.config.synchronizer.sync_method == SyncMethod.CHECKPOINT
)
self.eval_explore_step_num = None
# For checkpoint weights update
# Use explorer to periodically load the latest model weights and
# boradcast to all rollout models
if self.use_checkpoint_weights_update:
self.old_checkpoint = None
self.state_dict = {}
else: # nccl mode
self.state_dict_meta = []
self.status = RunningStatus.RUNNING
self.logger.info("Finished initializing Explorer.")
[docs]
async def setup_weight_sync_group(
self, master_address: str, master_port: int, state_dict_meta: List = None
):
# In checkpoint mode, we use explorer to store the model weights which has no rank
base_offset = 0 if self.use_checkpoint_weights_update else 1
world_size = (
len(self.models) * self.config.explorer.rollout_model.tensor_parallel_size + base_offset
)
self.logger.info(
f"Initialize process group for weight synchronization, "
f"master_address={master_address}, master_port={master_port}, "
f"world_size={world_size}, rank_offset={base_offset}"
)
self.state_dict_meta = state_dict_meta
# TODO: save state_dict in models
refs = [
model.init_process_group.remote(
master_address=master_address,
master_port=master_port,
rank_offset=i * self.config.explorer.rollout_model.tensor_parallel_size
+ base_offset,
world_size=world_size,
group_name=ROLLOUT_WEIGHT_SYNC_GROUP_NAME,
explorer_name=self.config.explorer.name,
timeout=self.config.synchronizer.sync_timeout,
update_with_checkpoint=self.use_checkpoint_weights_update,
state_dict_meta=state_dict_meta,
)
for i, model in enumerate(self.models)
]
await asyncio.gather(*refs)
def _init_runner_pool(self) -> RunnerPool:
if self.config.explorer.rollout_model.engine_type != "vllm_async":
# sync model requires the same number of runners as the number of models
self.config.explorer.runner_num = self.config.explorer.rollout_model.engine_num
self.logger.info(
"Sync vLLM model requires the same number of runners as the number of models"
)
if self.config.explorer.runner_num < self.config.explorer.rollout_model.engine_num:
self.config.explorer.runner_num = self.config.explorer.rollout_model.engine_num
self.logger.info(
f"Number of Runners is less than number of models, set to {self.config.explorer.runner_num}"
)
self.logger.info(f"Setup {self.config.explorer.runner_num} WorkflowRunners")
return RunnerPool(self.config, self.models, self.auxiliary_models)
async def _update_model_weight(self, step_num: int, state_dict: dict) -> None:
# TODO: update model weight
self.state_dict = state_dict
if self.state_dict_meta is None:
update_weight_args_list = []
for name, param in state_dict.items():
update_weight_args_list.append((name, str(param.dtype), tuple(param.shape)))
self.state_dict_meta = update_weight_args_list
else:
update_weight_args_list = None
await asyncio.gather(
*[model.sync_model.remote(step_num, update_weight_args_list) for model in self.models]
)
self.state_dict.clear()
async def _checkpoint_weights_update(self, step_num: Optional[int] = None) -> None:
# TODO: support more checkpoint types
try:
checkpoint_dir, checkpoint_step_num = get_checkpoint_dir_with_step_num(
checkpoint_root_path=self.config.checkpoint_job_dir,
trainer_type=self.config.trainer.trainer_type,
step_num=step_num,
)
if checkpoint_dir == self.old_checkpoint:
return
model_weights = load_state_dict(os.path.join(checkpoint_dir, "actor"))
await self._update_model_weight(checkpoint_step_num, model_weights)
self.old_checkpoint = checkpoint_dir
except Exception as e:
self.logger.warning(f"Fail to load checkpoint: {e}")
async def _nccl_weights_update(self):
assert self.state_dict_meta is not None
await asyncio.gather(
*[model.sync_model.remote(self.explore_step_num) for model in self.models]
)
[docs]
async def prepare(self) -> None:
"""Preparation before running."""
if self.use_checkpoint_weights_update:
master_address, master_port = await self.models[0].get_available_address.remote()
await self.setup_weight_sync_group(master_address, master_port)
[docs]
async def get_weight(self, name: str) -> torch.Tensor:
"""Get the weight of the loaded model (For checkpoint weights update)."""
return self.state_dict[name]
[docs]
async def explore(self) -> str:
"""
The dreamming loop for explorer and trainer.
| <----------------------------------------- one period ----------------------------------------------> |
explorer | <-- step_1 --> | <-- step_2 --> | ... | <-- step_n --> | <-- eval --> | <-- [idle] --> | <-- sync --> |
trainer | <-- idle --> | <-- step_1 --> | <-- step_2 --> | ... | <-- step_n --> | <-- [idle] --> | <-- sync --> |
"""
self.eval_explore_step_num = None
while True:
try:
if (
self.eval_explore_step_num is None
and self.explore_step_num % self.config.explorer.eval_interval == 0
):
self.eval_explore_step_num = self.explore_step_num
explore_contionue = self.explore_step()
if not explore_contionue:
break
if self.need_sync():
self.wait_for_workflow_done()
await self.sync_weight()
except Exception as e:
self.logger.error(f"Error in Explorer: {e}")
break
self.logger.info("--------------------\n> Explorer finished.\n--------------------")
return self.config.explorer.name
[docs]
def explore_step(self) -> bool:
algo_config = self.algorithm_manager.get_current_algorithm_config(self.explore_step_num + 1)
# skip warmup
if algo_config.algorithm_type == "sft":
self.explore_step_num += 1
return True
try:
tasks = self.taskset.read()
except StopIteration:
self.logger.warning("No more tasks to explore. Stop exploring.")
self.cache.save_explorer(
current_step=self.explore_step_num,
current_task_index=self.explore_step_num * self.config.buffer.batch_size,
)
self.status = RunningStatus.STOPPED
self.wait_for_workflow_done()
self.experience_buffer.release()
return False
self.runner_pool.run_tasks(tasks)
self.explore_step_num += 1
return True
[docs]
def need_sync(self) -> bool:
if self.explore_step_num <= self.config.synchronizer.sync_offset:
return False
return (
self.explore_step_num - self.config.synchronizer.sync_offset
) % self.config.synchronizer.sync_interval == 0
[docs]
def eval(self, eval_explore_step_num: int):
"""Evaluation on all evaluation data samples."""
if len(self.config.buffer.explorer_input.eval_tasksets) == 0:
self.logger.warning("No evaluation data samples. Skip evaluation.")
return
self.logger.info(f"Evaluation at step {eval_explore_step_num} started.")
all_st = time.time()
log_metrics = {}
for eval_taskset_config in self.config.buffer.explorer_input.eval_tasksets:
self.logger.info(
f"Evaluation on {eval_taskset_config.name} at step {eval_explore_step_num} started."
)
eval_taskset = get_buffer_reader(eval_taskset_config, self.config.buffer)
st = time.time()
all_metrics = defaultdict(list)
def wait():
status_list = self.runner_pool.get_next_unorder()
if not isinstance(status_list, list):
status_list = [status_list]
for status in status_list:
if not status.ok:
self.logger.error(f"Error when running task: {status.message}")
else:
for metric_name, metric_value in status.metric.items():
all_metrics[metric_name].append(metric_value)
while True:
if not self.runner_pool.has_free():
wait()
try:
self.runner_pool.run_tasks(eval_taskset.read())
except StopIteration:
break
while self.runner_pool.has_next():
wait()
metrics = self.monitor.calculate_metrics(all_metrics, prefix=f"eval/{eval_taskset.name}") # type: ignore
log_metrics.update(metrics)
log_metrics[f"eval/{eval_taskset.name}/time"] = time.time() - st
log_metrics["eval/total_time"] = time.time() - all_st
self.monitor.log(log_metrics, step=eval_explore_step_num) # type: ignore
self.logger.info(f"Evaluation at step {eval_explore_step_num} finished.")
[docs]
async def benchmark(self) -> bool:
"""Benchmark the model checkpoints."""
# benchmark on the latest checkpoint
if self.config.explorer.eval_on_latest_checkpoint:
await self._checkpoint_weights_update()
self.eval(self.explore_step_num)
return True
# benchmark on base model
self.eval(0)
# benchmark on all checkoints
all_ckp_steps = sorted(
[
int(ckp.split("global_step_")[-1])
for ckp in os.listdir(self.config.checkpoint_job_dir)
if os.path.isdir(os.path.join(self.config.checkpoint_job_dir, ckp))
and ckp.startswith("global_step_")
]
)
for step_num in all_ckp_steps:
await self._checkpoint_weights_update(step_num=step_num)
self.eval(step_num)
return True
[docs]
def wait_for_workflow_done(self) -> None:
"""Wait for workflow to finish."""
all_metrics = defaultdict(list)
# wait for all tasks of this step to finish
while self.runner_pool.has_next():
status_list = self.runner_pool.get_next_unorder()
if not isinstance(status_list, list):
status_list = [status_list]
for status in status_list:
if not status.ok:
self.logger.error(f"Error when running task: {status.message}")
# submit another task to replace the failed task
try:
tasks = self.taskset.read(batch_size=1)
except StopIteration:
self.logger.warning("No more tasks in taskset. Stop retrying.")
return
self.runner_pool.run_tasks(tasks)
else:
for metric_name, metric_value in status.metric.items():
all_metrics[metric_name].append(metric_value)
# eval
if self.eval_explore_step_num is not None:
self.eval(self.eval_explore_step_num)
self.eval_explore_step_num = None
# calculate metrics
log_metrics = self.monitor.calculate_metrics(all_metrics, prefix="rollout") # type: ignore
self.monitor.log(log_metrics, step=self.explore_step_num)
self.logger.info(f"Explore step {self.explore_step_num} finished.")
[docs]
async def sync_weight(self) -> None:
"""Synchronize model weights."""
# call this method before training start to load the latest model weights
self.logger.info(f"Explorer sync weights at step {self.explore_step_num}.")
self.status = RunningStatus.WAITING_SYNC
if self.use_checkpoint_weights_update:
await self._checkpoint_weights_update()
else: # nccl weights update
await self._nccl_weights_update()
# save explore checkpoint
self.cache.save_explorer(
current_step=self.explore_step_num,
current_task_index=self.explore_step_num * self.config.buffer.batch_size,
)
self.status = RunningStatus.RUNNING
self.logger.info(f"Explorer sync at step {self.explore_step_num} finished")
[docs]
async def running_status(self) -> RunningStatus:
return self.status
[docs]
def flush_log(self, step: int) -> None:
"""Flush the log of the current step."""
self.monitor.log({}, step=step, commit=True)
[docs]
def shutdown(self) -> None:
self.monitor.close()