Source code for trinity.explorer.proxy.service

import asyncio
import time
from collections import deque
from typing import Dict, List, Tuple

import torch

from trinity.common.constants import RunningStatus
from trinity.common.experience import Experience
from trinity.common.models.model import ModelWrapper
from trinity.explorer.explorer import Explorer
from trinity.explorer.proxy.recorder import HistoryRecorder
from trinity.utils.log import get_logger


[docs] class ExplorerService: """Manages the lifecycle and operations of the Explorer API service."""
[docs] def __init__(self, explorer: Explorer, listen_address: str = "localhost", port: int = 8010): self.logger = get_logger(__name__) self.explorer = explorer self.app = None self.port = port self.listen_address = listen_address self.running = False self.models: List[ModelWrapper] = [ModelWrapper(model) for model in explorer.models] self.min_running_model_num = explorer.config.explorer.min_running_model_num self.check_interval = explorer.config.explorer.service_status_check_interval self.max_timeout = explorer.config.explorer.max_timeout self.running_model_ids: deque[int] = deque() # indices of running models self.model_version_map: Dict[int, int] = {} # model index -> model version self.sync_task_map: Dict[asyncio.Future, int] = {} # sync task -> model index self.latest_model_version = 0 self.session_level_experience_queue: Dict[int, deque[Experience]] = {} self.commit_lock = asyncio.Lock() self.ready_experiences = deque() self.recorder = HistoryRecorder( db_url=explorer.config.explorer.db_url or f"sqlite:///{explorer.config.buffer.cache_dir}/proxy_history.db", table_name="proxy_history", ) self.total_experience_count = 0 self.ready_experience_count = 0
[docs] async def serve(self) -> None: from trinity.explorer.proxy.app import run_app if self.running: self.logger.warning("Server is already running.") return self.running = True await asyncio.gather(*[model.prepare() for model in self.models]) for i, _ in enumerate(self.models): self.running_model_ids.append(i) self.serve_task = asyncio.create_task( run_app(service=self, listen_address=self.listen_address, port=self.port) ) self.sync_model_weights_task = asyncio.create_task(self.model_weights_sync_loop())
[docs] async def model_weights_sync_loop(self) -> None: self.logger.info("Starting model weights synchronization loop.") while self.running: for idx in list(self.running_model_ids): self.model_version_map[idx] = await self.models[idx].model_version_async if ( len(self.running_model_ids) > self.explorer.config.explorer.min_running_model_num and self.model_version_map[idx] < self.latest_model_version ): self.logger.info(f"Model {idx} scheduled for synchronization.") self.models[idx].status = RunningStatus.REQUIRE_SYNC self.running_model_ids.remove(idx) asyncio.create_task(self._sync_model_weights(idx)) # wait half interval await asyncio.sleep(self.check_interval / 2) self.logger.info("Model weights synchronization loop stopped.")
[docs] def set_latest_model_version(self, version: int) -> None: if version > self.latest_model_version: self.latest_model_version = version self.logger.info(f"Updated latest model version to {version}.")
async def _sync_model_weights(self, index: int) -> None: """Synchronize model weights for the given model index.""" # wait until the model is free start_time = time.time() timeout_flag = True current_load = -1 while time.time() - start_time < self.max_timeout: current_load = await self.models[index].get_current_load() if current_load == 0: self.models[index].status = RunningStatus.WAITING_SYNC self.logger.info(f"Model {index} begins synchronization.") timeout_flag = False break else: self.logger.info( "Waiting for model %d to be free. Current load: %d", index, current_load ) await asyncio.sleep(1) if timeout_flag: raise asyncio.TimeoutError( f"Timeout waiting for model {index} to be free for synchronization. Current load: {current_load}" ) latest_version = self.latest_model_version # capture the latest version # perform synchronization await self.models[index].sync_model_weights(latest_version) self.logger.info(f"Model {index} synchronized to version {latest_version}.") self.model_version_map[index] = await self.models[index].model_version_async self.models[index].status = RunningStatus.RUNNING self.running_model_ids.append(index)
[docs] async def allocate_model(self, increase_count: bool = True) -> Tuple[str, int]: """Allocate a model for handling a request. Returns: A tuple of (model_api_address, model_version). """ model_id = self.running_model_ids[0] model = self.models[model_id] if increase_count: model.request_count += 1 self.running_model_ids.rotate(-1) if model.api_address is None: raise ValueError( "Model does not have a valid API address, please set `enable_openai_api` to `True`." ) return model.api_address, self.model_version_map[model_id]
[docs] def collect_metrics(self) -> Dict: metrics = {} for i, model in enumerate(self.models): metrics[f"rollout/model_{i}/total_request_count"] = model.request_count metrics[f"rollout/model_{i}/model_version"] = model.model_version metrics["rollout/total_experience_count"] = self.total_experience_count metrics["rollout/ready_experience_count"] = self.ready_experience_count return metrics
[docs] async def record_experience(self, response, model_version: int) -> None: experiences = [] for choice in response["choices"]: exp = Experience( tokens=torch.cat( ( torch.tensor(response["prompt_token_ids"], dtype=torch.int32), torch.tensor(choice["token_ids"], dtype=torch.int32), ) ), logprobs=( torch.tensor( [logprob["logprob"] for logprob in choice["logprobs"]["content"]], dtype=torch.float32, ) if "logprobs" in choice and choice["logprobs"] is not None else torch.tensor([], dtype=torch.float32) ), prompt_length=len(response["prompt_token_ids"]), ) exp.eid.suffix = response["id"] exp.info["model_version"] = model_version experiences.append(exp) self.total_experience_count += len(experiences) self.recorder.record_history(experiences)
[docs] async def submit_experiences(self) -> None: async with self.commit_lock: experiences = list(self.ready_experiences) self.ready_experiences.clear() metrics = await self.explorer.experience_pipeline.process.remote(experiences) metrics.update(self.collect_metrics()) self.explorer.explore_step_num += 1 self.explorer.monitor.log(metrics, self.explorer.explore_step_num)
[docs] async def record_feedback(self, reward: float, msg_ids: List[str], task_id: str, run_id: int): exps = self.recorder.update_reward( reward=reward, msg_ids=msg_ids, task_id=task_id, run_id=run_id, ) self.ready_experience_count += len(exps) self.ready_experiences.extend(exps)
[docs] async def shutdown(self): if not self.running: self.logger.warning("Server is not running.") return self.sync_model_weights_task.cancel() self.serve_task.cancel() try: await self.serve_task except asyncio.CancelledError: pass self.running = False self.logger.info("API server shutdown.")