Source code for trinity.explorer.workflow_runner

# -*- coding: utf-8 -*-
"""The Workflow Runner Module."""
import time
import traceback
from collections import defaultdict
from dataclasses import dataclass
from typing import List, Optional, Tuple

from trinity.buffer import get_buffer_writer
from trinity.common.config import Config
from trinity.common.experience import Experience
from trinity.common.models.model import InferenceModel, ModelWrapper
from trinity.common.workflows import Task
from trinity.utils.log import get_logger


[docs] @dataclass(frozen=True) class Status: """Status of the task running result.""" ok: bool metric: dict[str, float] message: Optional[str] = None
[docs] class WorkflowRunner: """A Ray remote actor to run the workflow and generate experiences."""
[docs] def __init__( self, config: Config, model: InferenceModel, auxiliary_models: Optional[List[InferenceModel]] = None, runner_id: Optional[int] = None, ) -> None: self.config = config self.experience_buffer = get_buffer_writer( self.config.buffer.explorer_output, # type: ignore self.config.buffer, ) self.model = model self.model_wrapper = ModelWrapper( model, config.explorer.rollout_model.engine_type, enable_history=config.explorer.rollout_model.enable_history, ) self.auxiliary_models = [] if auxiliary_models is not None: for model in auxiliary_models: api_client = ModelWrapper( model, "vllm_async", ).get_openai_client() self.auxiliary_models.append(api_client) self.logger = get_logger(__name__) self.workflow_instance = None self.runner_id = runner_id self.return_experiences = self.config.explorer.collect_experiences
[docs] def is_alive(self): return True
def _create_workflow_instance(self, task: Task) -> None: if task.workflow is None: raise ValueError("Workflow is not set in the task.") if ( self.workflow_instance is None or not self.workflow_instance.__class__ == task.workflow or not self.workflow_instance.resettable ): self.workflow_instance = task.to_workflow(self.model_wrapper, self.auxiliary_models) else: self.workflow_instance.reset(task) def _run_task(self, task: Task, repeat_times: int, run_id_base: int) -> List[Experience]: """Init workflow from the task and run it.""" self._create_workflow_instance(task) if self.workflow_instance.repeatable: self.workflow_instance.set_repeat_times(repeat_times, run_id_base) exps = self.workflow_instance.run() else: exps = [] for i in range(repeat_times): new_exps = self.workflow_instance.run() for exp in new_exps: exp.eid.run = run_id_base + i exps.extend(new_exps) if i < repeat_times - 1: self._create_workflow_instance(task) return exps
[docs] def run_task( self, task: Task, repeat_times: int = 1, run_id_base: int = 0, ) -> Tuple[Status, List[Experience]]: """Run the task and return the states.""" # TODO: avoid sending the experiences back to the scheduler to reduce the communication overhead try: st = time.time() exps = self._run_task(task, repeat_times, run_id_base) assert exps is not None and len(exps) > 0, "An empty experience is generated" metrics: dict[str, List[float]] = defaultdict(list) # set eid for each experience for i, exp in enumerate(exps): exp.eid.batch = task.batch_id exp.eid.task = task.task_id if not hasattr(exp, "info") or exp.info is None: exp.info = {} exp.info["model_version"] = self.model_wrapper.model_version exp.info["use_count"] = 0 if not hasattr(exp, "metrics") or exp.metrics is None: exp.metrics = {} for k, v in exp.metrics.items(): metrics[k].append(v) # We get the average of metrics into the state metric = {} metric["time_per_task"] = time.time() - st if metrics: for k, v in metrics.items(): metric[k] = sum(v) / len(v) # type: ignore if task.is_eval: # If the task is an evaluation task, we do not record the experiences to the buffer return Status(True, metric=metric), [] elif self.return_experiences: return Status(True, metric=metric), exps else: self.experience_buffer.write(exps) return Status(True, metric=metric), [] except Exception as e: error_trace_back = traceback.format_exc() self.logger.error(f"WorkflowRunner run task error: {e}\nTraceback:\n{error_trace_back}") return Status(False, metric={"time_per_task": time.time() - st}, message=str(e)), []