# -*- coding: utf-8 -*-
"""The explorer module"""
from __future__ import annotations
import asyncio
import os
import time
import traceback
from collections import deque
from typing import List, Optional
import torch
from trinity.algorithm import ADD_STRATEGY
from trinity.algorithm.algorithm_manager import AlgorithmManager
from trinity.buffer import get_buffer_writer
from trinity.buffer.buffer import get_buffer_reader
from trinity.common.config import Config
from trinity.common.constants import (
ROLLOUT_WEIGHT_SYNC_GROUP_NAME,
RunningStatus,
SyncMethod,
SyncStyle,
)
from trinity.common.models import create_inference_models
from trinity.explorer.scheduler import Scheduler
from trinity.manager.manager import CacheManager
from trinity.manager.synchronizer import Synchronizer
from trinity.utils.log import get_logger
from trinity.utils.monitor import MONITOR, gather_metrics
[docs]
class Explorer:
"""Responsible for exploring the taskset."""
[docs]
def __init__(self, config: Config):
self.logger = get_logger(__name__)
self.cache = CacheManager(config)
explorer_meta = self.cache.load_explorer()
self.explore_step_num = explorer_meta.get("latest_iteration", 0)
self.last_sync_step = self.explore_step_num if self.explore_step_num > 0 else -1
self.synchronizer = Synchronizer.get_actor(config)
self.config = config
self.algorithm_manager = AlgorithmManager(config)
self.models, self.auxiliary_models = create_inference_models(config)
self.experience_buffer = None
if self.config.mode != "bench":
self.experience_buffer = get_buffer_writer(
self.config.buffer.explorer_output, # type: ignore
self.config.buffer,
)
self.config.buffer.explorer_input.taskset.index = explorer_meta.get("latest_task_index", 0)
self.taskset = get_buffer_reader(
self.config.buffer.explorer_input.taskset, self.config.buffer
)
self.scheduler = self._init_scheduler()
self.monitor = MONITOR.get(self.config.monitor.monitor_type)(
project=self.config.project,
group=self.config.group,
name=self.config.name,
role=self.config.explorer.name,
config=config,
)
self.batch_size = config.buffer.batch_size
self.update_interval = (
self.config.synchronizer.sync_interval * self.config.buffer.batch_size
)
self.use_nccl_sync = self.config.synchronizer.sync_method == SyncMethod.NCCL
self.pending_eval_tasks = deque()
# For checkpoint weights update
# Use explorer to periodically load the latest model weights and
# boradcast to all rollout models
self.model_version = -1
self.last_sync_successful = True
self.logger.info("Finished initializing Explorer.")
self.collect_experiences = self.config.explorer.collect_experiences
self.generated_experience_cnt = 0
if self.collect_experiences:
assert (
self.experience_buffer is not None
), "Experience buffer is required when collect_experiences is True."
self.add_strategy = ADD_STRATEGY.get(self.config.algorithm.add_strategy)(
self.experience_buffer, **self.config.algorithm.add_strategy_args
)
[docs]
async def setup_weight_sync_group(
self, master_address: str, master_port: int, state_dict_meta: List = None
):
# In checkpoint mode, we use explorer to store the model weights which has no rank
base_offset = 1 if self.use_nccl_sync else 0
world_size = (
len(self.models) * self.config.explorer.rollout_model.tensor_parallel_size + base_offset
)
self.logger.info(
f"Initialize process group for weight synchronization, "
f"master_address={master_address}, master_port={master_port}, "
f"world_size={world_size}, rank_offset={base_offset}"
)
# TODO: save state_dict in models
refs = [
model.init_process_group.remote(
master_address=master_address,
master_port=master_port,
rank_offset=i * self.config.explorer.rollout_model.tensor_parallel_size
+ base_offset,
world_size=world_size,
group_name=ROLLOUT_WEIGHT_SYNC_GROUP_NAME,
explorer_name=self.config.explorer.name,
timeout=self.config.synchronizer.sync_timeout,
state_dict_meta=state_dict_meta,
)
for i, model in enumerate(self.models)
]
await asyncio.gather(*refs)
def _init_scheduler(self) -> Scheduler:
if self.config.explorer.rollout_model.engine_type != "vllm_async":
# sync model requires the same number of runners as the number of models
self.config.explorer.runner_per_model = 1
self.logger.info(
"Sync vLLM model requires the same number of runners as the number of models"
)
return Scheduler(self.config, self.models, self.auxiliary_models)
async def _checkpoint_weights_update(self, step_num: Optional[int] = None) -> int:
step_num = await self.synchronizer.set_model_state_dict_with_step_num.remote(step_num)
await asyncio.gather(*[model.sync_model.remote(step_num) for model in self.models])
return step_num # type: ignore
async def _pull_latest_weights(self):
self.logger.info("Start to pull latest model weights.")
new_version = await self.synchronizer.wait_new_model_state_dict.remote(self.model_version)
if new_version > self.model_version:
if self.model_version != -1:
self.logger.info(f"New model weights version: {new_version}")
await asyncio.gather(
*[model.sync_model.remote(new_version) for model in self.models]
)
self.model_version = new_version
self.last_sync_step = self.explore_step_num
self.last_sync_successful = True
else:
self.logger.warning(
f"No new model weights found, current version: {self.model_version}"
)
self.last_sync_successful = False
async def _nccl_weights_update(self):
new_version = await self.synchronizer.ready_to_nccl_sync.remote(
"explorer", self.model_version
)
if new_version is None:
self.logger.info("Trainer is not ready to sync weight. Skipping sync weight.")
self.last_sync_successful = False
return
self.model_version = new_version
await asyncio.gather(
*[model.sync_model.remote(self.model_version) for model in self.models]
)
self.last_sync_step = self.explore_step_num
self.last_sync_successful = True
[docs]
async def prepare(self) -> None:
"""Preparation before running."""
futures = [
asyncio.create_task(self.scheduler.start()),
]
if self.experience_buffer:
futures.append(asyncio.create_task(self.experience_buffer.acquire())) # type: ignore
if not self.use_nccl_sync:
master_address, master_port = await self.models[0].get_available_address.remote()
futures.append(
asyncio.create_task(self.setup_weight_sync_group(master_address, master_port))
)
await asyncio.gather(*futures, return_exceptions=True)
if self.config.explorer.eval_on_startup and self.explore_step_num == 0:
await self.eval()
await self.synchronizer.set_explorer_status.remote(RunningStatus.REQUIRE_SYNC)
[docs]
async def get_weight(self, name: str) -> torch.Tensor:
"""Get the weight of the loaded model (For checkpoint weights update)."""
return self.state_dict[name]
[docs]
async def explore(self) -> str:
"""
The timeline of the exploration process:
| <--------------------------------- one period -------------------------------------> |
explorer | <---------------- step_1 --------------> | |
| | <---------------- step_2 --------------> | |
| ... |
| | <---------------- step_n ---------------> | |
| | <---------------------- eval --------------------> | <-- sync --> |
|--------------------------------------------------------------------------------------|
trainer | <-- idle --> | <-- step_1 --> | <-- step_2 --> | ... | <-- step_n --> | <-- sync --> |
"""
while True:
try:
self.logger.info(f"Explore step {self.explore_step_num + 1} started.")
explore_contionue = await self.explore_step()
if not explore_contionue:
# TODO: support eval on last checkpoint
break
if self.need_eval():
await self.eval()
if await self.need_sync():
await self.sync_weight()
except Exception:
self.logger.error(f"Error in Explorer: {traceback.format_exc()}")
break
self.logger.info(
f"--------------------\n> Explorer ({self.config.explorer.name}) finished.\n--------------------"
)
return self.config.explorer.name
[docs]
async def explore_step(self) -> bool:
algo_config = self.algorithm_manager.get_current_algorithm_config(self.explore_step_num + 1)
# skip warmup
if algo_config.algorithm_type == "sft":
self.explore_step_num += 1
return True
try:
tasks = await self.taskset.read_async()
except StopAsyncIteration:
self.logger.warning("No more tasks to explore. Stop exploring.")
await self.save_checkpoint(sync_weight=False)
await self.synchronizer.set_explorer_status.remote(
RunningStatus.STOPPED,
old_status=RunningStatus.RUNNING
if self.last_sync_successful
else RunningStatus.REQUIRE_SYNC,
)
await self.experience_buffer.release()
return False
self.scheduler.schedule(tasks, batch_id=self.explore_step_num + 1)
self.explore_step_num += 1
return True
[docs]
async def need_sync(self) -> bool:
if self.config.synchronizer.sync_style == SyncStyle.FIXED:
if self.explore_step_num <= self.config.synchronizer.sync_offset:
return False
require_sync = (
self.explore_step_num - self.config.synchronizer.sync_offset
) % self.config.synchronizer.sync_interval == 0
else:
require_sync = False
if self.config.synchronizer.sync_style == SyncStyle.DYNAMIC_BY_EXPLORER:
delta = self.explore_step_num - self.last_sync_step
if delta >= self.config.synchronizer.sync_interval:
require_sync = True
else:
require_sync = await (
self.synchronizer.get_trainer_status.remote() == RunningStatus.REQUIRE_SYNC
)
if require_sync and self.last_sync_successful:
await self.synchronizer.set_explorer_status.remote(
RunningStatus.REQUIRE_SYNC, old_status=RunningStatus.RUNNING
)
return require_sync
[docs]
def need_eval(self) -> bool:
return self.explore_step_num % self.config.explorer.eval_interval == 0
[docs]
async def eval(self):
"""Evaluation on all evaluation data samples."""
if len(self.config.buffer.explorer_input.eval_tasksets) == 0:
self.logger.warning("No evaluation data samples. Skip evaluation.")
return
self.logger.info(f"Evaluation at step {self.explore_step_num} started.")
if self.config.buffer.explorer_input.default_eval_workflow_type:
self.logger.info(
f"Use '{self.config.buffer.explorer_input.default_eval_workflow_type}' for evaluation."
)
for eval_taskset_config in self.config.buffer.explorer_input.eval_tasksets:
self.logger.info(
f"Evaluation on {eval_taskset_config.name} at step {self.explore_step_num} started."
)
eval_taskset = get_buffer_reader(eval_taskset_config, self.config.buffer)
eval_batch_id = f"{self.explore_step_num}/{eval_taskset.name}"
self.pending_eval_tasks.append((self.explore_step_num, eval_taskset.name))
while True:
try:
data = await eval_taskset.read_async()
self.scheduler.schedule(data, batch_id=eval_batch_id)
except StopAsyncIteration:
break
[docs]
async def benchmark(self) -> bool:
"""Benchmark the model checkpoints."""
# benchmark on the latest checkpoint
if self.config.explorer.bench_on_latest_checkpoint:
self.explore_step_num = await self._checkpoint_weights_update()
await self.eval()
await self._finish_eval_step(prefix="bench")
return True
# benchmark on base model
if self.config.explorer.eval_on_startup:
await self._finish_eval_step(prefix="bench")
# benchmark on all checkpoints
all_ckp_steps = sorted(
[
int(ckp.split("global_step_")[-1])
for ckp in os.listdir(self.config.checkpoint_job_dir)
if os.path.isdir(os.path.join(self.config.checkpoint_job_dir, ckp))
and ckp.startswith("global_step_")
]
)
for step_num in all_ckp_steps:
self.explore_step_num = await self._checkpoint_weights_update(step_num=step_num)
await self.eval()
await self._finish_eval_step(prefix="bench")
return True
[docs]
async def save_checkpoint(self, sync_weight: bool = False) -> None:
if not self.config.explorer.collect_experiences:
# wait for all tasks to complete
self.logger.info("Waiting for all tasks to complete")
await self.scheduler.wait_all()
self.logger.info(f"All tasks before step {self.explore_step_num} have completed.")
await self._finish_steps(self.last_sync_step + 1, self.explore_step_num, self.model_version)
if sync_weight:
# sync weights
self.logger.info(f"Explorer sync_weights at step {self.explore_step_num} started.")
if self.use_nccl_sync:
await self._nccl_weights_update()
else: # pull weights from Synchronizer
await self._pull_latest_weights()
self.logger.info(
f"Explorer sync_weights at step {self.explore_step_num} finished, model version = {self.model_version}."
)
# save explore checkpoint
self.cache.save_explorer(
current_step=self.explore_step_num,
current_task_index=self.explore_step_num * self.config.buffer.batch_size,
)
[docs]
async def sync_weight(self) -> None:
"""Synchronize model weights."""
# call this method before training start to load the latest model weights
await self.save_checkpoint(sync_weight=True)
async def _finish_steps(self, start_step: int, end_step: int, model_version: int) -> None:
for step in range(start_step, end_step + 1):
self.logger.info(f"Log metrics of step {step}")
await self._finish_explore_step(step=step, model_version=model_version)
await self._finish_eval_step(step=step)
async def _finish_explore_step(self, step: int, model_version: int) -> None:
statuses, exps = await self.scheduler.get_results(batch_id=step)
metric = {"rollout/model_version": model_version}
if self.config.explorer.collect_experiences:
exp_cnt, add_strategy_metric = await self.add_strategy.add(exps, step)
self.generated_experience_cnt += exp_cnt
metric.update(add_strategy_metric)
metric["rollout/experience_count"] = exp_cnt
if statuses:
metric.update(gather_metrics([status.metric for status in statuses], "rollout"))
self.monitor.log(metric, step=step)
async def _finish_eval_step(self, step: Optional[int] = None, prefix: str = "eval") -> None:
if not self.pending_eval_tasks:
return
step = step or self.explore_step_num
st = time.time()
metric = {}
while self.pending_eval_tasks:
eval_step, eval_task_name = self.pending_eval_tasks[0]
if eval_step != step:
return
self.pending_eval_tasks.popleft()
eval_results, _ = await self.scheduler.get_results(f"{step}/{eval_task_name}")
metric.update(
gather_metrics(
[status.metric for status in eval_results], f"{prefix}/{eval_task_name}"
)
)
metric[f"{prefix}/total_time"] = time.time() - st
self.monitor.log(metric, step)
[docs]
async def shutdown(self) -> None:
await self.scheduler.stop()
self.monitor.close()
[docs]
def is_alive(self) -> bool:
"""Check if the explorer is alive."""
return True