from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Dict, List, Tuple
from trinity.common.config import OperatorConfig
from trinity.common.experience import Experience
from trinity.utils.registry import Registry
EXPERIENCE_OPERATORS = Registry("experience_operators")
[docs]
class ExperienceOperator(ABC):
"""
Base class for all experience operators in the Trinity framework.
Operators are used to process experiences and perform some transformations based on them.
"""
[docs]
@abstractmethod
def process(self, exps: List[Experience]) -> Tuple[List[Experience], Dict]:
"""Process a list of experiences and return a transformed list.
Args:
exps (List[Experience]): List of experiences to process, which contains
all experiences generated by the Explorer in one explore step.
Returns:
Tuple[List[Experience], Dict]: A tuple containing the processed list of experiences and a dictionary of metrics.
"""
[docs]
@classmethod
def create_operators(cls, operator_configs: List[OperatorConfig]) -> List[ExperienceOperator]:
"""Create a list of ExperienceOperator instances based on the provided operator configurations.
Args:
operator_configs (List[OperatorConfig]): List of operator configurations.
Returns:
List[ExperienceOperator]: List of instantiated ExperienceOperator objects.
"""
operators = []
for config in operator_configs:
operator_class = EXPERIENCE_OPERATORS.get(config.name)
if not operator_class:
raise ValueError(f"Unknown operator: {config.name}")
operators.append(operator_class(**config.args))
return operators
[docs]
def close(self):
"""Close the operator if it has any resources to release."""
pass