# -*- coding: utf-8 -*-
"""The Workflow Runner Module."""
import asyncio
import time
import traceback
from collections import defaultdict
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
from trinity.buffer import get_buffer_reader
from trinity.common.config import Config
from trinity.common.experience import Experience
from trinity.common.models import get_debug_inference_model
from trinity.common.models.model import InferenceModel, ModelWrapper
from trinity.common.workflows import Task, Workflow
from trinity.utils.log import get_logger
[docs]
@dataclass(frozen=True)
class Status:
"""Status of the task running result."""
ok: bool
metrics: List[Dict[str, float]]
# A list of metric dictionaries, where each dictionary is from a single run.
message: Optional[str] = None
[docs]
def calculate_run_level_metrics(experiences: List[Experience]) -> Dict[str, float]:
"""Calculate metrics from experiences.
For non-repeatable workflows, this function will average the metrics from experiences
generated by each run, which is equivalent to calculating run level metrics.
For repeatable workflows, please do not use this function.
"""
run_level_metrics: Dict[str, List[float]] = defaultdict(list)
for exp in experiences:
if exp.metrics:
for k, v in exp.metrics.items():
run_level_metrics[k].append(v)
averaged_metrics: Dict[str, float] = {}
for key, values in run_level_metrics.items():
averaged_metrics[key] = sum(values) / len(values)
return averaged_metrics
[docs]
class WorkflowRunner:
"""A Ray remote actor to run the workflow and generate experiences."""
[docs]
def __init__(
self,
config: Config,
model: InferenceModel,
auxiliary_models: Optional[List[InferenceModel]] = None,
runner_id: Optional[int] = None,
) -> None:
self.logger = get_logger(f"{config.explorer.name}_runner_{runner_id}", in_ray_actor=True)
self.config = config
self.model = model
self.model_wrapper = ModelWrapper(
model,
config.explorer.rollout_model.engine_type,
enable_lora=config.explorer.rollout_model.enable_lora,
enable_history=config.explorer.rollout_model.enable_history,
)
self.auxiliary_models = [
ModelWrapper(
model,
)
for model in (auxiliary_models or [])
]
self.auxiliary_model_clients = []
self.auxiliary_model_async_clients = []
self.workflow_instance: Workflow = None
self.runner_id = runner_id
self.runner_state = {
"workflow_id": None,
"model_version": None,
"begin_time": 0,
"terminate_time": 0,
}
[docs]
async def prepare(self) -> None:
"""Prepare the runner."""
await asyncio.gather(
self.model_wrapper.prepare(),
*(aux_model.prepare() for aux_model in self.auxiliary_models),
)
for model in self.auxiliary_models:
api_client = model.get_openai_client()
async_api_client = model.get_openai_async_client()
self.auxiliary_model_clients.append(api_client)
self.auxiliary_model_async_clients.append(async_api_client)
[docs]
def is_alive(self):
return True
def _create_workflow_instance(self, task: Task) -> None:
if task.workflow is None:
raise ValueError("Workflow is not set in the task.")
if (
self.workflow_instance is None
or not self.workflow_instance.__class__ == task.workflow
or not self.workflow_instance.resettable
):
self.workflow_instance = task.to_workflow(
self.model_wrapper,
(
self.auxiliary_model_async_clients
if task.workflow.is_async
else self.auxiliary_model_clients
),
)
else:
self.workflow_instance.reset(task)
async def _run_workflow(self, workflow_instance: Workflow) -> List[Experience]:
if workflow_instance.asynchronous:
exps = await workflow_instance.run_async()
else:
exps = workflow_instance.run()
return exps
async def _run_task(
self, task: Task, repeat_times: int, run_id_base: int
) -> Tuple[List[Experience], List[Dict]]:
"""Init workflow from the task and run it."""
self._create_workflow_instance(task)
if self.workflow_instance.repeatable:
self.workflow_instance.set_repeat_times(repeat_times, run_id_base)
st = time.time()
await self.model_wrapper.clean_workflow_state()
self.runner_state["workflow_id"] = f"{task.batch_id}/{task.task_id}/{run_id_base}"
self.runner_state["terminate_time"] = None
self.runner_state["begin_time"] = st
exps = await self._run_workflow(self.workflow_instance)
et = time.time()
self.runner_state["terminate_time"] = et
# repeatable workflow cannot calculate run level metrics, we use experience level metrics directly
run_metrics = [exp.metrics for exp in exps if exp.metrics]
for metric in run_metrics:
metric["time/run_execution"] = et - st
else:
exps = []
run_metrics = []
for i in range(repeat_times):
st = time.time()
await self.model_wrapper.clean_workflow_state()
self.runner_state["workflow_id"] = f"{task.batch_id}/{task.task_id}/{i}"
self.runner_state["terminate_time"] = None
self.runner_state["begin_time"] = st
new_exps = await self._run_workflow(self.workflow_instance)
et = time.time()
self.runner_state["terminate_time"] = et
run_metric = calculate_run_level_metrics(new_exps)
run_metric["time/run_execution"] = et - st
run_metrics.append(run_metric)
for exp in new_exps:
exp.eid.run = run_id_base + i
exps.extend(new_exps)
if i < repeat_times - 1:
self._create_workflow_instance(task)
return exps, run_metrics
[docs]
async def get_runner_state(self) -> Dict:
"""Get the runner state."""
runner_state = self.runner_state.copy()
runner_state.update(await self.model_wrapper.get_workflow_state())
return runner_state
[docs]
async def run_task(
self,
task: Task,
repeat_times: int = 1,
run_id_base: int = 0,
) -> Tuple[Status, List[Experience]]:
"""Run the task and return the states."""
# TODO: avoid sending the experiences back to the scheduler to reduce the communication overhead
try:
st = time.time()
model_version = await self.model_wrapper.model_version_async
self.runner_state["model_version"] = model_version
exps, metrics = await self._run_task(task, repeat_times, run_id_base)
assert exps is not None and len(exps) > 0, "An empty experience is generated"
# set eid for each experience
for exp in exps:
exp.eid.batch = task.batch_id
# keep exp.eid.task if it has been set before (e.g., in workflow)
if exp.eid.task == "": # "" is the default value
exp.eid.task = task.task_id
if not hasattr(exp, "info") or exp.info is None:
exp.info = {}
exp.info["model_version"] = model_version
exp.info["use_count"] = 0
exp.info["task_index"] = task.index
if not hasattr(exp, "metrics") or exp.metrics is None:
exp.metrics = {}
if task.is_eval:
# If the task is an evaluation task, we do not record the experiences to the buffer
return Status(True, metrics=metrics), []
else:
return Status(True, metrics=metrics), exps
except Exception as e:
error_trace_back = traceback.format_exc()
self.logger.error(f"WorkflowRunner run task error: {e}\nTraceback:\n{error_trace_back}")
return (
Status(False, metrics=[{"time/run_execution": time.time() - st}], message=str(e)),
[],
)
[docs]
class DebugWorkflowRunner(WorkflowRunner):
"""A WorkflowRunner for debugging."""
[docs]
def __init__(
self,
config: Config,
output_file: str,
) -> None:
model, auxiliary_models = get_debug_inference_model(config)
super().__init__(config, model, auxiliary_models, 0)
self.taskset = get_buffer_reader(config.buffer.explorer_input.tasksets[0])
self.output_file = output_file
[docs]
async def debug(self) -> None:
"""Run the debug workflow."""
from viztracer import VizTracer
await self.prepare()
tasks = await self.taskset.read_async(batch_size=1)
task = tasks[0]
self.logger.info(f"Read task: {task.task_id}, repeat_times: {task.repeat_times}")
with VizTracer(output_file=self.output_file):
status, exps = await self.run_task(task, task.repeat_times, 0)
if status.ok:
print(f"Task {task.task_id} completed successfully with metrics:\n{status.metrics}")
for exp in exps:
print(f"Generated experience:\n{exp}")
else:
self.logger.error(f"Task {task.task_id} failed with message: {status.message}")
self.logger.info("Debugging completed.")