Source code for trinity.explorer.workflow_runner

# -*- coding: utf-8 -*-
"""The Workflow Runner Moudle."""
import time
import traceback
import uuid
from collections import defaultdict
from dataclasses import dataclass
from typing import List, Optional

import ray

from trinity.buffer import get_buffer_writer
from trinity.common.config import Config
from trinity.common.experience import Experience
from trinity.common.models.model import InferenceModel, ModelWrapper
from trinity.common.workflows import Task
from trinity.utils.log import get_logger


[docs] @dataclass(frozen=True) class Status: """Status of the task running result.""" ok: bool metric: dict[str, float] message: Optional[str] = None
@ray.remote(scheduling_strategy="SPREAD") class WorkflowRunner: """A Ray remote actor to run the workflow and put the returned experiences into the buffer.""" def __init__( self, config: Config, model: InferenceModel, auxiliary_models: Optional[List[InferenceModel]] = None, ) -> None: self.config = config self.experience_buffer = get_buffer_writer( self.config.buffer.explorer_output, # type: ignore self.config.buffer, ) self.model = model self.model_wrapper = ModelWrapper( model, config.explorer.rollout_model.engine_type, ) self.auxiliary_models = [] if auxiliary_models is not None: for model in auxiliary_models: api_client = ModelWrapper( model, "vllm_async", ).get_openai_client() self.auxiliary_models.append(api_client) self.logger = get_logger(__name__) self.workflow_instance = None def is_alive(self): return True def _run_task(self, task: Task) -> List[Experience]: """Init workflow from the task and run it.""" if task.workflow is None: raise ValueError("Workflow is not set in the task.") if ( self.workflow_instance is None or not self.workflow_instance.__class__ == task.workflow or not self.workflow_instance.resettable ): self.workflow_instance = task.to_workflow(self.model_wrapper, self.auxiliary_models) else: self.workflow_instance.reset(task) return self.workflow_instance.run() def run_task(self, task: Task) -> Status: """Run the task and return the states.""" try: st = time.time() exps = self._run_task(task) assert exps is not None and len(exps) > 0, "An empty experience is generated" metrics: dict[str, List[float]] = defaultdict(list) # add run_id run_id = str(uuid.uuid4()) for exp in exps: setattr(exp, "run_id", run_id) if not hasattr(exp, "info") or exp.info is None: exp.info = {} exp.info["model_version"] = self.model_wrapper.get_ckp_version() if not hasattr(exp, "metrics") or exp.metrics is None: exp.metrics = {} for k, v in exp.metrics.items(): metrics[k].append(v) # We get the average of metrics into the state metric = {} metric["time_per_task"] = time.time() - st if metrics: for k, v in metrics.items(): metric[k] = sum(v) / len(v) # type: ignore if not task.is_eval: self.experience_buffer.write(exps) return Status(True, metric=metric) except Exception as e: error_trace_back = traceback.format_exc() self.logger.error(f"WorkflowRunner run task error: {e}\nTraceback:\n{error_trace_back}") return Status(False, metric={"time_per_task": time.time() - st}, message=str(e))