Source code for trinity.utils.dlc_utils

import os
import subprocess
import sys
import time

import ray

from trinity.utils.log import get_logger

logger = get_logger(__name__)

CLUSTER_ACTOR_NAME = "cluster_status"


@ray.remote
class ClusterStatus:
    def __init__(self):
        self.finished = False

    def finish(self) -> None:
        self.finished = True

    def running(self) -> bool:
        return not self.finished


[docs] def get_dlc_env_vars() -> dict: envs = { "RANK": int(os.environ.get("RANK", -1)), # type: ignore "WORLD_SIZE": int(os.environ.get("WORLD_SIZE", -1)), # type: ignore "MASTER_ADDR": os.environ.get("MASTER_ADDR", None), "MASTER_PORT": os.environ.get("MASTER_PORT", None), } for key, value in envs.items(): if value is None or value == -1: logger.error(f"DLC env var `{key}` is not set.") raise ValueError(f"DLC env var `{key}` is not set.") return envs
[docs] def is_running() -> bool: """Check if ray cluster is running.""" ret = subprocess.run("ray status", shell=True, capture_output=True) return ret.returncode == 0
[docs] def wait_for_ray_setup() -> None: while True: if is_running(): break else: logger.info("Waiting for ray cluster to be ready...") time.sleep(1)
[docs] def wait_for_ray_worker_nodes(world_size: int) -> None: while True: alive_nodes = [node for node in ray.nodes() if node["Alive"]] if len(alive_nodes) >= world_size: break else: logger.info( f"{len(alive_nodes)} nodes have joined so far, waiting for {world_size - len(alive_nodes)} nodes..." ) time.sleep(1)
[docs] def setup_ray_cluster(namespace: str): env_vars = get_dlc_env_vars() is_master = env_vars["RANK"] == 0 if is_running(): # reuse existing ray cluster if is_master: ray.init(namespace=namespace, ignore_reinit_error=True) else: if is_master: cmd = f"ray start --head --port={env_vars['MASTER_PORT']}" else: cmd = f"ray start --address={env_vars['MASTER_ADDR']}:{env_vars['MASTER_PORT']}" ret = subprocess.run(cmd, shell=True, capture_output=True) logger.info(f"Starting ray cluster: {cmd}") if ret.returncode != 0: logger.error(f"Failed to start ray cluster: {cmd}") logger.error(f"ret.stdout: {ret.stdout!r}") logger.error(f"ret.stderr: {ret.stderr!r}") sys.exit(1) wait_for_ray_setup() ray.init( address=f"{env_vars['MASTER_ADDR']}:{env_vars['MASTER_PORT']}", namespace=namespace, ignore_reinit_error=True, ) if is_master: # master wait for worker nodes to join wait_for_ray_worker_nodes(env_vars["WORLD_SIZE"]) else: # woker wait on the cluster status actor cluster_status = ClusterStatus.options( name=CLUSTER_ACTOR_NAME, get_if_exists=True, ).remote() while True: if ray.get(cluster_status.running.remote()): time.sleep(5) else: break sys.exit(0)
[docs] def stop_ray_cluster(): """ Stop the ray cluster by sending a signal to the cluster status actor. """ cluster_status = ClusterStatus.options( name=CLUSTER_ACTOR_NAME, get_if_exists=True, ).remote() ray.get(cluster_status.finish.remote())