Source code for trinity.buffer.operators.mappers.pass_rate_calculator
from collections import defaultdict
from typing import Dict, List, Tuple
import numpy as np
from trinity.buffer.operators.experience_operator import (
EXPERIENCE_OPERATORS,
ExperienceOperator,
)
from trinity.common.constants import SELECTOR_METRIC
from trinity.common.experience import Experience
[docs]
@EXPERIENCE_OPERATORS.register_module("pass_rate_calculator")
class PassRateCalculator(ExperienceOperator):
[docs]
def __init__(self, **kwargs):
pass
[docs]
def process(self, exps: List[Experience]) -> Tuple[List[Experience], Dict]:
raw_metric = defaultdict(lambda: defaultdict(list))
for exp in exps:
task_index = exp.info["task_index"]
assert "taskset_id" in task_index
assert "index" in task_index
raw_metric[task_index["taskset_id"]][task_index["index"]].append(exp.reward)
metric = {}
for taskset_id, taskset_metric in raw_metric.items():
indices = []
reward_means = []
for index, rewards in taskset_metric.items():
indices.append(index)
reward_means.append(float(np.mean(rewards)))
metric[taskset_id] = {
"indices": indices,
"values": reward_means,
}
return exps, {SELECTOR_METRIC: metric}