Source code for trinity.buffer.pipelines.experience_pipeline

import traceback
from typing import Dict, List, Optional

from trinity.buffer.buffer import BufferWriter, get_buffer_reader, get_buffer_writer
from trinity.buffer.operators.experience_operator import ExperienceOperator
from trinity.buffer.ray_wrapper import is_database_url, is_json_file
from trinity.common.config import (
    AlgorithmConfig,
    BufferConfig,
    Config,
    ExperiencePipelineConfig,
    StorageConfig,
)
from trinity.common.constants import StorageType
from trinity.common.experience import Experience
from trinity.utils.log import get_logger
from trinity.utils.plugin_loader import load_plugins


[docs] def get_input_buffers( pipeline_config: ExperiencePipelineConfig, buffer_config: BufferConfig ) -> Dict: """Get input buffers for the experience pipeline.""" input_buffers = {} for input_name, input_config in pipeline_config.inputs.items(): buffer_reader = get_buffer_reader(input_config, buffer_config) input_buffers[input_name] = buffer_reader return input_buffers
[docs] class ExperiencePipeline: """ A class to process experiences. """
[docs] def __init__(self, config: Config): self.logger = get_logger(f"{config.explorer.name}_experience_pipeline", in_ray_actor=True) load_plugins() pipeline_config = config.data_processor.experience_pipeline buffer_config = config.buffer self.input_store = self._init_input_storage(pipeline_config, buffer_config) # type: ignore [arg-type] try: self.operators = ExperienceOperator.create_operators(pipeline_config.operators) except Exception as e: self.logger.error(f"Failed to create experience operators: {traceback.format_exc()}") raise e self._set_algorithm_operators(config.algorithm) self.output = get_buffer_writer( buffer_config.trainer_input.experience_buffer, # type: ignore [arg-type] buffer_config, )
def _init_input_storage( self, pipeline_config: ExperiencePipelineConfig, buffer_config: BufferConfig, ) -> Optional[BufferWriter]: """Initialize the input storage if it is not already set.""" if pipeline_config.save_input: if pipeline_config.input_save_path is None: raise ValueError("input_save_path must be set when save_input is True.") elif is_json_file(pipeline_config.input_save_path): return get_buffer_writer( StorageConfig( storage_type=StorageType.FILE, path=pipeline_config.input_save_path, wrap_in_ray=False, ), buffer_config, ) elif is_database_url(pipeline_config.input_save_path): return get_buffer_writer( StorageConfig( storage_type=StorageType.SQL, path=pipeline_config.input_save_path, wrap_in_ray=False, ), buffer_config, ) else: raise ValueError( f"Unsupported save_input format: {pipeline_config.save_input}. " "Only JSON file path or SQLite URL is supported." ) else: return None def _set_algorithm_operators(self, algorithm_config: AlgorithmConfig) -> None: """Add algorithm-specific operators to the pipeline.""" from trinity.algorithm import ADVANTAGE_FN, ALGORITHM_TYPE algorithm = ALGORITHM_TYPE.get(algorithm_config.algorithm_type) if not algorithm.compute_advantage_in_trainer and algorithm_config.advantage_fn: advantage_fn_cls = ADVANTAGE_FN.get(algorithm_config.advantage_fn) assert ( advantage_fn_cls is not None ), f"AdvantageFn {algorithm_config.advantage_fn} not found." assert ( not advantage_fn_cls.compute_in_trainer() ), f"AdvantageFn {algorithm_config.advantage_fn} can only be computed in the trainer, please check your implementation." self.operators.append(advantage_fn_cls(**algorithm_config.advantage_fn_args))
[docs] async def prepare(self) -> None: await self.output.acquire()
[docs] async def process(self, exps: List[Experience]) -> Dict: """Process a batch of experiences. Args: exps (List[Experience]): List of experiences to process. These experiences are typically generated by an explorer in one step. Returns: Dict: A dictionary containing metrics collected during the processing of experiences. """ if self.input_store is not None: await self.input_store.write_async(exps) metrics = {} # Process experiences through operators for operator in self.operators: exps, metric = operator.process(exps) metrics.update(metric) metrics["experience_count"] = len(exps) # Write processed experiences to output buffer await self.output.write_async(exps) # prefix metrics keys with 'pipeline/' result_metrics = {} for key, value in metrics.items(): if isinstance(value, (int, float)): result_metrics[f"pipeline/{key}"] = float(value) return result_metrics
[docs] async def close(self) -> None: await self.output.release() for operator in self.operators: operator.close()