# -*- coding: utf-8 -*-
"""The explorer module"""
from __future__ import annotations
import asyncio
import os
import time
import traceback
from collections import deque
from typing import List, Optional
import ray
import torch
from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy
from trinity.buffer.buffer import get_buffer_reader
from trinity.buffer.pipelines.experience_pipeline import ExperiencePipeline
from trinity.common.config import Config
from trinity.common.constants import (
ROLLOUT_WEIGHT_SYNC_GROUP_NAME,
RunningStatus,
SyncMethod,
SyncStyle,
)
from trinity.common.models import create_inference_models
from trinity.common.models.utils import get_checkpoint_dir_with_step_num
from trinity.explorer.scheduler import Scheduler
from trinity.manager.state_manager import StateManager
from trinity.manager.synchronizer import Synchronizer
from trinity.utils.annotations import Experimental
from trinity.utils.log import get_logger
from trinity.utils.monitor import MONITOR, gather_metrics
from trinity.utils.plugin_loader import load_plugins
[docs]
class Explorer:
"""Responsible for exploring the taskset."""
[docs]
def __init__(self, config: Config):
self.logger = get_logger(config.explorer.name, in_ray_actor=True)
load_plugins()
self.state = StateManager(
path=config.checkpoint_job_dir, explorer_name=config.explorer.name, config=config
)
explorer_state = self.state.load_explorer()
self.explore_step_num = explorer_state.get("latest_iteration", 0)
self.last_sync_step = self.explore_step_num if self.explore_step_num > 0 else -1
self.synchronizer = Synchronizer.get_actor(config)
self.config = config
self.models, self.auxiliary_models = create_inference_models(config)
self.experience_pipeline = self._init_experience_pipeline()
self.config.buffer.explorer_input.taskset.index = explorer_state.get("latest_task_index", 0)
self.taskset = (
get_buffer_reader(self.config.buffer.explorer_input.taskset, self.config.buffer)
if self.config.mode != "serve"
else None
)
self.scheduler = None
self.monitor = MONITOR.get(self.config.monitor.monitor_type)(
project=self.config.project,
group=self.config.group,
name=self.config.name,
role=self.config.explorer.name,
config=config,
)
self.batch_size = config.buffer.batch_size
self.update_interval = (
self.config.synchronizer.sync_interval * self.config.buffer.batch_size
)
self.use_nccl_sync = self.config.synchronizer.sync_method == SyncMethod.NCCL
self.pending_eval_tasks = deque()
# For checkpoint weights update
# Use explorer to periodically load the latest model weights and
# boradcast to all rollout models
self.enable_lora = self.config.explorer.rollout_model.enable_lora
self.model_version = -1
self.last_sync_successful = True
self.logger.info("Finished initializing Explorer.")
[docs]
async def setup_weight_sync_group(
self, master_address: str, master_port: int, state_dict_meta: List = None
):
# In checkpoint mode, we use explorer to store the model weights which has no rank
base_offset = 1 if self.use_nccl_sync else 0
world_size = (
len(self.models) * self.config.explorer.rollout_model.tensor_parallel_size + base_offset
)
self.logger.info(
f"Initialize process group for weight synchronization, "
f"master_address={master_address}, master_port={master_port}, "
f"world_size={world_size}, rank_offset={base_offset}"
)
# TODO: save state_dict in models
refs = [
model.init_process_group.remote(
master_address=master_address,
master_port=master_port,
rank_offset=i * self.config.explorer.rollout_model.tensor_parallel_size
+ base_offset,
world_size=world_size,
group_name=ROLLOUT_WEIGHT_SYNC_GROUP_NAME,
explorer_name=self.config.explorer.name,
timeout=self.config.synchronizer.sync_timeout,
state_dict_meta=state_dict_meta,
)
for i, model in enumerate(self.models)
]
await asyncio.gather(*refs)
async def _checkpoint_weights_update(self, step_num: Optional[int] = None) -> int:
step_num = await self.synchronizer.set_model_state_dict_with_step_num.remote(step_num)
await asyncio.gather(*[model.sync_model.remote(step_num) for model in self.models])
return step_num # type: ignore
async def _pull_latest_weights(self):
self.logger.info("Start to pull latest model weights.")
new_version = await self.synchronizer.wait_new_model_state_dict.remote(self.model_version)
if new_version > self.model_version:
if self.model_version != -1:
self.logger.info(f"New model weights version: {new_version}")
await asyncio.gather(
*[model.sync_model.remote(new_version) for model in self.models]
)
self.model_version = new_version
self.last_sync_step = self.explore_step_num
self.last_sync_successful = True
else:
self.logger.warning(
f"No new model weights found, current version: {self.model_version}"
)
self.last_sync_successful = False
async def _nccl_weights_update(self):
new_version = await self.synchronizer.ready_to_nccl_sync.remote(
"explorer", self.model_version
)
if new_version is None:
self.logger.info("Trainer is not ready to sync weight. Skipping sync weight.")
self.last_sync_successful = False
return
self.model_version = new_version
await asyncio.gather(
*[model.sync_model.remote(self.model_version) for model in self.models]
)
self.last_sync_step = self.explore_step_num
self.last_sync_successful = True
[docs]
async def prepare(self) -> None:
"""Preparation before running."""
try:
await self.experience_pipeline.prepare.remote()
self.logger.info("Experience pipeline is ready.")
# make sure all rollout models are ready
model_ready_ref = [model.__ray_ready__.remote() for model in self.models]
await asyncio.gather(*model_ready_ref)
self.logger.info("All rollout models are ready.")
if not self.use_nccl_sync:
master_address, master_port = await self.models[0].get_available_address.remote()
await self.setup_weight_sync_group(master_address, master_port)
if self.config.mode != "serve":
self.scheduler = Scheduler(self.config, self.models, self.auxiliary_models)
await self.scheduler.start()
if self.config.explorer.eval_on_startup and self.explore_step_num == 0:
await self.eval()
await self.synchronizer.set_explorer_status.remote(RunningStatus.REQUIRE_SYNC)
except Exception as e:
self.logger.error(f"Error during explorer preparation: {traceback.format_exc()}")
await self.shutdown()
raise e
[docs]
async def get_weight(self, name: str) -> torch.Tensor:
"""Get the weight of the loaded model (For checkpoint weights update)."""
return self.state_dict[name]
[docs]
async def explore(self) -> str:
"""
The timeline of the exploration process:
| <--------------------------------- one period -------------------------------------> |
explorer | <---------------- step_1 --------------> | |
| | <---------------- step_2 --------------> | |
| ... |
| | <---------------- step_n ---------------> | |
| | <---------------------- eval --------------------> | <-- sync --> |
|--------------------------------------------------------------------------------------|
trainer | <-- idle --> | <-- step_1 --> | <-- step_2 --> | ... | <-- step_n --> | <-- sync --> |
"""
while True:
try:
self.logger.info(f"Explore step {self.explore_step_num + 1} started.")
explore_contionue = await self.explore_step()
if not explore_contionue:
# TODO: support eval on last checkpoint
break
if self.need_eval():
await self.eval()
if await self.need_sync():
await self.sync_weight()
except Exception:
self.logger.error(f"Error in Explorer: {traceback.format_exc()}")
break
self.logger.info(
f"--------------------\n> Explorer ({self.config.explorer.name}) finished.\n--------------------"
)
return self.config.explorer.name
[docs]
async def explore_step(self) -> bool:
try:
tasks = await self.taskset.read_async()
except StopAsyncIteration:
self.logger.warning("No more tasks to explore. Stop exploring.")
await self.save_checkpoint(sync_weight=False)
await self.synchronizer.set_explorer_status.remote(
RunningStatus.STOPPED,
old_status=RunningStatus.RUNNING
if self.last_sync_successful
else RunningStatus.REQUIRE_SYNC,
)
await self.shutdown()
return False
self.scheduler.schedule(tasks, batch_id=self.explore_step_num + 1)
self.explore_step_num += 1
return True
[docs]
async def need_sync(self) -> bool:
if self.config.synchronizer.sync_style == SyncStyle.FIXED:
if self.explore_step_num <= self.config.synchronizer.sync_offset:
return False
require_sync = (
self.explore_step_num - self.config.synchronizer.sync_offset
) % self.config.synchronizer.sync_interval == 0
else:
require_sync = False
if self.config.synchronizer.sync_style == SyncStyle.DYNAMIC_BY_EXPLORER:
delta = self.explore_step_num - self.last_sync_step
if delta >= self.config.synchronizer.sync_interval:
require_sync = True
else:
require_sync = await (
self.synchronizer.get_trainer_status.remote() == RunningStatus.REQUIRE_SYNC
)
if require_sync and self.last_sync_successful:
await self.synchronizer.set_explorer_status.remote(
RunningStatus.REQUIRE_SYNC, old_status=RunningStatus.RUNNING
)
return require_sync
[docs]
def need_eval(self) -> bool:
return self.explore_step_num % self.config.explorer.eval_interval == 0
[docs]
async def eval(self):
"""Evaluation on all evaluation data samples."""
if len(self.config.buffer.explorer_input.eval_tasksets) == 0:
self.logger.warning("No evaluation data samples. Skip evaluation.")
return
self.logger.info(f"Evaluation at step {self.explore_step_num} started.")
if self.config.buffer.explorer_input.default_eval_workflow_type:
self.logger.info(
f"Use '{self.config.buffer.explorer_input.default_eval_workflow_type}' for evaluation."
)
for eval_taskset_config in self.config.buffer.explorer_input.eval_tasksets:
self.logger.info(
f"Evaluation on {eval_taskset_config.name} at step {self.explore_step_num} started."
)
eval_taskset = get_buffer_reader(eval_taskset_config, self.config.buffer)
eval_batch_id = f"{self.explore_step_num}/{eval_taskset.name}"
self.pending_eval_tasks.append((self.explore_step_num, eval_taskset.name))
while True:
try:
data = await eval_taskset.read_async()
self.scheduler.schedule(data, batch_id=eval_batch_id)
except StopAsyncIteration:
break
[docs]
async def benchmark(self) -> bool:
"""Benchmark the model checkpoints."""
# benchmark on the latest checkpoint
if self.config.explorer.bench_on_latest_checkpoint:
self.explore_step_num = await self._checkpoint_weights_update()
await self.eval()
await self._finish_eval_step(prefix="bench")
return True
# benchmark on base model
if self.config.explorer.eval_on_startup:
await self._finish_eval_step(prefix="bench")
# benchmark on all checkpoints
all_ckp_steps = sorted(
[
int(ckp.split("global_step_")[-1])
for ckp in os.listdir(self.config.checkpoint_job_dir)
if os.path.isdir(os.path.join(self.config.checkpoint_job_dir, ckp))
and ckp.startswith("global_step_")
]
)
for step_num in all_ckp_steps:
self.explore_step_num = await self._checkpoint_weights_update(step_num=step_num)
await self.eval()
await self._finish_eval_step(prefix="bench")
return True
[docs]
async def save_checkpoint(self, sync_weight: bool = False) -> None:
if self.scheduler:
await self._finish_steps(
self.last_sync_step + 1, self.explore_step_num, self.model_version
)
if sync_weight:
# sync weights
self.logger.info(f"Explorer sync_weights at step {self.explore_step_num} started.")
if self.use_nccl_sync:
await self._nccl_weights_update()
else: # pull weights from Synchronizer
await self._pull_latest_weights()
self.logger.info(
f"Explorer sync_weights at step {self.explore_step_num} finished, model version = {self.model_version}."
)
# save explore checkpoint
self.state.save_explorer(
current_step=self.explore_step_num,
current_task_index=self.explore_step_num * self.config.buffer.batch_size,
)
[docs]
async def sync_weight(self) -> None:
"""Synchronize model weights."""
# call this method before training start to load the latest model weights
await self.save_checkpoint(sync_weight=True)
async def _finish_steps(self, start_step: int, end_step: int, model_version: int) -> None:
for step in range(start_step, end_step + 1):
self.logger.info(f"Log metrics of step {step}")
await self._finish_explore_step(step=step, model_version=model_version)
await self._finish_eval_step(step=step)
async def _finish_explore_step(self, step: int, model_version: int) -> None:
statuses, exps = await self.scheduler.get_results(batch_id=step)
metric = {"rollout/model_version": model_version}
pipeline_metrics = await self.experience_pipeline.process.remote(exps)
metric.update(pipeline_metrics)
if statuses:
metric.update(gather_metrics([status.metric for status in statuses], "rollout"))
self.monitor.log(metric, step=step)
async def _finish_eval_step(self, step: Optional[int] = None, prefix: str = "eval") -> None:
if not self.pending_eval_tasks:
return
step = step or self.explore_step_num
st = time.time()
metric = {}
while self.pending_eval_tasks:
eval_step, eval_task_name = self.pending_eval_tasks[0]
if eval_step != step:
return
self.pending_eval_tasks.popleft()
eval_results, _ = await self.scheduler.get_results(f"{step}/{eval_task_name}")
metric.update(
gather_metrics(
[status.metric for status in eval_results], f"{prefix}/{eval_task_name}"
)
)
metric[f"{prefix}/total_time"] = time.time() - st
self.monitor.log(metric, step)
[docs]
async def shutdown(self) -> None:
if self.scheduler:
await self.scheduler.stop()
self.scheduler = None
if self.experience_pipeline:
await self.experience_pipeline.close.remote()
self.experience_pipeline = None
if self.monitor:
self.monitor.close()
self.monitor = None
self.logger.info(
f"Explorer ({self.config.explorer.name}) shutdown successfully at step {self.explore_step_num}."
)
[docs]
async def is_alive(self) -> bool:
"""Check if the explorer is alive."""
return True
def _init_experience_pipeline(self) -> ray.actor.ActorHandle:
"""Init experience pipeline for the explorer."""
node_id = ray.get_runtime_context().get_node_id()
return (
ray.remote(ExperiencePipeline)
.options(
name=f"{self.config.explorer.name}_pipeline",
namespace=ray.get_runtime_context().namespace,
scheduling_strategy=NodeAffinitySchedulingStrategy(
node_id=node_id,
soft=False,
),
)
.remote(self.config)
)
[docs]
@Experimental
async def serve(self) -> None:
"""Run the explorer in serving mode.
In serving mode, the explorer starts an OpenAI compatible server to handle requests.
Agent applications can be deployed separately and interact with the explorer via the API.
.. code-block:: python
import openai
client = openai.OpenAI(
base_url=f"{explorer_server_url}/v1",
api_key="EMPTY",
)
response = client.chat.completions.create(
model=config.model.model_path,
messages=[{"role": "user", "content": "Hello!"}]
)
"""
from trinity.explorer.api.service import ExplorerService
self.service = ExplorerService(
self,
listen_address=self.config.explorer.listen_address,
port=self.config.explorer.api_port,
)
await self.service.serve()
self.server_url = f"http://{ray.util.get_node_ip_address()}:{self.service.port}"
self.logger.info(
f"Explorer API Server is started on {self.server_url} and listening to {self.service.listen_address}."
)
self.state.save_explorer_server_url(self.server_url)
while True:
self.explore_step_num += 1
await asyncio.sleep(self.config.explorer.service_status_check_interval)
# process experiences generated in the last interval
exps = await self.service.get_all_experiences()
metrics = await self.experience_pipeline.process.remote(exps)
metrics.update(self.service.collect_metrics())
self.monitor.log(metrics, self.explore_step_num)
# get the latest checkpoint
_, step_num = get_checkpoint_dir_with_step_num(
self.config.checkpoint_job_dir, raise_error=False
)
self.service.set_latest_model_version(step_num)
[docs]
@classmethod
def get_actor(cls, config: Config):
"""Get a Ray actor for the explorer."""
return (
ray.remote(cls)
.options(
name=config.explorer.name,
namespace=ray.get_runtime_context().namespace,
)
.remote(config)
)