Source code for trinity.explorer.workflow_runner

# -*- coding: utf-8 -*-
"""The Workflow Runner Module."""
import asyncio
import time
import traceback
from collections import defaultdict
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple

from trinity.buffer import get_buffer_reader
from trinity.common.config import Config
from trinity.common.experience import Experience
from trinity.common.models import get_debug_inference_model
from trinity.common.models.model import InferenceModel, ModelWrapper
from trinity.common.workflows import Task, Workflow
from trinity.utils.log import get_logger


[docs] @dataclass(frozen=True) class Status: """Status of the task running result.""" ok: bool metrics: List[Dict[str, float]] # A list of metric dictionaries, where each dictionary is from a single run. message: Optional[str] = None
[docs] def calculate_run_level_metrics(experiences: List[Experience]) -> Dict[str, float]: """Calculate metrics from experiences. For non-repeatable workflows, this function will average the metrics from experiences generated by each run, which is equivalent to calculating run level metrics. For repeatable workflows, please do not use this function. """ run_level_metrics: Dict[str, List[float]] = defaultdict(list) for exp in experiences: if exp.metrics: for k, v in exp.metrics.items(): run_level_metrics[k].append(v) averaged_metrics: Dict[str, float] = {} for key, values in run_level_metrics.items(): averaged_metrics[key] = sum(values) / len(values) return averaged_metrics
[docs] class WorkflowRunner: """A Ray remote actor to run the workflow and generate experiences."""
[docs] def __init__( self, config: Config, model: InferenceModel, auxiliary_models: Optional[List[InferenceModel]] = None, runner_id: Optional[int] = None, ) -> None: self.logger = get_logger(f"{config.explorer.name}_runner_{runner_id}", in_ray_actor=True) self.config = config self.model = model self.model_wrapper = ModelWrapper( model, config.explorer.rollout_model.engine_type, enable_lora=config.explorer.rollout_model.enable_lora, enable_history=config.explorer.rollout_model.enable_history, ) self.auxiliary_models = [ ModelWrapper( model, ) for model in (auxiliary_models or []) ] self.auxiliary_model_clients = [] self.auxiliary_model_async_clients = [] self.workflow_instance: Workflow = None self.runner_id = runner_id self.runner_state = { "workflow_id": None, "model_version": None, "begin_time": 0, "terminate_time": 0, }
[docs] async def prepare(self) -> None: """Prepare the runner.""" await asyncio.gather( self.model_wrapper.prepare(), *(aux_model.prepare() for aux_model in self.auxiliary_models), ) for model in self.auxiliary_models: api_client = model.get_openai_client() async_api_client = model.get_openai_async_client() self.auxiliary_model_clients.append(api_client) self.auxiliary_model_async_clients.append(async_api_client)
[docs] def is_alive(self): return True
def _create_workflow_instance(self, task: Task) -> None: if task.workflow is None: raise ValueError("Workflow is not set in the task.") if ( self.workflow_instance is None or not self.workflow_instance.__class__ == task.workflow or not self.workflow_instance.resettable ): self.workflow_instance = task.to_workflow( self.model_wrapper, ( self.auxiliary_model_async_clients if task.workflow.is_async else self.auxiliary_model_clients ), ) else: self.workflow_instance.reset(task) async def _run_workflow(self, workflow_instance: Workflow) -> List[Experience]: if workflow_instance.asynchronous: exps = await workflow_instance.run_async() else: exps = workflow_instance.run() return exps async def _run_task( self, task: Task, repeat_times: int, run_id_base: int ) -> Tuple[List[Experience], List[Dict]]: """Init workflow from the task and run it.""" self._create_workflow_instance(task) if self.workflow_instance.repeatable: self.workflow_instance.set_repeat_times(repeat_times, run_id_base) st = time.time() await self.model_wrapper.clean_workflow_state() self.runner_state["workflow_id"] = f"{task.batch_id}/{task.task_id}/{run_id_base}" self.runner_state["terminate_time"] = None self.runner_state["begin_time"] = st exps = await self._run_workflow(self.workflow_instance) et = time.time() self.runner_state["terminate_time"] = et # repeatable workflow cannot calculate run level metrics, we use experience level metrics directly run_metrics = [exp.metrics for exp in exps if exp.metrics] for metric in run_metrics: metric["time/run_execution"] = et - st else: exps = [] run_metrics = [] for i in range(repeat_times): st = time.time() await self.model_wrapper.clean_workflow_state() self.runner_state["workflow_id"] = f"{task.batch_id}/{task.task_id}/{i}" self.runner_state["terminate_time"] = None self.runner_state["begin_time"] = st new_exps = await self._run_workflow(self.workflow_instance) et = time.time() self.runner_state["terminate_time"] = et run_metric = calculate_run_level_metrics(new_exps) run_metric["time/run_execution"] = et - st run_metrics.append(run_metric) for exp in new_exps: exp.eid.run = run_id_base + i exps.extend(new_exps) if i < repeat_times - 1: self._create_workflow_instance(task) return exps, run_metrics
[docs] async def get_runner_state(self) -> Dict: """Get the runner state.""" runner_state = self.runner_state.copy() runner_state.update(await self.model_wrapper.get_workflow_state()) return runner_state
[docs] async def run_task( self, task: Task, repeat_times: int = 1, run_id_base: int = 0, ) -> Tuple[Status, List[Experience]]: """Run the task and return the states.""" # TODO: avoid sending the experiences back to the scheduler to reduce the communication overhead try: st = time.time() model_version = await self.model_wrapper.model_version_async self.runner_state["model_version"] = model_version exps, metrics = await self._run_task(task, repeat_times, run_id_base) assert exps is not None and len(exps) > 0, "An empty experience is generated" # set eid for each experience for exp in exps: exp.eid.batch = task.batch_id # keep exp.eid.task if it has been set before (e.g., in workflow) if exp.eid.task == "": # "" is the default value exp.eid.task = task.task_id if not hasattr(exp, "info") or exp.info is None: exp.info = {} exp.info["model_version"] = model_version exp.info["use_count"] = 0 exp.info["task_index"] = task.index if not hasattr(exp, "metrics") or exp.metrics is None: exp.metrics = {} if task.is_eval: # If the task is an evaluation task, we do not record the experiences to the buffer return Status(True, metrics=metrics), [] else: return Status(True, metrics=metrics), exps except Exception as e: error_trace_back = traceback.format_exc() self.logger.error(f"WorkflowRunner run task error: {e}\nTraceback:\n{error_trace_back}") return ( Status(False, metrics=[{"time/run_execution": time.time() - st}], message=str(e)), [], )
[docs] class DebugWorkflowRunner(WorkflowRunner): """A WorkflowRunner for debugging."""
[docs] def __init__( self, config: Config, output_file: str, ) -> None: model, auxiliary_models = get_debug_inference_model(config) super().__init__(config, model, auxiliary_models, 0) self.taskset = get_buffer_reader(config.buffer.explorer_input.tasksets[0]) self.output_file = output_file
[docs] async def debug(self) -> None: """Run the debug workflow.""" from viztracer import VizTracer await self.prepare() tasks = await self.taskset.read_async(batch_size=1) task = tasks[0] self.logger.info(f"Read task: {task.task_id}, repeat_times: {task.repeat_times}") with VizTracer(output_file=self.output_file): status, exps = await self.run_task(task, task.repeat_times, 0) if status.ok: print(f"Task {task.task_id} completed successfully with metrics:\n{status.metrics}") for exp in exps: print(f"Generated experience:\n{exp}") else: self.logger.error(f"Task {task.task_id} failed with message: {status.message}") self.logger.info("Debugging completed.")