Source code for trinity.explorer.api.service
import asyncio
import time
from collections import deque
from typing import Dict, List
import torch
from trinity.common.constants import RunningStatus
from trinity.common.experience import Experience
from trinity.common.models.model import ModelWrapper
from trinity.explorer.explorer import Explorer
from trinity.utils.log import get_logger
[docs]
class ExplorerService:
[docs]
def __init__(self, explorer: Explorer, listen_address: str = "localhost", port: int = 8010):
self.logger = get_logger(__name__)
self.explorer = explorer
self.app = None
self.port = port
self.listen_address = listen_address
self.running = False
self.models: List[ModelWrapper] = [ModelWrapper(model) for model in explorer.models]
self.min_running_model_num = explorer.config.explorer.min_running_model_num
self.check_interval = explorer.config.explorer.service_status_check_interval
self.max_timeout = explorer.config.explorer.max_timeout
self.running_models: deque[int] = deque() # indices of running models
self.sync_task_map: Dict[asyncio.Future, int] = {} # sync task -> model index
self.latest_model_version = 0
self.experience_queue = asyncio.Queue()
self.experience_count = 0
[docs]
async def serve(self):
from trinity.explorer.api.api import run_app
if self.running:
self.logger.warning("Server is already running.")
return
self.running = True
await asyncio.gather(*[model.prepare() for model in self.models])
for i, _ in enumerate(self.models):
self.running_models.append(i)
self.serve_task = asyncio.create_task(
run_app(service=self, listen_address=self.listen_address, port=self.port)
)
self.sync_model_weights_task = asyncio.create_task(self.model_weights_sync_loop())
[docs]
async def model_weights_sync_loop(self):
self.logger.info("Starting model weights synchronization loop.")
while self.running:
for idx in list(self.running_models):
if (
len(self.running_models) > self.explorer.config.explorer.min_running_model_num
and self.models[idx].model_version < self.latest_model_version
):
self.running_models.remove(idx)
self.models[idx].status = RunningStatus.REQUIRE_SYNC
self.logger.info(f"Model {idx} scheduled for synchronization.")
future = asyncio.create_task(self._wait_for_sync_start(idx))
self.sync_task_map[future] = idx
future.add_done_callback(self._sync_model_weights)
# wait half interval
await asyncio.sleep(self.check_interval / 2)
self.logger.info("Model weights synchronization loop stopped.")
[docs]
def set_latest_model_version(self, version: int) -> None:
if version > self.latest_model_version:
self.latest_model_version = version
self.logger.info(f"Updated latest model version to {version}.")
async def _wait_for_sync_start(self, index: int):
start_time = time.time()
while time.time() - start_time < self.max_timeout:
current_load = await self.models[index].get_current_load()
if current_load == 0:
self.models[index].status = RunningStatus.WAITING_SYNC
self.logger.info(f"Model {index} begins synchronization.")
return
else:
await asyncio.sleep(2)
raise asyncio.TimeoutError(
f"Timeout waiting for model {index} to be free for synchronization. Current load: {current_load}"
)
async def _sync_model_weights(self, task: asyncio.Future):
index = self.sync_task_map.pop(task)
latest_version = self.latest_model_version # capture the latest version
if task.cancelled():
self.logger.warning(f"Synchronization of model {index} was cancelled.")
elif task.exception():
self.logger.error(f"Error during synchronization of model {index}: {task.exception()}")
else:
await self.models[index].sync_model_weights(latest_version)
self.logger.info(f"Model {index} synchronized to version {latest_version}.")
self.running_models.append(index)
self.models[index].status = RunningStatus.RUNNING
[docs]
async def allocate_model(self, increase_count: bool = True) -> str:
model = self.models[self.running_models[0]]
if increase_count:
model.request_count += 1
self.running_models.rotate(-1)
return model.api_address
[docs]
def collect_metrics(self) -> Dict:
metrics = {}
for i, model in enumerate(self.models):
metrics[f"rollout/model_{i}/total_request_count"] = model.request_count
metrics[f"rollout/model_{i}/model_version"] = model.model_version
metrics["rollout/total_experience_count"] = self.experience_count
return metrics
[docs]
async def check_requiring_sync_models(self):
if not self.running:
self.logger.warning("Server is not running.")
return
await asyncio.gather(
*[self._sync_model_weights(idx) for idx in list(self.requiring_sync_models)]
)
[docs]
async def record_experience(self, response):
experiences = []
for choice in response["choices"]:
exp = Experience(
tokens=torch.cat(
(
torch.tensor(response["prompt_token_ids"], dtype=torch.int32),
torch.tensor(choice["token_ids"], dtype=torch.int32),
)
),
logprobs=choice.get("logprobs", None),
prompt_length=len(response["prompt_token_ids"]),
response_text=choice.get("message", {}).get("content", ""),
)
experiences.append(exp)
self.experience_count += len(experiences)
for exp in experiences:
await self.experience_queue.put(exp)
[docs]
async def get_all_experiences(self) -> List:
experiences = []
while not self.experience_queue.empty():
experiences.append(await self.experience_queue.get())
return experiences
[docs]
async def shutdown(self):
if not self.running:
self.logger.warning("Server is not running.")
return
self.sync_model_weights_task.cancel()
self.serve_task.cancel()
try:
await self.serve_task
except asyncio.CancelledError:
pass
self.running = False
self.logger.info("API server shut down.")