import traceback
from typing import Dict, List, Optional
from trinity.buffer.buffer import BufferWriter, get_buffer_reader, get_buffer_writer
from trinity.buffer.operators.experience_operator import ExperienceOperator
from trinity.buffer.ray_wrapper import is_database_url, is_json_file
from trinity.common.config import (
AlgorithmConfig,
BufferConfig,
Config,
ExperiencePipelineConfig,
StorageConfig,
)
from trinity.common.constants import StorageType
from trinity.common.experience import Experience
from trinity.utils.log import get_logger
from trinity.utils.plugin_loader import load_plugins
[docs]
class ExperiencePipeline:
"""
A class to process experiences.
"""
[docs]
def __init__(self, config: Config):
self.logger = get_logger(f"{config.explorer.name}_experience_pipeline", in_ray_actor=True)
load_plugins()
pipeline_config = config.data_processor.experience_pipeline
buffer_config = config.buffer
self.input_store = self._init_input_storage(pipeline_config, buffer_config) # type: ignore [arg-type]
try:
self.operators = ExperienceOperator.create_operators(pipeline_config.operators)
except Exception as e:
self.logger.error(f"Failed to create experience operators: {traceback.format_exc()}")
raise e
self._set_algorithm_operators(config.algorithm)
self.output = get_buffer_writer(
buffer_config.trainer_input.experience_buffer, # type: ignore [arg-type]
buffer_config,
)
def _init_input_storage(
self,
pipeline_config: ExperiencePipelineConfig,
buffer_config: BufferConfig,
) -> Optional[BufferWriter]:
"""Initialize the input storage if it is not already set."""
if pipeline_config.save_input:
if pipeline_config.input_save_path is None:
raise ValueError("input_save_path must be set when save_input is True.")
elif is_json_file(pipeline_config.input_save_path):
return get_buffer_writer(
StorageConfig(
storage_type=StorageType.FILE,
path=pipeline_config.input_save_path,
wrap_in_ray=False,
),
buffer_config,
)
elif is_database_url(pipeline_config.input_save_path):
return get_buffer_writer(
StorageConfig(
storage_type=StorageType.SQL,
path=pipeline_config.input_save_path,
wrap_in_ray=False,
),
buffer_config,
)
else:
raise ValueError(
f"Unsupported save_input format: {pipeline_config.save_input}. "
"Only JSON file path or SQLite URL is supported."
)
else:
return None
def _set_algorithm_operators(self, algorithm_config: AlgorithmConfig) -> None:
"""Add algorithm-specific operators to the pipeline."""
from trinity.algorithm import ADVANTAGE_FN, ALGORITHM_TYPE
algorithm = ALGORITHM_TYPE.get(algorithm_config.algorithm_type)
if not algorithm.compute_advantage_in_trainer and algorithm_config.advantage_fn:
advantage_fn_cls = ADVANTAGE_FN.get(algorithm_config.advantage_fn)
assert (
advantage_fn_cls is not None
), f"AdvantageFn {algorithm_config.advantage_fn} not found."
assert (
not advantage_fn_cls.compute_in_trainer()
), f"AdvantageFn {algorithm_config.advantage_fn} can only be computed in the trainer, please check your implementation."
self.operators.append(advantage_fn_cls(**algorithm_config.advantage_fn_args))
[docs]
async def prepare(self) -> None:
await self.output.acquire()
[docs]
async def process(self, exps: List[Experience]) -> Dict:
"""Process a batch of experiences.
Args:
exps (List[Experience]): List of experiences to process. These experiences are typically generated by an explorer in one step.
Returns:
Dict: A dictionary containing metrics collected during the processing of experiences.
"""
if self.input_store is not None:
await self.input_store.write_async(exps)
metrics = {}
# Process experiences through operators
for operator in self.operators:
exps, metric = operator.process(exps)
metrics.update(metric)
metrics["experience_count"] = len(exps)
# Write processed experiences to output buffer
await self.output.write_async(exps)
# prefix metrics keys with 'pipeline/'
result_metrics = {}
for key, value in metrics.items():
if isinstance(value, (int, float)):
result_metrics[f"pipeline/{key}"] = float(value)
return result_metrics
[docs]
async def close(self) -> None:
await self.output.release()
for operator in self.operators:
operator.close()